├── .gitignore ├── CBN ├── README.md ├── config.py ├── frameworks │ ├── __init__.py │ ├── evaluating │ │ ├── __init__.py │ │ ├── base.py │ │ ├── evaluator_manager.py │ │ └── frame_evaluator.py │ ├── models │ │ ├── __init__.py │ │ ├── backbone │ │ │ ├── __init__.py │ │ │ └── resnet_backbone.py │ │ ├── model_arch.py │ │ └── weight_utils.py │ └── training │ │ ├── __init__.py │ │ ├── base.py │ │ ├── data_parallel.py │ │ ├── optimizers.py │ │ └── trainer.py ├── io_stream │ ├── __init__.py │ ├── data_manager.py │ ├── data_utils.py │ ├── datasets │ │ ├── combine.py │ │ ├── duke.py │ │ ├── market.py │ │ ├── msmt.py │ │ ├── personx.py │ │ ├── randperson.py │ │ ├── soma.py │ │ ├── syri.py │ │ └── unreal.py │ └── samplers.py ├── requirements.txt ├── test_model.py ├── train_model.py └── utils │ ├── __init__.py │ ├── loss.py │ ├── meters.py │ ├── serialization.py │ └── transforms.py ├── Download.md ├── Experiments.md ├── JVTC ├── README.md ├── config.py ├── list_duke │ ├── list_duke_test.txt │ └── list_duke_train.txt ├── list_market │ ├── list_market_test.txt │ └── list_market_train.txt ├── list_msmt │ ├── list_msmt_test.txt │ ├── list_msmt_train.txt │ ├── list_test.txt │ ├── list_train.txt │ ├── rename.py │ └── rename_test.py ├── list_unreal │ └── list_unreal_train.txt ├── multi_train_cbn.py ├── test_cbn.py ├── train_cbn.py └── utils │ ├── __init__.py │ ├── dataset.py │ ├── evaluate_joint_sim.py │ ├── evaluators.py │ ├── logger.py │ ├── losses.py │ ├── losses_msmt.py │ ├── lr_adjust.py │ ├── ranking.py │ ├── rerank.py │ ├── resnet.py │ ├── st_distribution.py │ └── util.py ├── LICENSE ├── README.md ├── SynthesisToolkit.md ├── UnrealPerson-DataSynthesisToolkit ├── 9_massproduce │ ├── __init__.py │ ├── humanstate.py │ ├── massproduce.py │ ├── modifiergroups.py │ ├── randomizationsettings.py │ └── randomizeaction.py ├── caminfo_collect.py ├── generate_datasets.py ├── interplate_video.py ├── levelbp.txt ├── levelbp_preview.png ├── other_scripts.py ├── postprocess.py ├── postprocess_video.py ├── script_clothingcoparsing_clothing_patch.py ├── script_deepfashion_clothing_patch.py ├── script_makehuman_asset_download.py ├── unrealcv │ ├── unrealcv_plugin-4.24-mac.zip │ └── unrealcv_plugin-4.24-win64.zip └── utils.py └── imgs └── unrealperson.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | # my file 2 | .idea 3 | *DS_Store 4 | # folders 5 | data 6 | pytorch-ckpt 7 | # others 8 | *.h5 9 | *.pkl 10 | *.npy 11 | *.pdf 12 | *.pickle 13 | __pycache__ 14 | *.pyc 15 | 16 | # Byte-compiled / optimized / DLL files 17 | __pycache__/ 18 | *.py[cod] 19 | *$py.class 20 | 21 | # C extensions 22 | *.so 23 | 24 | # Distribution / packaging 25 | .Python 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | downloads/ 30 | eggs/ 31 | .eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | MANIFEST 42 | 43 | # PyInstaller 44 | # Usually these files are written by a python script from a template 45 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 46 | *.manifest 47 | *.spec 48 | 49 | # Installer logs 50 | pip-log.txt 51 | pip-delete-this-directory.txt 52 | 53 | # Unit test / coverage reports 54 | htmlcov/ 55 | .tox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | 118 | # mypy 119 | .mypy_cache/ 120 | -------------------------------------------------------------------------------- /CBN/README.md: -------------------------------------------------------------------------------- 1 | This code is based on CBN. 2 | 3 | The original repo: [Camera-based-Person-ReID](https://github.com/automan000/Camera-based-Person-ReID). 4 | 5 | The original paper: [Rethinking the Distribution Gap of Person Re-identification with Camera-based Batch Normalization](https://arxiv.org/abs/2001.08680). 6 | -------------------------------------------------------------------------------- /CBN/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | import warnings 6 | 7 | 8 | class DefaultConfig(object): 9 | seed = 0 10 | share_cam = False 11 | num_pids=800 12 | num_cams=0 13 | img_per_person = 60 14 | model_path =None 15 | # dataset options 16 | trainset_name = 'market' 17 | testset_name = 'duke' 18 | height = 256 19 | width = 128 20 | # sampler 21 | workers = 8 22 | num_instances = 4 23 | # default optimization params 24 | train_batch = 64 25 | test_batch = 32 26 | max_epoch = 15 27 | decay_epoch = 10 28 | save_step = max_epoch 29 | # estimate bn statistics 30 | batch_num_bn_estimatation = 10 31 | # io 32 | datasets='market,unreal_v6+v4+v7' 33 | dataset = 'unreal_v4.1,unreal_v6.1,unreal_v7.1,unreal_v8.1' 34 | print_freq = 50 35 | loss ='softmax' 36 | save_dir = './pytorch-ckpt/market' 37 | cam_bal=False 38 | testepoch='model_best' 39 | def _parse(self, kwargs): 40 | for k, v in kwargs.items(): 41 | if not hasattr(self, k): 42 | warnings.warn("Warning: opt has not attribut %s" % k) 43 | setattr(self, k, v) 44 | 45 | def _state_dict(self): 46 | return {k: getattr(self, k) for k, _ in DefaultConfig.__dict__.items() 47 | if not k.startswith('_')} 48 | 49 | 50 | opt = DefaultConfig() 51 | -------------------------------------------------------------------------------- /CBN/frameworks/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | -------------------------------------------------------------------------------- /CBN/frameworks/evaluating/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | __all__ = [ 7 | 'init_evaluator' 8 | ] 9 | 10 | from .evaluator_manager import init_evaluator 11 | -------------------------------------------------------------------------------- /CBN/frameworks/evaluating/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | import os 6 | import matplotlib 7 | import pickle 8 | matplotlib.use('Agg') 9 | from tqdm import tqdm 10 | import numpy as np 11 | from multiprocessing import cpu_count 12 | from multiprocessing import Pool 13 | from collections import defaultdict 14 | 15 | def mt_eval_func(query_id, query_cam, gallery_ids, gallery_cams, order, matches, max_rank): 16 | 17 | remove = (gallery_ids[order] == query_id) & (gallery_cams[order] == query_cam) 18 | keep = np.invert(remove) 19 | orig_cmc = matches[keep] 20 | if not np.any(orig_cmc): 21 | return -1, -1 22 | 23 | cmc = orig_cmc.cumsum() 24 | cmc[cmc > 1] = 1 25 | single_cmc = cmc[:max_rank] 26 | 27 | num_rel = orig_cmc.sum() 28 | tmp_cmc = orig_cmc.cumsum() 29 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 30 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 31 | single_ap = tmp_cmc.sum() / num_rel 32 | return single_ap, single_cmc, orig_cmc[:max_rank] 33 | 34 | 35 | class BaseEvaluator(object): 36 | def __init__(self, model): 37 | self.model = model 38 | self.eval_func = self.eval_func1 39 | 40 | def _parse_data(self, inputs): 41 | raise NotImplementedError 42 | 43 | def _forward(self, inputs): 44 | raise NotImplementedError 45 | 46 | def evaluate(self, queryloader, galleryloader, ranks): 47 | raise NotImplementedError 48 | 49 | 50 | def eval_func1(self, distmat, q_pids, g_pids, q_camids, g_camids, q_paths,g_paths,max_rank=50): 51 | num_q, num_g = distmat.shape 52 | if num_g < max_rank: 53 | max_rank = num_g 54 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 55 | indices = np.argsort(distmat, axis=1) 56 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 57 | 58 | mt_pool = Pool(cpu_count()) 59 | results = [] 60 | path_dict = dict() 61 | for q_idx in range(num_q): 62 | path_dict[q_idx] = indices[q_idx][:30] 63 | params = (q_pids[q_idx], q_camids[q_idx], g_pids, g_camids, indices[q_idx], matches[q_idx], max_rank) 64 | results.append(mt_pool.apply_async(mt_eval_func, params)) 65 | 66 | mt_pool.close() 67 | mt_pool.join() 68 | res=[] 69 | for x in results: 70 | if len(x.get())==3: 71 | res.append(x.get()) 72 | results = res #[x.get() for x in results] 73 | 74 | all_AP = np.array([x[0] for x in results]) 75 | valid_index = all_AP > -1 76 | all_AP = all_AP[valid_index] 77 | all_cmc = np.array([x[1] for x in results]) 78 | all_cmc = all_cmc[valid_index, ...] 79 | num_valid_q = len(all_AP) 80 | try: 81 | all_ranks = np.array([x[2] for x in results]) 82 | except Exception as e: 83 | from IPython import embed 84 | embed() 85 | all_cmc = np.asarray(all_cmc).astype(np.float32) 86 | all_cmc = all_cmc.sum(0) / num_valid_q 87 | mAP = np.mean(all_AP) 88 | 89 | pickle.dump(path_dict,open('path_dict.pkl','wb')) 90 | print(num_valid_q) 91 | return all_cmc, mAP, all_ranks 92 | -------------------------------------------------------------------------------- /CBN/frameworks/evaluating/evaluator_manager.py: -------------------------------------------------------------------------------- 1 | from frameworks.evaluating.frame_evaluator import FrameEvaluator 2 | 3 | __data_factory = { 4 | 'market': FrameEvaluator, 5 | 'msmt': FrameEvaluator, 6 | 'duke': FrameEvaluator, 7 | } 8 | 9 | 10 | def init_evaluator(name, model, flip): 11 | if name not in __data_factory.keys(): 12 | return FrameEvaluator(model, flip) 13 | return __data_factory[name](model, flip) 14 | -------------------------------------------------------------------------------- /CBN/frameworks/evaluating/frame_evaluator.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import numpy as np 7 | import torch 8 | import random 9 | from frameworks.evaluating.base import BaseEvaluator 10 | from tqdm import tqdm 11 | import pickle 12 | class FrameEvaluator(BaseEvaluator): 13 | def __init__(self, model, flip=True): 14 | super().__init__(model) 15 | self.loop = 2 if flip else 1 16 | self.evaluator_prints = [] 17 | 18 | def _parse_data(self, inputs): 19 | imgs, pids, camids, paths = inputs 20 | return imgs.cuda(), pids, camids , paths 21 | 22 | def flip_tensor_lr(self, img): 23 | inv_idx = torch.arange(img.size(3) - 1, -1, -1).long().cuda() 24 | img_flip = img.index_select(3, inv_idx) 25 | return img_flip 26 | 27 | def _forward(self, inputs): 28 | with torch.no_grad(): 29 | feature,_ = self.model(inputs) 30 | if isinstance(feature, tuple) or isinstance(feature, list): 31 | output = [] 32 | for x in feature: 33 | if isinstance(x, tuple) or isinstance(x, list): 34 | output.append([item.cpu() for item in x]) 35 | else: 36 | output.append(x.cpu()) 37 | return output 38 | else: 39 | return feature.cpu() 40 | 41 | def produce_features(self, dataloader, normalize=True): 42 | self.model.eval() 43 | all_feature_norm = [] 44 | qf, q_pids, q_camids,q_paths = [], [], [],[] 45 | for batch_idx, inputs in enumerate(dataloader): 46 | inputs, pids, camids, img_paths = self._parse_data(inputs) 47 | feature = None 48 | for i in range(self.loop): 49 | if i == 1: 50 | inputs = self.flip_tensor_lr(inputs) 51 | global_f = self._forward(inputs) 52 | if feature is None: 53 | feature = global_f 54 | else: 55 | feature += global_f 56 | if normalize: 57 | fnorm = torch.norm(feature, p=2, dim=1, keepdim=True) 58 | all_feature_norm.extend(list(fnorm.cpu().numpy()[:, 0])) 59 | feature = feature.div(fnorm.expand_as(feature)) 60 | else: 61 | feature = feature / 2 62 | 63 | qf.append(feature) 64 | q_pids.extend(pids) 65 | q_camids.extend(camids) 66 | q_paths.extend(img_paths) 67 | 68 | if len(qf)>1: 69 | qf = torch.cat(qf, 0) 70 | else: 71 | qf=qf[0] 72 | q_pids = np.asarray(q_pids) 73 | q_camids = np.asarray(q_camids) 74 | 75 | return qf, q_pids, q_camids, q_paths 76 | 77 | def get_final_results_with_features(self, qf, q_pids, q_camids, gf, g_pids, g_camids,q_paths,g_paths, target_ranks=[1, 5, 10, 20]): 78 | m, n = qf.size(0), gf.size(0) 79 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 80 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 81 | distmat.addmm_( qf, gf.t(),beta= 1,alpha=-2) 82 | distmat = distmat.numpy() 83 | 84 | 85 | 86 | cmc, mAP, ranks = self.eval_func(distmat, q_pids, g_pids, q_camids, g_camids,q_paths,g_paths) 87 | print("Results ----------") 88 | print("mAP: {:.1%}".format(mAP)) 89 | self.evaluator_prints.append("mAP: {:.1%}".format(mAP)) 90 | print("CMC curve") 91 | self.evaluator_prints.append("CMC curve") 92 | for r in target_ranks: 93 | print("Rank-{:<3}: {:.1%}".format(r, cmc[r - 1])) 94 | self.evaluator_prints.append("Rank-{:<3}: {:.1%}".format(r, cmc[r - 1])) 95 | print("------------------") 96 | return cmc[0] 97 | 98 | def collect_sim_bn_info(self, dataloader): 99 | network_bns = [x for x in list(self.model.modules()) if 100 | isinstance(x, torch.nn.BatchNorm2d) or isinstance(x, torch.nn.BatchNorm1d)] 101 | for bn in network_bns: 102 | bn.running_mean = torch.zeros(bn.running_mean.size()).float().cuda() 103 | bn.running_var = torch.ones(bn.running_var.size()).float().cuda() 104 | bn.num_batches_tracked = torch.tensor(0).cuda().long() 105 | 106 | self.model.train() 107 | for batch_idx, inputs in enumerate(dataloader): 108 | # each camera should has at least 2 images for estimating BN statistics 109 | assert len(inputs[0].size()) == 4 and inputs[0].size( 110 | 0) > 1, 'Cannot estimate BN statistics. Each camera should have at least 2 images' 111 | inputs, pids, camids,_ = self._parse_data(inputs) 112 | for i in range(self.loop): 113 | if i == 1: 114 | inputs = self.flip_tensor_lr(inputs) 115 | self._forward(inputs) 116 | self.model.eval() 117 | 118 | bn_info=list() 119 | for bn in network_bns: 120 | bn_info.append([bn.running_mean.cpu().numpy(),bn.running_var.cpu().numpy()]) 121 | return bn_info 122 | -------------------------------------------------------------------------------- /CBN/frameworks/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | __all__ = [ 7 | 'ResNetBuilder', 8 | ] 9 | 10 | from .model_arch import ResNetBuilder 11 | 12 | -------------------------------------------------------------------------------- /CBN/frameworks/models/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | __all__ = [ 7 | 'ResNet_Backbone', 8 | ] 9 | 10 | from .resnet_backbone import ResNet as ResNet_Backbone 11 | -------------------------------------------------------------------------------- /CBN/frameworks/models/backbone/resnet_backbone.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import math 7 | 8 | import torch as th 9 | from torch import nn 10 | 11 | 12 | class Bottleneck(nn.Module): 13 | expansion = 4 14 | 15 | def __init__(self, inplanes, planes, stride=1, downsample=None): 16 | super(Bottleneck, self).__init__() 17 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes, momentum=None) 19 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 20 | padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes, momentum=None) 22 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 23 | self.bn3 = nn.BatchNorm2d(planes * 4, momentum=None) 24 | self.relu = nn.ReLU(inplace=True) 25 | self.downsample = downsample 26 | self.stride = stride 27 | 28 | def forward(self, x): 29 | residual = x 30 | out = self.conv1(x) 31 | out = self.bn1(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv2(out) 35 | out = self.bn2(out) 36 | out = self.relu(out) 37 | 38 | out = self.conv3(out) 39 | out = self.bn3(out) 40 | 41 | if self.downsample is not None: 42 | residual = self.downsample(x) 43 | 44 | out += residual 45 | out = self.relu(out) 46 | 47 | return out 48 | 49 | 50 | class ResNet(nn.Module): 51 | def __init__(self, last_stride=2, block=Bottleneck, layers=[3, 4, 6, 3]): 52 | self.inplanes = 64 53 | super().__init__() 54 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 55 | bias=False) 56 | self.bn1 = nn.BatchNorm2d(64, momentum=None) 57 | self.relu = nn.ReLU(inplace=True) 58 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 59 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1) 60 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 61 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 62 | self.layer4 = self._make_layer( 63 | block, 512, layers[3], stride=last_stride) 64 | 65 | def _make_layer(self, block, planes, blocks, stride): 66 | downsample = None 67 | if stride != 1 or self.inplanes != planes * block.expansion: 68 | downsample = nn.Sequential( 69 | nn.Conv2d(self.inplanes, planes * block.expansion, 70 | kernel_size=1, stride=stride, bias=False), 71 | nn.BatchNorm2d(planes * block.expansion, momentum=None) 72 | ) 73 | 74 | layers = [] 75 | layers.append(block(self.inplanes, planes, stride, downsample)) 76 | self.inplanes = planes * block.expansion 77 | for i in range(1, blocks): 78 | layers.append(block(self.inplanes, planes)) 79 | 80 | return nn.Sequential(*layers) 81 | 82 | def forward(self, x): 83 | x = self.conv1(x) 84 | x = self.bn1(x) 85 | x = self.relu(x) 86 | x = self.maxpool(x) 87 | x = self.layer1(x) 88 | x = self.layer2(x) 89 | x = self.layer3(x) 90 | x = self.layer4(x) 91 | return x 92 | 93 | def load_param(self, model_path, ): 94 | param_dict = th.load(model_path) 95 | for param_name in param_dict: 96 | if 'fc' in param_name: 97 | continue 98 | if param_name in self.state_dict(): 99 | self.state_dict()[param_name].copy_(param_dict[param_name]) 100 | 101 | def random_init(self): 102 | for m in self.modules(): 103 | if isinstance(m, nn.Conv2d): 104 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 105 | m.weight.data.normal_(0, math.sqrt(2. / n)) 106 | elif isinstance(m, nn.BatchNorm2d): 107 | m.weight.data.fill_(1) 108 | m.bias.data.zero_() 109 | -------------------------------------------------------------------------------- /CBN/frameworks/models/model_arch.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | from getpass import getuser 7 | 8 | from torch import nn 9 | import torch.nn.functional as F 10 | 11 | from frameworks.models.backbone import ResNet_Backbone 12 | from frameworks.models.weight_utils import weights_init_kaiming 13 | 14 | 15 | def weights_init_classifier(m): 16 | classname = m.__class__.__name__ 17 | if classname.find('Linear') != -1: 18 | nn.init.normal_(m.weight, std=0.001) 19 | if m.bias is not None: 20 | nn.init.constant_(m.bias, 0.0) 21 | 22 | 23 | class ResNetBuilder(nn.Module): 24 | in_planes = 2048 25 | 26 | def __init__(self, num_pids=None, last_stride=1): 27 | super().__init__() 28 | self.num_pids = num_pids 29 | self.base = ResNet_Backbone(last_stride) 30 | model_path = '/home/' + getuser() + '/.torch/models/resnet50-19c8e357.pth' 31 | self.base.load_param(model_path) 32 | #self.base.random_init() 33 | 34 | bn_neck = nn.BatchNorm1d(2048, momentum=None) 35 | bn_neck.bias.requires_grad_(False) 36 | self.bottleneck = nn.Sequential(bn_neck) 37 | self.bottleneck.apply(weights_init_kaiming) 38 | if self.num_pids is not None: 39 | self.classifier = nn.Linear(2048, self.num_pids, bias=False) 40 | self.classifier.apply(weights_init_classifier) 41 | 42 | def forward(self, x): 43 | feat_before_bn = self.base(x) 44 | feat_before_bn = F.avg_pool2d(feat_before_bn, feat_before_bn.shape[2:]) 45 | feat_before_bn = feat_before_bn.view(feat_before_bn.shape[0], -1) 46 | feat_after_bn = self.bottleneck(feat_before_bn) 47 | if self.num_pids is not None: 48 | classification_results = self.classifier(feat_after_bn) 49 | return feat_after_bn, classification_results 50 | else: 51 | return feat_after_bn, None 52 | 53 | def get_optim_policy(self): 54 | base_param_group = filter(lambda p: p.requires_grad, self.base.parameters()) 55 | add_param_group = filter(lambda p: p.requires_grad, self.bottleneck.parameters()) 56 | 57 | all_param_groups = [] 58 | all_param_groups.append({'params': base_param_group, "weight_decay": 0.0005}) 59 | all_param_groups.append({'params': add_param_group, "weight_decay": 0.0005}) 60 | if self.num_pids is None: 61 | return all_param_groups 62 | 63 | cls_param_group = filter(lambda p: p.requires_grad, self.classifier.parameters()) 64 | all_param_groups.append({'params': cls_param_group, "weight_decay": 0.0005}) 65 | return all_param_groups 66 | -------------------------------------------------------------------------------- /CBN/frameworks/models/weight_utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def weights_init_kaiming(m): 5 | classname = m.__class__.__name__ 6 | if classname.find('Linear') != -1: 7 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 8 | if m.bias is not None: 9 | nn.init.constant_(m.bias, 0.0) 10 | elif classname.find('Conv') != -1: 11 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 12 | if m.bias is not None: 13 | nn.init.constant_(m.bias, 0.0) 14 | elif classname.find('BatchNorm') != -1: 15 | if m.affine: 16 | nn.init.normal_(m.weight, 1.0, 0.02) 17 | nn.init.constant_(m.bias, 0.0) 18 | 19 | 20 | def weights_init_classifier(m): 21 | classname = m.__class__.__name__ 22 | if classname.find('Linear') != -1: 23 | nn.init.normal_(m.weight, std=0.001) 24 | if m.bias is not None: 25 | nn.init.constant_(m.bias, 0.0) 26 | -------------------------------------------------------------------------------- /CBN/frameworks/training/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | __all__ = [ 7 | 'get_our_optimizer_strategy', 8 | 'CameraClsTrainer', 9 | 'CameraClsTrainer', 10 | 'CamDataParallel' 11 | ] 12 | 13 | from .optimizers import get_our_optimizer_strategy 14 | from .trainer import CameraClsTrainer, CameraClsTrainer 15 | from .data_parallel import CamDataParallel -------------------------------------------------------------------------------- /CBN/frameworks/training/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import time 7 | from utils.meters import AverageMeter 8 | 9 | 10 | class BaseTrainer(object): 11 | def __init__(self, opt, model, optimzier, criterion, summary_writer): 12 | self.opt = opt 13 | self.model = model 14 | self.optimizer = optimzier 15 | self.criterion = criterion 16 | self.summary_writer = summary_writer 17 | self.global_step = 0 18 | 19 | def train(self, epoch, data_loader): 20 | self.model.train() 21 | 22 | batch_time = AverageMeter() 23 | data_time = AverageMeter() 24 | losses = AverageMeter() 25 | 26 | start = time.time() 27 | for i, inputs in enumerate(data_loader): 28 | 29 | data_time.update(time.time() - start) 30 | # model optimizer 31 | self._parse_data(inputs) 32 | self._forward(epoch) 33 | 34 | self.optimizer.zero_grad() 35 | self._backward() 36 | self.optimizer.step() 37 | 38 | batch_time.update(time.time() - start) 39 | losses.update(self.loss.item()) 40 | 41 | # tensorboard 42 | self.global_step = epoch * len(data_loader) + i 43 | self.summary_writer.add_scalar('loss', self.loss.item(), self.global_step) 44 | self.summary_writer.add_scalar('lr', self.optimizer.param_groups[0]['lr'], self.global_step) 45 | 46 | start = time.time() 47 | 48 | if (i + 1) % self.opt.print_freq == 0: 49 | print('Epoch: [{}][{}/{}]\t' 50 | 'Batch Time {:.3f} ({:.3f})\t' 51 | 'Data Time {:.3f} ({:.3f})\t' 52 | 'Loss {:.3f} ({:.3f})\t' 53 | .format(epoch, i + 1, len(data_loader), 54 | batch_time.mean, batch_time.val, 55 | data_time.mean, data_time.val, 56 | losses.mean, losses.val)) 57 | 58 | param_group = self.optimizer.param_groups 59 | print('Epoch: [{}]\tEpoch Time {:.3f} s\tLoss {:.3f}\t' 60 | 'Lr {:.2e}' 61 | .format(epoch, batch_time.sum, losses.mean, param_group[0]['lr'])) 62 | 63 | def _parse_data(self, inputs): 64 | raise NotImplementedError 65 | 66 | def _forward(self): 67 | raise NotImplementedError 68 | 69 | def _backward(self): 70 | raise NotImplementedError 71 | -------------------------------------------------------------------------------- /CBN/frameworks/training/data_parallel.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | 3 | from torch.nn import DataParallel 4 | from torch.nn.parallel.scatter_gather import scatter, gather 5 | from torch.nn.parallel.replicate import replicate 6 | from torch.nn.parallel.parallel_apply import parallel_apply 7 | 8 | 9 | class CamDataParallel(DataParallel): 10 | def forward(self, *inputs, **kwargs): 11 | if not self.device_ids: 12 | return self.module(*inputs, **kwargs) 13 | 14 | for t in chain(self.module.parameters(), self.module.buffers()): 15 | if t.device != self.src_device_obj: 16 | raise RuntimeError("module must have its parameters and buffers " 17 | "on device {} (device_ids[0]) but found one of " 18 | "them on device: {}".format(self.src_device_obj, t.device)) 19 | 20 | all_inputs = inputs[0] 21 | all_kwargs = kwargs 22 | all_outputs = [] 23 | 24 | while len(all_inputs) > 0: 25 | num_required_gpu = min(len(all_inputs), len(self.device_ids)) 26 | actual_inputs = [all_inputs.pop(0) for _ in range(num_required_gpu)] 27 | inputs, kwargs = self.scatter(actual_inputs, all_kwargs, self.device_ids[:num_required_gpu]) 28 | replicas = self.replicate(self.module, self.device_ids[:num_required_gpu]) 29 | all_outputs.extend(self.parallel_apply(replicas, inputs, kwargs)) 30 | 31 | return self.gather(all_outputs, self.output_device) 32 | 33 | def replicate(self, module, device_ids): 34 | return replicate(module, device_ids) 35 | 36 | def scatter(self, input_list, kwargs, device_ids): 37 | inputs = [] 38 | for input, gpu in zip(input_list, device_ids): 39 | inputs.extend(scatter(input, [gpu], dim=0)) 40 | kwargs = scatter(kwargs, device_ids, dim=0) if kwargs else [] 41 | if len(inputs) < len(kwargs): 42 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 43 | elif len(kwargs) < len(inputs): 44 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 45 | inputs = tuple(inputs) 46 | kwargs = tuple(kwargs) 47 | return inputs, kwargs 48 | 49 | def parallel_apply(self, replicas, inputs, kwargs): 50 | return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) 51 | 52 | def gather(self, outputs, output_device): 53 | return gather(outputs, output_device, dim=self.dim) 54 | -------------------------------------------------------------------------------- /CBN/frameworks/training/optimizers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import torch 7 | 8 | 9 | def get_our_optimizer_strategy(opt, optim_policy=None): 10 | base_lr = 1e-2 11 | if opt.loss=='triplet': 12 | base_lr=2e-4 13 | optimizer = torch.optim.SGD( 14 | optim_policy, lr=base_lr, weight_decay=5e-4, momentum=0.9 15 | ) 16 | if opt.loss=='softmax': 17 | def adjust_lr(optimizer, ep): 18 | if ep < opt.decay_epoch: 19 | lr = 1e-2 20 | else: 21 | lr = 1e-3 22 | for i, p in enumerate(optimizer.param_groups): 23 | p['lr'] = lr 24 | return lr 25 | elif opt.loss=='triplet': 26 | def adjust_lr(optimizer, ep): 27 | lr=1e-3 28 | if ep >=opt.decay_epoch: 29 | lr = 1e-3 * (0.001 ** (float(ep + 1 - opt.decay_epoch)/ (opt.max_epoch + 1 - opt.decay_epoch))) 30 | for p in optimizer.param_groups: 31 | p['lr'] = lr 32 | return lr 33 | 34 | 35 | return optimizer, adjust_lr 36 | -------------------------------------------------------------------------------- /CBN/frameworks/training/trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | import collections 6 | import torch 7 | from frameworks.training.base import BaseTrainer 8 | from utils.meters import AverageMeter 9 | import time 10 | 11 | 12 | class CameraClsTrainer(BaseTrainer): 13 | def __init__(self, opt, model, optimizer, criterion, summary_writer): 14 | super().__init__(opt, model, optimizer, criterion, summary_writer) 15 | 16 | def _parse_data(self, inputs): 17 | imgs, pids, camids = inputs 18 | self.data = imgs.cuda() 19 | self.pids = pids.cuda() 20 | self.camids = camids.cuda() 21 | 22 | def _ogranize_data(self): 23 | unique_camids = torch.unique(self.camids).cpu().numpy() 24 | reorg_data = [] 25 | reorg_pids = [] 26 | for current_camid in unique_camids: 27 | current_camid = (self.camids == current_camid).nonzero().view(-1) 28 | if current_camid.size(0) > 1: 29 | data = torch.index_select(self.data, index=current_camid, dim=0) 30 | pids = torch.index_select(self.pids, index=current_camid, dim=0) 31 | reorg_data.append(data) 32 | reorg_pids.append(pids) 33 | 34 | # Sort the list for our modified data-parallel 35 | # This process helps to increase efficiency when utilizing multiple GPUs 36 | # However, our experiments show that this process slightly decreases the final performance 37 | # You can enable the following process if you prefer 38 | # sort_index = [x.size(0) for x in reorg_pids] 39 | # sort_index = [i[0] for i in sorted(enumerate(sort_index), key=lambda x: x[1], reverse=True)] 40 | # reorg_data = [reorg_data[i] for i in sort_index] 41 | # reorg_pids = [reorg_pids[i] for i in sort_index] 42 | # ===== The end of the sort process ==== # 43 | self.data = reorg_data 44 | self.pids = reorg_pids 45 | 46 | def _forward(self, data): 47 | feat, id_scores = self.model(data) 48 | return feat, id_scores 49 | 50 | def _backward(self): 51 | self.loss.backward() 52 | 53 | def train(self, epoch, data_loader): 54 | self.model.train() 55 | batch_time = AverageMeter() 56 | losses = AverageMeter() 57 | for i, inputs in enumerate(data_loader): 58 | self._parse_data(inputs) 59 | self._ogranize_data() 60 | 61 | torch.cuda.synchronize() 62 | tic = time.time() 63 | 64 | feat, id_scores = self._forward(self.data) 65 | pids = torch.cat(self.pids, dim=0) 66 | self.loss = self.criterion(feat, id_scores, pids, self.global_step, 67 | self.summary_writer) 68 | self.optimizer.zero_grad() 69 | self._backward() 70 | self.optimizer.step() 71 | 72 | torch.cuda.synchronize() 73 | batch_time.update(time.time() - tic) 74 | losses.update(self.loss.item()) 75 | # tensorboard 76 | self.global_step = epoch * len(data_loader) + i 77 | self.summary_writer.add_scalar('loss', self.loss.item(), self.global_step) 78 | self.summary_writer.add_scalar('lr', self.optimizer.param_groups[0]['lr'], self.global_step) 79 | if (i + 1) % self.opt.print_freq == 0: 80 | print('Epoch: [{}][{}/{}]\t' 81 | 'Batch Time {:.3f} ({:.3f})\t' 82 | 'Loss {:.3f} ({:.3f})\t' 83 | .format(epoch, i + 1, len(data_loader), 84 | batch_time.mean, batch_time.val, 85 | losses.mean, losses.val)) 86 | 87 | param_group = self.optimizer.param_groups 88 | print('Epoch: [{}]\tEpoch Time {:.3f} s\tLoss {:.3f}\t' 89 | 'Lr {:.2e}' 90 | .format(epoch, batch_time.sum, losses.mean, param_group[0]['lr'])) 91 | -------------------------------------------------------------------------------- /CBN/io_stream/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | __all__ = [ 7 | 'data_manager', 8 | 'NormalCollateFn', 9 | 'IdentitySampler', 10 | ] 11 | 12 | from . import data_manager 13 | from .samplers import * #NormalCollateFn, IdentitySampler 14 | -------------------------------------------------------------------------------- /CBN/io_stream/data_manager.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | import numpy as np 6 | 7 | from io_stream.datasets.market import Market1501 8 | from io_stream.datasets.msmt import MSMT17 9 | from io_stream.datasets.duke import Duke 10 | from io_stream.datasets.randperson import RandPerson 11 | from io_stream.datasets.combine import Combine 12 | from io_stream.datasets.unreal import Unreal 13 | from io_stream.datasets.syri import SyRI 14 | from io_stream.datasets.personx import PersonX 15 | 16 | class ReID_Data(Dataset): 17 | def __init__(self, dataset, transform, with_path=False): 18 | self.dataset = dataset 19 | self.transform = transform 20 | self.with_path = with_path 21 | 22 | def __getitem__(self, item): 23 | img_path, pid, camid = self.dataset[item] 24 | img = Image.open(img_path).convert('RGB') 25 | if self.transform is not None: 26 | img = self.transform(img) 27 | if self.with_path: 28 | return img,pid,camid,img_path 29 | return img, pid, camid 30 | 31 | def __len__(self): 32 | return len(self.dataset) 33 | 34 | 35 | 36 | """Create datasets""" 37 | 38 | __data_factory = { 39 | 'market': Market1501, 40 | 'duke': Duke, 41 | 'msmt': MSMT17, 42 | 'randperson':RandPerson, 43 | 'combine':Combine, 44 | 'unreal':Unreal, 45 | 'syri':SyRI, 46 | 'personx':PersonX 47 | } 48 | 49 | __folder_factory = { 50 | 'market': ReID_Data, 51 | 'duke': ReID_Data, 52 | 'msmt': ReID_Data, 53 | } 54 | 55 | def init_unreal_dataset(name,datasets,*args,**kwargs): 56 | sets = datasets.split(',') if type(datasets)!=tuple else datasets 57 | dataset=[] 58 | for s in sets: 59 | dataset.append(s) 60 | 61 | return __data_factory[name](dataset=dataset,*args,**kwargs) 62 | 63 | def init_combine_dataset(name,options,datasets,*args,**kwargs): 64 | sets=datasets.split(',') if type(datasets)!=tuple else datasets 65 | dataset=[] 66 | for s in sets: 67 | if s=='unreal': 68 | dataset.append(init_unreal_dataset(s,datasets=options.dataset,num_pids=options.num_pids, 69 | img_per_person = options.img_per_person)) 70 | else: 71 | dataset.append(init_dataset(s,*args,**kwargs)) 72 | 73 | return Combine(dataset=dataset,*args,**kwargs) 74 | 75 | def init_dataset(name, *args, **kwargs): 76 | if name not in __data_factory.keys(): 77 | raise KeyError("Unknown datasets: {}".format(name)) 78 | return __data_factory[name](*args, **kwargs) 79 | 80 | 81 | def init_datafolder(name, data_list, transforms,with_path=False): 82 | if name not in __folder_factory.keys(): 83 | return ReID_Data(data_list,transforms,with_path) 84 | return __folder_factory[name](data_list, transforms,with_path) 85 | -------------------------------------------------------------------------------- /CBN/io_stream/data_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | import numpy as np 4 | 5 | 6 | def reorganize_images_by_camera(data, sample_per_camera): 7 | cams = np.unique([x[2] for x in data]) 8 | images_per_cam = defaultdict(list) 9 | images_per_cam_sampled = defaultdict(list) 10 | for cam_id in cams: 11 | all_file_info = [x for x in data if x[2] == cam_id] 12 | all_file_info = sorted(all_file_info, key=lambda x: x[0]) 13 | random.shuffle(all_file_info) 14 | images_per_cam[cam_id] = all_file_info 15 | images_per_cam_sampled[cam_id] = all_file_info[:sample_per_camera] 16 | 17 | return images_per_cam, images_per_cam_sampled 18 | -------------------------------------------------------------------------------- /CBN/io_stream/datasets/combine.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import re 3 | import numpy as np 4 | import random 5 | from os import path as osp 6 | from collections import defaultdict 7 | from io_stream.data_utils import reorganize_images_by_camera 8 | 9 | 10 | class Combine(object): 11 | """ 12 | Market1501 13 | Reference: 14 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 15 | URL: http://www.liangzheng.org/Project/project_reid.html 16 | 17 | Dataset statistics: 18 | # identities: 1501 (+1 for background) 19 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 20 | """ 21 | dataset_dir = 'market' 22 | 23 | def __init__(self, root='data', dataset=None,**kwargs): 24 | self.name = 'combine' 25 | self.dataset=dataset 26 | share_cam = kwargs['share_cam'] 27 | train, num_train_pids, num_train_imgs =[],0,0 28 | cam_dict=dict() 29 | num_d = 0 30 | self.cams_of_dataset=[] 31 | self.len_of_real_dataset = len(self.dataset[0].train) 32 | for d in self.dataset: 33 | d_train = d.train 34 | cams=set() 35 | new_train=[] 36 | for item in d_train: 37 | img_path, pid , camid = item 38 | pid+=num_train_pids 39 | if not (num_d,camid) in cam_dict.keys(): 40 | cam_dict[(num_d,camid)]=len(cam_dict) 41 | cams.add(cam_dict[(num_d,camid)]) 42 | camid= cam_dict[(num_d,camid)] if not share_cam else camid 43 | new_train.append((img_path,pid,camid)) 44 | self.cams_of_dataset.append(cams) 45 | num_d +=1 46 | train.extend(new_train) 47 | num_train_pids+=d.num_train_pids 48 | num_train_imgs+=len(d.train) 49 | 50 | query, num_query_pids, num_query_imgs =[],0,0 51 | gallery, num_gallery_pids, num_gallery_imgs = [],0,0 52 | 53 | num_total_pids = num_train_pids + num_query_pids 54 | num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs 55 | 56 | print("=> Combine loaded") 57 | print("Dataset statistics:") 58 | print(" ------------------------------") 59 | print(" subset | # ids | # images") 60 | print(" ------------------------------") 61 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_imgs)) 62 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs)) 63 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs)) 64 | print(" ------------------------------") 65 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_imgs)) 66 | print(" ------------------------------") 67 | print("length of real data {}".format(self.len_of_real_dataset)) 68 | self.train = train 69 | self.query = query 70 | self.gallery = gallery 71 | 72 | self.num_train_pids = num_train_pids 73 | self.num_query_pids = num_query_pids 74 | self.num_gallery_pids = num_gallery_pids 75 | 76 | 77 | def _check_before_run(self): 78 | """Check if all files are available before going deeper""" 79 | if not osp.exists(self.dataset_dir): 80 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 81 | if not osp.exists(self.train_dir): 82 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 83 | if not osp.exists(self.query_dir): 84 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 85 | if not osp.exists(self.gallery_dir): 86 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 87 | 88 | def _process_dir(self, dir_path, relabel=False): 89 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 90 | pattern = re.compile(r'([-\d]+)_c(\d)') 91 | 92 | pid_container = set() 93 | for img_path in img_paths: 94 | pid, _ = map(int, pattern.search(img_path).groups()) 95 | if pid == -1: continue # junk images are just ignored 96 | pid_container.add(pid) 97 | 98 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 99 | if relabel == True: 100 | self.pid2label = pid2label 101 | 102 | dataset = [] 103 | for img_path in img_paths: 104 | pid, camid = map(int, pattern.search(img_path).groups()) 105 | if pid == -1: 106 | continue 107 | camid -= 1 # index starts from 0 108 | if relabel: pid = pid2label[pid] 109 | dataset.append((img_path, pid, camid)) 110 | 111 | if relabel: 112 | self.pid2label = pid2label 113 | num_pids = len(pid_container) 114 | num_imgs = len(dataset) 115 | return dataset, num_pids, num_imgs 116 | -------------------------------------------------------------------------------- /CBN/io_stream/datasets/duke.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import re 3 | from os import path as osp 4 | from io_stream.data_utils import reorganize_images_by_camera 5 | 6 | 7 | class Duke(object): 8 | """ 9 | Market1501 10 | Reference: 11 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 12 | URL: http://www.liangzheng.org/Project/project_reid.html 13 | 14 | Dataset statistics: 15 | # identities: 1501 (+1 for background) 16 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 17 | """ 18 | dataset_dir = 'duke' 19 | 20 | def __init__(self, root='data', **kwargs): 21 | self.name = 'duke' 22 | self.dataset_dir = osp.join(root, self.dataset_dir) 23 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 24 | self.query_dir = osp.join(self.dataset_dir, 'query') 25 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 26 | self._check_before_run() 27 | 28 | train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, relabel=True) 29 | query, num_query_pids, num_query_imgs = self._process_dir(self.query_dir, relabel=False) 30 | gallery, num_gallery_pids, num_gallery_imgs = self._process_dir(self.gallery_dir, relabel=False) 31 | 32 | num_total_pids = num_train_pids + num_query_pids 33 | num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs 34 | 35 | print("=> DukeMTMC-reID loaded") 36 | print("Dataset statistics:") 37 | print(" ------------------------------") 38 | print(" subset | # ids | # images") 39 | print(" ------------------------------") 40 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_imgs)) 41 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs)) 42 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs)) 43 | print(" ------------------------------") 44 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_imgs)) 45 | print(" ------------------------------") 46 | 47 | self.train = train 48 | self.query = query 49 | self.gallery = gallery 50 | 51 | self.num_train_pids = num_train_pids 52 | self.num_query_pids = num_query_pids 53 | self.num_gallery_pids = num_gallery_pids 54 | 55 | self.query_per_cam, self.query_per_cam_sampled = reorganize_images_by_camera(self.query, 56 | kwargs['num_bn_sample']) 57 | self.gallery_per_cam, self.gallery_per_cam_sampled = reorganize_images_by_camera(self.gallery, 58 | kwargs['num_bn_sample']) 59 | 60 | def _check_before_run(self): 61 | """Check if all files are available before going deeper""" 62 | if not osp.exists(self.dataset_dir): 63 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 64 | if not osp.exists(self.train_dir): 65 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 66 | if not osp.exists(self.query_dir): 67 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 68 | if not osp.exists(self.gallery_dir): 69 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 70 | 71 | def _process_dir(self, dir_path, relabel=False): 72 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 73 | pattern = re.compile(r'([-\d]+)_c(\d)') 74 | 75 | pid_container = set() 76 | for img_path in img_paths: 77 | pid, _ = map(int, pattern.search(img_path).groups()) 78 | if pid == -1: continue # junk images are just ignored 79 | pid_container.add(pid) 80 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 81 | 82 | dataset = [] 83 | for img_path in img_paths: 84 | pid, camid = map(int, pattern.search(img_path).groups()) 85 | if pid == -1: 86 | continue 87 | camid -= 1 # index starts from 0 88 | if relabel: pid = pid2label[pid] 89 | dataset.append((img_path, pid, camid)) 90 | if relabel: 91 | self.pid2label = pid2label 92 | num_pids = len(pid_container) 93 | num_imgs = len(dataset) 94 | 95 | dataset = sorted(dataset, key=lambda k: k[2]) 96 | return dataset, num_pids, num_imgs 97 | -------------------------------------------------------------------------------- /CBN/io_stream/datasets/market.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import re 3 | import numpy as np 4 | import random 5 | from os import path as osp 6 | from collections import defaultdict 7 | from io_stream.data_utils import reorganize_images_by_camera 8 | 9 | 10 | class Market1501(object): 11 | """ 12 | Market1501 13 | Reference: 14 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 15 | URL: http://www.liangzheng.org/Project/project_reid.html 16 | 17 | Dataset statistics: 18 | # identities: 1501 (+1 for background) 19 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 20 | """ 21 | dataset_dir = 'market' 22 | 23 | def __init__(self, root='data', **kwargs): 24 | self.name = 'market' 25 | self.dataset_dir = osp.join(root, self.dataset_dir) 26 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 27 | self.query_dir = osp.join(self.dataset_dir, 'query') 28 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 29 | self._check_before_run() 30 | 31 | train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, relabel=True) 32 | query, num_query_pids, num_query_imgs = self._process_dir(self.query_dir, relabel=False) 33 | gallery, num_gallery_pids, num_gallery_imgs = self._process_dir(self.gallery_dir, relabel=False) 34 | 35 | num_total_pids = num_train_pids + num_query_pids 36 | num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs 37 | 38 | print("=> Market1501 loaded") 39 | print("Dataset statistics:") 40 | print(" ------------------------------") 41 | print(" subset | # ids | # images") 42 | print(" ------------------------------") 43 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_imgs)) 44 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs)) 45 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs)) 46 | print(" ------------------------------") 47 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_imgs)) 48 | print(" ------------------------------") 49 | 50 | self.train = train 51 | self.query = query 52 | self.gallery = gallery 53 | 54 | self.num_train_pids = num_train_pids 55 | self.num_query_pids = num_query_pids 56 | self.num_gallery_pids = num_gallery_pids 57 | 58 | self.query_per_cam, self.query_per_cam_sampled = reorganize_images_by_camera(self.query, 59 | kwargs['num_bn_sample']) 60 | self.gallery_per_cam, self.gallery_per_cam_sampled = reorganize_images_by_camera(self.gallery, 61 | kwargs['num_bn_sample']) 62 | 63 | def _check_before_run(self): 64 | """Check if all files are available before going deeper""" 65 | if not osp.exists(self.dataset_dir): 66 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 67 | if not osp.exists(self.train_dir): 68 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 69 | if not osp.exists(self.query_dir): 70 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 71 | if not osp.exists(self.gallery_dir): 72 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 73 | 74 | def _process_dir(self, dir_path, relabel=False): 75 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 76 | pattern = re.compile(r'([-\d]+)_c(\d)') 77 | 78 | pid_container = set() 79 | for img_path in img_paths: 80 | pid, _ = map(int, pattern.search(img_path).groups()) 81 | if pid == -1: continue # junk images are just ignored 82 | pid_container.add(pid) 83 | 84 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 85 | if relabel == True: 86 | self.pid2label = pid2label 87 | 88 | dataset = [] 89 | for img_path in img_paths: 90 | pid, camid = map(int, pattern.search(img_path).groups()) 91 | if pid == -1: 92 | continue 93 | camid -= 1 # index starts from 0 94 | if relabel: pid = pid2label[pid] 95 | dataset.append((img_path, pid, camid)) 96 | 97 | if relabel: 98 | self.pid2label = pid2label 99 | num_pids = len(pid_container) 100 | num_imgs = len(dataset) 101 | return dataset, num_pids, num_imgs 102 | -------------------------------------------------------------------------------- /CBN/io_stream/datasets/msmt.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os.path as osp 6 | from io_stream.data_utils import reorganize_images_by_camera 7 | 8 | 9 | class MSMT17(object): 10 | """ 11 | MSMT17 12 | 13 | Reference: 14 | Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018. 15 | 16 | URL: http://www.pkuvmc.com/publications/msmt17.html 17 | 18 | Dataset statistics: 19 | # identities: 4101 20 | # images: 32621 (train) + 11659 (query) + 82161 (gallery) 21 | # cameras: 15 22 | """ 23 | dataset_dir = 'msmt17' 24 | 25 | def __init__(self, root='data', **kwargs): 26 | self.name = "msmt" 27 | self.dataset_dir = osp.join(root, self.dataset_dir) 28 | self.train_dir = osp.join(self.dataset_dir, 'train') 29 | self.test_dir = osp.join(self.dataset_dir, 'test') 30 | self.list_train_path = osp.join(self.dataset_dir, 'list_train.txt') 31 | self.list_val_path = osp.join(self.dataset_dir, 'list_val.txt') 32 | self.list_query_path = osp.join(self.dataset_dir, 'list_query.txt') 33 | self.list_gallery_path = osp.join(self.dataset_dir, 'list_gallery.txt') 34 | 35 | self._check_before_run() 36 | train = self._process_dir(self.train_dir, self.list_train_path) 37 | query = self._process_dir(self.test_dir, self.list_query_path) 38 | gallery = self._process_dir(self.test_dir, self.list_gallery_path) 39 | 40 | self.train = train 41 | self.query = query 42 | self.gallery = gallery 43 | 44 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 45 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 46 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 47 | 48 | self.num_total_pids = self.num_train_pids + self.num_query_pids 49 | self.num_total_imgs = self.num_train_imgs + self.num_query_imgs + self.num_gallery_imgs 50 | 51 | print("=> MSMT17 loaded") 52 | print("Dataset statistics:") 53 | print(" ------------------------------") 54 | print(" subset | # ids | # images") 55 | print(" ------------------------------") 56 | print(" train | {:5d} | {:8d}".format(self.num_train_pids, self.num_train_imgs)) 57 | print(" query | {:5d} | {:8d}".format(self.num_query_pids, self.num_query_imgs)) 58 | print(" gallery | {:5d} | {:8d}".format(self.num_gallery_pids, self.num_gallery_imgs)) 59 | print(" ------------------------------") 60 | print(" total | {:5d} | {:8d}".format(self.num_total_pids, self.num_total_imgs)) 61 | print(" ------------------------------") 62 | 63 | self.query_per_cam, self.query_per_cam_sampled = reorganize_images_by_camera(self.query, 64 | kwargs['num_bn_sample']) 65 | self.gallery_per_cam, self.gallery_per_cam_sampled = reorganize_images_by_camera(self.gallery, 66 | kwargs['num_bn_sample']) 67 | 68 | def get_imagedata_info(self, data): 69 | pids, cams = [], [] 70 | for _, pid, camid in data: 71 | pids += [pid] 72 | cams += [camid] 73 | pids = set(pids) 74 | cams = set(cams) 75 | num_pids = len(pids) 76 | num_cams = len(cams) 77 | num_imgs = len(data) 78 | return num_pids, num_imgs, num_cams 79 | 80 | def _check_before_run(self): 81 | """Check if all files are available before going deeper""" 82 | if not osp.exists(self.dataset_dir): 83 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 84 | if not osp.exists(self.train_dir): 85 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 86 | if not osp.exists(self.test_dir): 87 | raise RuntimeError("'{}' is not available".format(self.test_dir)) 88 | 89 | def _process_dir(self, dir_path, list_path): 90 | with open(list_path, 'r') as txt: 91 | lines = txt.readlines() 92 | dataset = [] 93 | pid_container = set() 94 | pid2label=dict() 95 | for img_idx, img_info in enumerate(lines): 96 | img_path, pid = img_info.split(' ') 97 | pid = int(pid) # no need to relabel 98 | camid = int(img_path.split('_')[2]) - 1 # index starts from 0 99 | #if camid==12:continue 100 | if pid in pid2label.keys(): 101 | pid=pid2label[pid] 102 | else: 103 | pid=len(pid2label) 104 | pid2label[pid]=pid 105 | 106 | 107 | img_path = osp.join(dir_path, img_path) 108 | dataset.append((img_path, pid, camid)) 109 | pid_container.add(pid) 110 | 111 | # check if pid starts from 0 and increments with 1 112 | for idx, pid in enumerate(pid_container): 113 | assert idx == pid, "See code comment for explanation" 114 | return dataset 115 | -------------------------------------------------------------------------------- /CBN/io_stream/datasets/personx.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import re 3 | import numpy as np 4 | import random 5 | from os import path as osp 6 | from collections import defaultdict 7 | from io_stream.data_utils import reorganize_images_by_camera 8 | 9 | 10 | class PersonX(object): 11 | dataset_dir = 'personx' 12 | 13 | def __init__(self, root='data', **kwargs): 14 | self.name = 'personx' 15 | self.dataset_dir = osp.join(root, self.dataset_dir) 16 | self.train_dir = osp.join(self.dataset_dir,'*', 'bounding_box_train') 17 | self.query_dir = osp.join(self.dataset_dir,'*' ,'query') 18 | self.gallery_dir = osp.join(self.dataset_dir,'*', 'bounding_box_test') 19 | self._check_before_run() 20 | 21 | train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, relabel=True) 22 | query, num_query_pids, num_query_imgs = self._process_dir(self.query_dir, relabel=False) 23 | gallery, num_gallery_pids, num_gallery_imgs = self._process_dir(self.gallery_dir, relabel=False) 24 | 25 | num_total_pids = num_train_pids + num_query_pids 26 | num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs 27 | 28 | print("=> PersonX loaded") 29 | print("Dataset statistics:") 30 | print(" ------------------------------") 31 | print(" subset | # ids | # images") 32 | print(" ------------------------------") 33 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_imgs)) 34 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs)) 35 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs)) 36 | print(" ------------------------------") 37 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_imgs)) 38 | print(" ------------------------------") 39 | 40 | self.train = train 41 | self.query = query 42 | self.gallery = gallery 43 | 44 | self.num_train_pids = num_train_pids 45 | self.num_query_pids = num_query_pids 46 | self.num_gallery_pids = num_gallery_pids 47 | 48 | self.query_per_cam, self.query_per_cam_sampled = reorganize_images_by_camera(self.query, 49 | kwargs['num_bn_sample']) 50 | self.gallery_per_cam, self.gallery_per_cam_sampled = reorganize_images_by_camera(self.gallery, 51 | kwargs['num_bn_sample']) 52 | 53 | def _check_before_run(self): 54 | """Check if all files are available before going deeper""" 55 | None 56 | 57 | def _process_dir(self, dir_path, relabel=False): 58 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 59 | pattern = re.compile(r'([-\d]+)_c(\d)') 60 | 61 | pid_container = set() 62 | for img_path in img_paths: 63 | pid, _ = map(int, pattern.search(img_path).groups()) 64 | if pid == -1: continue # junk images are just ignored 65 | pid_container.add(pid) 66 | 67 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 68 | if relabel == True: 69 | self.pid2label = pid2label 70 | 71 | dataset = [] 72 | for img_path in img_paths: 73 | pid, camid = map(int, pattern.search(img_path).groups()) 74 | if pid == -1: 75 | continue 76 | camid -= 1 # index starts from 0 77 | if relabel: pid = pid2label[pid] 78 | dataset.append((img_path, pid, camid)) 79 | 80 | if relabel: 81 | self.pid2label = pid2label 82 | num_pids = len(pid_container) 83 | num_imgs = len(dataset) 84 | return dataset, num_pids, num_imgs 85 | -------------------------------------------------------------------------------- /CBN/io_stream/datasets/randperson.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import re 3 | import numpy as np 4 | import random 5 | from os import path as osp 6 | from collections import defaultdict 7 | from io_stream.data_utils import reorganize_images_by_camera 8 | 9 | 10 | class RandPerson(object): 11 | """ 12 | The first dataset with bad illumanations and low resolutions. 13 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 14 | """ 15 | dataset_dir = 'randperson' 16 | 17 | def __init__(self, root='data', **kwargs): 18 | self.name = 'randperson' 19 | self.dataset_dir = osp.join(root, self.dataset_dir) 20 | self.train_dir = self.dataset_dir 21 | self.query_dir = None 22 | self.gallery_dir = None 23 | self._check_before_run() 24 | 25 | train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, relabel=True) 26 | query, num_query_pids, num_query_imgs = [],0,0 #self._process_dir(self.query_dir, relabel=False) 27 | gallery, num_gallery_pids, num_gallery_imgs = [],0,0 #self._process_dir(self.gallery_dir, relabel=False) 28 | 29 | num_total_pids = num_train_pids + num_query_pids 30 | num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs 31 | 32 | print("=> RandPerson loaded") 33 | print("Dataset statistics:") 34 | print(" ------------------------------") 35 | print(" subset | # ids | # images") 36 | print(" ------------------------------") 37 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_imgs)) 38 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs)) 39 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs)) 40 | print(" ------------------------------") 41 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_imgs)) 42 | print(" ------------------------------") 43 | 44 | self.train = train 45 | self.query = query 46 | self.gallery = gallery 47 | 48 | self.num_train_pids = num_train_pids 49 | self.num_query_pids = num_query_pids 50 | self.num_gallery_pids = num_gallery_pids 51 | 52 | #self.query_per_cam, self.query_per_cam_sampled = reorganize_images_by_camera(self.query, 53 | # kwargs['num_bn_sample']) 54 | #self.gallery_per_cam, self.gallery_per_cam_sampled = reorganize_images_by_camera(self.gallery, 55 | # kwargs['num_bn_sample']) 56 | 57 | def _check_before_run(self): 58 | """Check if all files are available before going deeper""" 59 | if not osp.exists(self.dataset_dir): 60 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 61 | if not osp.exists(self.train_dir): 62 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 63 | 64 | def _process_dir(self, dir_path, relabel=False): 65 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 66 | pattern = re.compile(r'([-\d]+)_s([\d]+)_c([\d]+)') 67 | 68 | pid_container = set() 69 | cam_container = set() 70 | random.seed(1) 71 | # pids = random.sample(list(range(8000)),800) 72 | for img_path in img_paths: 73 | pid, sce,cam = map(int, pattern.search(img_path).groups()) 74 | pid_container.add(pid) 75 | 76 | cam_container.add((sce,cam)) 77 | 78 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 79 | cam2label = {cid: label for label, cid in enumerate(cam_container)} 80 | if relabel == True: 81 | self.pid2label = pid2label 82 | 83 | dataset = [] 84 | for img_path in img_paths: 85 | pid,sid, camid = map(int, pattern.search(img_path).groups()) 86 | if (sid,camid) not in cam_container:continue 87 | camid = cam2label[(sid,camid)] # index starts from 0 88 | if relabel: pid = pid2label[pid] 89 | dataset.append((img_path, pid, camid)) 90 | 91 | 92 | print('cameras:',len(cam_container)) 93 | num_pids = len(pid_container) 94 | num_imgs = len(dataset) 95 | return dataset, num_pids, num_imgs 96 | -------------------------------------------------------------------------------- /CBN/io_stream/datasets/soma.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import re 3 | import numpy as np 4 | import random 5 | from os import path as osp 6 | from collections import defaultdict 7 | from io_stream.data_utils import reorganize_images_by_camera 8 | 9 | 10 | class SOMAset(object): 11 | """ 12 | Market1501 13 | Reference: 14 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 15 | URL: http://www.liangzheng.org/Project/project_reid.html 16 | 17 | Dataset statistics: 18 | # identities: 1501 (+1 for background) 19 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 20 | """ 21 | dataset_dir = 'soma' 22 | 23 | def __init__(self, root='data', **kwargs): 24 | self.name = 'soma' 25 | self.dataset_dir = osp.join(root, self.dataset_dir) 26 | self.train_dir = self.dataset_dir 27 | self.query_dir = self.dataset_dir 28 | self.gallery_dir = self.dataset_dir 29 | self._check_before_run() 30 | 31 | train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, relabel=True) 32 | query, num_query_pids, num_query_imgs =[],0,0# self._process_dir(self.query_dir, relabel=False) 33 | gallery, num_gallery_pids, num_gallery_imgs =[],0,0 #self._process_dir(self.gallery_dir, relabel=False) 34 | 35 | num_total_pids = num_train_pids + num_query_pids 36 | num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs 37 | 38 | print("=> Market1501 loaded") 39 | print("Dataset statistics:") 40 | print(" ------------------------------") 41 | print(" subset | # ids | # images") 42 | print(" ------------------------------") 43 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_imgs)) 44 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs)) 45 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs)) 46 | print(" ------------------------------") 47 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_imgs)) 48 | print(" ------------------------------") 49 | 50 | self.train = train 51 | self.query = query 52 | self.gallery = gallery 53 | 54 | self.num_train_pids = num_train_pids 55 | self.num_query_pids = num_query_pids 56 | self.num_gallery_pids = num_gallery_pids 57 | 58 | # self.query_per_cam, self.query_per_cam_sampled = reorganize_images_by_camera(self.query, 59 | # kwargs['num_bn_sample']) 60 | # self.gallery_per_cam, self.gallery_per_cam_sampled = reorganize_images_by_camera(self.gallery, 61 | # kwargs['num_bn_sample']) 62 | 63 | def _check_before_run(self): 64 | """Check if all files are available before going deeper""" 65 | if not osp.exists(self.dataset_dir): 66 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 67 | if not osp.exists(self.train_dir): 68 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 69 | if not osp.exists(self.query_dir): 70 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 71 | if not osp.exists(self.gallery_dir): 72 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 73 | 74 | def _process_dir(self, dir_path, relabel=False): 75 | img_paths = glob.glob(osp.join(dir_path, '*/*/*.jpg')) 76 | 77 | pid_container = set() 78 | for img_path in img_paths: 79 | arrs = img_path.split('/') 80 | pid = arrs[-3]#,arrs[-2]) 81 | pid_container.add(pid) 82 | 83 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 84 | if relabel == True: 85 | self.pid2label = pid2label 86 | 87 | dataset = [] 88 | for img_path in img_paths: 89 | arrs = img_path.split('/') 90 | pid =arrs[-3]#,arrs[-2]) 91 | camid = 0 92 | if relabel: pid = pid2label[pid] 93 | dataset.append((img_path, pid, camid)) 94 | 95 | if relabel: 96 | self.pid2label = pid2label 97 | num_pids = len(pid_container) 98 | num_imgs = len(dataset) 99 | return dataset, num_pids, num_imgs 100 | -------------------------------------------------------------------------------- /CBN/io_stream/datasets/syri.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import re 3 | import numpy as np 4 | import random 5 | from os import path as osp 6 | from collections import defaultdict 7 | from io_stream.data_utils import reorganize_images_by_camera 8 | 9 | 10 | class SyRI(object): 11 | """ 12 | Market1501 13 | Reference: 14 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 15 | URL: http://www.liangzheng.org/Project/project_reid.html 16 | 17 | Dataset statistics: 18 | # identities: 1501 (+1 for background) 19 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 20 | """ 21 | dataset_dir = 'syri' 22 | 23 | def __init__(self, root='data', **kwargs): 24 | self.name = 'syri' 25 | self.dataset_dir = osp.join(root, self.dataset_dir) 26 | self.train_dir = self.dataset_dir 27 | self.query_dir = osp.join(self.dataset_dir, 'query') 28 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 29 | self._check_before_run() 30 | 31 | train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, relabel=True) 32 | query, num_query_pids, num_query_imgs = [],0,0 #self._process_dir(self.query_dir, relabel=False) 33 | gallery, num_gallery_pids, num_gallery_imgs = [],0 ,0 # self._process_dir(self.gallery_dir, relabel=False) 34 | 35 | num_total_pids = num_train_pids + num_query_pids 36 | num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs 37 | 38 | print("=> SyRI loaded") 39 | print("Dataset statistics:") 40 | print(" ------------------------------") 41 | print(" subset | # ids | # images") 42 | print(" ------------------------------") 43 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_imgs)) 44 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs)) 45 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs)) 46 | print(" ------------------------------") 47 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_imgs)) 48 | print(" ------------------------------") 49 | 50 | self.train = train 51 | self.query = query 52 | self.gallery = gallery 53 | 54 | self.num_train_pids = num_train_pids 55 | self.num_query_pids = num_query_pids 56 | self.num_gallery_pids = num_gallery_pids 57 | 58 | #self.query_per_cam, self.query_per_cam_sampled = reorganize_images_by_camera(self.query, 59 | # kwargs['num_bn_sample']) 60 | #self.gallery_per_cam, self.gallery_per_cam_sampled = reorganize_images_by_camera(self.gallery, 61 | # kwargs['num_bn_sample']) 62 | 63 | def _check_before_run(self): 64 | """Check if all files are available before going deeper""" 65 | if not osp.exists(self.dataset_dir): 66 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 67 | if not osp.exists(self.train_dir): 68 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 69 | 70 | def _process_dir(self, dir_path, relabel=False): 71 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 72 | pattern = re.compile(r'([\d]+)_([\d]+)_') 73 | dataset = [] 74 | pid_container=set() 75 | for img_path in img_paths: 76 | pid, camid = map(int, pattern.search(img_path).groups()) 77 | camid = camid//2 # index starts from 0 78 | dataset.append((img_path, pid, camid)) 79 | pid_container.add(pid) 80 | 81 | num_pids = len(pid_container) 82 | num_imgs = len(dataset) 83 | return dataset, num_pids, num_imgs 84 | -------------------------------------------------------------------------------- /CBN/io_stream/datasets/unreal.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import glob 3 | import re 4 | import numpy as np 5 | import random 6 | from os import path as osp 7 | from collections import defaultdict 8 | from io_stream.data_utils import reorganize_images_by_camera 9 | import os 10 | 11 | class Unreal(object): 12 | """ 13 | The first dataset with bad illumanations and low resolutions. 14 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 15 | """ 16 | 17 | 18 | def __init__(self, root='data',dataset=None, **kwargs): 19 | self.name = 'unreal_dataset' 20 | self.dataset_dir = dataset 21 | if type(self.dataset_dir)==list: 22 | self.dataset_dir = [osp.join(root,d) for d in self.dataset_dir] 23 | self.train_dir = [osp.join(d,'images') for d in self.dataset_dir] 24 | else: 25 | self.dataset_dir = osp.join(root, self.dataset_dir) 26 | self.train_dir = osp.join(self.dataset_dir, 'images') 27 | 28 | self.num_pids = kwargs['num_pids'] 29 | self.num_cams = kwargs['num_cams'] 30 | self.img_per_person = kwargs['img_per_person'] 31 | self.cams_of_dataset=None 32 | train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, relabel=True) 33 | with open('list_unreal_train.txt','w') as train_file: 34 | for t in train: 35 | train_file.write('{} {} {} \n'.format(t[0],t[1],t[2])) 36 | query, num_query_pids, num_query_imgs = [],0,0 37 | gallery, num_gallery_pids, num_gallery_imgs = [],0,0 38 | 39 | num_total_pids = num_train_pids + num_query_pids 40 | num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs 41 | 42 | print("=> Unreal {} loaded".format(dataset)) 43 | print("Dataset statistics:") 44 | print(" ------------------------------") 45 | print(" subset | # ids | # images") 46 | print(" ------------------------------") 47 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_imgs)) 48 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs)) 49 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs)) 50 | print(" ------------------------------") 51 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_imgs)) 52 | print(" ------------------------------") 53 | 54 | self.train = train 55 | self.query = query 56 | self.gallery = gallery 57 | 58 | self.num_train_pids = num_train_pids 59 | self.num_query_pids = num_query_pids 60 | self.num_gallery_pids = num_gallery_pids 61 | 62 | def size(self,path): 63 | return os.path.getsize(path)/float(1024) 64 | 65 | def _process_dir(self, dir_path, relabel = True): 66 | if type(dir_path)!=list: 67 | dir_path=[dir_path] 68 | 69 | cid_container = set() 70 | pid_container = set() 71 | 72 | img_paths =[] 73 | for d in dir_path: 74 | if not os.path.exists(d): 75 | assert False, 'Check unreal data dir' 76 | 77 | iii = glob.glob(osp.join(d, '*.*g')) 78 | print(d,len(iii)) 79 | img_paths.extend( iii ) 80 | 81 | # regex: scene id, person model version, person id, camera id ,frame id 82 | pattern = re.compile(r'unreal_v([\d]+).([\d]+)/images/([-\d]+)_c([\d]+)_([\d]+)') 83 | cid_container = set() 84 | pid_container = set() 85 | pid_container_sep = defaultdict(set) 86 | for img_path in img_paths: 87 | sid,pv, pid, cid,fid = map(int, pattern.search(img_path).groups()) 88 | # if pv==3 and pid>=1800: 89 | ## continue 90 | # if pv<3 and pid>=1800: 91 | # continue # For training, we use 1600 models. Others may be used for testing later. 92 | cid_container.add((sid,cid)) 93 | pid_container_sep[pv].add((pv,pid)) 94 | for k in pid_container_sep.keys(): 95 | print("Unreal pids ({}): {}".format(k,len(pid_container_sep[k]))) 96 | print("Unreal cams: {}".format(len(cid_container))) 97 | # we need a balanced sampler here . 98 | num_pids_sep = self.num_pids // len(pid_container_sep) 99 | for k in pid_container_sep.keys(): 100 | pid_container_sep[k]=random.sample(pid_container_sep[k],num_pids_sep) if len(pid_container_sep[k])>=num_pids_sep else pid_container_sep[k] 101 | for pid in pid_container_sep[k]: 102 | pid_container.add(pid) 103 | 104 | if self.num_cams!=0: 105 | cid_container = random.sample(cid_container,self.num_cams) 106 | print(cid_container) 107 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 108 | cid2label = {cid: label for label, cid in enumerate(cid_container)} 109 | if relabel == True: 110 | self.pid2label = pid2label 111 | self.cid2label = cid2label 112 | 113 | dataset = [] 114 | ss=[] 115 | for img_path in tqdm(img_paths): 116 | sid,pv, pid, cid,fid = map(int, pattern.search(img_path).groups()) 117 | if (pv,pid) not in pid_container:continue 118 | if (sid,cid) not in cid_container:continue 119 | if relabel: 120 | pid = pid2label[(pv,pid)] 121 | camid = cid2label[(sid,cid)] 122 | # if self.size(img_path)>2.5: 123 | dataset.append((img_path, pid, camid)) 124 | print("Sampled pids: {}".format(len(pid_container))) 125 | print("Sampled imgs: {}".format(len(dataset))) 126 | if relabel: 127 | self.pid2label = pid2label 128 | if len(dataset)>self.img_per_person*self.num_pids: 129 | dataset=random.sample(dataset,self.img_per_person*self.num_pids) 130 | 131 | num_pids = len(pid_container) 132 | num_imgs = len(dataset) 133 | print("Sampled imgs: {}".format(len(dataset))) 134 | return dataset, num_pids, num_imgs 135 | 136 | 137 | -------------------------------------------------------------------------------- /CBN/io_stream/samplers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from collections import defaultdict 4 | import random 5 | import numpy as np 6 | import torch 7 | import copy 8 | from torch.utils.data.sampler import Sampler 9 | 10 | 11 | class IdentitySampler(Sampler): 12 | def __init__(self, data_source, batch_size, num_instances): 13 | if batch_size < num_instances: 14 | raise ValueError('batch_size={} must be no less ' 15 | 'than num_instances={}'.format(batch_size, num_instances)) 16 | self.data_source = data_source 17 | self.batch_size = batch_size 18 | self.num_instances = num_instances 19 | self.num_pids_per_batch = self.batch_size // self.num_instances # approximate 20 | self.index_dic = defaultdict(list) 21 | for index, (_, pid, camid) in enumerate(self.data_source): 22 | self.index_dic[pid].append(index) 23 | self.pids = list(self.index_dic.keys()) 24 | 25 | def __iter__(self): 26 | batch_idxs_dict = defaultdict(list) 27 | for pid in self.pids: 28 | idxs = copy.deepcopy(self.index_dic[pid]) 29 | if len(idxs) < self.num_instances: 30 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 31 | random.shuffle(idxs) 32 | batch_idxs = [] 33 | for idx in idxs: 34 | batch_idxs.append(idx) 35 | if len(batch_idxs) == self.num_instances: 36 | batch_idxs_dict[pid].append(batch_idxs) 37 | batch_idxs = [] 38 | 39 | avai_pids = copy.deepcopy(self.pids) 40 | final_idxs = [] 41 | while len(avai_pids) >= self.num_pids_per_batch: 42 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 43 | for pid in selected_pids: 44 | batch_idxs = batch_idxs_dict[pid].pop(0) 45 | final_idxs.extend(batch_idxs) 46 | if len(batch_idxs_dict[pid]) == 0: 47 | avai_pids.remove(pid) 48 | self.length = len(final_idxs) 49 | return iter(final_idxs) 50 | 51 | def __len__(self): 52 | return self.length 53 | 54 | 55 | class IdentityCameraSampler(Sampler): 56 | def __init__(self, data_source, batch_size, num_instances,cams_of_dataset=None,len_of_real_data=None): 57 | if batch_size < num_instances: 58 | raise ValueError('batch_size={} must be no less ' 59 | 'than num_instances={}'.format(batch_size, num_instances)) 60 | self.data_source = data_source 61 | self.batch_size = batch_size 62 | self.num_instances = num_instances 63 | self.num_pids_per_batch = self.batch_size // self.num_instances # approximate 64 | self.num_cams_per_batch = 8 65 | self.index_dic = defaultdict(list) 66 | self.cam_index_dic = dict() 67 | self.num_pids_per_cam = self.num_pids_per_batch//self.num_cams_per_batch 68 | for index, (_, pid, camid) in enumerate(self.data_source): 69 | self.index_dic[pid].append(index) 70 | if camid not in self.cam_index_dic.keys(): 71 | self.cam_index_dic[camid]=defaultdict(list) 72 | self.cam_index_dic[camid][pid].append(index) 73 | self.pids = list(self.index_dic.keys()) 74 | self.cams_of_dataset=cams_of_dataset 75 | self.len_of_real_data = len_of_real_data 76 | 77 | def __iter__(self): 78 | final_idxs = [] 79 | length = 2*self.len_of_real_data if self.len_of_real_data is not None else len(self.data_source) 80 | # F setting 81 | #length = len(self.data_source) 82 | while(len(final_idxs) < length): 83 | if self.cams_of_dataset is not None: 84 | # C setting 85 | #c_rnd = np.random.choice(list(self.cam_index_dic.keys()),size=1)[0] 86 | #for cams_of_data in self.cams_of_dataset: 87 | # if c_rnd in cams_of_data: 88 | # cams = np.random.choice(list(cams_of_data),size=self.num_cams_per_batch,replace=True) 89 | # break 90 | 91 | # D setting 92 | c_rnd = np.random.choice([i for i in range(len(self.cams_of_dataset))],size=1)[0] 93 | cams = np.random.choice(list(self.cams_of_dataset[c_rnd]),size=self.num_cams_per_batch,replace=True) 94 | 95 | # E setting: data balance, mixed in mini-batches (dontsep) 96 | #cams0 = np.random.choice(list(self.cams_of_dataset[0]),size=self.num_cams_per_batch//2) 97 | #cams1 = np.random.choice(list(self.cams_of_dataset[1]),size=self.num_cams_per_batch//2) 98 | #cams = list(cams0)+list(cams1) 99 | 100 | # F setting databalfix 101 | # cams = np.random.choice(list(self.cam_index_dic.keys()),size=self.num_cams_per_batch,replace=True) 102 | else: 103 | cams = np.random.choice(list(self.cam_index_dic.keys()),size=self.num_cams_per_batch,replace=True) 104 | for c in cams: 105 | pids = np.random.choice(list(self.cam_index_dic[c].keys()),size=self.num_pids_per_cam, replace=True) 106 | for p in pids: 107 | idxs =np.random.choice(self.cam_index_dic[c][p],size=self.num_instances,replace=True) 108 | random.shuffle(idxs) 109 | final_idxs.extend(idxs) 110 | self.length=len(final_idxs) 111 | return iter(final_idxs) 112 | 113 | 114 | def __iter_old__(self): 115 | batch_idxs_dict = defaultdict(list) 116 | for pid in self.pids: 117 | idxs = copy.deepcopy(self.index_dic[pid]) 118 | if len(idxs) < self.num_instances: 119 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 120 | random.shuffle(idxs) 121 | batch_idxs = [] 122 | for idx in idxs: 123 | batch_idxs.append(idx) 124 | if len(batch_idxs) == self.num_instances: 125 | batch_idxs_dict[pid].append(batch_idxs) 126 | batch_idxs = [] 127 | 128 | avai_pids = copy.deepcopy(self.pids) 129 | final_idxs = [] 130 | while len(avai_pids) >= self.num_pids_per_batch: 131 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 132 | for pid in selected_pids: 133 | batch_idxs = batch_idxs_dict[pid].pop(0) 134 | final_idxs.extend(batch_idxs) 135 | if len(batch_idxs_dict[pid]) == 0: 136 | avai_pids.remove(pid) 137 | self.length = len(final_idxs) 138 | return iter(final_idxs) 139 | 140 | def __len__(self): 141 | return self.length 142 | 143 | 144 | 145 | class NormalCollateFn: 146 | def __call__(self, batch): 147 | img_tensor = [x[0] for x in batch] 148 | pids = np.array([x[1] for x in batch]) 149 | camids = np.array([x[2] for x in batch]) 150 | return torch.stack(img_tensor, dim=0), torch.from_numpy(pids), torch.from_numpy(np.array(camids)) 151 | -------------------------------------------------------------------------------- /CBN/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.3.1 2 | torchvision==0.4.2 3 | tensorboard 4 | future 5 | fire 6 | tqdm -------------------------------------------------------------------------------- /CBN/test_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | import pickle 6 | import os 7 | import sys 8 | import random 9 | import tqdm 10 | import time 11 | import numpy as np 12 | import torch 13 | from torch.utils.data import DataLoader 14 | 15 | from config import opt 16 | from io_stream import data_manager 17 | 18 | from frameworks.models import ResNetBuilder 19 | from frameworks.evaluating import evaluator_manager 20 | 21 | from utils.serialization import Logger, load_previous_model, load_moco_model 22 | from utils.transforms import TestTransform 23 | import os 24 | 25 | def test(**kwargs): 26 | opt._parse(kwargs) 27 | if opt.save_dir.startswith('pytorch'): 28 | opt.save_dir=os.path.split(opt.save_dir)[1] 29 | save_file = 'log_test_{}_{}.txt'.format(opt.testset_name,opt.testepoch) 30 | if opt.testset_name == 'unreal_test': 31 | save_file = 'log_test_{}_{}_{}.txt'.format(opt.testset_name,opt.testepoch,opt.datasets) 32 | 33 | sys.stdout = Logger( 34 | os.path.join("./pytorch-ckpt/current", opt.save_dir, save_file)) 35 | torch.manual_seed(opt.seed) 36 | random.seed(opt.seed) 37 | np.random.seed(opt.seed) 38 | 39 | use_gpu = torch.cuda.is_available() 40 | print('initializing dataset {}'.format(opt.testset_name)) 41 | if opt.testset_name != 'unreal_test': 42 | dataset = data_manager.init_dataset(name=opt.testset_name, 43 | num_bn_sample=opt.batch_num_bn_estimatation * opt.test_batch) 44 | 45 | if opt.testset_name=='unreal_test': 46 | dataset = data_manager.init_unreal_dataset(name=opt.testset_name, 47 | datasets = opt.datasets, 48 | num_pids=opt.num_pids, 49 | img_per_person = opt.img_per_person, 50 | num_bn_sample=opt.batch_num_bn_estimatation * opt.test_batch) 51 | pin_memory = True if use_gpu else False 52 | 53 | print('loading model from {} ...'.format(opt.save_dir)) 54 | model = ResNetBuilder() 55 | if opt.model_path is not None: 56 | model_path = opt.model_path 57 | else: 58 | model_path = os.path.join("./pytorch-ckpt/current", opt.save_dir, 59 | '{}.pth.tar'.format(opt.testepoch)) 60 | if opt.model_path is not None and opt.model_path.find('moco')!=-1: 61 | model = load_moco_model(model,model_path) 62 | else: 63 | model = load_previous_model(model, model_path, load_fc_layers=False) 64 | model.eval() 65 | 66 | if use_gpu: 67 | model = torch.nn.DataParallel(model).cuda() 68 | reid_evaluator = evaluator_manager.init_evaluator(opt.testset_name, model, flip=True) 69 | 70 | def _calculate_bn_and_features(all_data, sampled_data): 71 | time.sleep(1) 72 | all_features, all_ids, all_cams,all_paths = [], [], [], [] 73 | available_cams = list(sampled_data) 74 | cam_bn_info = dict() 75 | for current_cam in tqdm.tqdm(available_cams): 76 | camera_data = all_data[current_cam] 77 | if len(camera_data)==0: 78 | continue 79 | camera_samples = sampled_data[current_cam] 80 | data_for_camera_loader = DataLoader( 81 | data_manager.init_datafolder(opt.testset_name, camera_samples, TestTransform(opt.height, opt.width),with_path=True), 82 | batch_size=opt.test_batch, num_workers=opt.workers, 83 | pin_memory=False, drop_last=True 84 | ) 85 | bn_info = reid_evaluator.collect_sim_bn_info(data_for_camera_loader) 86 | cam_bn_info[current_cam]=bn_info 87 | camera_data = all_data[current_cam] 88 | if len(camera_data)==0: 89 | continue 90 | data_loader = DataLoader( 91 | data_manager.init_datafolder(opt.testset_name, camera_data, TestTransform(opt.height, opt.width), with_path =True), 92 | batch_size=opt.test_batch, num_workers=opt.workers, 93 | pin_memory=pin_memory, shuffle=False 94 | ) 95 | fs, pids, camids, img_paths = reid_evaluator.produce_features(data_loader, normalize=True) 96 | all_features.append(fs) 97 | all_ids.append(pids) 98 | all_cams.append(camids) 99 | all_paths.extend(img_paths) 100 | 101 | all_features = torch.cat(all_features, 0) 102 | all_ids = np.concatenate(all_ids, axis=0) 103 | all_cams = np.concatenate(all_cams, axis=0) 104 | 105 | time.sleep(1) 106 | 107 | pickle.dump(cam_bn_info,open('cam_bn_info-{}-{}.pkl'.format(opt.save_dir,opt.testset_name),'wb')) 108 | return all_features, all_ids, all_cams, all_paths 109 | 110 | print('Processing query features...') 111 | qf, q_pids, q_camids, q_paths = _calculate_bn_and_features(dataset.query_per_cam, dataset.query_per_cam_sampled) 112 | print('Processing gallery features...') 113 | gf, g_pids, g_camids, g_paths = _calculate_bn_and_features(dataset.gallery_per_cam, 114 | dataset.gallery_per_cam_sampled) 115 | if opt.testset_name =='msmt_sepcam': 116 | cid2label = dataset.cid2label 117 | label2cid = dict() 118 | for c,l in cid2label.items(): 119 | label2cid[l]=c[0] 120 | print(label2cid) 121 | q_cids = list() 122 | for qc in q_camids: 123 | q_cids.append(label2cid[qc]) 124 | g_cids = list() 125 | for gc in g_camids: 126 | g_cids.append(label2cid[gc]) 127 | q_camids = np.asarray(q_cids) 128 | g_camids = np.asarray(g_cids) 129 | 130 | pickle.dump({'qp':q_paths,'gp':g_paths},open('paths.pkl','wb')) 131 | print('Computing CMC and mAP...') 132 | reid_evaluator.get_final_results_with_features(qf, q_pids, q_camids, gf, g_pids, g_camids, q_paths,g_paths) 133 | 134 | 135 | if __name__ == '__main__': 136 | import fire 137 | 138 | fire.Fire() 139 | -------------------------------------------------------------------------------- /CBN/train_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import sys 7 | import os 8 | import random 9 | import numpy as np 10 | 11 | import torch 12 | from torch import nn 13 | from torch.backends import cudnn 14 | from torch.utils.data import DataLoader 15 | from torch.utils.tensorboard import SummaryWriter 16 | 17 | from config import opt 18 | 19 | from io_stream import data_manager, NormalCollateFn, IdentitySampler, IdentityCameraSampler 20 | 21 | from frameworks.models import ResNetBuilder 22 | from frameworks.training import CameraClsTrainer, get_our_optimizer_strategy, CamDataParallel 23 | 24 | from utils.serialization import Logger, save_checkpoint, load_moco_model, load_previous_model 25 | from utils.transforms import TrainTransform 26 | from utils.loss import TripletLoss 27 | 28 | 29 | def train(**kwargs): 30 | opt._parse(kwargs) 31 | # torch.backends.cudnn.deterministic = True # I think this line may slow down the training process 32 | # set random seed and cudnn benchmark 33 | torch.manual_seed(opt.seed) 34 | random.seed(opt.seed) 35 | np.random.seed(opt.seed) 36 | 37 | use_gpu = torch.cuda.is_available() 38 | sys.stdout = Logger(os.path.join('./pytorch-ckpt/current', opt.save_dir, 'log_train.txt')) 39 | 40 | if use_gpu: 41 | print('currently using GPU') 42 | cudnn.benchmark = True 43 | else: 44 | print('currently using cpu') 45 | print(opt._state_dict()) 46 | print('initializing dataset {}'.format(opt.trainset_name)) 47 | if opt.trainset_name=='combine': 48 | #input dataset name as 'datasets' 49 | train_dataset= data_manager.init_combine_dataset(name=opt.trainset_name,options=opt, 50 | datasets=opt.datasets, 51 | num_bn_sample=opt.batch_num_bn_estimatation * opt.test_batch, 52 | share_cam=opt.share_cam,num_pids=opt.num_pids) 53 | elif opt.trainset_name=='unreal': 54 | # input dataset dir in 'datasets' 55 | train_dataset = data_manager.init_unreal_dataset(name=opt.trainset_name, 56 | datasets = opt.datasets, 57 | num_pids=opt.num_pids, 58 | num_cams=opt.num_cams, 59 | img_per_person = opt.img_per_person) 60 | 61 | else: 62 | train_dataset = data_manager.init_dataset(name=opt.trainset_name, 63 | num_bn_sample=opt.batch_num_bn_estimatation * opt.test_batch,num_pids=opt.num_pids) 64 | pin_memory = True if use_gpu else False 65 | summary_writer = SummaryWriter(os.path.join('./pytorch-ckpt/current', opt.save_dir, 'tensorboard_log')) 66 | 67 | if opt.cam_bal: 68 | IDSampler=IdentityCameraSampler 69 | else: 70 | IDSampler=IdentitySampler 71 | if opt.trainset_name=='combine': 72 | samp = IDSampler(train_dataset.train, opt.train_batch, opt.num_instances,train_dataset.cams_of_dataset,train_dataset.len_of_real_dataset) 73 | else: 74 | samp = IDSampler(train_dataset.train, opt.train_batch, opt.num_instances) 75 | 76 | trainloader = DataLoader( 77 | data_manager.init_datafolder(opt.trainset_name, train_dataset.train, TrainTransform(opt.height, opt.width)), 78 | sampler=samp , 79 | batch_size=opt.train_batch, num_workers=opt.workers, 80 | pin_memory=pin_memory, drop_last=True, collate_fn=NormalCollateFn() 81 | ) 82 | print('initializing model ...') 83 | num_pid = train_dataset.num_train_pids if opt.loss=='softmax' else None 84 | model = ResNetBuilder(num_pid) 85 | if opt.model_path is not None and 'moco' in opt.model_path: 86 | model = load_moco_model(model,opt.model_path) 87 | elif opt.model_path is not None: 88 | model = load_previous_model(model, opt.model_path, load_fc_layers=False) 89 | optim_policy = model.get_optim_policy() 90 | print('model size: {:.5f}M'.format(sum(p.numel() 91 | for p in model.parameters()) / 1e6)) 92 | 93 | if use_gpu: 94 | model = CamDataParallel(model).cuda() 95 | 96 | xent = nn.CrossEntropyLoss() 97 | triplet = TripletLoss() 98 | def standard_cls_criterion(feat, 99 | preditions, 100 | targets, 101 | global_step, 102 | summary_writer): 103 | identity_loss = xent(preditions, targets) 104 | identity_accuracy = torch.mean((torch.argmax(preditions, dim=1) == targets).float()) 105 | summary_writer.add_scalar('cls_loss', identity_loss.item(), global_step) 106 | summary_writer.add_scalar('cls_accuracy', identity_accuracy.item(), global_step) 107 | return identity_loss 108 | 109 | def triplet_criterion(feat,preditons,targets,global_step,summary_writer): 110 | triplet_loss, acc = triplet(feat,targets) 111 | summary_writer.add_scalar('loss', triplet_loss.item(), global_step) 112 | print(np.mean(acc.item())) 113 | summary_writer.add_scalar('accuracy', acc.item(), global_step) 114 | return triplet_loss 115 | 116 | 117 | 118 | # get trainer and evaluator 119 | optimizer, adjust_lr = get_our_optimizer_strategy(opt, optim_policy) 120 | if opt.loss=='softmax': 121 | crit = standard_cls_criterion 122 | elif opt.loss=='triplet': 123 | crit = triplet_criterion 124 | reid_trainer = CameraClsTrainer(opt, model, optimizer, crit, summary_writer) 125 | 126 | print('Start training') 127 | for epoch in range(opt.max_epoch): 128 | adjust_lr(optimizer, epoch) 129 | reid_trainer.train(epoch, trainloader) 130 | if (epoch+1)%opt.save_step==0: 131 | if use_gpu: 132 | state_dict = model.module.state_dict() 133 | else: 134 | state_dict = model.state_dict() 135 | 136 | save_checkpoint({ 137 | 'state_dict': state_dict, 138 | 'epoch': epoch + 1, 139 | }, save_dir=os.path.join('./pytorch-ckpt/current', opt.save_dir), ep=epoch+1) 140 | 141 | # if (epoch+1)%15==0: 142 | # save_checkpoint({ 143 | # 'state_dict': state_dict, 144 | # 'epoch': epoch + 1, 145 | # }, save_dir=os.path.join('./pytorch-ckpt/current', opt.save_dir)) 146 | 147 | if use_gpu: 148 | state_dict = model.module.state_dict() 149 | else: 150 | state_dict = model.state_dict() 151 | 152 | save_checkpoint({ 153 | 'state_dict': state_dict, 154 | 'epoch': epoch + 1, 155 | }, save_dir=os.path.join('./pytorch-ckpt/current', opt.save_dir)) 156 | 157 | 158 | if __name__ == '__main__': 159 | import fire 160 | fire.Fire() 161 | -------------------------------------------------------------------------------- /CBN/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | -------------------------------------------------------------------------------- /CBN/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import math 7 | import numpy as np 8 | 9 | 10 | class AverageMeter(object): 11 | def __init__(self): 12 | self.n = 0 13 | self.sum = 0.0 14 | self.var = 0.0 15 | self.val = 0.0 16 | self.mean = np.nan 17 | self.std = np.nan 18 | 19 | def update(self, value, n=1): 20 | self.val = value 21 | self.sum += value 22 | self.var += value * value 23 | self.n += n 24 | 25 | if self.n == 0: 26 | self.mean, self.std = np.nan, np.nan 27 | elif self.n == 1: 28 | self.mean, self.std = self.sum, np.inf 29 | else: 30 | self.mean = self.sum / self.n 31 | self.std = math.sqrt( 32 | (self.var - self.n * self.mean * self.mean) / (self.n - 1.0)) 33 | 34 | def value(self): 35 | return self.mean, self.std 36 | 37 | def reset(self): 38 | self.n = 0 39 | self.sum = 0.0 40 | self.var = 0.0 41 | self.val = 0.0 42 | self.mean = np.nan 43 | self.std = np.nan 44 | -------------------------------------------------------------------------------- /CBN/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import errno 7 | import os 8 | import sys 9 | 10 | import os.path as osp 11 | import torch 12 | 13 | 14 | class Logger(object): 15 | """ 16 | Write console output to external text file. 17 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 18 | """ 19 | 20 | def __init__(self, fpath=None): 21 | self.console = sys.stdout 22 | self.file = None 23 | if fpath is not None: 24 | mkdir_if_missing(os.path.dirname(fpath)) 25 | self.file = open(fpath, 'a') 26 | 27 | def __del__(self): 28 | self.close() 29 | 30 | def __enter__(self): 31 | pass 32 | 33 | def __exit__(self, *args): 34 | self.close() 35 | 36 | def write(self, msg): 37 | self.console.write(msg) 38 | if self.file is not None: 39 | self.file.write(msg) 40 | 41 | def flush(self): 42 | self.console.flush() 43 | if self.file is not None: 44 | self.file.flush() 45 | os.fsync(self.file.fileno()) 46 | 47 | def close(self): 48 | self.console.close() 49 | if self.file is not None: 50 | self.file.close() 51 | 52 | 53 | def mkdir_if_missing(dir_path): 54 | try: 55 | os.makedirs(dir_path) 56 | except OSError as e: 57 | if e.errno != errno.EEXIST: 58 | raise 59 | 60 | 61 | def save_checkpoint(state, save_dir,ep=None): 62 | mkdir_if_missing(save_dir) 63 | if ep ==None: 64 | fpath = osp.join(save_dir, 'model_best.pth.tar') 65 | else: 66 | fpath = osp.join(save_dir, 'model_{}.pth.tar'.format(ep)) 67 | torch.save(state, fpath) 68 | 69 | def load_moco_model(model,file_path=None): 70 | checkpoint = torch.load(file_path) 71 | 72 | state_dict = checkpoint['state_dict'] 73 | for k in list(state_dict.keys()): 74 | # retain only encoder_q up to before the embedding layer 75 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): 76 | state_dict['base.'+k[len("module.encoder_q."):]] = state_dict[k] 77 | del state_dict[k] 78 | msg = model.load_state_dict(state_dict,strict=False) 79 | print(msg.missing_keys) 80 | return model 81 | 82 | def load_previous_model(model, file_path=None, load_fc_layers=True): 83 | assert file_path is not None, 'Must define the path of the saved model' 84 | ckpt = torch.load(file_path) 85 | if load_fc_layers: 86 | state_dict = ckpt['state_dict'] 87 | else: 88 | state_dict = dict() 89 | for k, v in ckpt['state_dict'].items(): 90 | if 'classif' not in k: 91 | state_dict[k] = v 92 | 93 | msg = model.load_state_dict(state_dict, strict=False) 94 | print('missing keys:',msg.missing_keys) 95 | print('model size: {:.5f}M'.format(sum(p.numel() for p in model.parameters()) / 1e6)) 96 | return model 97 | -------------------------------------------------------------------------------- /CBN/utils/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | from torchvision import transforms as T 7 | import random 8 | 9 | class TrainTransform(object): 10 | def __init__(self, h, w): 11 | self.h = h 12 | self.w = w 13 | 14 | def __call__(self, x): 15 | #if random.randint(0,1)==0: 16 | # x = T.Resize((40,20))(x) 17 | x = T.Resize((self.h, self.w))(x) 18 | x = T.RandomHorizontalFlip()(x) 19 | x = T.Pad(10)(x) 20 | x = T.RandomCrop(size=(self.h, self.w))(x) 21 | x = T.ToTensor()(x) 22 | x = T.Normalize(mean=[0.485, 0.456, 0.406], 23 | std=[0.229, 0.224, 0.225])(x) 24 | return x 25 | 26 | 27 | class DownTransform(object): 28 | def __init__(self, h, w): 29 | self.h = h 30 | self.w = w 31 | 32 | def __call__(self, x=None): 33 | x = T.Resize((40,20))(x) 34 | x = T.Resize((self.h, self.w))(x) 35 | x = T.ToTensor()(x) 36 | x = T.Normalize(mean=[0.485, 0.456, 0.406], 37 | std=[0.229, 0.224, 0.225])(x) 38 | return x 39 | 40 | class TestTransform(object): 41 | def __init__(self, h, w): 42 | self.h = h 43 | self.w = w 44 | 45 | def __call__(self, x=None): 46 | x = T.Resize((self.h, self.w))(x) 47 | x = T.ToTensor()(x) 48 | x = T.Normalize(mean=[0.485, 0.456, 0.406], 49 | std=[0.229, 0.224, 0.225])(x) 50 | return x 51 | -------------------------------------------------------------------------------- /Download.md: -------------------------------------------------------------------------------- 1 | # Download Synthesized Data 2 | 3 | Our synthesized data (named Unreal in the paper) is generated with Makehuman, Mixamo, and UnrealEngine 4. We provide 1.2M images of 4.8K identities, captured from 4 unreal environments. 4 | 5 | Beihang Netdisk: [Download Link](https://bhpan.buaa.edu.cn:443/link/BD6502DF5A2A2434BC5FC62793F80F96) valid until: 2024-01-01 6 | 7 | BaiduPan: [Download Link](https://pan.baidu.com/s/1P_UKdhmuDvJNQHuO81ifww) password: abcd 8 | 9 | Google Drive: [Download Link](https://drive.google.com/drive/folders/1sQHVBWvDwn-SVJtMqZDpk2HK9g9tnZ1I?usp=sharing) 10 | 11 | The image path is formulated as: unreal_v{X}.{Y}/images/{P}\_c{D}_{F}.jpg, 12 | for example, unreal_v3.1/images/333_c001_78.jpg. 13 | 14 | _X_ represents the ID of unreal environment; _Y_ is the version of human models; _P_ is the person identity label; _D_ is the camera label; _F_ is the frame number. 15 | 16 | We provide three types of human models: version 1 is the basic type; version 2 contains accessories, like handbags, hats and backpacks; version 3 contains hard samples with similar global appearance. 17 | Four virtual environments are used in our synthesized data: the first three are city environments and the last one is a supermarket. 18 | Note that cameras under different virtual environments may have the same label and persons of different versions may also have the same identity label. 19 | Therefore, images with the same (Y, P) belong the the same virtual person; images with the same (X, D) belong the same camera. 20 | 21 | ## Cite our paper 22 | 23 | If you find our work useful in your research, please kindly cite: 24 | 25 | ``` 26 | @inproceedings{unrealperson, 27 | title={UnrealPerson: An Adaptive Pipeline towards Costless Person Re-identification}, 28 | author={Tianyu Zhang and Lingxi Xie and Longhui Wei and Zijie Zhuang and Yongfei Zhang and Bo Li and Qi Tian}, 29 | year={2021}, 30 | booktitle={CVPR} 31 | } 32 | ``` 33 | 34 | If you have any questions about the data or paper, please leave an issue or contact me: 35 | zhangtianyu@buaa.edu.cn 36 | 37 | -------------------------------------------------------------------------------- /Experiments.md: -------------------------------------------------------------------------------- 1 | # Experiments of UnrealPerson 2 | 3 | The codes are based on [CBN](https://github.com/automan000/Camera-based-Person-ReID) (ECCV 2020) and [JVTC](https://github.com/ljn114514/JVTC) (ECCV 2020). 4 | 5 | ### Direct Transfer and Supervised Fine-tuning 6 | 7 | We use Camera-based Batch Normalization baseline for direct transfer and supervised fine-tuning experiments. 8 | 9 | **1. Clone this repo and change directory to CBN** 10 | ```bash 11 | git clone https://github.com/FlyHighest/UnrealPerson.git 12 | cd UnrealPerson/CBN 13 | ``` 14 | 15 | **2. Download Market-1501, DukeMTMC-reID, MSMT17, UnrealPerson data and organize them as follows:** 16 |
17 | . 18 | +-- data 19 | | +-- market 20 | | +-- bounding_box_train 21 | | +-- query 22 | | +-- bounding_box_test 23 | | +-- duke 24 | | +-- bounding_box_train 25 | | +-- query 26 | | +-- bounding_box_test 27 | | +-- msmt17 28 | | +-- train 29 | | +-- test 30 | | +-- list_train.txt 31 | | +-- list_val.txt 32 | | +-- list_query.txt 33 | | +-- list_gallery.txt 34 | | +-- unreal_vX.Y 35 | | +-- images 36 | | +-- unreal_vX.Y 37 | | +-- images 38 | + -- other files in this repo 39 |40 | 41 | 42 | 43 | **3. Install the required packages** 44 | ```console 45 | pip install -r requirements.txt 46 | ``` 47 | 48 | 49 | **4. Put the official PyTorch [ResNet-50](https://download.pytorch.org/models/resnet50-19c8e357.pth) pretrained model to your home folder: 50 | '~/.torch/models/'** 51 | 52 | 53 | **5. Train a ReID model with our synthesized data** 54 | 55 | Reproduce the results in our paper: 56 | 57 | ```console 58 | CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=0,1 \ 59 | python train_model.py train --trainset_name unreal --datasets='unreal_v1.1,unreal_v2.1,unreal_v3.1,unreal_v4.1,unreal_v1.2,unreal_v2.2,unreal_v3.2,unreal_v4.2,unreal_v1.3,unreal_v2.3,unreal_v3.3,unreal_v4.3' --save_dir='unreal_4678_v1v2v3_cambal_3000' --save_step 15 --num_pids 3000 --cam_bal True --img_per_person 40 60 | ``` 61 | 62 | We also provide the trained weights of this experiment in the data download links above. 63 | 64 | Configs: 65 | When ``trainset_name`` is unreal, ``datasets`` contains the directories of unreal data that will be used. ``num_pids`` is the number of humans and ``cam_bal`` denotes the camera balanced sampling strategy is adopted. ``img_per_person`` controls the size of the training set. 66 | 67 | More configurations are in [config.py](https://github.com/FlyHighest/UnrealPerson/CBN/config.py). 68 | 69 | **6.1 Direct transfer to real datasets** 70 | ```console 71 | CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=0 \ 72 | python test_model.py test --testset_name market --save_dir='unreal_4678_v1v2v3_cambal_3000' 73 | ``` 74 | 75 | **6.2 Fine-tuning** 76 | ```console 77 | CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=1,0 \ 78 | python train_model.py train --trainset_name market --save_dir='market_unrealpretrain_demo' --max_epoch 60 --decay_epoch 40 --model_path pytorch-ckpt/current/unreal_4678_v1v2v3_cambal_3000/model_best.pth.tar 79 | 80 | 81 | CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=0 \ 82 | python test_model.py test --testset_name market --save_dir='market_unrealpretrain_demo' 83 | ``` 84 | 85 | 86 | ### Unsupervised Domain Adaptation 87 | 88 | We use joint visual and temporal consistency (JVTC) framework. CBN is also implemented in JVTC. 89 | 90 | **1. Clone this repo and change directory to JVTC** 91 | 92 | ```bash 93 | git clone https://github.com/FlyHighest/UnrealPerson.git 94 | cd UnrealPerson/JVTC 95 | ``` 96 | 97 | **2. Prepare data** 98 | 99 | Basicly, it is the same as CBN, except for an extra directory ``bounding_box_train_camstyle_merge``, which can be downloaded from [ECN](https://github.com/zhunzhong07/ECN). We suggest using ``ln -s`` to save disk space. 100 |
101 | . 102 | +-- data 103 | | +-- market 104 | | +-- bounding_box_train 105 | | +-- query 106 | | +-- bounding_box_test 107 | | +-- bounding_box_train_camstyle_merge 108 | + -- other files in this repo 109 |110 | 111 | **3. Install the required packages** 112 | 113 | ```console 114 | pip install -r ../CBN/requirements.txt 115 | ``` 116 | 117 | 118 | **4. Put the official PyTorch [ResNet-50](https://download.pytorch.org/models/resnet50-19c8e357.pth) pretrained model to your home folder: 119 | '~/.torch/models/'** 120 | 121 | **5. Train and test** 122 | 123 | (Unreal to MSMT) 124 | 125 | ```console 126 | python train_cbn.py --gpu_ids 0,1,2 --src unreal --tar msmt --num_cam 6 --name unreal2msmt --max_ep 60 127 | 128 | python test_cbn.py --gpu_ids 1 --weights snapshot/unreal2msmt/resnet50_unreal2market_epoch60_cbn.pth --name 'unreal2msmt' --tar market --num_cam 6 --joint True 129 | ``` 130 | 131 | The unreal data used in JVTC is defined in list_unreal/list_unreal_train.txt. The CBN codes support generating this file (see CBN/io_stream/datasets/unreal.py). 132 | 133 | More details can be seen in [JVTC](https://github.com/ljn114514/JVTC). 134 | 135 | ### References 136 | 137 | - [1] Rethinking the Distribution Gap of Person Re-identification with Camera-Based Batch Normalization. ECCV 2020. 138 | 139 | - [2] Joint Visual and Temporal Consistency for Unsupervised Domain Adaptive Person Re-Identification. ECCV 2020. 140 | 141 | 142 | ## Cite our paper 143 | 144 | If you find our work useful in your research, please kindly cite: 145 | 146 | ``` 147 | @inproceedings{unrealperson, 148 | title={UnrealPerson: An Adaptive Pipeline towards Costless Person Re-identification}, 149 | author={Tianyu Zhang and Lingxi Xie and Longhui Wei and Zijie Zhuang and Yongfei Zhang and Bo Li and Qi Tian}, 150 | year={2021}, 151 | booktitle={CVPR} 152 | } 153 | ``` 154 | 155 | If you have any questions about the data or paper, please leave an issue or contact me: 156 | zhangtianyu@buaa.edu.cn -------------------------------------------------------------------------------- /JVTC/README.md: -------------------------------------------------------------------------------- 1 | This code is based on JVTC. We implement CBN in ``train_cbn.py`` and ``test_cbn.py``. 2 | 3 | The original repo: [JVTC](https://github.com/ljn114514/JVTC). 4 | 5 | The paper: [Joint Visual and Temporal Consistency for Unsupervised Domain Adaptive Person Re-Identification ECCV 2020](https://arxiv.org/pdf/2007.10854.pdf). 6 | -------------------------------------------------------------------------------- /JVTC/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | import warnings 6 | 7 | 8 | class DefaultConfig(object): 9 | seed = 0 10 | # dataset options 11 | trainset_name = 'market' 12 | testset_name = 'duke' 13 | # sampler 14 | num_cam = 6 15 | # default optimization params 16 | train_batch = 64 17 | test_batch = 64 18 | max_epoch = 60 19 | decay_epoch = 40 20 | # estimate bn statistics 21 | batch_num_bn_estimatation = 10 22 | # io 23 | print_freq = 50 24 | save_dir = './pytorch-ckpt/market' 25 | 26 | def _parse(self, kwargs): 27 | for k, v in kwargs.items(): 28 | if not hasattr(self, k): 29 | warnings.warn("Warning: opt has not attribut %s" % k) 30 | setattr(self, k, v) 31 | 32 | def _state_dict(self): 33 | return {k: getattr(self, k) for k, _ in DefaultConfig.__dict__.items() 34 | if not k.startswith('_')} 35 | 36 | 37 | opt = DefaultConfig() 38 | 39 | -------------------------------------------------------------------------------- /JVTC/list_msmt/rename.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import glob 3 | import os 4 | from tqdm import tqdm 5 | 6 | train_files = open('list_train.txt','r') 7 | 8 | txt = open('list_msmt_train.txt','w') 9 | 10 | temporal_slot = ['0113morning','0113noon','0113afternoon', 11 | '0114morning','0114noon','0114afternoon', 12 | '0302morning','0302noon','0302afternoon', 13 | '0303morning','0303noon','0303afternoon'] 14 | 15 | for f in tqdm(train_files.readlines()): 16 | if 'fake' in f or len(f.split('_'))<5: 17 | continue 18 | else: 19 | 20 | pid = f.split('/')[0] 21 | cam = f.split('_')[2] 22 | img_num = f.split("_")[1] 23 | slot = f.split("_")[3] 24 | 25 | frame = f.split("_")[4] 26 | 27 | frame = int(frame)+10000*temporal_slot.index(slot) 28 | 29 | 30 | new_name = '{}_c{}_0{}.jpg'.format(pid,int(cam),img_num) 31 | txt.write('{} {} {} {}\n'.format(new_name, pid, cam, frame)) 32 | 33 | txt.close() 34 | 35 | 36 | -------------------------------------------------------------------------------- /JVTC/list_msmt/rename_test.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import glob 3 | import os 4 | from tqdm import tqdm 5 | 6 | test_files = open('list_test.txt','r') 7 | 8 | txt = open('list_msmt_test.txt','w') 9 | 10 | temporal_slot = ['0113morning','0113noon','0113afternoon', 11 | '0114morning','0114noon','0114afternoon', 12 | '0302morning','0302noon','0302afternoon', 13 | '0303morning','0303noon','0303afternoon'] 14 | # query/0000/0000_000_01_0303morning_0015_0.jpg 0 01 0015 15 | for f in tqdm(test_files.readlines()): 16 | if True: 17 | pid = f.split('/')[1] 18 | cam = f.split('_')[2] 19 | slot = f.split("_")[3] 20 | 21 | frame = f.split("_")[4] 22 | 23 | frame = int(frame)+10000*temporal_slot.index(slot) 24 | 25 | 26 | new_name = f.split(' ')[0] 27 | txt.write('{} {} {} {}\n'.format(new_name, pid, cam, frame)) 28 | 29 | txt.close() 30 | 31 | 32 | -------------------------------------------------------------------------------- /JVTC/multi_train_cbn.py: -------------------------------------------------------------------------------- 1 | import os, torch,sys 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from torch.utils.data import DataLoader 5 | 6 | from utils.resnet import resnet50 7 | from utils.dataset import imgdataset, imgdataset_camtrans, IdentityCameraSampler, imgdataset_withsource 8 | from utils.losses import Losses#, LocalLoss, GlobalLoss 9 | from utils.evaluators import evaluate_all 10 | from utils.lr_adjust import StepLrUpdater, SetLr 11 | from utils.util import CamDataParallel , organize_data 12 | import argparse 13 | from tqdm import tqdm,trange 14 | from utils.logger import Logger 15 | 16 | parser = argparse.ArgumentParser(description='Training') 17 | parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2') 18 | parser.add_argument('--name',default='ft_ResNet50', type=str, help='output model name') 19 | parser.add_argument('--src',default='/home/zzd/Market/pytorch',type=str, help='training dir path') 20 | parser.add_argument('--src2',default='/home/zzd/Market/pytorch',type=str, help='training dir path') 21 | parser.add_argument('--tar',default='/home/zzd/Market/pytorch',type=str, help='training dir path') 22 | 23 | parser.add_argument('--num_cam', default=8, type=int, help='tar cam') 24 | parser.add_argument('--max_ep', default=8, type=int, help='tar cam') 25 | 26 | opt = parser.parse_args() 27 | 28 | 29 | 30 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids 31 | ########### HYPER ########### 32 | base_lr = 0.01 33 | num_epoches = 100 34 | batch_size = 128 35 | num_instances=4 36 | K = 4 37 | num_cam = opt.num_cam 38 | sys.stdout = Logger(os.path.join('./snapshot',opt.name,'log_train.txt')) 39 | ########## DATASET ########### 40 | dataset_path = 'data/' 41 | if opt.src=='unreal': 42 | src_dir = dataset_path 43 | else: 44 | src_dir = dataset_path +opt.src+ '/bounding_box_train_camstyle_merge/' 45 | src_annfile = 'list_{}/list_{}_train.txt'.format(opt.src,opt.src) 46 | train_dataset = imgdataset(dataset_dir=src_dir, txt_path=src_annfile, transformer='train') 47 | 48 | dataset_path = 'data/' 49 | if opt.src2=='unreal': 50 | src_dir2 = dataset_path 51 | else: 52 | src_dir2 = dataset_path +opt.src2+ '/bounding_box_train_camstyle_merge/' 53 | src2_annfile = 'list_{}/list_{}_train.txt'.format(opt.src2,opt.src2) 54 | train_dataset2 = imgdataset(dataset_dir=src_dir2, txt_path=src2_annfile, transformer='train') 55 | 56 | train_datasource = train_dataset.data_source 57 | numpids = len(train_dataset.pids) 58 | numcams = len(train_dataset.cams) 59 | for img, pid, cam in train_dataset2.data_source: 60 | train_datasource.append((img,pid+numpids,cam+10000)) 61 | 62 | train_dataset = imgdataset_withsource(train_datasource) 63 | 64 | 65 | sampler = IdentityCameraSampler(train_dataset.data_source, batch_size, num_instances) 66 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, sampler = sampler, num_workers=4, drop_last=True) 67 | 68 | 69 | 70 | tar_dir = dataset_path + opt.tar+ '/bounding_box_train_camstyle_merge/' 71 | tar_dir_test = dataset_path +opt.tar+ '/' 72 | tar_annfile = 'list_{}/list_{}_train.txt'.format(opt.tar,opt.tar) 73 | tar_annfile_test = 'list_{}/list_{}_test.txt'.format(opt.tar,opt.tar) 74 | 75 | train_dataset_t = imgdataset_camtrans(dataset_dir=tar_dir, txt_path=tar_annfile, 76 | transformer='train', num_cam=num_cam, K=K) 77 | train_loader_t = DataLoader(dataset=train_dataset_t, batch_size=int(batch_size/K), shuffle=True, num_workers=4, drop_last=True) 78 | 79 | 80 | ########### MODEL ########### 81 | imageNet_pretrain = 'resnet50-19c8e357.pth' 82 | model, param = resnet50(pretrained=imageNet_pretrain, num_classes=numpids+len(train_dataset2.pids)) 83 | model.cuda() 84 | model = CamDataParallel(model)#, device_ids=[0,1]) 85 | 86 | losses = Losses(K=K, 87 | batch_size=batch_size, 88 | bank_size=len(train_dataset_t), 89 | ann_file=tar_annfile, 90 | cam_num=num_cam) 91 | losses = losses.cuda() 92 | optimizer = torch.optim.SGD(param, lr=base_lr, momentum=0.9, weight_decay=5e-4, nesterov=True) 93 | 94 | ########### TRAIN ########### 95 | target_iter = iter(train_loader_t) 96 | for epoch in trange(1, num_epoches+1): 97 | 98 | lr = StepLrUpdater(epoch, base_lr=base_lr, gamma=0.1, step=40) 99 | SetLr(lr, optimizer) 100 | 101 | print('-' * 10) 102 | print('Epoch [%d/%d], lr:%f'%(epoch, num_epoches, lr)) 103 | 104 | running_loss_src = 0.0 105 | running_loss_local = 0.0 106 | running_loss_global = 0.0 107 | 108 | if (epoch)%5 == 0 and epoch!=opt.max_ep: 109 | losses.reset_multi_label(epoch) 110 | 111 | model.train() 112 | for i, source_data in enumerate(train_loader, 1): 113 | try: 114 | target_data = next(target_iter) 115 | except: 116 | target_iter = iter(train_loader_t) 117 | target_data = next(target_iter) 118 | 119 | image_src = source_data[0].cuda() 120 | label_src = source_data[1].cuda() 121 | image_tar = target_data[0].cuda() 122 | image_tar = image_tar.view(-1, image_tar.size(2), image_tar.size(3), image_tar.size(4)) 123 | label_tar = target_data[2].cuda() 124 | cams_src = source_data[-1].cuda() 125 | cams_tar = target_data[-1].cuda().view(-1) 126 | label_tar = label_tar.view(-1) 127 | image_src_org, label_src_org = organize_data(image_src, cams_src , label_src) 128 | image_tar_org, label_tar_org = organize_data(image_tar, cams_tar , label_tar) 129 | x_src = model(image_src_org)[0] 130 | x_tar = model(image_tar_org)[2] 131 | loss_all= losses(x_src, label_src_org, x_tar, label_tar_org, epoch) 132 | loss, loss_s, loss_l, loss_g = loss_all 133 | 134 | 135 | running_loss_src += loss_s.mean().item() 136 | running_loss_local += loss_l.mean().item() 137 | running_loss_global += loss_g.mean().item() 138 | 139 | optimizer.zero_grad() 140 | loss.backward() 141 | optimizer.step() 142 | 143 | losses.update_memory(x_tar, label_tar_org, epoch=epoch) 144 | 145 | if i % 50 == 0: 146 | print(' iter: %3d/%d, loss src: %.3f, loss local: %.3f, loss global: %.3f'%(i, len(train_loader), running_loss_src/i, running_loss_local/i, running_loss_global/i)) 147 | if epoch>5 and i==300: 148 | break 149 | 150 | print('Finish {} epoch\n'.format(epoch)) 151 | 152 | if epoch % 10 ==0: 153 | if hasattr(model, 'module'): 154 | model_save = model.module 155 | else: 156 | model_save = model 157 | torch.save(model_save.state_dict(), 'snapshot/{}/resnet50_{}_{}2{}_epoch{}_cbn.pth'.format(opt.name,opt.src,opt.src2,opt.tar,epoch)) 158 | if epoch >=opt.max_ep: 159 | break 160 | 161 | -------------------------------------------------------------------------------- /JVTC/test_cbn.py: -------------------------------------------------------------------------------- 1 | import os, torch 2 | from torch.utils.data import DataLoader 3 | from torch.nn import functional as F 4 | import numpy as np 5 | from scipy.spatial.distance import cdist 6 | 7 | from utils.util import cluster, get_info 8 | from utils.util import extract_fea_camtrans, extract_fea_test, extract_fea_test_cbn 9 | from utils.resnet import resnet50 10 | from utils.dataset import imgdataset, imgdataset_camtrans 11 | from utils.rerank import re_ranking 12 | from utils.st_distribution import get_st_distribution 13 | from utils.evaluate_joint_sim import evaluate_joint 14 | import argparse 15 | import sys 16 | from utils.logger import Logger 17 | 18 | parser = argparse.ArgumentParser(description='Testing') 19 | parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2') 20 | parser.add_argument('--weights',type=str) 21 | parser.add_argument('--name',type=str) 22 | parser.add_argument('--tar',default='market',type=str, help='target domain') 23 | parser.add_argument('--num_cam', default=8, type=int, help='target camera number') 24 | parser.add_argument('--joint', default=True, type=bool, help='joint similarity or visual similarity') 25 | 26 | opt = parser.parse_args() 27 | sys.stdout=Logger(os.path.join('./snapshot',opt.name,'log_test_{}.txt'.format(os.path.split(opt.weights)[1]))) 28 | 29 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids 30 | 31 | dataset_path = 'data/' 32 | ann_file_train = 'list_{}/list_{}_train.txt'.format(opt.tar,opt.tar) 33 | ann_file_test = 'list_{}/list_{}_test.txt'.format(opt.tar,opt.tar) 34 | 35 | snapshot = opt.weights 36 | #'snapshot/resnet50_{}2{}_epoch{}_cbn.pth'.format(opt.src,opt.tar,opt.epoch) 37 | 38 | num_cam = opt.num_cam 39 | ########### DATASET ########### 40 | img_dir = dataset_path + '{}/bounding_box_train_camstyle_merge/'.format(opt.tar) 41 | train_dataset = imgdataset_camtrans(dataset_dir=img_dir, txt_path=ann_file_train, 42 | transformer='test', K=num_cam, num_cam=num_cam) 43 | train_loader = DataLoader(dataset=train_dataset, batch_size=1, shuffle=False, num_workers=4) 44 | 45 | img_dir = dataset_path + '{}/'.format(opt.tar) 46 | test_dataset = imgdataset(dataset_dir=img_dir, txt_path=ann_file_test, transformer='test') 47 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4) 48 | 49 | ########### TEST ########### 50 | model, _ = resnet50(pretrained=snapshot, num_classes=2) 51 | model.cuda() 52 | model.eval() 53 | 54 | print('extract feature for testing set') 55 | test_feas,pids,cids,fids,is_q = extract_fea_test_cbn(model, img_dir, ann_file_test) 56 | 57 | if opt.joint: 58 | print('extract feature for training set') 59 | train_feas, _,cam_ids, frames , _ = extract_fea_test_cbn(model,dataset_path + '{}/bounding_box_train_camstyle_merge/'.format(opt.tar),ann_file_train) 60 | print('generate spatial-temporal distribution') 61 | dist = cdist(train_feas, train_feas) 62 | dist = np.power(dist,2) 63 | #dist = re_ranking(original_dist=dist) 64 | labels = cluster(dist) 65 | num_ids = len(set(labels)) 66 | print('cluster id num:', num_ids) 67 | distribution = get_st_distribution(cam_ids, labels, frames, id_num=num_ids, cam_num=num_cam) 68 | else: 69 | distribution = None 70 | 71 | print('evaluation') 72 | evaluate_joint(test_fea=test_feas, st_distribute=distribution, ann_file=(pids,cids,fids,is_q)) 73 | -------------------------------------------------------------------------------- /JVTC/train_cbn.py: -------------------------------------------------------------------------------- 1 | import os, torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from torch.utils.data import DataLoader 5 | 6 | from utils.resnet import resnet50 7 | from utils.dataset import imgdataset, imgdataset_camtrans, IdentityCameraSampler, NormalCollateFn 8 | from utils.losses import Losses#, LocalLoss, GlobalLoss 9 | from utils.evaluators import evaluate_all 10 | from utils.lr_adjust import StepLrUpdater, SetLr 11 | from utils.util import CamDataParallel , organize_data 12 | from utils.logger import Logger 13 | import sys 14 | import argparse 15 | from tqdm import tqdm,trange 16 | import time 17 | from getpass import getuser 18 | 19 | parser = argparse.ArgumentParser(description='Training') 20 | parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2') 21 | parser.add_argument('--name',default='ft_ResNet50', type=str, help='output model name') 22 | parser.add_argument('--src',default='market,duke,msmt17,unreal',type=str, help='src domaim') 23 | parser.add_argument('--tar',default='market,duke',type=str, help='target domain') 24 | parser.add_argument('--max_ep', default=100, type=int, help='maximum epoch number') 25 | 26 | parser.add_argument('--num_cam', default=8, type=int, help='target camera number') 27 | 28 | opt = parser.parse_args() 29 | 30 | t_start = time.time() 31 | 32 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids 33 | ########### HYPER ########### 34 | base_lr = 0.01 35 | num_epoches = 100 36 | batch_size = 128 37 | num_instances=4 38 | K = 4 39 | num_cam = opt.num_cam 40 | sys.stdout = Logger(os.path.join('./snapshot',opt.name,'log_train.txt')) 41 | ########## DATASET ########### 42 | dataset_path = 'data/' 43 | if opt.src=='unreal': 44 | src_dir = dataset_path 45 | else: 46 | src_dir = dataset_path +opt.src+ '/bounding_box_train_camstyle_merge/' 47 | tar_dir = dataset_path + opt.tar+ '/bounding_box_train_camstyle_merge/' 48 | tar_dir_test = dataset_path +opt.tar+ '/' 49 | 50 | src_annfile = 'list_{}/list_{}_train.txt'.format(opt.src,opt.src) 51 | tar_annfile = 'list_{}/list_{}_train.txt'.format(opt.tar,opt.tar) 52 | tar_annfile_test = 'list_{}/list_{}_test.txt'.format(opt.tar,opt.tar) 53 | 54 | #resnet50: https://download.pytorch.org/models/resnet50-19c8e357.pth 55 | imageNet_pretrain = '/home/'+getuser()+'/.torch/models/resnet50-19c8e357.pth' 56 | 57 | 58 | train_dataset = imgdataset(dataset_dir=src_dir, txt_path=src_annfile, transformer='train') 59 | 60 | sampler = IdentityCameraSampler(train_dataset.data_source, batch_size, num_instances) 61 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, sampler = sampler, num_workers=4, drop_last=True, collate_fn = NormalCollateFn()) 62 | 63 | train_dataset_t = imgdataset_camtrans(dataset_dir=tar_dir, txt_path=tar_annfile, 64 | transformer='train', num_cam=num_cam, K=K) 65 | train_loader_t = DataLoader(dataset=train_dataset_t, batch_size=int(batch_size/K), shuffle=True, num_workers=4, drop_last=True) 66 | 67 | 68 | ########### MODEL ########### 69 | model, param = resnet50(pretrained=imageNet_pretrain, num_classes=len(train_dataset.pids)) 70 | model.cuda() 71 | model = CamDataParallel(model)#, device_ids=[0,1]) 72 | 73 | losses = Losses(K=K, 74 | batch_size=batch_size, 75 | bank_size=len(train_dataset_t), 76 | ann_file=tar_annfile, 77 | cam_num=num_cam) 78 | losses = losses.cuda() 79 | optimizer = torch.optim.SGD(param, lr=base_lr, momentum=0.9, weight_decay=5e-4, nesterov=True) 80 | 81 | ########### TRAIN ########### 82 | target_iter = iter(train_loader_t) 83 | for epoch in trange(1, num_epoches+1): 84 | 85 | lr = StepLrUpdater(epoch, base_lr=base_lr, gamma=0.1, step=40) 86 | SetLr(lr, optimizer) 87 | 88 | print('-' * 10) 89 | print('Epoch [%d/%d], lr:%f'%(epoch, num_epoches, lr)) 90 | 91 | running_loss_src = 0.0 92 | running_loss_local = 0.0 93 | running_loss_global = 0.0 94 | 95 | if (epoch)%5 == 0 and epoch!=opt.max_ep: 96 | losses.reset_multi_label(epoch) 97 | 98 | model.train() 99 | for i, source_data in enumerate(train_loader, 1): 100 | try: 101 | target_data = next(target_iter) 102 | except: 103 | target_iter = iter(train_loader_t) 104 | target_data = next(target_iter) 105 | image_src = source_data[0].cuda() 106 | label_src = source_data[1].cuda() 107 | image_tar = target_data[0].cuda() 108 | image_tar = image_tar.view(-1, image_tar.size(2), image_tar.size(3), image_tar.size(4)) 109 | label_tar = target_data[2].cuda() 110 | cams_src = source_data[-1].cuda() 111 | cams_tar = target_data[-1].cuda().view(-1) 112 | label_tar = label_tar.view(-1) 113 | image_src_org, label_src_org = organize_data(image_src, cams_src , label_src) 114 | image_tar_org, label_tar_org = organize_data(image_tar, cams_tar , label_tar) 115 | x_src = model(image_src_org)[0] 116 | x_tar = model(image_tar_org)[2] 117 | loss_all= losses(x_src, label_src_org, x_tar, label_tar_org, epoch) 118 | loss, loss_s, loss_l, loss_g = loss_all 119 | 120 | 121 | running_loss_src += loss_s.mean().item() 122 | running_loss_local += loss_l.mean().item() 123 | running_loss_global += loss_g.mean().item() 124 | 125 | optimizer.zero_grad() 126 | loss.backward() 127 | optimizer.step() 128 | 129 | losses.update_memory(x_tar, label_tar_org, epoch=epoch) 130 | 131 | if i % 50 == 0: 132 | print(' iter: %3d/%d, loss src: %.3f, loss local: %.3f, loss global: %.3f'%(i, len(train_loader), running_loss_src/i, running_loss_local/i, running_loss_global/i)) 133 | if i==300: 134 | break 135 | 136 | print('Finish {} epoch\n'.format(epoch)) 137 | 138 | if epoch % 10 ==0 and epoch>45: 139 | if hasattr(model, 'module'): 140 | model_save = model.module 141 | else: 142 | model_save = model 143 | torch.save(model_save.state_dict(), 'snapshot/{}/resnet50_{}2{}_epoch{}_cbn.pth'.format(opt.name,opt.src,opt.tar,epoch)) 144 | 145 | if epoch >= opt.max_ep: 146 | print('Training stops at {}. Total Time: {} minutes.'.format(epoch,(time.time()-t_start)//60)) 147 | break 148 | -------------------------------------------------------------------------------- /JVTC/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from utils import * -------------------------------------------------------------------------------- /JVTC/utils/evaluate_joint_sim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from torchvision import transforms 6 | 7 | from .rerank import re_ranking 8 | from .ranking import cmc, mean_ap 9 | from .st_distribution import joint_similarity 10 | from .util import get_info, l2_dist 11 | 12 | def compute_joint_dist(distribution, q_feas, g_feas, q_frames, g_frames, q_cams, g_cams): 13 | dists = [] 14 | for i in range(len(q_frames)): 15 | dist = joint_similarity( 16 | q_feas[i],q_cams[i],q_frames[i], 17 | g_feas, g_cams, g_frames, 18 | distribution) 19 | 20 | dist = np.expand_dims(dist, axis=0) 21 | dists.append(dist) 22 | #print(i, dist.shape) 23 | dists = np.concatenate(dists, axis=0) 24 | 25 | return dists 26 | 27 | def evaluate_joint(test_fea, st_distribute, ann_file, select_set='duke'): 28 | #fea_duke_test = np.load('duke_test_feas.npy') 29 | print( 'test feature', test_fea.shape) 30 | #print(len(cams)) 31 | #2228 for duke, 3368 for market, 11659 for msmt17 32 | if select_set == 'duke': 33 | query_num = 2228 34 | elif select_set == 'market': 35 | query_num = 3368 36 | 37 | if type(ann_file)==tuple: 38 | labels,cams,frames,is_query = ann_file 39 | 40 | query_labels = labels[is_query] 41 | query_cams = cams[is_query] 42 | query_frames = frames[is_query] 43 | query_features = test_fea[is_query] 44 | 45 | gallery_labels = labels[~is_query] 46 | gallery_cams = cams[~is_query] 47 | gallery_frames = frames[~is_query] 48 | gallery_features = test_fea[~is_query] 49 | 50 | else: 51 | labels, cams, frames = get_info(ann_file) 52 | query_labels = labels[0:query_num] 53 | query_cams = cams[0:query_num] 54 | query_frames = frames[0:query_num] 55 | query_features = test_fea[0:query_num, :] 56 | 57 | gallery_labels = labels[query_num:] 58 | gallery_cams = cams[query_num:] 59 | gallery_frames = frames[query_num:] 60 | gallery_features = test_fea[query_num:, :] 61 | 62 | 63 | dist = l2_dist(query_features, gallery_features) 64 | mAP = mean_ap(dist, query_labels, gallery_labels, query_cams, gallery_cams) 65 | cmc_scores = cmc(dist, query_labels, gallery_labels, query_cams, gallery_cams, 66 | separate_camera_set=False, single_gallery_shot=False, first_match_break=True) 67 | print('performance based on visual similarity') 68 | print('mAP: %.4f, r1:%.4f, r5:%.4f, r10:%.4f, r20:%.4f'%(mAP, cmc_scores[0], cmc_scores[4], cmc_scores[9], cmc_scores[19])) 69 | if st_distribute is None: 70 | return 71 | 72 | #st_distribute = np.load('distribution_duke_train.npy') 73 | dist = compute_joint_dist(st_distribute, 74 | query_features, gallery_features, 75 | query_frames, gallery_frames, 76 | query_cams, gallery_cams) 77 | 78 | mAP = mean_ap(dist, query_labels, gallery_labels, query_cams, gallery_cams) 79 | cmc_scores = cmc(dist, query_labels, gallery_labels, query_cams, gallery_cams, 80 | separate_camera_set=False, single_gallery_shot=False, first_match_break=True) 81 | print('performance based on joint similarity') 82 | print('mAP: %.4f, r1:%.4f, r5:%.4f, r10:%.4f, r20:%.4f'%(mAP, cmc_scores[0], cmc_scores[4], cmc_scores[9], cmc_scores[19])) 83 | -------------------------------------------------------------------------------- /JVTC/utils/evaluators.py: -------------------------------------------------------------------------------- 1 | import torch, cv2, os 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from .ranking import cmc, mean_ap 6 | 7 | #os.environ["CUDA_VISIBLE_DEVICES"] = "2" 8 | def extract_features(model, data_loader, select_set='market'): 9 | if select_set == 'market': 10 | query_num = 3368 11 | elif select_set =='duke': 12 | query_num = 2228 13 | 14 | print(select_set, "feature extraction start") 15 | model.eval() 16 | features = [] 17 | labels = [] 18 | cams = [] 19 | for i, (images, label, cam) in enumerate(data_loader): 20 | with torch.no_grad(): 21 | out = model(Variable(images).cuda()) 22 | features.append(out[1]) 23 | labels.append(label) 24 | cams.append(cam) 25 | 26 | features = torch.cat(features).cpu().numpy() 27 | labels = torch.cat(labels).cpu().numpy() 28 | cams = torch.cat(cams).cpu().numpy() 29 | #print('features', features.shape, labels.shape, cams.shape) 30 | 31 | query_labels = labels[0:query_num] 32 | query_cams = cams[0:query_num] 33 | 34 | gallery_labels = labels[query_num:] 35 | gallery_cams = cams[query_num:] 36 | 37 | query_features = features[0:query_num, :] 38 | gallery_features = features[query_num:, :] 39 | print("extraction done, feature shape:", np.shape(features)) 40 | 41 | return (query_features, query_labels, query_cams), (gallery_features, gallery_labels, gallery_cams) 42 | 43 | 44 | def evaluate_all(model, data_loader, select_set='market'): 45 | query, gallery = extract_features(model, data_loader, select_set=select_set) 46 | 47 | query_features, query_labels, query_cams = query 48 | gallery_features, gallery_labels, gallery_cams = gallery 49 | 50 | dist = np.zeros((query_features.shape[0], gallery_features.shape[0]), dtype = np.float64) 51 | for i in range(query_features.shape[0]): 52 | dist[i, :] = np.sum((gallery_features-query_features[i,:])**2, axis=1) 53 | 54 | mAP = mean_ap(dist, query_labels, gallery_labels, query_cams, gallery_cams) 55 | cmc_scores = cmc(dist, query_labels, gallery_labels, query_cams, gallery_cams, 56 | separate_camera_set=False, single_gallery_shot=False, first_match_break=True) 57 | print('mAP: %.4f, r1:%.4f, r5:%.4f, r10:%.4f, r20:%.4f'%(mAP, cmc_scores[0], cmc_scores[4], cmc_scores[9], cmc_scores[19])) 58 | 59 | return None 60 | -------------------------------------------------------------------------------- /JVTC/utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | import sys 4 | import os.path as osp 5 | 6 | class Logger(object): 7 | """ 8 | Write console output to external text file. 9 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 10 | """ 11 | 12 | def __init__(self, fpath=None): 13 | self.console = sys.stdout 14 | self.file = None 15 | if fpath is not None: 16 | mkdir_if_missing(os.path.dirname(fpath)) 17 | self.file = open(fpath, 'w') 18 | 19 | def __del__(self): 20 | self.close() 21 | 22 | def __enter__(self): 23 | pass 24 | 25 | def __exit__(self, *args): 26 | self.close() 27 | 28 | def write(self, msg): 29 | self.console.write(msg) 30 | if self.file is not None: 31 | self.file.write(msg) 32 | 33 | def flush(self): 34 | self.console.flush() 35 | if self.file is not None: 36 | self.file.flush() 37 | os.fsync(self.file.fileno()) 38 | 39 | def close(self): 40 | self.console.close() 41 | if self.file is not None: 42 | self.file.close() 43 | 44 | 45 | def mkdir_if_missing(dir_path): 46 | try: 47 | os.makedirs(dir_path) 48 | except OSError as e: 49 | if e.errno != errno.EEXIST: 50 | raise 51 | 52 | -------------------------------------------------------------------------------- /JVTC/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch, math, random 2 | import torch.nn.functional as F 3 | from torch import nn 4 | import numpy as np 5 | 6 | from scipy.spatial.distance import cdist 7 | 8 | from .rerank import re_ranking 9 | from .st_distribution import compute_joint_dist, get_st_distribution 10 | from .util import cluster 11 | from .losses_msmt import QueueLoss 12 | 13 | class Losses(nn.Module): 14 | def __init__(self, K, batch_size, bank_size, ann_file, cam_num=8, beta1=0.1, beta2=0.05): 15 | super(Losses, self).__init__() 16 | 17 | self.loss_src = nn.CrossEntropyLoss()#.cuda() 18 | 19 | self.loss_local= LocalLoss(K=K, beta=beta1) 20 | self.loss_global = GlobalLoss(K=K, beta=beta2, bank_size=bank_size, cam_num=cam_num, ann_file=ann_file) 21 | if 'msmt' in ann_file: 22 | self.loss_global = QueueLoss(K=K,beta=beta2) 23 | 24 | def forward(self, x_src, label_src, x_tar, label_tar, epoch): 25 | 26 | loss_s = self.loss_src(x_src, label_src) 27 | 28 | loss_l = self.loss_local(x_tar,label_tar) 29 | loss_g = self.loss_global(x_tar, label_tar) 30 | 31 | loss = loss_s + loss_l 32 | if epoch >= 10: 33 | loss = loss + loss_g * 0.2 34 | 35 | return loss, loss_s, loss_l, loss_g 36 | 37 | def reset_multi_label(self, epoch): 38 | if epoch >= 30: 39 | print('Reset label on target dataset', epoch) 40 | self.loss_global.reset_label_based_joint_smi() 41 | elif epoch >=10: 42 | print('Reset label on target dataset', epoch) 43 | self.loss_global.reset_label_based_visual_smi() 44 | 45 | def update_memory(self, x_tar, label_tar, epoch): 46 | self.loss_global.update(x_tar, label_tar, epoch=epoch) 47 | 48 | 49 | class LocalLoss(nn.Module): 50 | def __init__(self, K=4, batch_size=128, beta=0.1): 51 | super(LocalLoss, self).__init__() 52 | 53 | self.K = K 54 | self.beta = beta 55 | 56 | def forward(self, x, idx): 57 | one_hot_label = [] 58 | label_dic={} 59 | uni_idx = torch.unique(idx) 60 | v=[] 61 | for index in uni_idx: 62 | if index not in label_dic: 63 | label_dic[int(index)]=len(label_dic) 64 | uni_w = x[index==idx] 65 | uni_w = uni_w.mean(dim=0) 66 | v.append(uni_w) 67 | 68 | for index in idx: 69 | one_hot_label.append(label_dic[int(index)]) 70 | 71 | v=torch.stack(v,dim=0) 72 | v = F.normalize(v) 73 | one_hot_label = torch.tensor(one_hot_label).cuda() 74 | x = F.normalize(x) 75 | x = x.mm(v.t())/self.beta 76 | loss = F.cross_entropy(x, one_hot_label) 77 | 78 | return loss 79 | 80 | # Global loss 81 | class GlobalLoss(nn.Module): 82 | def __init__(self, K=4, beta=0.05, cam_num=8, bank_size=114514, ann_file=None): 83 | super(GlobalLoss, self).__init__() 84 | 85 | self.K = K 86 | self.beta = beta 87 | self.cam_num = cam_num 88 | self.alpha = 0.01 89 | 90 | #self.bank = torch.zeros(bank_size, 512).cuda() 91 | self.bank = torch.rand(bank_size, 512).cuda() 92 | self.bank.requires_grad = False 93 | self.labels = torch.arange(0, bank_size).cuda() 94 | print('Memory bank size', self.bank.size()) 95 | 96 | with open(ann_file) as f: 97 | lines = f.readlines() 98 | #self.img_list = [os.path.join(dataset_dir, i.split()[0]) for i in lines] 99 | self.cam_ids = [int(i.split()[2]) for i in lines] 100 | self.frames = [int(i.split()[3]) for i in lines] 101 | 102 | print('dataset size:', len(self.cam_ids)) 103 | 104 | def reset_label_based_visual_smi(self): 105 | 106 | bank_fea = self.bank.cpu().data.numpy() 107 | #fea = fea.numpy() 108 | dist = cdist(bank_fea, bank_fea) 109 | dist = np.power(dist,2) 110 | print('Compute visual similarity') 111 | rerank_dist = re_ranking(original_dist=dist) 112 | 113 | labels = cluster(rerank_dist) 114 | num_ids = len(set(labels)) 115 | 116 | self.labels = torch.tensor(labels).cuda() 117 | print('Cluster class num based on visual similarity:', num_ids) 118 | 119 | def reset_label_based_joint_smi(self,): 120 | 121 | print('Compute distance based on visual similarity') 122 | 123 | bank_fea = self.bank.cpu().data.numpy() 124 | dist = cdist(bank_fea, bank_fea) 125 | dist = np.power(dist,2) 126 | 127 | #Jaccard distance for better cluster result 128 | dist = re_ranking(original_dist=dist) 129 | labels = cluster(dist) 130 | num_ids = len(set(labels)) 131 | 132 | print('update st distribution') 133 | st_distribute = get_st_distribution(self.cam_ids, labels, self.frames, id_num=num_ids, cam_num=self.cam_num) 134 | print('Compute distance based on joint similarity') 135 | st_dist = compute_joint_dist(st_distribute, 136 | bank_fea, bank_fea, 137 | self.frames, self.frames, 138 | self.cam_ids, self.cam_ids) 139 | 140 | #Jaccard distance for better cluster result 141 | st_dist = re_ranking(original_dist=st_dist, lambda_value=0.5) 142 | labels_st = cluster(st_dist) 143 | num_ids = len(set(labels_st)) 144 | 145 | print('Cluster class num based on joint similarity:', num_ids) 146 | self.labels = torch.tensor(labels_st).cuda() 147 | 148 | def forward(self, x, idx): 149 | uni_idx = torch.unique(idx) 150 | w = [] 151 | for index in uni_idx: 152 | uni_w = x[index==idx] 153 | uni_w = uni_w.mean(dim=0) 154 | w.append(uni_w) 155 | w=torch.stack(w) 156 | w=F.normalize(w) 157 | 158 | x = F.normalize(x) 159 | 160 | label = w.mm(self.bank.t()) 161 | label = self.multi_class_label(idx) 162 | targets = [] 163 | for i in range(label.size(0)): 164 | targets.append(label[i,:]) 165 | targets = torch.stack(targets).detach() 166 | x = x.mm(self.bank.t())/self.beta 167 | x = F.log_softmax(x, dim=1) 168 | loss = - (x * targets).sum(dim=1).mean(dim=0) 169 | 170 | #self.w = w.detach() 171 | #self.idx = idx.detach() 172 | return loss 173 | 174 | def update(self, x, idx, epoch=1): 175 | uni_idx = torch.unique(idx) 176 | momentum = min(self.alpha*epoch, 0.8) 177 | w=[] 178 | for index in uni_idx: 179 | uni_w = x[index==idx] 180 | uni_w = uni_w.mean(dim=0) 181 | w.append(uni_w) 182 | w = torch.stack(w) 183 | w=F.normalize(w).detach() 184 | self.bank[uni_idx]= w*(1-momentum) + self.bank[uni_idx]*momentum 185 | 186 | #w = x.view(int(x.size(0)/self.K), self.K, -1) 187 | #w = w.mean(dim=1) 188 | #w = F.normalize(w).detach() 189 | 190 | #momentum = min(self.alpha*epoch, 0.8) 191 | #self.bank[idx] = w*(1-momentum) + self.bank[idx]*momentum 192 | # for i in range(self.w.size(0)): 193 | # self.bank[self.idx[i]] = self.w[i,:]*(1-momentum) + self.bank[self.idx[i]]*momentum 194 | # self.bank[self.idx[i]] = F.normalize(self.bank[self.idx[i]], dim=0) 195 | # self.bank[idx[i]] /= self.bank[idx[i]].norm() 196 | 197 | def multi_class_label(self, index): 198 | batch_label = self.labels[index] 199 | target = (batch_label.unsqueeze(dim=1) == self.labels.t().unsqueeze(dim=0)).float() 200 | target = F.normalize(target, 1) 201 | #print target.size() 202 | return target 203 | -------------------------------------------------------------------------------- /JVTC/utils/lr_adjust.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def SetLr(lr, optimizer): 5 | for g in optimizer.param_groups: 6 | g['lr'] = lr * g.get('lr_mult', 1) 7 | 8 | def StepLrUpdater(epoch, base_lr=0.01, gamma=0.1, step=[8,11]): 9 | if isinstance(step, int): 10 | return base_lr * (gamma**(epoch // step)) 11 | 12 | exp = len(step) 13 | for i, s in enumerate(step): 14 | if epoch < s: 15 | exp = i 16 | break 17 | return base_lr * gamma**exp 18 | 19 | 20 | def warmup_lr(cur_iters, warmup_iters=500, warmup_type='linear', warmup_ratio=1/3): 21 | 22 | if warmup_type == 'constant': 23 | warmup_lr = warmup_ratio 24 | elif warmup_type == 'linear': 25 | warmup_lr = 1 - (1 - cur_iters / warmup_iters) * (1 - warmup_ratio) 26 | elif warmup_type == 'exp': 27 | warmup_lr = warmup_ratio**(1 - cur_iters / warmup_iters) 28 | return warmup_lr 29 | 30 | -------------------------------------------------------------------------------- /JVTC/utils/ranking.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | from sklearn.metrics import average_precision_score 6 | 7 | #from ..utils import to_numpy 8 | 9 | 10 | def _unique_sample(ids_dict, num): 11 | mask = np.zeros(num, dtype=np.bool) 12 | for _, indices in ids_dict.items(): 13 | i = np.random.choice(indices) 14 | mask[i] = True 15 | return mask 16 | 17 | 18 | def cmc(distmat, query_ids=None, gallery_ids=None, 19 | query_cams=None, gallery_cams=None, topk=100, 20 | separate_camera_set=False, 21 | single_gallery_shot=False, 22 | first_match_break=False): 23 | #distmat = to_numpy(distmat) 24 | m, n = distmat.shape 25 | # Fill up default values 26 | if query_ids is None: 27 | query_ids = np.arange(m) 28 | if gallery_ids is None: 29 | gallery_ids = np.arange(n) 30 | if query_cams is None: 31 | query_cams = np.zeros(m).astype(np.int32) 32 | if gallery_cams is None: 33 | gallery_cams = np.ones(n).astype(np.int32) 34 | # Ensure numpy array 35 | query_ids = np.asarray(query_ids) 36 | gallery_ids = np.asarray(gallery_ids) 37 | query_cams = np.asarray(query_cams) 38 | gallery_cams = np.asarray(gallery_cams) 39 | # Sort and find correct matches 40 | indices = np.argsort(distmat, axis=1) 41 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 42 | # Compute CMC for each query 43 | ret = np.zeros(topk) 44 | num_valid_queries = 0 45 | for i in range(m): 46 | # Filter out the same id and same camera 47 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 48 | (gallery_cams[indices[i]] != query_cams[i])) 49 | if separate_camera_set: 50 | # Filter out samples from same camera 51 | valid &= (gallery_cams[indices[i]] != query_cams[i]) 52 | if not np.any(matches[i, valid]): continue 53 | if single_gallery_shot: 54 | repeat = 10 55 | gids = gallery_ids[indices[i][valid]] 56 | inds = np.where(valid)[0] 57 | ids_dict = defaultdict(list) 58 | for j, x in zip(inds, gids): 59 | ids_dict[x].append(j) 60 | else: 61 | repeat = 1 62 | for _ in range(repeat): 63 | if single_gallery_shot: 64 | # Randomly choose one instance for each id 65 | sampled = (valid & _unique_sample(ids_dict, len(valid))) 66 | index = np.nonzero(matches[i, sampled])[0] 67 | else: 68 | index = np.nonzero(matches[i, valid])[0] 69 | delta = 1. / (len(index) * repeat) 70 | for j, k in enumerate(index): 71 | if k - j >= topk: break 72 | if first_match_break: 73 | ret[k - j] += 1 74 | break 75 | ret[k - j] += delta 76 | num_valid_queries += 1 77 | if num_valid_queries == 0: 78 | raise RuntimeError("No valid query") 79 | return ret.cumsum() / num_valid_queries 80 | 81 | 82 | def mean_ap(distmat, query_ids=None, gallery_ids=None, 83 | query_cams=None, gallery_cams=None): 84 | #distmat = to_numpy(distmat) 85 | m, n = distmat.shape 86 | # Fill up default values 87 | if query_ids is None: 88 | query_ids = np.arange(m) 89 | if gallery_ids is None: 90 | gallery_ids = np.arange(n) 91 | if query_cams is None: 92 | query_cams = np.zeros(m).astype(np.int32) 93 | if gallery_cams is None: 94 | gallery_cams = np.ones(n).astype(np.int32) 95 | # Ensure numpy array 96 | query_ids = np.asarray(query_ids) 97 | gallery_ids = np.asarray(gallery_ids) 98 | query_cams = np.asarray(query_cams) 99 | gallery_cams = np.asarray(gallery_cams) 100 | # Sort and find correct matches 101 | indices = np.argsort(distmat, axis=1) 102 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 103 | # Compute AP for each query 104 | aps = [] 105 | for i in range(m): 106 | # Filter out the same id and same camera 107 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 108 | (gallery_cams[indices[i]] != query_cams[i])) #& gallery_ids[indices[i]] !=-1 109 | y_true = matches[i, valid] 110 | y_score = -distmat[i][indices[i]][valid] 111 | if not np.any(y_true): continue 112 | aps.append(average_precision_score(y_true, y_score)) 113 | if len(aps) == 0: 114 | raise RuntimeError("No valid query") 115 | return np.mean(aps) 116 | -------------------------------------------------------------------------------- /JVTC/utils/rerank.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | import numpy as np 4 | from scipy.spatial.distance import cdist 5 | 6 | 7 | def re_ranking(original_dist, k1=20, k2=6, lambda_value=0.3): 8 | 9 | all_num = original_dist.shape[0] 10 | 11 | 12 | euclidean_dist = original_dist 13 | gallery_num = original_dist.shape[0] #gallery_num=all_num 14 | 15 | #original_dist = original_dist - np.min(original_dist) 16 | original_dist = original_dist - np.min(original_dist,axis = 0) 17 | original_dist = np.transpose(original_dist/np.max(original_dist,axis = 0)) 18 | V = np.zeros_like(original_dist).astype(np.float16) 19 | initial_rank = np.argsort(original_dist).astype(np.int32) ## default axis=-1. 20 | 21 | print('Starting re_ranking...') 22 | for i in range(all_num): 23 | # k-reciprocal neighbors 24 | forward_k_neigh_index = initial_rank[i,:k1+1] ## k1+1 because self always ranks first. forward_k_neigh_index.shape=[k1+1]. forward_k_neigh_index[0] == i. 25 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] ##backward.shape = [k1+1, k1+1]. For each ele in forward_k_neigh_index, find its rank k1 neighbors 26 | fi = np.where(backward_k_neigh_index==i)[0] 27 | k_reciprocal_index = forward_k_neigh_index[fi] ## get R(p,k) in the paper 28 | k_reciprocal_expansion_index = k_reciprocal_index 29 | for j in range(len(k_reciprocal_index)): 30 | candidate = k_reciprocal_index[j] 31 | candidate_forward_k_neigh_index = initial_rank[candidate,:int(np.around(k1/2))+1] 32 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,:int(np.around(k1/2))+1] 33 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 34 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 35 | if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2/3*len(candidate_k_reciprocal_index): 36 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 37 | 38 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) ## element-wise unique 39 | weight = np.exp(-original_dist[i,k_reciprocal_expansion_index]) 40 | V[i,k_reciprocal_expansion_index] = weight/np.sum(weight) 41 | #original_dist = original_dist[:query_num,] 42 | if k2 != 1: 43 | V_qe = np.zeros_like(V,dtype=np.float16) 44 | for i in range(all_num): 45 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0) 46 | V = V_qe 47 | del V_qe 48 | del initial_rank 49 | invIndex = [] 50 | for i in range(gallery_num): 51 | invIndex.append(np.where(V[:,i] != 0)[0]) #len(invIndex)=all_num 52 | 53 | jaccard_dist = np.zeros_like(original_dist,dtype = np.float16) 54 | 55 | 56 | for i in range(all_num): 57 | temp_min = np.zeros(shape=[1,gallery_num],dtype=np.float16) 58 | indNonZero = np.where(V[i,:] != 0)[0] 59 | indImages = [] 60 | indImages = [invIndex[ind] for ind in indNonZero] 61 | for j in range(len(indNonZero)): 62 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 63 | jaccard_dist[i] = 1-temp_min/(2-temp_min) 64 | 65 | pos_bool = (jaccard_dist < 0) 66 | jaccard_dist[pos_bool] = 0.0 67 | 68 | #np.save('dist_jaccard_temoral.npy', jaccard_dist) 69 | 70 | #return jaccard_dist 71 | if lambda_value == 0: 72 | return jaccard_dist 73 | else: 74 | final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value 75 | return final_dist 76 | 77 | -------------------------------------------------------------------------------- /JVTC/utils/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math, torch 3 | import torch.utils.model_zoo as model_zoo 4 | from torch.nn import init 5 | from torch.nn import functional as F 6 | 7 | class Bottleneck(nn.Module): 8 | expansion = 4 9 | 10 | def __init__(self, inplanes, planes, stride=1, downsample=None): 11 | super(Bottleneck, self).__init__() 12 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 13 | self.bn1 = nn.BatchNorm2d(planes, momentum=None) 14 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 15 | padding=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(planes, momentum=None) 17 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 18 | self.bn3 = nn.BatchNorm2d(planes * 4, momentum=None) 19 | self.relu = nn.ReLU(inplace=True) 20 | self.downsample = downsample 21 | self.stride = stride 22 | 23 | def forward(self, x): 24 | residual = x 25 | 26 | out = self.conv1(x) 27 | out = self.bn1(out) 28 | out = self.relu(out) 29 | 30 | out = self.conv2(out) 31 | out = self.bn2(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv3(out) 35 | out = self.bn3(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | 46 | class ResNet(nn.Module): 47 | 48 | def __init__(self, block, layers, num_classes=1000, train=True): 49 | self.inplanes = 64 50 | super(ResNet, self).__init__() 51 | self.istrain = train 52 | 53 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 54 | self.bn1 = nn.BatchNorm2d(64, momentum=None) 55 | self.relu = nn.ReLU(inplace=True) 56 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 57 | self.layer1 = self._make_layer(block, 64, layers[0]) 58 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 59 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 60 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1) 61 | #self.avgpool = nn.AvgPool2d((16,8), stride=1) 62 | 63 | self.num_features = 512 64 | self.feat = nn.Linear(512 * block.expansion, self.num_features) 65 | init.kaiming_normal_(self.feat.weight, mode='fan_out') 66 | init.constant_(self.feat.bias, 0) 67 | 68 | self.feat_bn = nn.BatchNorm1d(self.num_features, momentum=None) 69 | init.constant_(self.feat_bn.weight, 1) 70 | init.constant_(self.feat_bn.bias, 0) 71 | 72 | self.classifier = nn.Linear(self.num_features, num_classes) 73 | init.normal_(self.classifier.weight, std=0.001) 74 | init.constant_(self.classifier.bias, 0) 75 | 76 | 77 | for m in self.modules(): 78 | if isinstance(m, nn.Conv2d): 79 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 80 | m.weight.data.normal_(0, math.sqrt(2. / n)) 81 | elif isinstance(m, nn.BatchNorm2d): 82 | m.weight.data.fill_(1) 83 | m.bias.data.zero_() 84 | 85 | def _make_layer(self, block, planes, blocks, stride=1): 86 | downsample = None 87 | if stride != 1 or self.inplanes != planes * block.expansion: 88 | downsample = nn.Sequential( 89 | nn.Conv2d(self.inplanes, planes * block.expansion, 90 | kernel_size=1, stride=stride, bias=False), 91 | nn.BatchNorm2d(planes * block.expansion, momentum=None),) 92 | 93 | layers = [] 94 | layers.append(block(self.inplanes, planes, stride, downsample)) 95 | self.inplanes = planes * block.expansion 96 | for i in range(1, blocks): 97 | layers.append(block(self.inplanes, planes)) 98 | 99 | return nn.Sequential(*layers) 100 | 101 | def forward(self, x): 102 | x = self.conv1(x) 103 | x = self.bn1(x) 104 | x = self.relu(x) 105 | x = self.maxpool(x) 106 | 107 | x = self.layer1(x) 108 | x = self.layer2(x) 109 | x = self.layer3(x) 110 | x = self.layer4(x) 111 | 112 | x = F.avg_pool2d(x, x.size()[2:]) 113 | x = x.view(x.size(0), -1) 114 | 115 | x = self.feat(x) 116 | fea = self.feat_bn(x) 117 | fea_norm = F.normalize(fea) 118 | 119 | x = F.relu(fea) 120 | x = self.classifier(x) 121 | 122 | return x, fea_norm, fea 123 | 124 | def resnet50(pretrained=None, num_classes=1000, train=True): 125 | model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes, train) 126 | weight = torch.load(pretrained) 127 | static = model.state_dict() 128 | 129 | base_param = [] 130 | for name, param in weight.items(): 131 | if name not in static: 132 | continue 133 | if 'classifier' in name: 134 | continue 135 | if isinstance(param, nn.Parameter): 136 | param = param.data 137 | static[name].copy_(param) 138 | base_param.append(name) 139 | 140 | params = [] 141 | params_dict = dict(model.named_parameters()) 142 | for key, v in params_dict.items(): 143 | if key in base_param: 144 | params += [{ 'params':v, 'lr_mult':1}] 145 | else: 146 | #new parameter have larger learning rate 147 | params += [{ 'params':v, 'lr_mult':10}] 148 | 149 | return model, params 150 | -------------------------------------------------------------------------------- /JVTC/utils/st_distribution.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os, math 3 | 4 | 5 | def joint_similarity(qf,qc,qfr,gf,gc,gfr,distribution): 6 | query = qf 7 | score = np.dot(gf,query) 8 | gamma = 5 9 | 10 | interval = 100 11 | score_st = np.zeros(len(gc)) 12 | for i in range(len(gc)): 13 | if qfr>gfr[i]: 14 | diff = qfr-gfr[i] 15 | hist_ = int(diff/interval) 16 | pr = distribution[qc-1][gc[i]-1][hist_] 17 | else: 18 | diff = gfr[i]-qfr 19 | hist_ = int(diff/interval) 20 | pr = distribution[gc[i]-1][qc-1][hist_] 21 | score_st[i] = pr 22 | 23 | score = 1-1/(1+np.exp(-gamma*score))*1/(1+2*np.exp(-gamma*score_st)) 24 | return score 25 | 26 | def compute_joint_dist(distribution, q_feas, g_feas, q_frames, g_frames, q_cams, g_cams): 27 | dists = [] 28 | for i in range(len(q_frames)): 29 | dist = joint_similarity( 30 | q_feas[i],q_cams[i],q_frames[i], 31 | g_feas, g_cams, g_frames, 32 | distribution) 33 | 34 | dist = np.expand_dims(dist, axis=0) 35 | dists.append(dist) 36 | #print(i, dist.shape) 37 | dists = np.concatenate(dists, axis=0) 38 | 39 | return dists 40 | 41 | 42 | 43 | 44 | def gaussian_func(x, u, o=50): 45 | temp1 = 1.0 / (o * math.sqrt(2 * math.pi)) 46 | temp2 = -(np.power(x - u, 2)) / (2 * np.power(o, 2)) 47 | return temp1 * np.exp(temp2) 48 | 49 | def gauss_smooth(arr,o): 50 | hist_num = len(arr) 51 | vect= np.zeros((hist_num,1)) 52 | for i in range(hist_num): 53 | vect[i,0]=i 54 | 55 | approximate_delta = 3*o # when x-u>approximate_delta, e.g., 6*o, the gaussian value is approximately equal to 0. 56 | gaussian_vect= gaussian_func(vect,0,o) 57 | matrix = np.zeros((hist_num,hist_num)) 58 | for i in range(hist_num): 59 | k=0 60 | for j in range(i,hist_num): 61 | if k>approximate_delta: 62 | continue 63 | matrix[i][j]=gaussian_vect[j-i] 64 | k=k+1 65 | matrix = matrix+matrix.transpose() 66 | for i in range(hist_num): 67 | matrix[i][i]=matrix[i][i]/2 68 | 69 | xxx = np.dot(matrix,arr) 70 | return xxx 71 | 72 | def get_st_distribution(camera_id, labels, frames, id_num, cam_num=8): 73 | spatial_temporal_sum = np.zeros((id_num,cam_num)) 74 | spatial_temporal_count = np.zeros((id_num,cam_num)) 75 | eps = 0.0000001 76 | interval = 100.0 77 | 78 | for i in range(len(camera_id)): 79 | label_k = int(labels[i]) #### not in order, done 80 | cam_k = int(camera_id[i]-1) ##### ### ### ### ### ### ### ### ### ### ### ### # from 1, not 0 81 | frame_k = frames[i] 82 | 83 | spatial_temporal_sum[label_k][cam_k]=spatial_temporal_sum[label_k][cam_k]+frame_k 84 | spatial_temporal_count[label_k][cam_k] = spatial_temporal_count[label_k][cam_k] + 1 85 | spatial_temporal_avg = spatial_temporal_sum/(spatial_temporal_count+eps) # spatial_temporal_avg: 702 ids, 8cameras, center point 86 | 87 | distribution = np.zeros((cam_num,cam_num,3000)) 88 | for i in range(id_num): 89 | for j in range(cam_num-1): 90 | for k in range(j+1,cam_num): 91 | if spatial_temporal_count[i][j]==0 or spatial_temporal_count[i][k]==0: 92 | continue 93 | st_ij = spatial_temporal_avg[i][j] 94 | st_ik = spatial_temporal_avg[i][k] 95 | if st_ij>st_ik: 96 | diff = st_ij-st_ik 97 | hist_ = int(diff/interval) 98 | distribution[j][k][hist_] = distribution[j][k][hist_]+1 # [big][small] 99 | else: 100 | diff = st_ik-st_ij 101 | hist_ = int(diff/interval) 102 | distribution[k][j][hist_] = distribution[k][j][hist_]+1 103 | 104 | for i in range(id_num): 105 | for j in range(cam_num): 106 | if spatial_temporal_count[i][j] >1: 107 | 108 | frames_same_cam = [] 109 | for k in range(len(camera_id)): 110 | if labels[k]==i and camera_id[k]-1 ==j: 111 | frames_same_cam.append(frames[k]) 112 | frame_id_min = min(frames_same_cam) 113 | 114 | #print 'id, cam, len',i, j, len(frames_same_cam) 115 | for item in frames_same_cam: 116 | #if item != frame_id_min: 117 | diff = item - frame_id_min 118 | hist_ = int(diff/interval) 119 | #print item, frame_id_min, diff, hist_ 120 | distribution[j][j][hist_] = distribution[j][j][hist_] + spatial_temporal_count[i][j] 121 | 122 | smooth = 50 123 | for i in range(cam_num): 124 | for j in range(cam_num): 125 | #print("gauss "+str(i)+"->"+str(j)) 126 | distribution[i][j][:]=gauss_smooth(distribution[i][j][:],smooth) 127 | 128 | sum_ = np.sum(distribution,axis=2) 129 | for i in range(cam_num): 130 | for j in range(cam_num): 131 | distribution[i][j][:]=distribution[i][j][:]/(sum_[i][j]+eps) 132 | 133 | return distribution # [to][from], to xxx camera, from xxx camera 134 | 135 | 136 | 137 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # UnrealPerson: An Adaptive Pipeline for Costless Person Re-identification 3 | In our paper ([arxiv](https://arxiv.org/abs/2012.04268v2)), we propose a novel pipeline, UnrealPerson, that decreases the costs in both the training and deployment stages of person ReID. 4 | 5 | We develop an automatic data synthesis toolkit and use synthesized data in mutiple ReID tasks, including (i) Direct transfer, (ii) Unsupervised domain adaptation, and (iii) Supervised fine-tuning. 6 | 7 | 8 | **This repo contains:** 9 | 1. The **synthesized data** we use in the paper, including more than 6,000 identities in 4 different virtual scenes. Please see this document [Download.md](Download.md) for details. 10 | 2. The codes for our experiments reported in the paper. To reproduce the results on direct transfer, supervised learning and unsupervised domain adaptation, please refer to this document: [Experiments.md](Experiments.md) . 11 | 3. The necessary scripts and detailed tutorials to help you generate your own data. [SynthesisToolkit.md](SynthesisToolkit.md) 12 | 13 | ## Demonstration 14 | 15 |  16 | 17 | **Highlights:** 18 | 1. In direct transfer evaluation, we achieve 38.5% rank-1 accuracy on MSMT17 and 79.0% on Market-1501 using our unreal data. 19 | 2. In unsupervised domain adaptation, we achieve 68.2% rank-1 accuracy on MSMT17 and 93.0% on Market-1501 using our unreal data. 20 | 3. We obtain a better pre-trained ReID model with our unreal data. 21 | 22 | ## Cite our paper 23 | 24 | If you find our work useful in your research, please kindly cite: 25 | 26 | ``` 27 | 28 | @inproceedings{zhang2021unrealperson, 29 | title={UnrealPerson: An Adaptive Pipeline towards Costless Person Re-identification}, 30 | author={Tianyu Zhang and Lingxi Xie and Longhui Wei and Zijie Zhuang and Yongfei Zhang and Bo Li and Qi Tian}, 31 | year={2021}, 32 | booktitle={The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)} 33 | } 34 | ``` 35 | 36 | If you have any questions about the data or paper, please leave an issue or contact me: 37 | zhangtianyu@buaa.edu.cn 38 | -------------------------------------------------------------------------------- /UnrealPerson-DataSynthesisToolkit/9_massproduce/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | **Project Name:** MakeHuman community mass produce 6 | 7 | **Product Home Page:** TBD 8 | 9 | **Code Home Page:** TBD 10 | 11 | **Authors:** Joel Palmius 12 | 13 | **Copyright(c):** Joel Palmius 2016 14 | 15 | **Licensing:** MIT 16 | 17 | Abstract 18 | -------- 19 | 20 | This plugin generates and exports series of characters 21 | 22 | """ 23 | 24 | from .massproduce import MassProduceTaskView 25 | 26 | category = None 27 | mpView = None 28 | 29 | def load(app): 30 | category = app.getCategory('Community') 31 | downloadView = category.addTask(MassProduceTaskView(category)) 32 | 33 | def unload(app): 34 | pass 35 | 36 | -------------------------------------------------------------------------------- /UnrealPerson-DataSynthesisToolkit/9_massproduce/modifiergroups.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | MACROGROUPS = dict() 5 | MACROGROUPS["age"] = ["macrodetails/Age"] 6 | MACROGROUPS["height"] = ["macrodetails-height/Height"] 7 | MACROGROUPS["weight"] = ["macrodetails-universal/Weight"] 8 | MACROGROUPS["muscle"] = ["macrodetails-universal/Muscle"] 9 | MACROGROUPS["gender"] = ["macrodetails/Gender"] 10 | MACROGROUPS["proportion"] = ["macrodetails-proportions/BodyProportions"] 11 | MACROGROUPS["ethnicity"] = ["macrodetails/African", "macrodetails/Asian", "macrodetails/Caucasian"] 12 | 13 | from core import G 14 | mhapi = G.app.mhapi 15 | 16 | class _ModifierInfo(): 17 | 18 | def __init__(self): 19 | 20 | self.human = mhapi.internals.getHuman() 21 | self.nonMacroModifierGroupNames = [] 22 | self.modifierInfo = dict() 23 | 24 | for modgroup in self.human.modifierGroups: 25 | #print("MODIFIER GROUP: " + modgroup) 26 | if not "macro" in modgroup and not "genitals" in modgroup and not "armslegs" in modgroup: 27 | self.nonMacroModifierGroupNames.append(modgroup) 28 | self.modifierInfo[modgroup] = [] 29 | for mod in self.human.getModifiersByGroup(modgroup): 30 | if not mod.name.startswith("r-"): 31 | self.modifierInfo[modgroup].append(self._deduceModifierInfo(modgroup, mod)) 32 | 33 | self.nonMacroModifierGroupNames.append("arms") 34 | self.modifierInfo["arms"] = [] 35 | 36 | self.nonMacroModifierGroupNames.append("hands") 37 | self.modifierInfo["hands"] = [] 38 | 39 | self.nonMacroModifierGroupNames.append("legs") 40 | self.modifierInfo["legs"] = [] 41 | 42 | self.nonMacroModifierGroupNames.append("feet") 43 | self.modifierInfo["feet"] = [] 44 | 45 | for mod in self.human.getModifiersByGroup("armslegs"): 46 | name = mod.name 47 | if not name.startswith("r-"): 48 | if "hand" in name: 49 | self.modifierInfo["hands"].append(self._deduceModifierInfo("hands", mod)) 50 | if "foot" in name or "feet" in name: 51 | self.modifierInfo["feet"].append(self._deduceModifierInfo("feet", mod)) 52 | if "arm" in name: 53 | self.modifierInfo["arms"].append(self._deduceModifierInfo("arms", mod)) 54 | if "leg" in name: 55 | self.modifierInfo["legs"].append(self._deduceModifierInfo("legs", mod)) 56 | 57 | self.nonMacroModifierGroupNames.sort() 58 | 59 | def _deduceModifierInfo(self,groupName,modifier): 60 | modi = dict() 61 | modi["modifier"] = modifier 62 | modi["name"] = modifier.name 63 | modi["actualGroupName"] = modifier.groupName 64 | modi["groupName"] = groupName 65 | # This is the value set at the point when the user starts producing, 66 | # rather than makehuman's standard default value 67 | modi["defaultValue"] = modifier.getValue() 68 | modi["twosided"] = modifier.getMin() < -0.05 69 | modi["leftright"] = modifier.name.startswith('l-') 70 | return modi 71 | 72 | def getModifierGroupNames(self): 73 | return list(self.nonMacroModifierGroupNames) 74 | 75 | def getModifierInfoForGroup(self, groupName): 76 | return list(self.modifierInfo[groupName]) 77 | 78 | _mfinstance = None 79 | 80 | def ModifierInfo(): 81 | global _mfinstance 82 | if _mfinstance is None: 83 | _mfinstance = _ModifierInfo() 84 | return _mfinstance 85 | -------------------------------------------------------------------------------- /UnrealPerson-DataSynthesisToolkit/9_massproduce/randomizationsettings.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | from PyQt5.QtWidgets import * 5 | import sys 6 | import qtgui 7 | 8 | class RandomizationSettings: 9 | 10 | def __init__(self): 11 | 12 | self._ui = dict() 13 | 14 | def addUI(self, category, name, widget, subName=None): 15 | 16 | if widget is None: 17 | raise ValueError("Trying to add None widget") 18 | 19 | if not category in self._ui: 20 | self._ui[category] = dict() 21 | 22 | if not subName is None: 23 | if not name in self._ui[category]: 24 | self._ui[category][name] = dict() 25 | self._ui[category][name][subName] = widget 26 | else: 27 | self._ui[category][name] = widget 28 | 29 | return widget 30 | 31 | def getUI(self, category, name, subName=None): 32 | 33 | if not category in self._ui: 34 | print("No such category: " + category) 35 | self.dumpValues() 36 | sys.exit() 37 | 38 | if not name in self._ui[category]: 39 | print("No such name: " + category + "/" + name) 40 | self.dumpValues() 41 | sys.exit() 42 | 43 | if not subName is None: 44 | if not subName in self._ui[category][name]: 45 | print("No such subName: " + category + "/" + name + "/" + subName) 46 | self.dumpValues() 47 | sys.exit() 48 | widget = self._ui[category][name][subName] 49 | else: 50 | widget = self._ui[category][name] 51 | 52 | if widget is None: 53 | print("Got None widget for " + category + "/" + name + "/" + str(subName)) 54 | return widget 55 | 56 | def getValue(self, category, name, subName=None): 57 | 58 | widget = self.getUI(category, name, subName) 59 | 60 | if isinstance(widget, QCheckBox) or isinstance(widget, qtgui.CheckBox): 61 | return widget.selected 62 | if isinstance(widget, QTextEdit) or isinstance(widget, qtgui.TextEdit): 63 | return widget.getText() 64 | if isinstance(widget, qtgui.Slider): 65 | return widget.getValue() 66 | if isinstance(widget, QComboBox): 67 | return str(widget.getCurrentItem()) 68 | if isinstance(widget, str): 69 | return widget 70 | 71 | print("Unknown widget type") 72 | print(type(widget)) 73 | sys.exit(1) 74 | 75 | def getValueHash(self, category, name): 76 | 77 | if not category in self._ui: 78 | print("No such category: " + category) 79 | self.dumpValues() 80 | sys.exit() 81 | 82 | if not name in self._ui[category]: 83 | print("No such name: " + category + "/" + name) 84 | self.dumpValues() 85 | sys.exit() 86 | 87 | subCat = self._ui[category][name] 88 | 89 | if not isinstance(subCat, dict): 90 | print(category + "/" + name + " is not a dict") 91 | self.dumpValues() 92 | sys.exit(1) 93 | 94 | values = dict() 95 | for subName in subCat: 96 | values[subName] = self.getValue(category, name, subName) 97 | 98 | return values 99 | 100 | def getNames(self, category): 101 | return self._ui[category].keys() 102 | 103 | def setValue(self, category, name, value, subName=None): 104 | pass 105 | 106 | def dumpValues(self): 107 | 108 | for category in self._ui: 109 | print(category) 110 | for name in self._ui[category]: 111 | widget = self.getUI(category, name) 112 | if isinstance(widget, dict): 113 | print(" " + name) 114 | for subName in widget: 115 | subWidget = self.getUI(category, name, subName) 116 | value = self.getValue(category, name, subName) 117 | print(" " + subName + " (" + str(type(subWidget)) + ") = " + str(value)) 118 | else: 119 | value = self.getValue(category, name) 120 | print(" " + name + " (" + str(type(widget)) + ") = " + str(value)) 121 | -------------------------------------------------------------------------------- /UnrealPerson-DataSynthesisToolkit/9_massproduce/randomizeaction.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import gui3d 5 | 6 | mhapi = gui3d.app.mhapi 7 | 8 | class RandomizeAction(gui3d.Action): 9 | def __init__(self, human, before, after): 10 | super(RandomizeAction, self).__init__("Randomize") 11 | self.human = human 12 | self.before = before 13 | self.after = after 14 | 15 | def do(self): 16 | self._assignModifierValues(self.after) 17 | return True 18 | 19 | def undo(self): 20 | self._assignModifierValues(self.before) 21 | return True 22 | 23 | def _assignModifierValues(self, valuesDict): 24 | _tmp = self.human.symmetryModeEnabled 25 | self.human.symmetryModeEnabled = False 26 | for mName, val in list(valuesDict.items()): 27 | try: 28 | self.human.getModifier(mName).setValue(val) 29 | except: 30 | pass 31 | self.human.applyAllTargets() 32 | self.human.symmetryModeEnabled = _tmp -------------------------------------------------------------------------------- /UnrealPerson-DataSynthesisToolkit/caminfo_collect.py: -------------------------------------------------------------------------------- 1 | from unrealcv import client 2 | import pickle 3 | cl=client 4 | cl.connect() 5 | 6 | d=list() 7 | 8 | def gc(): 9 | global cl 10 | global d 11 | l=cl.request('vget /camera/0/location') 12 | r=cl.request('vget /camera/0/rotation') 13 | d.append((l,r)) 14 | print(d) 15 | 16 | def sa(): 17 | global d 18 | pickle.dump(d, open('../caminfo_s006_high.pkl', 'wb')) 19 | 20 | from IPython import embed 21 | embed() -------------------------------------------------------------------------------- /UnrealPerson-DataSynthesisToolkit/generate_datasets.py: -------------------------------------------------------------------------------- 1 | from unrealcv import client 2 | import sys,re,traceback 3 | import numpy as np 4 | from io import BytesIO 5 | import time 6 | import pickle,random 7 | from tqdm import trange 8 | import argparse 9 | # TODO: replace this with a better implementation 10 | class Color(object): 11 | ''' A utility class to parse color value ''' 12 | def __init__(self, color_str): 13 | self.color_str = color_str 14 | self.R,self.G,self.B,self.A=0,0,0,0 15 | color_str = color_str.replace("(","").replace(")","").replace("R=","").replace("G=","").replace("B=","").replace("A=","") 16 | try: 17 | (self.R, self.G, self.B, self.A) = [int(i) for i in color_str.split(",")] 18 | except Exception as e: 19 | print("Error in Color:") 20 | print(color_str) 21 | 22 | def __repr__(self): 23 | return self.color_str 24 | 25 | class DataUtils(object): 26 | def __init__(self,client,dir_save): 27 | self.client = client 28 | self.client.connect() 29 | self.dir_save = dir_save 30 | if not client.isconnected(): 31 | print('UnrealCV server is not running. Run the game downloaded from http://unrealcv.github.io first.') 32 | sys.exit(-1) 33 | res = self.client.request('vget /unrealcv/status') 34 | print(res) 35 | self.scene_objects=None 36 | self.id2color = {} 37 | 38 | def read_png(self,res): 39 | import PIL.Image 40 | img = PIL.Image.open(BytesIO(res)) 41 | return np.asarray(img) 42 | 43 | def read_pngIMG(self,res): 44 | import PIL.Image 45 | img = PIL.Image.open(BytesIO(res)) 46 | return img 47 | 48 | def get_object_color(self): 49 | self.scene_objects = self.client.request('vget /objects').split(' ') 50 | print('Number of objects in this scene:', len(self.scene_objects)) 51 | self.id2color = {} # Map from object id to the labeling color 52 | for obj_id in self.scene_objects: 53 | if obj_id not in self.id2color.keys(): 54 | if obj_id.startswith("MH"): 55 | color = Color(self.client.request('vget /object/{}/color'.format(obj_id))) 56 | self.id2color[obj_id] = color 57 | pickle.dump(self.id2color,open(self.dir_save+"object_color.pkl",'wb')) 58 | 59 | def match_color(self,object_mask, target_color, tolerance=3): 60 | match_region = np.ones(object_mask.shape[0:2], dtype=bool) 61 | for c in range(3): # r,g,b 62 | min_val = target_color[c] - tolerance 63 | max_val = target_color[c] + tolerance 64 | channel_region = (object_mask[:,:,c] >= min_val) & (object_mask[:,:,c] <= max_val) 65 | match_region &= channel_region 66 | 67 | if match_region.sum() != 0: 68 | return match_region 69 | else: 70 | return None 71 | 72 | def generate_one_frame(self,cam,frame): 73 | self.client.request('vrun ce Pause') 74 | res = self.client.request('vget /camera/0/lit png') 75 | lit = self.read_pngIMG(res) 76 | #print('The image is saved to {}'.format(res)) 77 | res = self.client.request('vget /camera/0/object_mask png') 78 | object_mask = self.read_pngIMG(res) 79 | self.client.request('vrun ce Resume') 80 | lit.save(self.dir_save + "{}_{}_lit.png".format(cam, frame)) 81 | object_mask.save(self.dir_save + "{}_{}_mask.png".format(cam, frame)) 82 | 83 | # print('%s : %s' % (obj_id, str(color))) 84 | # 85 | # id2mask = {} 86 | # for obj_id in self.scene_objects: 87 | # if obj_id.startswith("MH"): 88 | # color = self.id2color[obj_id] 89 | # mask = self.match_color(object_mask, [color.R, color.G, color.B], tolerance=3) 90 | # if mask is not None: 91 | # id2mask[obj_id] = mask 92 | # # This may take a while 93 | # # TODO: Need to find a faster implementation for this 94 | # 95 | # for k, x in id2mask.items(): 96 | # if k.startswith("MH"): 97 | # left = 9999 98 | # top = 9999 99 | # right = 0 100 | # bottom = 0 101 | # for i in range(x.shape[0]): 102 | # for j in range(x.shape[1]): 103 | # if x[i][j] == True: 104 | # if i > bottom: 105 | # bottom = i 106 | # if i < top: 107 | # top = i 108 | # if j < left: 109 | # left = j 110 | # if j > right: 111 | # right = j 112 | # img = lit.crop((left, top, right, bottom)) 113 | # if (right-left) * (bottom-top)<1500: 114 | # continue 115 | # try: 116 | # img.save(self.dir_save+k + "_{}_{}.png".format(cam,frame)) 117 | # except AttributeError or SystemError as e: 118 | # continue 119 | 120 | 121 | 122 | 123 | 124 | if __name__=="__main__": 125 | parser = argparse.ArgumentParser() 126 | parser.add_argument("--scene",type=str,choices=['s001','s002','s003','s004']) 127 | parser.add_argument("--person",type=int) 128 | parser.add_argument("--images",type=int) #images per camera 129 | 130 | opt = parser.parse_args() 131 | 132 | person_per_batch = {'s001': 100, 133 | 's002': 100, 134 | 's003': 100, 135 | 's004': 50} 136 | light_condition = {'s001': 1, 137 | 's002': 12, 138 | 's003': 6, 139 | 's004': 1} 140 | try: 141 | import os,glob 142 | datautils = DataUtils(client, "") 143 | cam_info = glob.glob('caminfo_'+opt.scene+"*.pkl") 144 | cam_lr = [] 145 | for cam_info_file in cam_info: 146 | cam_lr.extend(pickle.load(open(cam_info_file, 'rb'))) 147 | 148 | cam = ["c{:0>3d}".format(x) for x in range(1, 1+len(cam_lr))] 149 | print("Found {} cameras".format(len(cam_lr))) 150 | batch = opt.person // person_per_batch[opt.scene] + 1 151 | print(cam_lr) 152 | for _ in range(batch): 153 | datautils.dir_save = "f:/video/tmp"+str(int(time.time()))+"/" 154 | os.mkdir(datautils.dir_save) 155 | datautils.client.request("vrun ce DelActor") 156 | datautils.client.request("vrun ce AddActor") 157 | datautils.get_object_color() 158 | time.sleep(60) 159 | for cc in range(len(cam_lr)): 160 | if light_condition[opt.scene] > 1: 161 | datautils.client.request("vrun ce LCreset") 162 | 163 | datautils.client.request('vset /camera/0/location '+cam_lr[cc][0].strip()) 164 | datautils.client.request('vset /camera/0/rotation ' + cam_lr[cc][1].strip()) 165 | for i in trange(opt.images): 166 | if i % opt.images//light_condition[opt.scene] == 0 and light_condition[opt.scene]>1: 167 | datautils.client.request("vrun ce LC") 168 | datautils.generate_one_frame(cam[cc], i) 169 | 170 | except Exception as e: 171 | datautils.client.disconnect() 172 | traceback.print_exc() 173 | finally: 174 | datautils.client.disconnect() 175 | 176 | 177 | 178 | -------------------------------------------------------------------------------- /UnrealPerson-DataSynthesisToolkit/interplate_video.py: -------------------------------------------------------------------------------- 1 | import glob, os, shutil 2 | from PIL import Image 3 | from tqdm import tqdm 4 | input_image_dir = "F:\\video\\unreal_video_bunker_psameclo75" 5 | full_image_dir = "F:\\datasets" 6 | save_path = "F:\\unreal_data\\unreal_video_bunker_psameclo75" 7 | 8 | def get_image_info(path): 9 | basename = os.path.basename(path) 10 | arrs = basename.split('_') 11 | pid, cid, frame, left, top, right, bottom = arrs[0], arrs[1], int(arrs[2][1:]), arrs[3],arrs[4],arrs[5],arrs[6][:-4] 12 | full_image_path = os.path.join(full_image_dir,path.split("\\")[-2])+"\\{}_{}_lit.png".format(cid,frame) 13 | return pid, cid, frame, (int(left), int(top), int(right), int(bottom)), full_image_path 14 | 15 | def save_tracklet(tid, img_list): 16 | save_dir = os.path.join(save_path, str(tid)) 17 | if not os.path.exists(save_dir): 18 | os.makedirs(save_dir) 19 | normal_count = 0 20 | extra_count = 0 21 | for i in range(len(img_list)): 22 | img, frame, bbox, full_img_path = img_list[i] 23 | shutil.copy(img, os.path.join(save_dir, os.path.basename(img))) 24 | normal_count+=1 25 | if i == len(img_list)-1: 26 | break 27 | else: 28 | next_frame = img_list[i+1][1] 29 | next_bbox = img_list[i+1][2] 30 | if next_frame-frame>1: 31 | diff = next_frame-frame 32 | 33 | bbox_extras = [(bbox[0] + (next_bbox[0] - bbox[0]) / (next_frame - frame) * x, 34 | bbox[1] + (next_bbox[1] - bbox[1]) / (next_frame - frame) * x, 35 | bbox[2] + (next_bbox[2] - bbox[2]) / (next_frame - frame) * x, 36 | bbox[3] + (next_bbox[3] - bbox[3]) / (next_frame - frame) * x) 37 | for x in range(1,diff)] 38 | full_img_extras = [full_img_path.replace("{}_lit".format(frame),"{}_lit".format(frame+x)) 39 | for x in range(1, diff)] 40 | 41 | for j in range(len(full_img_extras)): 42 | full_img_extra=full_img_extras[j] 43 | full_img = Image.open(full_img_extra) 44 | img_extra = full_img.crop(bbox_extras[j]).convert('RGB') 45 | img_extra.save(os.path.join(save_dir, os.path.basename(img.replace("F{}".format(frame),"F{}_extra".format(frame+j+1))))) 46 | extra_count+=1 47 | 48 | return normal_count,extra_count 49 | 50 | 51 | images = glob.glob(os.path.join(input_image_dir, 'images\\*\\*.jpg')) 52 | 53 | images.sort() 54 | 55 | current_pid, current_cid, current_frame = None, None, None 56 | current_image_list = [] 57 | tid = 0 58 | pid_container = set() 59 | normal = 0 60 | extra = 0 61 | for img in tqdm(images): 62 | pid, cid, frame, bbox, full_img_path = get_image_info(img) 63 | 64 | pid_container.add(pid) 65 | if pid==current_pid and cid==current_cid and frame-current_frame<10: 66 | current_image_list.append((img,frame, bbox, full_img_path)) 67 | current_frame=frame 68 | else: 69 | # save current list 70 | if len(current_image_list)>0: 71 | nc,ec = save_tracklet(tid, current_image_list) 72 | normal+=nc 73 | extra+=ec 74 | tid += 1 75 | current_image_list = [] 76 | current_pid=pid 77 | current_cid=cid 78 | current_frame=frame 79 | current_image_list.append((img,frame, bbox, full_img_path)) 80 | 81 | print("Tracklet: ",tid) 82 | print("Pids: ",len(pid_container)) 83 | print("Normal: ",normal ) 84 | print("Extra: ",extra) 85 | 86 | 87 | -------------------------------------------------------------------------------- /UnrealPerson-DataSynthesisToolkit/levelbp_preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlyHighest/UnrealPerson/9fd5a882bea16f6289f0353b78140bff36c16609/UnrealPerson-DataSynthesisToolkit/levelbp_preview.png -------------------------------------------------------------------------------- /UnrealPerson-DataSynthesisToolkit/postprocess.py: -------------------------------------------------------------------------------- 1 | import glob,shutil 2 | import tqdm 3 | import os 4 | import sys,re,traceback 5 | import numpy as np 6 | from io import BytesIO 7 | import time 8 | import threading 9 | import random 10 | import PIL.Image 11 | import pickle 12 | import glob 13 | from multiprocessing import Process 14 | 15 | class HandleImageThread(Process): 16 | def __init__(self,dir_save,dir_data,dir_data_mask,cam): 17 | super(HandleImageThread, self).__init__() 18 | self.dir_save=dir_save 19 | self.dir_data=dir_data 20 | self.dir_data_mask = dir_data_mask 21 | self.cam=cam 22 | print("Task: {} {}".format(self.dir_data,self.cam)) 23 | 24 | def run(self): 25 | datautils = DataUtils(self.dir_save, self.dir_data, self.dir_data_mask) 26 | datautils.get_object_color() 27 | 28 | 29 | files = glob.glob(os.path.join(self.dir_save ,"c00{}_*lit.png".format(self.cam))) 30 | files.sort() 31 | 32 | for fi in tqdm.tqdm(files): 33 | datautils.generate_one_frame(fi) 34 | 35 | # TODO: replace this with a better implementation 36 | class Color(object): 37 | ''' A utility class to parse color value ''' 38 | def __init__(self, color_str): 39 | self.color_str = color_str 40 | self.R,self.G,self.B,self.A=0,0,0,0 41 | color_str = color_str.replace("(","").replace(")","").replace("R=","").replace("G=","").replace("B=","").replace("A=","") 42 | try: 43 | (self.R, self.G, self.B, self.A) = [int(i) for i in color_str.split(",")] 44 | except Exception as e: 45 | print("Error in Color:") 46 | print(color_str) 47 | 48 | def __repr__(self): 49 | return self.color_str 50 | 51 | class DataUtils(object): 52 | def __init__(self,dir_save,dir_data,dir_data_mask): 53 | 54 | self.dir_save = dir_save 55 | self.dir_data = dir_data 56 | self.dir_data_mask =dir_data_mask 57 | self.id2color = {} 58 | 59 | 60 | def read_png(self,res): 61 | import PIL.Image 62 | img = PIL.Image.open(res) 63 | return np.asarray(img) 64 | 65 | def read_pngIMG(self,res): 66 | import PIL.Image 67 | img = PIL.Image.open(res) 68 | return img 69 | 70 | def get_object_color(self): 71 | self.id2color = pickle.load(open(self.dir_save+"object_color.pkl",'rb')) 72 | 73 | def match_color(self,object_mask, target_color, tolerance=3): 74 | match_region = np.ones(object_mask.shape[0:2], dtype=bool) 75 | for c in range(3): # r,g,b 76 | min_val = target_color[c] - tolerance 77 | max_val = target_color[c] + tolerance 78 | channel_region = (object_mask[:,:,c] >= min_val) & (object_mask[:,:,c] <= max_val) 79 | match_region &= channel_region 80 | 81 | if match_region.sum() > 2000: 82 | return match_region 83 | else: 84 | 85 | return None 86 | 87 | def generate_one_frame(self,lit_file): 88 | 89 | lit = self.read_pngIMG(lit_file) 90 | 91 | object_mask = self.read_png(lit_file.replace("lit","mask")) 92 | 93 | s=os.path.split(lit_file)[1] 94 | cam = s.split("_")[0] 95 | frame =s.split("_")[1] 96 | 97 | id2mask = {} 98 | for obj_id in self.id2color.keys(): 99 | color = self.id2color[obj_id] 100 | mask = self.match_color(object_mask, [color.R, color.G, color.B], tolerance=3) 101 | if mask is not None: 102 | id2mask[obj_id] = mask 103 | # This may take a while 104 | # TODO: Need to find a faster implementation for this 105 | count = 0 106 | for k, x in id2mask.items(): 107 | left = 9999 108 | top = 9999 109 | right = 0 110 | bottom = 0 111 | for i in range(x.shape[0]): 112 | for j in range(x.shape[1]): 113 | if x[i][j] == True: 114 | if i > bottom: 115 | bottom = i 116 | if i < top: 117 | top = i 118 | if j < left: 119 | left = j 120 | if j > right: 121 | right = j 122 | if left < 5 or top <5 or bottom >= x.shape[0] - 5 or right >= x.shape[1] - 5: 123 | continue 124 | 125 | r1 = right-left 126 | r2 = bottom - top 127 | left = max(0, left-random.randint(0,int(0.1*r1))) 128 | right= min(x.shape[1]-1, right+random.randint(0,int(0.1*r1))) 129 | top = max(0, top-random.randint(0,int(0.1*r2))) 130 | bottom= min(x.shape[0]-1, bottom+random.randint(0,int(0.1*r2))) 131 | 132 | img = lit.crop((left, top, right, bottom)) 133 | 134 | img = img.convert('RGB') 135 | 136 | # mask save disabled 137 | #x_part = x[top:bottom, left:right] 138 | #img_mask = np.zeros((bottom-top,right-left,3),dtype=np.uint8) 139 | #img_mask[x_part==True]=[255,255,255] 140 | #img_mask = PIL.Image.fromarray(img_mask) 141 | 142 | qqq = np.sum(x)/(r1*r2) 143 | #if (1.0*(right-left))/(1.0*(bottom-top))>0.7: 144 | # continue 145 | # print('too wide') 146 | if (right-left) * (bottom-top)<2000: 147 | print("too small :"+k.replace("uplow","").replace("A","") + "_{}_{}.png".format(cam,frame)) 148 | continue 149 | if np.sum(x)/(r1*r2) < 0.3: 150 | print("too few :" + k.replace("uplow", "").replace("A", "") + "_{}_{}.png".format(cam, frame)) 151 | continue 152 | try: 153 | img.save(self.dir_data+k.replace("MHMainClass_C_","") + "_{}_{}.jpg".format(cam,frame)) 154 | # mask save disabled 155 | # img_mask.save(self.dir_data_mask+k.replace("MHMainClassDivide_C_","")+ "_{}_{}.png".format(cam,frame)) 156 | count+=1 157 | except AttributeError or SystemError as e: 158 | continue 159 | print("Generate {} images.".format(count)) 160 | 161 | 162 | 163 | 164 | if __name__=="__main__": 165 | try: 166 | import os,tqdm,glob 167 | import argparse 168 | parser= argparse.ArgumentParser() 169 | parser.add_argument("--path",type=str) 170 | 171 | parser.add_argument("--cam",type=int) 172 | 173 | args=parser.parse_args() 174 | 175 | print(args.path) 176 | dir_saves = [i + '\\' for i in glob.glob("f:\\datasets\\tmp*")] 177 | time.sleep(1) 178 | #dir_saves=[args.path] 179 | 180 | dir_data = "F:\\datasets\\unreal_indoor_high\\images\\" 181 | dir_data_mask = "F:\\datasets\\unreal_indoor_high\\annos\\" 182 | if not os.path.exists(dir_data): 183 | os.makedirs(dir_data) 184 | if not os.path.exists(dir_data_mask): 185 | os.makedirs(dir_data_mask) 186 | l =[] 187 | for dir_save in dir_saves: 188 | #dir_save = "/Users/zhangtianyu/Documents/MakeHuman/datasets/tmp1592218940/" 189 | for c in range(1,8): 190 | thread = HandleImageThread(dir_save+"\\",dir_data,dir_data_mask,c) 191 | thread.start() 192 | l.append(thread) 193 | 194 | for t in l: 195 | t.join() 196 | 197 | 198 | except Exception as e: 199 | 200 | traceback.print_exc() -------------------------------------------------------------------------------- /UnrealPerson-DataSynthesisToolkit/postprocess_video.py: -------------------------------------------------------------------------------- 1 | import glob,shutil 2 | import tqdm 3 | import os 4 | import sys,re,traceback 5 | import numpy as np 6 | from io import BytesIO 7 | import time 8 | import threading 9 | import random 10 | import PIL.Image 11 | import pickle 12 | import glob 13 | from multiprocessing import Process 14 | 15 | class HandleImageThread(Process): 16 | def __init__(self, dir_save, dir_data, dir_data_mask, cam, ptype_name): 17 | super(HandleImageThread, self).__init__() 18 | self.dir_save=dir_save 19 | self.dir_data=dir_data 20 | self.dir_data_mask = dir_data_mask 21 | self.cam=cam 22 | self.ptype_name = ptype_name 23 | 24 | 25 | def run(self): 26 | datautils = DataUtils(self.dir_save, self.dir_data, self.dir_data_mask, self.ptype_name) 27 | datautils.get_object_color() 28 | 29 | 30 | files = glob.glob(os.path.join(self.dir_save ,"c{:0>3d}_*lit.png".format(self.cam))) 31 | files.sort() 32 | 33 | for fi in tqdm.tqdm(files): 34 | datautils.generate_one_frame(fi) 35 | 36 | # TODO: replace this with a better implementation 37 | class Color(object): 38 | ''' A utility class to parse color value ''' 39 | def __init__(self, color_str): 40 | self.color_str = color_str 41 | self.R,self.G,self.B,self.A=0,0,0,0 42 | color_str = color_str.replace("(","").replace(")","").replace("R=","").replace("G=","").replace("B=","").replace("A=","") 43 | try: 44 | (self.R, self.G, self.B, self.A) = [int(i) for i in color_str.split(",")] 45 | except Exception as e: 46 | print("Error in Color:") 47 | print(color_str) 48 | 49 | def __repr__(self): 50 | return self.color_str 51 | 52 | class DataUtils(object): 53 | def __init__(self,dir_save,dir_data,dir_data_mask, ptype_name): 54 | self.ptype_name = ptype_name 55 | self.dir_save = dir_save 56 | tmp = os.path.basename(os.path.split(self.dir_save)[0]) 57 | self.dir_data = os.path.join(dir_data, tmp) 58 | self.dir_data_mask = os.path.join(dir_data_mask, tmp) 59 | try: 60 | if not os.path.exists(self.dir_data): 61 | os.makedirs(self.dir_data) 62 | if not os.path.exists(self.dir_data_mask): 63 | os.makedirs(self.dir_data_mask) 64 | except Exception as execeptione: 65 | pass 66 | self.id2color = {} 67 | 68 | 69 | def read_png(self,res): 70 | import PIL.Image 71 | img = PIL.Image.open(res) 72 | return np.asarray(img) 73 | 74 | def read_pngIMG(self,res): 75 | import PIL.Image 76 | img = PIL.Image.open(res) 77 | return img 78 | 79 | def get_object_color(self): 80 | self.id2color = pickle.load(open(self.dir_save+"object_color.pkl",'rb')) 81 | 82 | def match_color(self,object_mask, target_color, tolerance=3): 83 | match_region = np.ones(object_mask.shape[0:2], dtype=bool) 84 | for c in range(3): # r,g,b 85 | min_val = target_color[c] - tolerance 86 | max_val = target_color[c] + tolerance 87 | channel_region = (object_mask[:,:,c] >= min_val) & (object_mask[:,:,c] <= max_val) 88 | match_region &= channel_region 89 | 90 | if match_region.sum() > 2000: 91 | return match_region 92 | else: 93 | 94 | return None 95 | 96 | def generate_one_frame(self,lit_file): 97 | 98 | lit = self.read_pngIMG(lit_file) 99 | 100 | object_mask = self.read_png(lit_file.replace("lit","mask")) 101 | 102 | s=os.path.split(lit_file)[1] 103 | cam = s.split("_")[0] 104 | frame =s.split("_")[1] 105 | 106 | id2mask = {} 107 | for obj_id in self.id2color.keys(): 108 | color = self.id2color[obj_id] 109 | mask = self.match_color(object_mask, [color.R, color.G, color.B], tolerance=3) 110 | if mask is not None: 111 | id2mask[obj_id] = mask 112 | # This may take a while 113 | # TODO: Need to find a faster implementation for this 114 | count = 0 115 | for k, x in id2mask.items(): 116 | left = 9999 117 | top = 9999 118 | right = 0 119 | bottom = 0 120 | for i in range(x.shape[0]): 121 | for j in range(x.shape[1]): 122 | if x[i][j] == True: 123 | if i > bottom: 124 | bottom = i 125 | if i < top: 126 | top = i 127 | if j < left: 128 | left = j 129 | if j > right: 130 | right = j 131 | if left < 5 or top <5 or bottom >= x.shape[0] - 5 or right >= x.shape[1] - 5: 132 | continue 133 | 134 | r1 = right-left 135 | r2 = bottom - top 136 | left = max(0, left-random.randint(0,int(0.1*r1))) 137 | right= min(x.shape[1]-1, right+random.randint(0,int(0.1*r1))) 138 | top = max(0, top-random.randint(0,int(0.1*r2))) 139 | bottom= min(x.shape[0]-1, bottom+random.randint(0,int(0.1*r2))) 140 | 141 | img = lit.crop((left, top, right, bottom)) 142 | 143 | img = img.convert('RGB') 144 | 145 | # mask save disabled 146 | x_part = x[top:bottom, left:right] 147 | img_mask = np.zeros((bottom-top,right-left,3),dtype=np.uint8) 148 | img_mask[x_part==True]=[255,255,255] 149 | img_mask = PIL.Image.fromarray(img_mask) 150 | 151 | qqq = np.sum(x)/(r1*r2) 152 | if (1.0*(right-left))/(1.0*(bottom-top)) > 0.7: 153 | continue 154 | 155 | if (right-left) * (bottom-top) < 3000: 156 | continue 157 | 158 | if np.sum(x)/(r1*r2) < 0.3: 159 | continue 160 | try: 161 | img.save(os.path.join(self.dir_data,k.replace(self.ptype_name,"") + "_{}_F{}_{}_{}_{}_{}.jpg" 162 | .format(cam,frame, left,top,right,bottom))) 163 | # mask save 164 | img_mask.save(os.path.join(self.dir_data_mask,k.replace(self.ptype_name,"")+ "_{}_F{}_{}_{}_{}_{}.png" 165 | .format(cam,frame, left,top,right,bottom))) 166 | count += 1 167 | except AttributeError or SystemError as e: 168 | continue 169 | 170 | 171 | 172 | 173 | 174 | if __name__=="__main__": 175 | try: 176 | import os,tqdm,glob 177 | import argparse 178 | parser= argparse.ArgumentParser() 179 | parser.add_argument("--path",type=str) 180 | parser.add_argument("--scene",type=str,choices=['s001','s002','s003','s004']) 181 | parser.add_argument("--ptype",type=int) 182 | cam_num = {'s001': 6, 183 | 's002': 16, 184 | 's003': 6, 185 | 's004': 6} 186 | ptype_name = {1: "MHMasterAI_C_", 187 | 2: "MHMainClass_C_", 188 | 3: "MHMainClassDivide_C_"} 189 | args = parser.parse_args() 190 | 191 | save_dir = 'unreal_video_{}_p{}'.format(args.scene, args.ptype) 192 | 193 | dir_saves = [i + '\\' for i in glob.glob(os.path.join(args.path, "tmp*"))] 194 | 195 | dir_data = "F:\\video\\{}\\images\\".format(save_dir) 196 | dir_data_mask = "F:\\video\\{}\\annos\\".format(save_dir) 197 | if not os.path.exists(dir_data): 198 | os.makedirs(dir_data) 199 | if not os.path.exists(dir_data_mask): 200 | os.makedirs(dir_data_mask) 201 | l =[] 202 | for dir_save in dir_saves: 203 | for c in range(1, 1+cam_num[args.scene]): 204 | thread = HandleImageThread(dir_save+"\\",dir_data,dir_data_mask,c, ptype_name[args.ptype]) 205 | thread.start() 206 | l.append(thread) 207 | 208 | for t in l: 209 | t.join() 210 | 211 | 212 | except Exception as e: 213 | 214 | traceback.print_exc() -------------------------------------------------------------------------------- /UnrealPerson-DataSynthesisToolkit/script_clothingcoparsing_clothing_patch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri Jun 26 15:07:13 2020 5 | 6 | @author: zhangtianyu 7 | """ 8 | 9 | import os 10 | import glob 11 | import scipy.io 12 | import tqdm 13 | import PIL.Image 14 | import numpy as np 15 | 16 | def findMaxRect(data): 17 | 18 | '''http://stackoverflow.com/a/30418912/5008845''' 19 | 20 | nrows,ncols = data.shape 21 | w = np.zeros(dtype=int, shape=data.shape) 22 | h = np.zeros(dtype=int, shape=data.shape) 23 | skip = 0 24 | area_max = (0, []) 25 | 26 | 27 | for r in range(nrows): 28 | for c in range(ncols): 29 | if data[r][c] == skip: 30 | continue 31 | if r == 0: 32 | h[r][c] = 1 33 | else: 34 | h[r][c] = h[r-1][c]+1 35 | if c == 0: 36 | w[r][c] = 1 37 | else: 38 | w[r][c] = w[r][c-1]+1 39 | minw = w[r][c] 40 | for dh in range(h[r][c]): 41 | minw = min(minw, w[r-dh][c]) 42 | area = (dh+1)*minw 43 | if area > area_max[0]: 44 | area_max = (area, [(r-dh, c-minw+1, r, c)]) 45 | 46 | return area_max 47 | 48 | save_dir = '/Users/zhangtianyu/Downloads/clothing-co-parsing-master/ccp_patches' 49 | img_dir = '/Users/zhangtianyu/Downloads/clothing-co-parsing-master/photos' 50 | anno_dir = '/Users/zhangtianyu/Downloads/clothing-co-parsing-master/annotations/pixel-level' 51 | labels = scipy.io.loadmat('/Users/zhangtianyu/Downloads/clothing-co-parsing-master/label_list.mat') 52 | labels= labels['label_list'] 53 | 54 | pixel_label = {} 55 | 56 | for i in range(59): 57 | pixel_label[i]=str(labels[0][i][0]) 58 | 59 | 60 | annos = glob.glob(anno_dir + '/*.mat') 61 | 62 | 63 | anno_path = annos[0] 64 | for anno_path in tqdm.tqdm(annos): 65 | num=os.path.split(anno_path)[1].split('.')[0] 66 | 67 | img_path = anno_path.replace('annotations/pixel-level','photos').replace('.mat','.jpg') 68 | mat = scipy.io.loadmat(anno_path) 69 | gt = mat['groundtruth'] 70 | img = PIL.Image.open(img_path) 71 | 72 | top1=0 73 | top1_count=0 74 | top2=0 75 | top2_count=0 76 | for i in range(1,59): 77 | if i==41:continue 78 | count=(gt==i).sum() 79 | if count==0:continue 80 | if count>top1_count: 81 | top2_count=top1_count 82 | top2=top1 83 | top1_count=count 84 | top1=i 85 | elif count>top2_count: 86 | top2_count=count 87 | top2=i 88 | 89 | top1_gt = gt==top1 90 | top2_gt = gt==top2 91 | 92 | 93 | 94 | a_1 = findMaxRect(top1_gt) 95 | cood=a_1[1][0] 96 | rec_1=img.crop((cood[1],cood[0],cood[3],cood[2])) 97 | 98 | a_2 = findMaxRect(top2_gt) 99 | cood=a_2[1][0] 100 | rec_2=img.crop((cood[1],cood[0],cood[3],cood[2])) 101 | 102 | rec_1.save(save_dir+"/"+pixel_label[top1]+"_{}.png".format(num)) 103 | rec_2.save(save_dir+"/"+pixel_label[top2]+"_{}.png".format(num)) 104 | -------------------------------------------------------------------------------- /UnrealPerson-DataSynthesisToolkit/script_deepfashion_clothing_patch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tues 1st Sep 2020 5 | 6 | @author: zhangtianyu 7 | """ 8 | 9 | import os 10 | import glob 11 | import scipy.io 12 | import tqdm 13 | import PIL.Image 14 | import numpy as np 15 | from multiprocessing import cpu_count 16 | from multiprocessing import Pool 17 | 18 | def mt_func(seg,pixel_label): 19 | print(seg) 20 | gender = "m" if seg.find('WOMEN')==-1 else "f" 21 | idx = seg.split('/')[-2]+'_'+seg.split('/')[-1][:4] 22 | img_path = seg.replace('_segment.png','.jpg') 23 | if not os.path.exists(img_path): 24 | print(img_path+" does not exist.") 25 | return False 26 | 27 | 28 | img = PIL.Image.open(img_path) 29 | gt = PIL.Image.open(seg) 30 | gt = np.array(gt) 31 | gt = gt[:,:,:3] 32 | top1=0 33 | top1_count=0 34 | top2=0 35 | top2_count=0 36 | for i in pixel_label.keys(): 37 | 38 | count=match_color(gt,i).sum() 39 | if count==0:continue 40 | if count>top1_count: 41 | top2_count=top1_count 42 | top2=top1 43 | top1_count=count 44 | top1=i 45 | elif count>top2_count: 46 | top2_count=count 47 | top2=i 48 | 49 | top1_gt = match_color(gt,top1) 50 | top2_gt = match_color(gt,top2) 51 | 52 | 53 | 54 | a_1 = findMaxRect(top1_gt) 55 | cood=a_1[1][0] 56 | rec_1=img.crop((cood[1],cood[0],cood[3],cood[2])) 57 | 58 | a_2 = findMaxRect(top2_gt) 59 | cood=a_2[1][0] 60 | rec_2=img.crop((cood[1],cood[0],cood[3],cood[2])) 61 | 62 | rec_1.save(save_dir+"/"+pixel_label[top1]+"_{}_{}.png".format(gender,idx)) 63 | rec_2.save(save_dir+"/"+pixel_label[top2]+"_{}_{}.png".format(gender,idx)) 64 | return True 65 | 66 | 67 | 68 | def match_color(mask, color): 69 | match_region = np.ones(mask.shape[:2],dtype=bool) 70 | for c in range(3): 71 | val = color[c] 72 | channel_region= mask[:,:,c]==val 73 | match_region &= channel_region 74 | return match_region 75 | 76 | def findMaxRect(data): 77 | 78 | '''http://stackoverflow.com/a/30418912/5008845''' 79 | 80 | nrows,ncols = data.shape 81 | w = np.zeros(dtype=int, shape=data.shape) 82 | h = np.zeros(dtype=int, shape=data.shape) 83 | skip = 0 84 | area_max = (0, []) 85 | 86 | 87 | for r in range(nrows): 88 | for c in range(ncols): 89 | if data[r][c] == skip: 90 | continue 91 | if r == 0: 92 | h[r][c] = 1 93 | else: 94 | h[r][c] = h[r-1][c]+1 95 | if c == 0: 96 | w[r][c] = 1 97 | else: 98 | w[r][c] = w[r][c-1]+1 99 | minw = w[r][c] 100 | for dh in range(h[r][c]): 101 | minw = min(minw, w[r-dh][c]) 102 | area = (dh+1)*minw 103 | if area > area_max[0]: 104 | area_max = (area, [(r-dh, c-minw+1, r, c)]) 105 | 106 | return area_max 107 | 108 | save_dir = '/Users/zhangtianyu/Downloads/clothing-co-parsing-master/df_patches' 109 | img_dir = '/Users/zhangtianyu/Downloads/df_clothes/img_highres_seg' 110 | 111 | 112 | pixel_label = { 113 | (255,250,250) : "top", 114 | (250,235,215) : "skirt", 115 | (255, 250, 205) : "dress", 116 | (220, 220, 220) : "outer", 117 | (211, 211, 211) : "pants", 118 | (127, 255, 212) : "headwear", 119 | } 120 | 121 | 122 | seg_imgs = glob.glob(os.path.join(img_dir,"*/*/*/*_segment.png")) 123 | seg_imgs.sort() 124 | results=[] 125 | mt_pool = Pool(cpu_count()) 126 | for seg in seg_imgs: 127 | params=(seg,pixel_label) 128 | results.append(mt_pool.apply_async(mt_func,params)) 129 | 130 | mt_pool.close() 131 | mt_pool.join() 132 | 133 | 134 | -------------------------------------------------------------------------------- /UnrealPerson-DataSynthesisToolkit/script_makehuman_asset_download.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import urllib,os 3 | from bs4 import BeautifulSoup 4 | import tqdm 5 | url_entrance = { 6 | "Accessory":["http://www.makehumancommunity.org/clothes.html?field_clothes_category_value=Accessory&field_clothes_status_value=All" 7 | ,"http://www.makehumancommunity.org/clothes.html?field_clothes_category_value=Accessory&field_clothes_status_value=All&page=1"] 8 | } 9 | 10 | 11 | 12 | def download_one_file(url,filepath): 13 | urllib.request.urlretrieve(url, filename=filepath) 14 | print(filepath) 15 | 16 | def get_html_soup(url): 17 | webpage = requests.get(url) 18 | content = webpage.text 19 | soup = BeautifulSoup(content, 'html.parser') 20 | return soup 21 | 22 | def download(url,category): 23 | soup = get_html_soup(url) 24 | tbody = soup.find("tbody") 25 | 26 | td = tbody.find_all('td') 27 | file_dirpath = 'assets/Clothes/' + category 28 | if not os.path.exists(file_dirpath): 29 | os.mkdir(file_dirpath) 30 | for t in td: 31 | if t.a: 32 | print(t.a.text) 33 | try: 34 | url_case = "http://www.makehumancommunity.org"+t.a['href'] 35 | soup_case = get_html_soup(url_case) 36 | spans = soup_case.find_all("span",class_="file") 37 | asset_name = t.a.text 38 | asset_name = asset_name.replace(" ","_") 39 | 40 | for s in spans: 41 | url_download = s.a['href'] 42 | 43 | 44 | 45 | file_dirpath = 'assets/Clothes/'+category+'/'+ asset_name 46 | filepath = file_dirpath+"/"+ url_download.split('/')[-1] 47 | if not os.path.exists(file_dirpath): 48 | os.mkdir(file_dirpath) 49 | download_one_file(url_download,filepath) 50 | fs = soup_case.find_all('figure') 51 | for f in fs: 52 | url_texture = f.a['href'] 53 | filepath = file_dirpath + "/" + url_texture.split('/')[-1] 54 | download_one_file(url_texture, filepath) 55 | except Exception as e: 56 | 57 | continue 58 | 59 | if __name__=="__main__": 60 | for k,v in url_entrance.items(): 61 | for url in v: 62 | download(url,k) -------------------------------------------------------------------------------- /UnrealPerson-DataSynthesisToolkit/unrealcv/unrealcv_plugin-4.24-mac.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlyHighest/UnrealPerson/9fd5a882bea16f6289f0353b78140bff36c16609/UnrealPerson-DataSynthesisToolkit/unrealcv/unrealcv_plugin-4.24-mac.zip -------------------------------------------------------------------------------- /UnrealPerson-DataSynthesisToolkit/unrealcv/unrealcv_plugin-4.24-win64.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlyHighest/UnrealPerson/9fd5a882bea16f6289f0353b78140bff36c16609/UnrealPerson-DataSynthesisToolkit/unrealcv/unrealcv_plugin-4.24-win64.zip -------------------------------------------------------------------------------- /UnrealPerson-DataSynthesisToolkit/utils.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import numpy as np 4 | import sys,pickle 5 | 6 | class Color(object): 7 | ''' A utility class to parse color value ''' 8 | def __init__(self, color_str): 9 | self.color_str = color_str 10 | self.R,self.G,self.B,self.A=0,0,0,0 11 | color_str = color_str.replace("(","").replace(")","").replace("R=","").replace("G=","").replace("B=","").replace("A=","") 12 | try: 13 | (self.R, self.G, self.B, self.A) = [int(i) for i in color_str.split(",")] 14 | except Exception as e: 15 | print("Error in Color:") 16 | print(color_str) 17 | 18 | def __repr__(self): 19 | return self.color_str 20 | 21 | class DataUtils(object): 22 | def __init__(self,client,dir_save): 23 | self.client = client 24 | self.client.connect() 25 | self.dir_save = dir_save 26 | if not client.isconnected(): 27 | print('UnrealCV server is not running. Run the game downloaded from http://unrealcv.github.io first.') 28 | sys.exit(-1) 29 | 30 | 31 | res = self.client.request('vget /unrealcv/status') 32 | 33 | 34 | 35 | 36 | # The image resolution and port is configured in the config file. 37 | print(res) 38 | self.scene_objects=None 39 | self.id2color = {} 40 | 41 | 42 | def read_png(self,res): 43 | import PIL.Image 44 | img = PIL.Image.open(BytesIO(res)) 45 | return np.asarray(img) 46 | 47 | def read_pngIMG(self,res): 48 | import PIL.Image 49 | img = PIL.Image.open(BytesIO(res)) 50 | return img 51 | 52 | def get_object_color(self): 53 | self.scene_objects = self.client.request('vget /objects').split(' ') 54 | print('Number of objects in this scene:', len(self.scene_objects)) 55 | self.id2color = {} # Map from object id to the labeling color 56 | for obj_id in self.scene_objects: 57 | if obj_id not in self.id2color.keys(): 58 | if obj_id.startswith("MH"): 59 | color = Color(self.client.request('vget /object/{}/color'.format(obj_id))) 60 | self.id2color[obj_id] = color 61 | pickle.dump(self.id2color,open(self.dir_save+"object_color.pkl",'wb')) 62 | 63 | def match_color(self,object_mask, target_color, tolerance=3): 64 | match_region = np.ones(object_mask.shape[0:2], dtype=bool) 65 | for c in range(3): # r,g,b 66 | min_val = target_color[c] - tolerance 67 | max_val = target_color[c] + tolerance 68 | channel_region = (object_mask[:,:,c] >= min_val) & (object_mask[:,:,c] <= max_val) 69 | match_region &= channel_region 70 | 71 | if match_region.sum() != 0: 72 | return match_region 73 | else: 74 | return None 75 | 76 | def generate_one_frame(self,cam,frame): 77 | self.client.request('vrun ce Pause') 78 | res = self.client.request('vget /camera/0/lit png') 79 | lit = self.read_pngIMG(res) 80 | #print('The image is saved to {}'.format(res)) 81 | res = self.client.request('vget /camera/0/object_mask png') 82 | object_mask = self.read_pngIMG(res) 83 | self.client.request('vrun ce Resume') 84 | lit.save(self.dir_save + "{}_{}_lit.png".format(cam, frame)) 85 | object_mask.save(self.dir_save + "{}_{}_mask.png".format(cam, frame)) 86 | -------------------------------------------------------------------------------- /imgs/unrealperson.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlyHighest/UnrealPerson/9fd5a882bea16f6289f0353b78140bff36c16609/imgs/unrealperson.jpg --------------------------------------------------------------------------------