├── README.md ├── data ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── build.cpython-35.pyc │ ├── build.cpython-36.pyc │ ├── build.cpython-37.pyc │ ├── collate_batch.cpython-35.pyc │ ├── collate_batch.cpython-36.pyc │ └── collate_batch.cpython-37.pyc ├── build.py ├── collate_batch.py ├── datasets │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── bases.cpython-35.pyc │ │ ├── bases.cpython-36.pyc │ │ ├── bases.cpython-37.pyc │ │ ├── celeba.cpython-35.pyc │ │ ├── celeba.cpython-36.pyc │ │ ├── celeba.cpython-37.pyc │ │ ├── celeba_msk.cpython-35.pyc │ │ ├── celeba_msk.cpython-37.pyc │ │ ├── cuhk03.cpython-35.pyc │ │ ├── cuhk03.cpython-36.pyc │ │ ├── cuhk03.cpython-37.pyc │ │ ├── dataset_loader.cpython-35.pyc │ │ ├── dataset_loader.cpython-36.pyc │ │ ├── dataset_loader.cpython-37.pyc │ │ ├── dukemtmcreid.cpython-35.pyc │ │ ├── dukemtmcreid.cpython-36.pyc │ │ ├── dukemtmcreid.cpython-37.pyc │ │ ├── eval_reid.cpython-35.pyc │ │ ├── eval_reid.cpython-36.pyc │ │ ├── eval_reid.cpython-37.pyc │ │ ├── last.cpython-35.pyc │ │ ├── last.cpython-36.pyc │ │ ├── last.cpython-37.pyc │ │ ├── last_cloth.cpython-35.pyc │ │ ├── last_cloth.cpython-37.pyc │ │ ├── last_vl.cpython-37.pyc │ │ ├── lslt.cpython-35.pyc │ │ ├── lslt.cpython-36.pyc │ │ ├── lslt.cpython-37.pyc │ │ ├── ltcc.cpython-35.pyc │ │ ├── ltcc.cpython-37.pyc │ │ ├── ltcc_mask.cpython-37.pyc │ │ ├── market1501.cpython-35.pyc │ │ ├── market1501.cpython-36.pyc │ │ ├── market1501.cpython-37.pyc │ │ ├── market_triplet.cpython-35.pyc │ │ ├── market_triplet.cpython-36.pyc │ │ ├── market_triplet.cpython-37.pyc │ │ ├── msmt17.cpython-35.pyc │ │ ├── msmt17.cpython-36.pyc │ │ ├── msmt17.cpython-37.pyc │ │ ├── night.cpython-35.pyc │ │ ├── night.cpython-36.pyc │ │ ├── night.cpython-37.pyc │ │ ├── prcc.cpython-35.pyc │ │ ├── prcc.cpython-36.pyc │ │ ├── prcc.cpython-37.pyc │ │ ├── prcc_abc.cpython-35.pyc │ │ ├── prcc_abc.cpython-37.pyc │ │ ├── prcc_c.cpython-35.pyc │ │ ├── prcc_gcn.cpython-35.pyc │ │ ├── prcc_gcn.cpython-36.pyc │ │ └── prcc_gcn.cpython-37.pyc │ ├── bases.py │ ├── celeba.py │ ├── celeba_msk.py │ ├── cuhk03.py │ ├── dataset_loader.py │ ├── dukemtmcreid.py │ ├── eval_reid.py │ ├── market1501.py │ ├── market_triplet.py │ ├── msmt17.py │ ├── prcc.py │ └── prcc_gcn.py ├── samplers │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── triplet_sampler.cpython-35.pyc │ │ ├── triplet_sampler.cpython-36.pyc │ │ └── triplet_sampler.cpython-37.pyc │ └── triplet_sampler.py └── transforms │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── build.cpython-35.pyc │ ├── build.cpython-36.pyc │ ├── build.cpython-37.pyc │ ├── transform.cpython-35.pyc │ ├── transform.cpython-36.pyc │ ├── transform.cpython-37.pyc │ ├── transforms.cpython-35.pyc │ ├── transforms.cpython-36.pyc │ └── transforms.cpython-37.pyc │ ├── build.py │ ├── transform.py │ └── transforms.py ├── engine ├── __pycache__ │ ├── inference.cpython-35.pyc │ ├── inference.cpython-36.pyc │ ├── inference.cpython-37.pyc │ ├── trainer.cpython-35.pyc │ ├── trainer.cpython-36.pyc │ ├── trainer.cpython-37.pyc │ ├── trainer_sgap.cpython-35.pyc │ └── trainer_triplet.cpython-35.pyc ├── inference.py └── trainer.py ├── layers ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── aligned_loss.cpython-35.pyc │ ├── aligned_loss.cpython-36.pyc │ ├── aligned_loss.cpython-37.pyc │ ├── center_loss.cpython-35.pyc │ ├── center_loss.cpython-36.pyc │ ├── center_loss.cpython-37.pyc │ ├── cluster_loss.cpython-35.pyc │ ├── cluster_loss.cpython-36.pyc │ ├── cluster_loss.cpython-37.pyc │ ├── fast_ap.cpython-35.pyc │ ├── fast_ap.cpython-36.pyc │ ├── fast_ap.cpython-37.pyc │ ├── fast_ap_mem.cpython-35.pyc │ ├── fast_ap_mem.cpython-36.pyc │ ├── fast_ap_mem.cpython-37.pyc │ ├── histogram.cpython-35.pyc │ ├── histogram.cpython-36.pyc │ ├── histogram.cpython-37.pyc │ ├── local_dist.cpython-35.pyc │ ├── local_dist.cpython-36.pyc │ ├── local_dist.cpython-37.pyc │ ├── map_loss.cpython-35.pyc │ ├── map_loss.cpython-36.pyc │ ├── map_loss.cpython-37.pyc │ ├── range_loss.cpython-35.pyc │ ├── triplet_loss.cpython-35.pyc │ ├── triplet_loss.cpython-36.pyc │ └── triplet_loss.cpython-37.pyc ├── center_loss.py ├── cluster_loss.py ├── local_dist.py ├── range_loss.py └── triplet_loss.py ├── modeling ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── baseline.cpython-35.pyc │ ├── baseline.cpython-36.pyc │ ├── baseline.cpython-37.pyc │ ├── gcn.cpython-37.pyc │ ├── hpm.cpython-35.pyc │ ├── hpm.cpython-36.pyc │ ├── hpm.cpython-37.pyc │ ├── mgn.cpython-35.pyc │ ├── mgn.cpython-36.pyc │ ├── mgn.cpython-37.pyc │ ├── pcb.cpython-35.pyc │ ├── pcb.cpython-36.pyc │ ├── pcb.cpython-37.pyc │ ├── pcb_seg.cpython-35.pyc │ ├── pcb_seg.cpython-36.pyc │ ├── pcb_seg.cpython-37.pyc │ ├── pyramid.cpython-35.pyc │ ├── pyramid.cpython-36.pyc │ ├── pyramid.cpython-37.pyc │ ├── res_net.cpython-35.pyc │ ├── res_net.cpython-36.pyc │ └── res_net.cpython-37.pyc ├── backbones │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── cls_hrnet.cpython-35.pyc │ │ ├── cls_hrnet.cpython-36.pyc │ │ ├── cls_hrnet.cpython-37.pyc │ │ ├── model.cpython-35.pyc │ │ ├── model.cpython-36.pyc │ │ ├── model.cpython-37.pyc │ │ ├── non_local.cpython-35.pyc │ │ ├── pcb.cpython-35.pyc │ │ ├── resnet.cpython-35.pyc │ │ ├── resnet.cpython-36.pyc │ │ ├── resnet.cpython-37.pyc │ │ ├── resnet_bap.cpython-35.pyc │ │ ├── resnet_bap.cpython-36.pyc │ │ ├── resnet_bap.cpython-37.pyc │ │ ├── senet.cpython-35.pyc │ │ ├── senet.cpython-36.pyc │ │ ├── senet.cpython-37.pyc │ │ ├── sga_resnet.cpython-35.pyc │ │ ├── sga_resnet.cpython-36.pyc │ │ └── sga_resnet.cpython-37.pyc │ ├── cls_hrnet.py │ ├── model.py │ ├── non_local.py │ ├── resnet.py │ └── senet.py ├── gcn.py ├── hpm.py ├── mgn.py ├── pcb.py ├── pcb_seg.py └── res_net.py ├── solver ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── build.cpython-35.pyc │ ├── build.cpython-36.pyc │ ├── build.cpython-37.pyc │ ├── lr_scheduler.cpython-35.pyc │ ├── lr_scheduler.cpython-36.pyc │ ├── lr_scheduler.cpython-37.pyc │ ├── ranger.cpython-35.pyc │ ├── ranger.cpython-36.pyc │ └── ranger.cpython-37.pyc ├── build.py ├── lr_scheduler.py └── ranger.py ├── tests ├── __init__.py └── lr_scheduler_test.py ├── train_prcc_base.py ├── train_prcc_hpm.py ├── train_prcc_hpm_pix.py ├── train_prcc_mgn.py ├── train_prcc_mgn_pix.py ├── train_prcc_pcb.py ├── train_prcc_pcb_pix.py ├── train_prcc_pixel_sampling.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-35.pyc ├── __init__.cpython-36.pyc ├── __init__.cpython-37.pyc ├── distance.cpython-35.pyc ├── distance.cpython-36.pyc ├── distance.cpython-37.pyc ├── iotools.cpython-35.pyc ├── iotools.cpython-36.pyc ├── iotools.cpython-37.pyc ├── logger.cpython-35.pyc ├── loss.cpython-37.pyc ├── re_ranking.cpython-35.pyc ├── re_ranking.cpython-36.pyc ├── re_ranking.cpython-37.pyc ├── reid_metric.cpython-35.pyc ├── reid_metric.cpython-36.pyc ├── reid_metric.cpython-37.pyc ├── reid_tool.cpython-35.pyc ├── rerank.cpython-35.pyc ├── rerank.cpython-36.pyc ├── rerank.cpython-37.pyc ├── visual.cpython-35.pyc ├── visual.cpython-36.pyc └── visual.cpython-37.pyc ├── compute_map.py ├── distance.py ├── iotools.py ├── logger.py ├── loss.py ├── re_ranking.py ├── reid_metric.py ├── reid_tool.py ├── rerank.py └── visual.py /README.md: -------------------------------------------------------------------------------- 1 | # Semantic-guided Pixel Sampling for Cloth-Changing Person Re-identification 2 | 3 | In our paper [publish](https://ieeexplore.ieee.org/abstract/document/9463711/), [arxiv](https://arxiv.org/abs/2107.11522), we propose a semantic-guided pixel sampling approach for the cloth-changing person re-ID task. This repo contains the training and testing codes. 4 | 5 | ## Prepare Dataset 6 | 1. Download the PRCC dataset: [PRCC](http://isee-ai.cn/~yangqize/clothing.html) 7 | 2. Obtain the human body parts: [SCHP](https://github.com/PeikeLi/Self-Correction-Human-Parsing) 8 | 3. The mask of PRCC dataset: [Baidu](https://pan.baidu.com/s/1sX1qFgo3I-4OfEdEr-opSg), password: r9kc or [Google](https://drive.google.com/drive/folders/1HaIoKRj1R4fxjVQ9Qg_IEk2_46b1hniH?usp=sharing) 9 | 10 | 11 | ## Trained Models 12 | The trained models can be downloaded from: [BaiduPan](https://pan.baidu.com/s/1JOOJp_NPbsU19DdBr7ze9g) password: 6ulj, [Google](https://drive.google.com/drive/folders/1aAltKSfRpHqADXb6sWOQ0VL7dj9GVwvU?usp=sharing) 13 | ``` 14 | Put the trained models to corresponding directories: 15 | >pixel_sampling/imagenet/resnet50-19c8e357.pth 16 | >pixel_sampling/logs/prcc_base/checkpoint_best.pth 17 | >pixel_sampling/logs/prcc_hpm/checkpoint_best.pth 18 | >...... 19 | ``` 20 | 21 | ## Training and Testing Models 22 | Only need to modify several parameters: 23 | ``` 24 | >parser.add_argument('--train', type=str, default='train', help='train, test') 25 | 26 | >parser.add_argument('--data_dir', type=str, default='/data/prcc/') 27 | ``` 28 | then 29 | ``` 30 | >python train_prcc_base.py 31 | ``` 32 | 33 | ## Citations 34 | If you think this work is useful for you, please cite 35 | ```bibtex 36 | @article{shu2021semantic, 37 | title={Semantic-guided Pixel Sampling for Cloth-Changing Person Re-identification}, 38 | author={Shu, Xiujun and Li, Ge and Wang, Xiao and Ruan, Weijian and Tian, Qi}, 39 | journal={IEEE Signal Processing Letters}, 40 | volume={28}, 41 | pages={1365-1369}, 42 | year={2021}, 43 | } 44 | ``` 45 | 46 | If you have any questions, please contact this e-mail: shuxj@mail.ioa.ac.cn 47 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | 4 | from .build import make_data_loader, make_data_loader_mask, make_data_loader_visual 5 | from .build import make_data_loader_path, make_data_loader_prcc_gcn, make_data_loader_prcc_hrnet 6 | from .build import make_data_loader_prcc_visual, make_data_loader_prcc_eraser, make_data_loader_base 7 | from .build import make_data_loader_path_data, make_data_loader_data 8 | from .build import make_data_loader_prcc, make_data_loader_path_visual, make_data_loader_visual_mask 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/build.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/__pycache__/build.cpython-35.pyc -------------------------------------------------------------------------------- /data/__pycache__/build.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/__pycache__/build.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/build.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/__pycache__/build.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/collate_batch.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/__pycache__/collate_batch.cpython-35.pyc -------------------------------------------------------------------------------- /data/__pycache__/collate_batch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/__pycache__/collate_batch.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/collate_batch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/__pycache__/collate_batch.cpython-37.pyc -------------------------------------------------------------------------------- /data/collate_batch.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import torch 4 | 5 | 6 | def train_collate_fn(batch): 7 | imgs, pids, _, _ = zip(*batch) 8 | pids = torch.tensor(pids, dtype=torch.int64) 9 | return torch.stack(imgs, dim=0), pids 10 | 11 | 12 | def val_collate_fn(batch): 13 | imgs, pids, camids, _ = zip(*batch) 14 | return torch.stack(imgs, dim=0), pids, camids 15 | 16 | 17 | 18 | def train_collate_gcn_mask(batch): 19 | imgs, masks, pids, _, pathes = zip(*batch) 20 | pids = torch.tensor(pids, dtype=torch.int64) 21 | return torch.stack(imgs), pids, pathes, torch.cat(masks) 22 | 23 | 24 | def val_collate_gcn_mask(batch): 25 | imgs, masks, pids, camids, paths = zip(*batch) 26 | return torch.stack(imgs, dim=0), pids, camids, paths, torch.cat(masks, dim=0) 27 | 28 | 29 | ###################################################### 30 | 31 | 32 | 33 | 34 | def train_collate_fn_path(batch): 35 | imgs, pids, _, pathes = zip(*batch) 36 | pids = torch.tensor(pids, dtype=torch.int64) 37 | return torch.stack(imgs, dim=0), pids, pathes 38 | 39 | 40 | def train_collate_fn_mem(batch): 41 | imgs, pids, camids, pathes = zip(*batch) 42 | pids = torch.tensor(pids, dtype=torch.int64) 43 | camids = torch.tensor(camids, dtype=torch.int64) 44 | return torch.stack(imgs, dim=0), pids, camids, pathes 45 | 46 | def train_collate_fn_idx(batch): 47 | imgs, pids, _, _, idx = zip(*batch) 48 | pids = torch.tensor(pids, dtype=torch.int64) 49 | return torch.stack(imgs, dim=0), pids, idx 50 | 51 | def train_collate_fn_adv(batch): 52 | imgs, pids, _, paths, labels = zip(*batch) 53 | pids = torch.tensor(pids, dtype=torch.int64) 54 | labels = torch.tensor(labels, dtype=torch.int64) 55 | return torch.stack(imgs, dim=0), pids, paths, labels 56 | 57 | def train_collate_fn_dom(batch): 58 | imgs, pids, _, pathes, ll = zip(*batch) 59 | pids = torch.tensor(pids, dtype=torch.int64) 60 | ll = torch.tensor(ll, dtype=torch.int64) 61 | return torch.stack(imgs, dim=0), pids, pathes, ll 62 | 63 | def train_collate_fn_vl(batch): 64 | imgs, pids, _, pathes, textes = zip(*batch) 65 | pids = torch.tensor(pids, dtype=torch.int64) 66 | return torch.stack(imgs, dim=0), pids, pathes, torch.cat(textes, dim=0) 67 | 68 | def val_collate_fn_path(batch): 69 | imgs, pids, camids, paths = zip(*batch) 70 | return torch.stack(imgs, dim=0), pids, camids, paths 71 | 72 | def val_collate_fn_idx(batch): 73 | imgs, pids, camids, paths, idx = zip(*batch) 74 | return torch.stack(imgs, dim=0), pids, camids, paths 75 | 76 | def val_collate_fn_vl(batch): 77 | imgs, pids, camids, paths, textes = zip(*batch) 78 | return torch.stack(imgs, dim=0), pids, camids, paths, torch.cat(textes, dim=0) 79 | 80 | def train_collate_bg(batch): 81 | imgs, pids, camids, _ = zip(*batch) 82 | pids = torch.tensor(pids, dtype=torch.int64) 83 | camids = torch.tensor(camids, dtype=torch.int64) 84 | return torch.stack(imgs, dim=0), pids, camids 85 | 86 | 87 | 88 | def train_collate_mask(batch): 89 | imgs, pids, _, _, mask, vis = zip(*batch) 90 | pids = torch.tensor(pids, dtype=torch.int64) 91 | return torch.stack(imgs, dim=0), pids, torch.stack(mask, dim=0), torch.stack(vis, dim=0) 92 | 93 | 94 | def val_collate_mask(batch): 95 | imgs, pids, camids, _, mask, vis = zip(*batch) 96 | return torch.stack(imgs, dim=0), pids, camids, torch.stack(mask, dim=0), torch.stack(vis, dim=0) 97 | 98 | 99 | 100 | def train_collate_build(batch): 101 | imgs, pids, _, _, mask = zip(*batch) 102 | pids = torch.tensor(pids, dtype=torch.int64) 103 | return torch.stack(imgs, dim=0), pids, torch.stack(mask, dim=0) 104 | 105 | 106 | def val_collate_build(batch): 107 | imgs, pids, camids, _, mask = zip(*batch) 108 | return torch.stack(imgs, dim=0), pids, camids, torch.stack(mask, dim=0) 109 | 110 | 111 | def train_collate_fn_visual(batch): 112 | imgs, pids, _, _, mask, path = zip(*batch) 113 | pids = torch.tensor(pids, dtype=torch.int64) 114 | return torch.stack(imgs, dim=0), pids, torch.stack(mask, dim=0), path 115 | 116 | 117 | def val_collate_fn_visual(batch): 118 | imgs, pids, camids, _, mask, path = zip(*batch) 119 | return torch.stack(imgs, dim=0), pids, camids, torch.stack(mask, dim=0), path 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | def train_collate_gcn_mask_head(batch): 128 | imgs, masks, pids, _, pathes, img_heads, img_legs = zip(*batch) 129 | pids = torch.tensor(pids, dtype=torch.int64) 130 | return torch.stack(imgs, dim=0), pids, pathes, torch.cat(masks, dim=0), torch.stack(img_heads, dim=0), torch.stack(img_legs, dim=0) 131 | 132 | 133 | def val_collate_gcn_mask_head(batch): 134 | imgs, masks, pids, camids, paths, img_heads, img_legs = zip(*batch) 135 | return torch.stack(imgs, dim=0), pids, camids, paths, torch.cat(masks, dim=0), torch.stack(img_heads, dim=0), torch.stack(img_legs, dim=0) 136 | 137 | 138 | 139 | def train_collate_gcn_mask_seg(batch): 140 | imgs, masks, pids, _, pathes, box_list, msk_list = zip(*batch) 141 | pids = torch.tensor(pids, dtype=torch.int64) 142 | return torch.stack(imgs, dim=0), pids, pathes, torch.cat(masks, dim=0), torch.stack(box_list, dim=0), torch.stack(msk_list, dim=0) 143 | 144 | 145 | def val_collate_gcn_mask_seg(batch): 146 | imgs, masks, pids, camids, paths, box_list, msk_list = zip(*batch) 147 | return torch.stack(imgs, dim=0), pids, camids, paths, torch.cat(masks, dim=0), torch.stack(box_list, dim=0), torch.stack(msk_list, dim=0) 148 | 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | from .cuhk03 import CUHK03 3 | from .dukemtmcreid import DukeMTMCreID 4 | from .market1501 import Market1501 5 | from .market_triplet import MarketTriplet 6 | from .msmt17 import MSMT17 7 | from .prcc import PRCC 8 | from .prcc_gcn import PRCC_GCN 9 | from .celeba import CELEBA 10 | from .celeba_msk import CELEBA_MSK 11 | from .dataset_loader import ImageDataset, ImageDatasetMask, ImageDatasetPath, ImageDatasetGcnMask 12 | from .dataset_loader import ImageDatasetVisualMask 13 | 14 | __factory = { 15 | 'market1501': Market1501, 16 | 'market_triplet': MarketTriplet, 17 | 'cuhk03': CUHK03, 18 | 'dukemtmc': DukeMTMCreID, 19 | 'msmt17': MSMT17, 20 | 'prcc': PRCC, 21 | 'prcc_gcn': PRCC_GCN, 22 | 'celeba': CELEBA, 23 | 'celeba_msk': CELEBA_MSK, 24 | } 25 | 26 | 27 | def get_names(): 28 | return __factory.keys() 29 | 30 | 31 | def init_dataset(name, *args, **kwargs): 32 | if name not in __factory.keys(): 33 | raise KeyError("Unknown datasets: {}".format(name)) 34 | return __factory[name](*args, **kwargs) 35 | -------------------------------------------------------------------------------- /data/datasets/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/bases.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/bases.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/bases.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/bases.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/bases.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/bases.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/celeba.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/celeba.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/celeba.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/celeba.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/celeba.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/celeba.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/celeba_msk.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/celeba_msk.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/celeba_msk.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/celeba_msk.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/cuhk03.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/cuhk03.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/cuhk03.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/cuhk03.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/cuhk03.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/cuhk03.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/dataset_loader.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/dataset_loader.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/dataset_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/dataset_loader.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/dataset_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/dataset_loader.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/dukemtmcreid.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/dukemtmcreid.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/dukemtmcreid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/dukemtmcreid.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/dukemtmcreid.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/dukemtmcreid.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/eval_reid.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/eval_reid.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/eval_reid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/eval_reid.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/eval_reid.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/eval_reid.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/last.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/last.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/last.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/last.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/last.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/last.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/last_cloth.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/last_cloth.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/last_cloth.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/last_cloth.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/last_vl.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/last_vl.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/lslt.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/lslt.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/lslt.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/lslt.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/lslt.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/lslt.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/ltcc.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/ltcc.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/ltcc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/ltcc.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/ltcc_mask.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/ltcc_mask.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/market1501.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/market1501.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/market1501.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/market1501.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/market1501.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/market1501.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/market_triplet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/market_triplet.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/market_triplet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/market_triplet.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/market_triplet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/market_triplet.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/msmt17.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/msmt17.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/msmt17.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/msmt17.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/msmt17.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/msmt17.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/night.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/night.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/night.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/night.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/night.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/night.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/prcc.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/prcc.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/prcc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/prcc.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/prcc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/prcc.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/prcc_abc.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/prcc_abc.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/prcc_abc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/prcc_abc.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/prcc_c.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/prcc_c.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/prcc_gcn.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/prcc_gcn.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/prcc_gcn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/prcc_gcn.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/prcc_gcn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/datasets/__pycache__/prcc_gcn.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/bases.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import numpy as np 8 | 9 | 10 | class BaseDataset(object): 11 | """ 12 | Base class of reid dataset 13 | """ 14 | 15 | def get_imagedata_info(self, data): 16 | pids, cams = [], [] 17 | for _, pid, camid in data: 18 | pids += [pid] 19 | cams += [camid] 20 | pids = set(pids) 21 | cams = set(cams) 22 | num_pids = len(pids) 23 | num_cams = len(cams) 24 | num_imgs = len(data) 25 | return num_pids, num_imgs, num_cams 26 | 27 | def get_videodata_info(self, data, return_tracklet_stats=False): 28 | pids, cams, tracklet_stats = [], [], [] 29 | for img_paths, pid, camid in data: 30 | pids += [pid] 31 | cams += [camid] 32 | tracklet_stats += [len(img_paths)] 33 | pids = set(pids) 34 | cams = set(cams) 35 | num_pids = len(pids) 36 | num_cams = len(cams) 37 | num_tracklets = len(data) 38 | if return_tracklet_stats: 39 | return num_pids, num_tracklets, num_cams, tracklet_stats 40 | return num_pids, num_tracklets, num_cams 41 | 42 | # def print_dataset_statistics(self): 43 | # raise NotImplementedError 44 | 45 | 46 | class BaseImageDataset(BaseDataset): 47 | """ 48 | Base class of image reid dataset 49 | """ 50 | 51 | def print_dataset_statistics(self, train, query, gallery): 52 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) 53 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) 54 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) 55 | 56 | print("Dataset statistics:") 57 | print(" ----------------------------------------") 58 | print(" subset | # ids | # images | # cameras") 59 | print(" ----------------------------------------") 60 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 61 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 62 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 63 | print(" ----------------------------------------") 64 | 65 | 66 | class BaseVideoDataset(BaseDataset): 67 | """ 68 | Base class of video reid dataset 69 | """ 70 | 71 | def print_dataset_statistics(self, train, val, query, gallery): 72 | num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_stats = \ 73 | self.get_videodata_info(train, return_tracklet_stats=True) 74 | 75 | num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_stats = \ 76 | self.get_videodata_info(query, return_tracklet_stats=True) 77 | 78 | num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_stats = \ 79 | self.get_videodata_info(gallery, return_tracklet_stats=True) 80 | 81 | tracklet_stats = train_tracklet_stats + query_tracklet_stats + gallery_tracklet_stats 82 | min_num = np.min(tracklet_stats) 83 | max_num = np.max(tracklet_stats) 84 | avg_num = np.mean(tracklet_stats) 85 | 86 | print("Dataset statistics:") 87 | print(" -------------------------------------------") 88 | print(" subset | # ids | # tracklets | # cameras") 89 | print(" -------------------------------------------") 90 | print(" train | {:5d} | {:11d} | {:9d}".format(num_train_pids, num_train_tracklets, num_train_cams)) 91 | print(" query | {:5d} | {:11d} | {:9d}".format(num_query_pids, num_query_tracklets, num_query_cams)) 92 | print(" gallery | {:5d} | {:11d} | {:9d}".format(num_gallery_pids, num_gallery_tracklets, num_gallery_cams)) 93 | print(" -------------------------------------------") 94 | print(" number of images per tracklet: {} ~ {}, average {:.2f}".format(min_num, max_num, avg_num)) 95 | print(" -------------------------------------------") 96 | -------------------------------------------------------------------------------- /data/datasets/celeba.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import glob 9 | import os.path as osp 10 | import numpy as np 11 | 12 | 13 | class BaseDataset(object): 14 | """ 15 | Base class of reid dataset 16 | """ 17 | 18 | def get_imagedata_info(self, data): 19 | pids, cams = [], [] 20 | for _, pid, camid in data: 21 | pids += [pid] 22 | cams += [camid] 23 | pids = set(pids) 24 | cams = set(cams) 25 | num_pids = len(pids) 26 | num_cams = len(cams) 27 | num_imgs = len(data) 28 | return num_pids, num_imgs, num_cams 29 | 30 | def get_videodata_info(self, data, return_tracklet_info=False): 31 | pids, cams, tracklet_info = [], [], [] 32 | for img_paths, pid, camid in data: 33 | pids += [pid] 34 | cams += [camid] 35 | tracklet_info += [len(img_paths)] 36 | pids = set(pids) 37 | cams = set(cams) 38 | num_pids = len(pids) 39 | num_cams = len(cams) 40 | num_tracklets = len(data) 41 | if return_tracklet_info: 42 | return num_pids, num_tracklets, num_cams, tracklet_info 43 | return num_pids, num_tracklets, num_cams 44 | 45 | 46 | 47 | 48 | class CELEBA(BaseDataset): 49 | """ 50 | -------------------------------------- 51 | subset | # ids | # images 52 | -------------------------------------- 53 | train | 632 | 20208 54 | query | 420 | 2972 55 | gallery | 420 | 11006 56 | """ 57 | dataset_dir = '' 58 | msk_dir = 'mask_6' 59 | 60 | def __init__(self, root='data', verbose=True, **kwargs): 61 | super(CELEBA, self).__init__() 62 | self.dataset_dir = osp.join(root, self.dataset_dir) 63 | self.train_dir = osp.join(self.dataset_dir, 'train') 64 | self.query_dir = osp.join(self.dataset_dir, 'query') 65 | self.gallery_dir = osp.join(self.dataset_dir, 'gallery') 66 | 67 | self._check_before_run() 68 | 69 | self.pid2label = self.get_pid2label(self.train_dir) 70 | self.train = self._process_dir(self.train_dir, pid2label=self.pid2label, relabel=True) # 13081 71 | self.query = self._process_dir(self.query_dir, relabel=False) # 484 72 | self.gallery = self._process_dir(self.gallery_dir, relabel=False) 73 | 74 | if verbose: 75 | print("=> CELEBA loaded") 76 | self.print_dataset_statistics(self.train, self.query, self.gallery) 77 | 78 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 79 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 80 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 81 | 82 | 83 | def get_pid2label(self, dir_path): 84 | images = os.listdir(dir_path) 85 | persons = [int(img.split('_')[0]) for img in images] 86 | pid_container = np.sort(list(set(persons))) 87 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 88 | return pid2label 89 | 90 | def _check_before_run(self): 91 | """Check if all files are available before going deeper""" 92 | if not osp.exists(self.dataset_dir): 93 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 94 | if not osp.exists(self.train_dir): 95 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 96 | if not osp.exists(self.query_dir): 97 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 98 | if not osp.exists(self.gallery_dir): 99 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 100 | 101 | def _process_dir(self, dir_path, pid2label=None, relabel=False): 102 | images = os.listdir(dir_path) 103 | dataset = [] 104 | cid = 0 105 | for img in images: 106 | pid_s = int(img.split('_')[0]) 107 | if relabel and pid2label is not None: 108 | pid = pid2label[pid_s] 109 | else: 110 | pid = int(pid_s) 111 | img_path = os.path.join(dir_path, img) 112 | dataset.append((img_path, pid, cid)) 113 | cid += 1 114 | return dataset 115 | 116 | 117 | def print_dataset_statistics(self, train, query, gallery): 118 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) 119 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) 120 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) 121 | 122 | print("Dataset statistics:") 123 | print(" --------------------------------------") 124 | print(" subset | # ids | # images") 125 | print(" --------------------------------------") 126 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_imgs)) 127 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs)) 128 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs)) 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /data/datasets/celeba_msk.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import glob 9 | import os.path as osp 10 | import numpy as np 11 | 12 | 13 | class BaseDataset(object): 14 | """ 15 | Base class of reid dataset 16 | """ 17 | 18 | def get_imagedata_info(self, data): 19 | pids, cams = [], [] 20 | for _, pid, camid, _ in data: 21 | pids += [pid] 22 | cams += [camid] 23 | pids = set(pids) 24 | cams = set(cams) 25 | num_pids = len(pids) 26 | num_cams = len(cams) 27 | num_imgs = len(data) 28 | return num_pids, num_imgs, num_cams 29 | 30 | def get_videodata_info(self, data, return_tracklet_info=False): 31 | pids, cams, tracklet_info = [], [], [] 32 | for img_paths, pid, camid in data: 33 | pids += [pid] 34 | cams += [camid] 35 | tracklet_info += [len(img_paths)] 36 | pids = set(pids) 37 | cams = set(cams) 38 | num_pids = len(pids) 39 | num_cams = len(cams) 40 | num_tracklets = len(data) 41 | if return_tracklet_info: 42 | return num_pids, num_tracklets, num_cams, tracklet_info 43 | return num_pids, num_tracklets, num_cams 44 | 45 | def print_dataset_statistics(self): 46 | raise NotImplementedError 47 | 48 | 49 | 50 | class CELEBA_MSK(BaseDataset): 51 | """ 52 | -------------------------------------- 53 | subset | # ids | # images 54 | -------------------------------------- 55 | train | 632 | 20208 56 | query | 420 | 2972 57 | gallery | 420 | 11006 58 | """ 59 | dataset_dir = '' 60 | msk_dir = 'mask_6' 61 | 62 | def __init__(self, root='data', verbose=True, **kwargs): 63 | super(CELEBA_MSK, self).__init__() 64 | self.dataset_dir = osp.join(root, self.dataset_dir) 65 | self.train_dir = osp.join(self.dataset_dir, 'train') 66 | self.query_dir = osp.join(self.dataset_dir, 'query') 67 | self.gallery_dir = osp.join(self.dataset_dir, 'gallery') 68 | 69 | self.mask_dir = osp.join(self.dataset_dir, self.msk_dir) 70 | 71 | self._check_before_run() 72 | 73 | self.pid2label = self.get_pid2label(self.train_dir) 74 | self.train = self._process_dir(self.train_dir, os.path.join(self.mask_dir, 'train'), pid2label=self.pid2label, relabel=True) # 13081 75 | self.query = self._process_dir(self.query_dir, os.path.join(self.mask_dir, 'query'), relabel=False) # 484 76 | self.gallery = self._process_dir(self.gallery_dir, os.path.join(self.mask_dir, 'gallery'), relabel=False) 77 | 78 | if verbose: 79 | print("=> CELEBA loaded") 80 | self.print_dataset_statistics(self.train, self.query, self.gallery) 81 | 82 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 83 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 84 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 85 | 86 | 87 | def get_pid2label(self, dir_path): 88 | images = os.listdir(dir_path) 89 | persons = [int(img.split('_')[0]) for img in images] 90 | pid_container = np.sort(list(set(persons))) 91 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 92 | return pid2label 93 | 94 | def _check_before_run(self): 95 | """Check if all files are available before going deeper""" 96 | if not osp.exists(self.dataset_dir): 97 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 98 | if not osp.exists(self.train_dir): 99 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 100 | if not osp.exists(self.query_dir): 101 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 102 | if not osp.exists(self.gallery_dir): 103 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 104 | 105 | def _process_dir(self, dir_path, mask_dir, pid2label=None, relabel=False): 106 | images = os.listdir(dir_path) 107 | dataset = [] 108 | cid = 0 109 | for img in images: 110 | pid_s = int(img.split('_')[0]) 111 | if relabel and pid2label is not None: 112 | pid = pid2label[pid_s] 113 | else: 114 | pid = int(pid_s) 115 | img_path = os.path.join(dir_path, img) 116 | name = img.split('.')[0] + '.npy' 117 | msk_path = os.path.join(mask_dir, name) 118 | if not os.path.exists(msk_path): 119 | continue 120 | dataset.append((img_path, pid, cid, msk_path)) 121 | cid += 1 122 | return dataset 123 | 124 | 125 | def print_dataset_statistics(self, train, query, gallery): 126 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) 127 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) 128 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) 129 | 130 | print("Dataset statistics:") 131 | print(" --------------------------------------") 132 | print(" subset | # ids | # images") 133 | print(" --------------------------------------") 134 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_imgs)) 135 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs)) 136 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs)) 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /data/datasets/dukemtmcreid.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import glob 8 | import re 9 | import urllib 10 | import zipfile 11 | 12 | import os.path as osp 13 | 14 | from utils.iotools import mkdir_if_missing 15 | from .bases import BaseImageDataset 16 | 17 | 18 | class DukeMTMCreID(BaseImageDataset): 19 | """ 20 | DukeMTMC-reID 21 | Reference: 22 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 23 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 24 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation 25 | 26 | Dataset statistics: 27 | # identities: 1404 (train + query) 28 | # images:16522 (train) + 2228 (query) + 17661 (gallery) 29 | # cameras: 8 30 | """ 31 | dataset_dir = 'DukeMTMC-reID' 32 | 33 | def __init__(self, root='/home/haoluo/data', verbose=True, **kwargs): 34 | super(DukeMTMCreID, self).__init__() 35 | self.dataset_dir = osp.join(root, self.dataset_dir) 36 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 37 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 38 | self.query_dir = osp.join(self.dataset_dir, 'query') 39 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 40 | 41 | self._download_data() 42 | self._check_before_run() 43 | 44 | train = self._process_dir(self.train_dir, relabel=True) 45 | query = self._process_dir(self.query_dir, relabel=False) 46 | gallery = self._process_dir(self.gallery_dir, relabel=False) 47 | 48 | if verbose: 49 | print("=> DukeMTMC-reID loaded") 50 | self.print_dataset_statistics(train, query, gallery) 51 | 52 | self.train = train 53 | self.query = query 54 | self.gallery = gallery 55 | 56 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 57 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 58 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 59 | 60 | def _download_data(self): 61 | if osp.exists(self.dataset_dir): 62 | print("This dataset has been downloaded.") 63 | return 64 | 65 | print("Creating directory {}".format(self.dataset_dir)) 66 | mkdir_if_missing(self.dataset_dir) 67 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 68 | 69 | print("Downloading DukeMTMC-reID dataset") 70 | urllib.request.urlretrieve(self.dataset_url, fpath) 71 | 72 | print("Extracting files") 73 | zip_ref = zipfile.ZipFile(fpath, 'r') 74 | zip_ref.extractall(self.dataset_dir) 75 | zip_ref.close() 76 | 77 | def _check_before_run(self): 78 | """Check if all files are available before going deeper""" 79 | if not osp.exists(self.dataset_dir): 80 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 81 | if not osp.exists(self.train_dir): 82 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 83 | if not osp.exists(self.query_dir): 84 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 85 | if not osp.exists(self.gallery_dir): 86 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 87 | 88 | def _process_dir(self, dir_path, relabel=False): 89 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 90 | pattern = re.compile(r'([-\d]+)_c(\d)') 91 | 92 | pid_container = set() 93 | for img_path in img_paths: 94 | pid, _ = map(int, pattern.search(img_path).groups()) 95 | pid_container.add(pid) 96 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 97 | 98 | dataset = [] 99 | for img_path in img_paths: 100 | pid, camid = map(int, pattern.search(img_path).groups()) 101 | assert 1 <= camid <= 8 102 | camid -= 1 # index starts from 0 103 | if relabel: pid = pid2label[pid] 104 | dataset.append((img_path, pid, camid)) 105 | 106 | return dataset 107 | 108 | 109 | 110 | 111 | # Dataset statistics: 112 | # ---------------------------------------- 113 | # subset | # ids | # images | # cameras 114 | # ---------------------------------------- 115 | # train | 702 | 16522 | 8 116 | # query | 702 | 2228 | 8 117 | # gallery | 1110 | 17661 | 8 118 | # ---------------------------------------- -------------------------------------------------------------------------------- /data/datasets/market1501.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import glob 8 | import re 9 | 10 | import os.path as osp 11 | import numpy as np 12 | 13 | from .bases import BaseImageDataset 14 | 15 | 16 | class Market1501(BaseImageDataset): 17 | """ 18 | Market1501 19 | Reference: 20 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 21 | URL: http://www.liangzheng.org/Project/project_reid.html 22 | 23 | Dataset statistics: 24 | # identities: 1501 (+1 for background) 25 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 26 | """ 27 | dataset_dir = 'market1501' 28 | 29 | def __init__(self, root='/home/haoluo/data', verbose=True, **kwargs): 30 | super(Market1501, self).__init__() 31 | self.dataset_dir = osp.join(root, self.dataset_dir) 32 | self.train_dir = osp.join(self.dataset_dir, 'train') 33 | self.val_dir = osp.join(self.dataset_dir, 'val') 34 | self.query_dir = osp.join(self.dataset_dir, 'query') 35 | self.gallery_dir = osp.join(self.dataset_dir, 'gallery') 36 | 37 | self._check_before_run() 38 | 39 | pid2label = self.get_pid2label(self.train_dir) 40 | train = self._process_dir(self.train_dir, pid2label=pid2label, relabel=True) 41 | val = self._process_dir(self.val_dir, pid2label=pid2label, relabel=True) 42 | query = self._process_dir(self.query_dir, relabel=False) 43 | gallery = self._process_dir(self.gallery_dir, relabel=False) 44 | 45 | if verbose: 46 | print("=> Market1501 loaded") 47 | self.print_dataset_statistics(train, query, gallery) 48 | 49 | self.train = train 50 | self.val = val 51 | self.query = query 52 | self.gallery = gallery 53 | 54 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 55 | self.num_val_pids, self.num_val_imgs, self.num_val_cams = self.get_imagedata_info(self.val) 56 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 57 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 58 | 59 | def get_pid2label(self, dir_path): 60 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 61 | pattern = re.compile(r'([-\d]+)_c(\d)') 62 | 63 | pid_container = set() 64 | for img_path in img_paths: 65 | pid, _ = map(int, pattern.search(img_path).groups()) 66 | if pid == -1: 67 | continue # junk images are just ignored 68 | pid_container.add(pid) 69 | pid_container = np.sort(list(pid_container)) 70 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 71 | return pid2label 72 | 73 | def _check_before_run(self): 74 | """Check if all files are available before going deeper""" 75 | if not osp.exists(self.dataset_dir): 76 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 77 | if not osp.exists(self.train_dir): 78 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 79 | if not osp.exists(self.query_dir): 80 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 81 | if not osp.exists(self.gallery_dir): 82 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 83 | 84 | def _process_dir(self, dir_path, pid2label=None, relabel=False): 85 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 86 | pattern = re.compile(r'([-\d]+)_c(\d)') 87 | 88 | # if relabel is True: 89 | # pid_container = set() 90 | # for img_path in img_paths: 91 | # pid, _ = map(int, pattern.search(img_path).groups()) 92 | # if pid == -1: continue # junk images are just ignored 93 | # pid_container.add(pid) 94 | # pid_container = np.sort(list(pid_container)) 95 | # pid2label = {pid: label for label, pid in enumerate(pid_container)} 96 | 97 | dataset = [] 98 | for img_path in img_paths: 99 | pid, camid = map(int, pattern.search(img_path).groups()) 100 | if pid == -1: 101 | continue # junk images are just ignored 102 | assert 0 <= pid <= 1501 # pid == 0 means background 103 | assert 1 <= camid <= 6 104 | camid -= 1 # index starts from 0 105 | if relabel and pid2label is not None: 106 | pid = pid2label[pid] 107 | dataset.append((img_path, pid, camid)) 108 | 109 | return dataset 110 | -------------------------------------------------------------------------------- /data/datasets/market_triplet.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import glob 8 | import re 9 | 10 | import os.path as osp 11 | import numpy as np 12 | 13 | from .bases import BaseImageDataset 14 | 15 | 16 | class MarketTriplet(BaseImageDataset): 17 | """ 18 | Market1501 19 | Reference: 20 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 21 | URL: http://www.liangzheng.org/Project/project_reid.html 22 | 23 | Dataset statistics: 24 | # identities: 1501 (+1 for background) 25 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 26 | """ 27 | dataset_dir = 'market1501' 28 | 29 | def __init__(self, root='/home/haoluo/data', verbose=True, **kwargs): 30 | super(MarketTriplet, self).__init__() 31 | self.dataset_dir = osp.join(root, self.dataset_dir) 32 | self.train_dir = osp.join(self.dataset_dir, 'train') 33 | self.val_dir = osp.join(self.dataset_dir, 'val') 34 | self.query_dir = osp.join(self.dataset_dir, 'query') 35 | self.gallery_dir = osp.join(self.dataset_dir, 'gallery') 36 | 37 | self.mask_train_dir = osp.join(self.dataset_dir, 'mask', 'train') 38 | self.mask_val_dir = osp.join(self.dataset_dir, 'mask', 'val') 39 | self.mask_query_dir = osp.join(self.dataset_dir, 'mask', 'query') 40 | self.mask_gallery_dir = osp.join(self.dataset_dir, 'mask', 'gallery') 41 | 42 | self._check_before_run() 43 | 44 | pid2label = self.get_pid2label(self.train_dir) 45 | train = self._process_dir(self.train_dir, pid2label=pid2label, relabel=True) 46 | val = self._process_dir(self.val_dir, pid2label=pid2label, relabel=True) 47 | query = self._process_dir(self.query_dir, relabel=False) 48 | gallery = self._process_dir(self.gallery_dir, relabel=False) 49 | 50 | self.mask_train = self._process_dir_mask(self.mask_train_dir, pid2label=pid2label, relabel=True) 51 | self.mask_val = self._process_dir_mask(self.mask_val_dir, pid2label=pid2label, relabel=True) 52 | self.mask_query = self._process_dir_mask(self.mask_query_dir, relabel=False) 53 | self.mask_gallery = self._process_dir_mask(self.mask_gallery_dir, relabel=False) 54 | 55 | if verbose: 56 | print("=> Market1501 loaded") 57 | self.print_dataset_statistics(train, query, gallery) 58 | 59 | self.train = train 60 | self.val = val 61 | self.query = query 62 | self.gallery = gallery 63 | 64 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 65 | self.num_val_pids, self.num_val_imgs, self.num_val_cams = self.get_imagedata_info(self.val) 66 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 67 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 68 | 69 | def get_pid2label(self, dir_path): 70 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 71 | pattern = re.compile(r'([-\d]+)_c(\d)') 72 | 73 | pid_container = set() 74 | for img_path in img_paths: 75 | pid, _ = map(int, pattern.search(img_path).groups()) 76 | if pid == -1: 77 | continue # junk images are just ignored 78 | pid_container.add(pid) 79 | pid_container = np.sort(list(pid_container)) 80 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 81 | return pid2label 82 | 83 | def _check_before_run(self): 84 | """Check if all files are available before going deeper""" 85 | if not osp.exists(self.dataset_dir): 86 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 87 | if not osp.exists(self.train_dir): 88 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 89 | if not osp.exists(self.query_dir): 90 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 91 | if not osp.exists(self.gallery_dir): 92 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 93 | 94 | def _process_dir(self, dir_path, pid2label=None, relabel=False): 95 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 96 | pattern = re.compile(r'([-\d]+)_c(\d)') 97 | 98 | # if relabel is True: 99 | # pid_container = set() 100 | # for img_path in img_paths: 101 | # pid, _ = map(int, pattern.search(img_path).groups()) 102 | # if pid == -1: continue # junk images are just ignored 103 | # pid_container.add(pid) 104 | # pid_container = np.sort(list(pid_container)) 105 | # pid2label = {pid: label for label, pid in enumerate(pid_container)} 106 | 107 | dataset = [] 108 | for img_path in img_paths: 109 | pid, camid = map(int, pattern.search(img_path).groups()) 110 | if pid == -1: 111 | continue # junk images are just ignored 112 | assert 0 <= pid <= 1501 # pid == 0 means background 113 | assert 1 <= camid <= 6 114 | camid -= 1 # index starts from 0 115 | if relabel and pid2label is not None: 116 | pid = pid2label[pid] 117 | dataset.append((img_path, pid, camid)) 118 | 119 | return dataset 120 | 121 | def _process_dir_mask(self, dir_path, pid2label=None, relabel=False): 122 | img_paths = glob.glob(osp.join(dir_path, '*.png')) 123 | pattern = re.compile(r'([-\d]+)_c(\d)') 124 | 125 | dataset = [] 126 | for img_path in img_paths: 127 | pid, camid = map(int, pattern.search(img_path).groups()) 128 | if pid == -1: 129 | continue # junk images are just ignored 130 | assert 0 <= pid <= 1501 # pid == 0 means background 131 | assert 1 <= camid <= 6 132 | camid -= 1 # index starts from 0 133 | if relabel and pid2label is not None: 134 | pid = pid2label[pid] 135 | dataset.append((img_path, pid, camid)) 136 | 137 | return dataset 138 | -------------------------------------------------------------------------------- /data/datasets/msmt17.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/1/17 15:00 4 | # @Author : Hao Luo 5 | # @File : msmt17.py 6 | 7 | import glob 8 | import re 9 | 10 | import os.path as osp 11 | 12 | from .bases import BaseImageDataset 13 | 14 | 15 | class MSMT17(BaseImageDataset): 16 | """ 17 | MSMT17 18 | 19 | Reference: 20 | Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018. 21 | 22 | URL: http://www.pkuvmc.com/publications/msmt17.html 23 | 24 | Dataset statistics: 25 | # identities: 4101 26 | # images: 32621 (train) + 11659 (query) + 82161 (gallery) 27 | # cameras: 15 28 | """ 29 | dataset_dir = 'MSMT17_V1' 30 | 31 | def __init__(self, root='/home/haoluo/data', verbose=True, **kwargs): 32 | super(MSMT17, self).__init__() 33 | self.dataset_dir = osp.join(root, self.dataset_dir) 34 | self.train_dir = osp.join(self.dataset_dir, 'train') 35 | self.test_dir = osp.join(self.dataset_dir, 'test') 36 | self.list_train_path = osp.join(self.dataset_dir, 'list_train.txt') 37 | self.list_val_path = osp.join(self.dataset_dir, 'list_val.txt') 38 | self.list_query_path = osp.join(self.dataset_dir, 'list_query.txt') 39 | self.list_gallery_path = osp.join(self.dataset_dir, 'list_gallery.txt') 40 | 41 | self._check_before_run() 42 | train = self._process_dir(self.train_dir, self.list_train_path) # 30248 43 | # val, num_val_pids, num_val_imgs = self._process_dir(self.train_dir, self.list_val_path) 44 | query = self._process_dir(self.test_dir, self.list_query_path) # 11659 45 | gallery = self._process_dir(self.test_dir, self.list_gallery_path) # 82161 46 | if verbose: 47 | print("=> MSMT17 loaded") 48 | self.print_dataset_statistics(train, query, gallery) 49 | 50 | self.train = train # 30248 51 | self.query = query # 11659 52 | self.gallery = gallery # 82161 53 | 54 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 55 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 56 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 57 | 58 | def _check_before_run(self): 59 | """Check if all files are available before going deeper""" 60 | if not osp.exists(self.dataset_dir): 61 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 62 | if not osp.exists(self.train_dir): 63 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 64 | if not osp.exists(self.test_dir): 65 | raise RuntimeError("'{}' is not available".format(self.test_dir)) 66 | 67 | def _process_dir(self, dir_path, list_path): 68 | with open(list_path, 'r') as txt: 69 | lines = txt.readlines() 70 | dataset = [] 71 | pid_container = set() 72 | for img_idx, img_info in enumerate(lines): 73 | img_path, pid = img_info.split(' ') 74 | pid = int(pid) # no need to relabel 75 | camid = int(img_path.split('_')[2]) 76 | img_path = osp.join(dir_path, img_path) 77 | dataset.append((img_path, pid, camid)) 78 | pid_container.add(pid) 79 | 80 | # check if pid starts from 0 and increments with 1 81 | for idx, pid in enumerate(pid_container): 82 | assert idx == pid, "See code comment for explanation" 83 | return dataset 84 | 85 | 86 | # Dataset statistics: 87 | # ---------------------------------------- 88 | # subset | # ids | # images | # cameras 89 | # ---------------------------------------- 90 | # train | 1041 | 30248 | 15 91 | # query | 3060 | 11659 | 15 92 | # gallery | 3060 | 82161 | 15 93 | # ---------------------------------------- -------------------------------------------------------------------------------- /data/datasets/prcc.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import glob 9 | import os.path as osp 10 | import numpy as np 11 | 12 | 13 | class BaseDataset(object): 14 | """ 15 | Base class of reid dataset 16 | """ 17 | 18 | def get_imagedata_info(self, data): 19 | pids, cams = [], [] 20 | for _, pid, camid in data: 21 | pids += [pid] 22 | cams += [camid] 23 | pids = set(pids) 24 | cams = set(cams) 25 | num_pids = len(pids) 26 | num_cams = len(cams) 27 | num_imgs = len(data) 28 | return num_pids, num_imgs, num_cams 29 | 30 | def get_videodata_info(self, data, return_tracklet_info=False): 31 | pids, cams, tracklet_info = [], [], [] 32 | for img_paths, pid, camid in data: 33 | pids += [pid] 34 | cams += [camid] 35 | tracklet_info += [len(img_paths)] 36 | pids = set(pids) 37 | cams = set(cams) 38 | num_pids = len(pids) 39 | num_cams = len(cams) 40 | num_tracklets = len(data) 41 | if return_tracklet_info: 42 | return num_pids, num_tracklets, num_cams, tracklet_info 43 | return num_pids, num_tracklets, num_cams 44 | 45 | def print_dataset_statistics(self): 46 | raise NotImplementedError 47 | 48 | 49 | 50 | class PRCC(BaseDataset): 51 | """ 52 | -------------------------------------- 53 | subset | # ids | # images 54 | -------------------------------------- 55 | train | 150 | 17896 56 | query | 71 | 213 57 | gallery | 71 | 10587 58 | """ 59 | dataset_dir = 'modify' 60 | cam2label = {'A': 0, 'B': 1, 'C': 2} 61 | 62 | def __init__(self, root='data', verbose=True, **kwargs): 63 | super(PRCC, self).__init__() 64 | self.dataset_dir = osp.join(root, self.dataset_dir) 65 | self.train_dir = osp.join(self.dataset_dir, 'train') 66 | self.query_dir = osp.join(self.dataset_dir, 'test', 'query') 67 | self.gallery_dir = osp.join(self.dataset_dir, 'test', 'gallery') 68 | 69 | self._check_before_run() 70 | 71 | self.pid2label = self.get_pid2label(self.train_dir) 72 | self.train = self._process_dir(self.train_dir, pid2label=self.pid2label, relabel=True) # 13081 73 | self.query = self._process_dir(self.query_dir, relabel=False) # 484 74 | self.gallery = self._process_dir(self.gallery_dir, relabel=False) 75 | 76 | if verbose: 77 | print("=> PRCC loaded") 78 | self.print_dataset_statistics_movie(self.train, self.query, self.gallery) 79 | 80 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 81 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 82 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 83 | 84 | 85 | def get_pid2label(self, dir_path): 86 | persons = os.listdir(dir_path) 87 | pid_container = np.sort(list(set(persons))) 88 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 89 | return pid2label 90 | 91 | def _check_before_run(self): 92 | """Check if all files are available before going deeper""" 93 | if not osp.exists(self.dataset_dir): 94 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 95 | if not osp.exists(self.train_dir): 96 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 97 | if not osp.exists(self.query_dir): 98 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 99 | if not osp.exists(self.gallery_dir): 100 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 101 | 102 | def _process_dir(self, dir_path, pid2label=None, relabel=False): 103 | persons = os.listdir(dir_path) 104 | dataset = [] 105 | for pid_s in persons: 106 | path_p = os.path.join(dir_path, pid_s) 107 | files = os.listdir(path_p) 108 | for file in files: 109 | cid = file.split('_')[0] 110 | cid = self.cam2label[cid] 111 | if relabel and pid2label is not None: 112 | pid = pid2label[pid_s] 113 | else: 114 | pid = int(pid_s) 115 | img_path = os.path.join(dir_path, pid_s, file) 116 | dataset.append((img_path, pid, cid)) 117 | return dataset 118 | 119 | 120 | def print_dataset_statistics_movie(self, train, query, gallery): 121 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) 122 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) 123 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) 124 | 125 | print("Dataset statistics:") 126 | print(" --------------------------------------") 127 | print(" subset | # ids | # images") 128 | print(" --------------------------------------") 129 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_imgs)) 130 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs)) 131 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs)) 132 | 133 | 134 | 135 | -------------------------------------------------------------------------------- /data/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .triplet_sampler import RandomIdentitySampler, RandomIdentitySampler_alignedreid, RandomIdentitySamplerAdv, RandomIdentitySamplerStar, RandomIdentitySamplerGcn 8 | from .triplet_sampler import RandomIdentitySamplerVL 9 | -------------------------------------------------------------------------------- /data/samplers/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/samplers/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /data/samplers/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/samplers/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/samplers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/samplers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /data/samplers/__pycache__/triplet_sampler.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/samplers/__pycache__/triplet_sampler.cpython-35.pyc -------------------------------------------------------------------------------- /data/samplers/__pycache__/triplet_sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/samplers/__pycache__/triplet_sampler.cpython-36.pyc -------------------------------------------------------------------------------- /data/samplers/__pycache__/triplet_sampler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/samplers/__pycache__/triplet_sampler.cpython-37.pyc -------------------------------------------------------------------------------- /data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .build import build_transforms, build_transforms_bap, build_transforms_visual, build_transforms_no_erase, build_transforms_usize, build_transforms_open, build_transforms_swap, build_transforms_head, build_transforms_eraser 8 | from .build import build_transforms_hist, build_transforms_base 9 | 10 | 11 | -------------------------------------------------------------------------------- /data/transforms/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/transforms/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/transforms/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/transforms/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/build.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/transforms/__pycache__/build.cpython-35.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/build.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/transforms/__pycache__/build.cpython-36.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/build.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/transforms/__pycache__/build.cpython-37.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/transform.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/transforms/__pycache__/transform.cpython-35.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/transform.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/transforms/__pycache__/transform.cpython-36.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/transform.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/transforms/__pycache__/transform.cpython-37.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/transforms.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/transforms/__pycache__/transforms.cpython-35.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/transforms/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/data/transforms/__pycache__/transforms.cpython-37.pyc -------------------------------------------------------------------------------- /data/transforms/transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms.functional as F 3 | import random 4 | import numpy as np 5 | from PIL import Image 6 | import cv2 7 | """We expect a list `cfg.transform_list`. The types specified in this list 8 | will be applied sequentially. Each type name corresponds to a function name in 9 | this file, so you have to implement the function w.r.t. your custom type. 10 | The function head should be `FUNC_NAME(in_dict, cfg)`, and it should modify `in_dict` 11 | in place. 12 | The transform list allows us to apply optional transforms in any order, while custom 13 | functions allow us to perform sync transformation for images and all labels. 14 | """ 15 | 16 | 17 | def hflip(in_dict, cfg): 18 | # Tricky!! random.random() can not reproduce the score of np.random.random(), 19 | # dropping ~1% for both Market1501 and Duke GlobalPool. 20 | # if random.random() < 0.5: 21 | if np.random.random() < 0.5: 22 | in_dict['img'] = F.hflip(in_dict['img']) 23 | in_dict['mask'] = F.hflip(in_dict['mask']) 24 | 25 | 26 | def resize_3d_np_array(maps, resize_h_w, interpolation): # [9, 24, 8], [24, 8], 0 27 | """maps: np array with shape [C, H, W], dtype is not restricted""" 28 | return np.stack([cv2.resize(m, tuple(resize_h_w[::-1]), interpolation=interpolation) for m in maps]) 29 | 30 | 31 | # Resize image using cv2.resize() 32 | def resize(in_dict, cfg): 33 | in_dict['img'] = Image.fromarray(cv2.resize(np.array(in_dict['img']), (cfg.width, cfg.height), interpolation=cv2.INTER_LINEAR)) # [128, 64] -> [384, 128] 34 | in_dict['mask'] = Image.fromarray(cv2.resize(np.array(in_dict['mask']), (cfg.width_mask, cfg.height_mask), cv2.INTER_NEAREST), mode='L') 35 | 36 | 37 | def to_tensor(in_dict, mean=[0.486, 0.459, 0.408], std=[0.229, 0.224, 0.225]): 38 | in_dict['img'] = F.to_tensor(in_dict['img']) # [3, 256, 128] 39 | in_dict['img'] = F.normalize(in_dict['img'], mean, std) 40 | in_dict['mask'] = torch.from_numpy(np.array(in_dict['mask'])).long() # [48, 16] 41 | 42 | 43 | def to_tensor_mask(in_dict, mean=[0.486, 0.459, 0.408], std=[0.229, 0.224, 0.225]): 44 | in_dict['mask'] = torch.from_numpy(np.array(in_dict['mask'])).long() # [48, 16] 45 | 46 | def transform(in_dict, transform_list, cfg): 47 | for t in transform_list: 48 | eval('{}(in_dict, cfg)'.format(t)) 49 | to_tensor(in_dict) 50 | return in_dict 51 | 52 | 53 | def transform_mask(in_dict, transform_list, cfg): 54 | for t in transform_list: 55 | eval('{}(in_dict, cfg)'.format(t)) 56 | return in_dict 57 | 58 | -------------------------------------------------------------------------------- /data/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import math 8 | import random 9 | 10 | 11 | class RandomErasing(object): 12 | """ Randomly selects a rectangle region in an image and erases its pixels. 13 | 'Random Erasing Data Augmentation' by Zhong et al. 14 | See https://arxiv.org/pdf/1708.04896.pdf 15 | Args: 16 | probability: The probability that the Random Erasing operation will be performed. 17 | sl: Minimum proportion of erased area against input image. 18 | sh: Maximum proportion of erased area against input image. 19 | r1: Minimum aspect ratio of erased area. 20 | mean: Erasing value. 21 | """ 22 | 23 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 24 | self.probability = probability # 0.5 25 | self.mean = mean 26 | self.sl = sl # 0.02 27 | self.sh = sh # 0.4 28 | self.r1 = r1 # 0.3 29 | 30 | def __call__(self, img): # [3, 256, 128] 31 | 32 | if random.uniform(0, 1) >= self.probability: 33 | return img 34 | 35 | for attempt in range(100): 36 | area = img.size()[1] * img.size()[2] # 256 * 128 = 32768 37 | 38 | target_area = random.uniform(self.sl, self.sh) * area # 6675.79 39 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) # 2.1 40 | 41 | h = int(round(math.sqrt(target_area * aspect_ratio))) # 118 42 | w = int(round(math.sqrt(target_area / aspect_ratio))) # 56 43 | 44 | if w < img.size()[2] and h < img.size()[1]: 45 | x1 = random.randint(0, img.size()[1] - h) 46 | y1 = random.randint(0, img.size()[2] - w) 47 | if img.size()[0] == 3: 48 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 49 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 50 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 51 | else: 52 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 53 | return img 54 | 55 | return img 56 | 57 | 58 | 59 | 60 | 61 | 62 | class RandomSwap(object): 63 | """ Randomly selects a rectangle region in an image and erases its pixels. 64 | 'Random Erasing Data Augmentation' by Zhong et al. 65 | See https://arxiv.org/pdf/1708.04896.pdf 66 | Args: 67 | probability: The probability that the Random Erasing operation will be performed. 68 | sl: Minimum proportion of erased area against input image. 69 | sh: Maximum proportion of erased area against input image. 70 | r1: Minimum aspect ratio of erased area. 71 | mean: Erasing value. 72 | """ 73 | 74 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 75 | self.probability = probability # 0.5 76 | self.mean = mean 77 | self.sl = sl # 0.02 78 | self.sh = sh # 0.4 79 | self.r1 = r1 # 0.3 80 | 81 | def __call__(self, img, swap): # [3, 256, 128], [3, 256, 128] 82 | 83 | if random.uniform(0, 1) >= self.probability: 84 | return img 85 | 86 | for attempt in range(100): 87 | area = img.size()[1] * img.size()[2] # 256 * 128 = 32768 88 | 89 | target_area = random.uniform(self.sl, self.sh) * area # 6675.79 90 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) # 2.1 91 | 92 | h = int(round(math.sqrt(target_area * aspect_ratio))) # 118 93 | w = int(round(math.sqrt(target_area / aspect_ratio))) # 56 94 | 95 | if w < img.size()[2] and h < img.size()[1]: 96 | x1 = random.randint(0, img.size()[1] - h) 97 | y1 = random.randint(0, img.size()[2] - w) 98 | if img.size()[0] == 3: 99 | img[0, x1:x1 + h, y1:y1 + w] = swap[0, x1:x1 + h, y1:y1 + w] 100 | img[1, x1:x1 + h, y1:y1 + w] = swap[1, x1:x1 + h, y1:y1 + w] 101 | img[2, x1:x1 + h, y1:y1 + w] = swap[2, x1:x1 + h, y1:y1 + w] 102 | else: 103 | img[0, x1:x1 + h, y1:y1 + w] = swap[0, x1:x1 + h, y1:y1 + w] 104 | return img 105 | 106 | return img 107 | 108 | 109 | 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /engine/__pycache__/inference.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/engine/__pycache__/inference.cpython-35.pyc -------------------------------------------------------------------------------- /engine/__pycache__/inference.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/engine/__pycache__/inference.cpython-36.pyc -------------------------------------------------------------------------------- /engine/__pycache__/inference.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/engine/__pycache__/inference.cpython-37.pyc -------------------------------------------------------------------------------- /engine/__pycache__/trainer.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/engine/__pycache__/trainer.cpython-35.pyc -------------------------------------------------------------------------------- /engine/__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/engine/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /engine/__pycache__/trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/engine/__pycache__/trainer.cpython-37.pyc -------------------------------------------------------------------------------- /engine/__pycache__/trainer_sgap.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/engine/__pycache__/trainer_sgap.cpython-35.pyc -------------------------------------------------------------------------------- /engine/__pycache__/trainer_triplet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/engine/__pycache__/trainer_triplet.cpython-35.pyc -------------------------------------------------------------------------------- /layers/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /layers/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /layers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /layers/__pycache__/aligned_loss.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/aligned_loss.cpython-35.pyc -------------------------------------------------------------------------------- /layers/__pycache__/aligned_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/aligned_loss.cpython-36.pyc -------------------------------------------------------------------------------- /layers/__pycache__/aligned_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/aligned_loss.cpython-37.pyc -------------------------------------------------------------------------------- /layers/__pycache__/center_loss.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/center_loss.cpython-35.pyc -------------------------------------------------------------------------------- /layers/__pycache__/center_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/center_loss.cpython-36.pyc -------------------------------------------------------------------------------- /layers/__pycache__/center_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/center_loss.cpython-37.pyc -------------------------------------------------------------------------------- /layers/__pycache__/cluster_loss.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/cluster_loss.cpython-35.pyc -------------------------------------------------------------------------------- /layers/__pycache__/cluster_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/cluster_loss.cpython-36.pyc -------------------------------------------------------------------------------- /layers/__pycache__/cluster_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/cluster_loss.cpython-37.pyc -------------------------------------------------------------------------------- /layers/__pycache__/fast_ap.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/fast_ap.cpython-35.pyc -------------------------------------------------------------------------------- /layers/__pycache__/fast_ap.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/fast_ap.cpython-36.pyc -------------------------------------------------------------------------------- /layers/__pycache__/fast_ap.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/fast_ap.cpython-37.pyc -------------------------------------------------------------------------------- /layers/__pycache__/fast_ap_mem.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/fast_ap_mem.cpython-35.pyc -------------------------------------------------------------------------------- /layers/__pycache__/fast_ap_mem.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/fast_ap_mem.cpython-36.pyc -------------------------------------------------------------------------------- /layers/__pycache__/fast_ap_mem.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/fast_ap_mem.cpython-37.pyc -------------------------------------------------------------------------------- /layers/__pycache__/histogram.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/histogram.cpython-35.pyc -------------------------------------------------------------------------------- /layers/__pycache__/histogram.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/histogram.cpython-36.pyc -------------------------------------------------------------------------------- /layers/__pycache__/histogram.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/histogram.cpython-37.pyc -------------------------------------------------------------------------------- /layers/__pycache__/local_dist.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/local_dist.cpython-35.pyc -------------------------------------------------------------------------------- /layers/__pycache__/local_dist.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/local_dist.cpython-36.pyc -------------------------------------------------------------------------------- /layers/__pycache__/local_dist.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/local_dist.cpython-37.pyc -------------------------------------------------------------------------------- /layers/__pycache__/map_loss.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/map_loss.cpython-35.pyc -------------------------------------------------------------------------------- /layers/__pycache__/map_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/map_loss.cpython-36.pyc -------------------------------------------------------------------------------- /layers/__pycache__/map_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/map_loss.cpython-37.pyc -------------------------------------------------------------------------------- /layers/__pycache__/range_loss.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/range_loss.cpython-35.pyc -------------------------------------------------------------------------------- /layers/__pycache__/triplet_loss.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/triplet_loss.cpython-35.pyc -------------------------------------------------------------------------------- /layers/__pycache__/triplet_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/triplet_loss.cpython-36.pyc -------------------------------------------------------------------------------- /layers/__pycache__/triplet_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/layers/__pycache__/triplet_loss.cpython-37.pyc -------------------------------------------------------------------------------- /layers/center_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class CenterLoss(nn.Module): 8 | """Center loss. 9 | 10 | Reference: 11 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 12 | 13 | Args: 14 | num_classes (int): number of classes. 15 | feat_dim (int): feature dimension. 16 | """ 17 | 18 | def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True): 19 | super(CenterLoss, self).__init__() 20 | self.num_classes = num_classes # 751 21 | self.feat_dim = feat_dim # 2048 22 | self.use_gpu = use_gpu # False 23 | 24 | if self.use_gpu: 25 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 26 | else: 27 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) # [751, 2048] 28 | 29 | def forward(self, x, labels): # x->[64, 2048], labels->[64,] 30 | """ 31 | Args: 32 | x: feature matrix with shape (batch_size, feat_dim). 33 | labels: ground truth labels with shape (num_classes). 34 | """ 35 | assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)" 36 | 37 | batch_size = x.size(0) # 64 38 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 39 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() # [16, 751] 40 | distmat.addmm_(1, -2, x, self.centers.t()) # (x - center)^2 41 | 42 | classes = torch.arange(self.num_classes).long() # [751,] [0, 1, 2, ..., 750] 43 | if self.use_gpu: 44 | classes = classes.cuda() 45 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) # [16, 751] [[275,275,...275], [153,153,...,153], ...,] 46 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) # [16, 751] one_hot 47 | 48 | dist = [] 49 | for i in range(batch_size): 50 | value = distmat[i][mask[i]] # 2641.08 51 | value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 52 | dist.append(value) 53 | dist = torch.cat(dist) # [16,] 54 | loss = dist.mean() # 2740.9619 55 | return loss 56 | 57 | 58 | 59 | 60 | class CenterLossPart(nn.Module): 61 | """Center loss. 62 | 63 | Reference: 64 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 65 | 66 | Args: 67 | num_classes (int): number of classes. 68 | feat_dim (int): feature dimension. 69 | """ 70 | 71 | def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True): 72 | super(CenterLossPart, self).__init__() 73 | self.num_classes = num_classes # 751 74 | self.feat_dim = feat_dim # 2048 75 | self.use_gpu = use_gpu # False 76 | 77 | if self.use_gpu: 78 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 79 | else: 80 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) # [751, 2048] 81 | 82 | def forward(self, x, labels, visible): # x->[16, 2048], labels->[16,] 83 | """ 84 | Args: 85 | x: feature matrix with shape (batch_size, feat_dim). 86 | labels: ground truth labels with shape (num_classes). 87 | """ 88 | assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)" 89 | 90 | batch_size = x.size(0) # 16 91 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 92 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() # [16, 751] 93 | distmat.addmm_(1, -2, x, self.centers.t()) # (x - center)^2 94 | 95 | classes = torch.arange(self.num_classes).long() # [751,] [0, 1, 2, ..., 750] 96 | if self.use_gpu: 97 | classes = classes.cuda() 98 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) # [16, 751] [[275,275,...275], [153,153,...,153], ...,] 99 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) # [16, 751] one_hot 100 | 101 | dist = [] 102 | for i in range(batch_size): 103 | value = distmat[i][mask[i]] # 2641.08 104 | value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 105 | dist.append(value) 106 | dist = torch.cat(dist) # [16,] 107 | loss = (dist * visible).mean() / visible.mean() # 2740.9619 108 | return loss 109 | 110 | 111 | 112 | 113 | 114 | if __name__ == '__main__': 115 | use_gpu = False 116 | center_loss = CenterLoss(use_gpu=use_gpu) 117 | features = torch.rand(16, 2048) 118 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).long() 119 | if use_gpu: 120 | features = torch.rand(16, 2048).cuda() 121 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).cuda() 122 | 123 | loss = center_loss(features, targets) # to minimize the intra-class distance 124 | print(loss) 125 | 126 | 127 | 128 | 129 | 130 | -------------------------------------------------------------------------------- /layers/local_dist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def batch_euclidean_dist(x, y): 5 | """ 6 | Args: 7 | x: pytorch Variable, with shape [Batch size, Local part, Feature channel] 8 | y: pytorch Variable, with shape [Batch size, Local part, Feature channel] 9 | Returns: 10 | dist: pytorch Variable, with shape [Batch size, Local part, Local part] 11 | """ 12 | assert len(x.size()) == 3 13 | assert len(y.size()) == 3 14 | assert x.size(0) == y.size(0) 15 | assert x.size(-1) == y.size(-1) 16 | 17 | N, m, d = x.size() 18 | N, n, d = y.size() 19 | 20 | # shape [N, m, n] 21 | xx = torch.pow(x, 2).sum(-1, keepdim=True).expand(N, m, n) 22 | yy = torch.pow(y, 2).sum(-1, keepdim=True).expand(N, n, m).permute(0, 2, 1) 23 | dist = xx + yy 24 | dist.baddbmm_(1, -2, x, y.permute(0, 2, 1)) 25 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 26 | return dist 27 | 28 | 29 | def shortest_dist(dist_mat): 30 | """Parallel version. 31 | Args: 32 | dist_mat: pytorch Variable, available shape: 33 | 1) [m, n] 34 | 2) [m, n, N], N is batch size 35 | 3) [m, n, *], * can be arbitrary additional dimensions 36 | Returns: 37 | dist: three cases corresponding to `dist_mat`: 38 | 1) scalar 39 | 2) pytorch Variable, with shape [N] 40 | 3) pytorch Variable, with shape [*] 41 | """ 42 | m, n = dist_mat.size()[:2] 43 | # Just offering some reference for accessing intermediate distance. 44 | dist = [[0 for _ in range(n)] for _ in range(m)] 45 | for i in range(m): 46 | for j in range(n): 47 | if (i == 0) and (j == 0): 48 | dist[i][j] = dist_mat[i, j] 49 | elif (i == 0) and (j > 0): 50 | dist[i][j] = dist[i][j - 1] + dist_mat[i, j] 51 | elif (i > 0) and (j == 0): 52 | dist[i][j] = dist[i - 1][j] + dist_mat[i, j] 53 | else: 54 | dist[i][j] = torch.min(dist[i - 1][j], dist[i][j - 1]) + dist_mat[i, j] 55 | dist = dist[-1][-1] 56 | return dist 57 | 58 | 59 | def hard_example_mining(dist_mat, labels, return_inds=False): # [32, 32], [32,] 60 | """For each anchor, find the hardest positive and negative sample. 61 | Args: 62 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 63 | labels: pytorch LongTensor, with shape [N] 64 | return_inds: whether to return the indices. Save time if `False`(?) 65 | Returns: 66 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 67 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 68 | p_inds: pytorch LongTensor, with shape [N]; 69 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 70 | n_inds: pytorch LongTensor, with shape [N]; 71 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 72 | NOTE: Only consider the case in which all labels have same num of samples, 73 | thus we can cope with all anchors in parallel. 74 | """ 75 | 76 | assert len(dist_mat.size()) == 2 77 | assert dist_mat.size(0) == dist_mat.size(1) 78 | N = dist_mat.size(0) # 32 79 | 80 | # shape [N, N] 81 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) # [32, 32], torch.uint8 82 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) # [32, 32] 83 | 84 | # `dist_ap` means distance(anchor, positive) 85 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 86 | dist_ap, relative_p_inds = torch.max( 87 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) # [32, 1], [32, 1] 88 | # `dist_an` means distance(anchor, negative) 89 | # both `dist_an` and `relative_n_inds` with shape [N, 1] 90 | dist_an, relative_n_inds = torch.min( 91 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) # [32, 1], [32, 1] 92 | # shape [N] 93 | dist_ap = dist_ap.squeeze(1) # [32,] 94 | dist_an = dist_an.squeeze(1) # [32,] 95 | 96 | if return_inds: 97 | # shape [N, N] 98 | ind = (labels.new().resize_as_(labels) 99 | .copy_(torch.arange(0, N).long()) 100 | .unsqueeze(0).expand(N, N)) # [32, 32] 101 | # shape [N, 1] 102 | p_inds = torch.gather( 103 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 104 | n_inds = torch.gather( 105 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 106 | # shape [N] 107 | p_inds = p_inds.squeeze(1) 108 | n_inds = n_inds.squeeze(1) 109 | return dist_ap, dist_an, p_inds, n_inds # [32,], [32,], [32,], [32,] 110 | 111 | return dist_ap, dist_an 112 | 113 | 114 | def euclidean_dist(x, y): 115 | """ 116 | Args: 117 | x: pytorch Variable, with shape [m, d] 118 | y: pytorch Variable, with shape [n, d] 119 | Returns: 120 | dist: pytorch Variable, with shape [m, n] 121 | """ 122 | m, n = x.size(0), y.size(0) 123 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 124 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 125 | dist = xx + yy 126 | dist.addmm_(1, -2, x, y.t()) 127 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 128 | return dist 129 | 130 | 131 | def batch_local_dist(x, y): 132 | """ 133 | Args: 134 | x: pytorch Variable, with shape [N, m, d] 135 | y: pytorch Variable, with shape [N, n, d] 136 | Returns: 137 | dist: pytorch Variable, with shape [N] 138 | """ 139 | assert len(x.size()) == 3 140 | assert len(y.size()) == 3 141 | assert x.size(0) == y.size(0) 142 | assert x.size(-1) == y.size(-1) 143 | 144 | # shape [N, m, n] 145 | dist_mat = batch_euclidean_dist(x, y) # [32, 8, 8] 146 | dist_mat = (torch.exp(dist_mat) - 1.) / (torch.exp(dist_mat) + 1.) # [32, 8, 8] 147 | # shape [N] 148 | dist = shortest_dist(dist_mat.permute(1, 2, 0)) # [32,] 149 | return dist 150 | 151 | 152 | if __name__ == '__main__': 153 | x = torch.randn(32, 2048) 154 | y = torch.randn(32, 2048) 155 | dist_mat = euclidean_dist(x, y) 156 | dist_ap, dist_an, p_inds, n_inds = hard_example_mining(dist_mat, return_inds=True) 157 | from IPython import embed 158 | 159 | embed() 160 | -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | from .pcb import pcb_p6 4 | from .pcb_seg import pcb_global 5 | from .hpm import HPM 6 | from .mgn import MGN 7 | 8 | 9 | def build_model(num_classes=None, model_type='base'): 10 | if model_type == 'global': 11 | model = pcb_global(num_classes) 12 | if model_type == 'pcb': 13 | model = pcb_p6(num_classes) 14 | elif model_type == 'hpm': 15 | model = HPM(num_classes) 16 | elif model_type == 'mgn': 17 | model = MGN(num_classes) 18 | else: 19 | pass 20 | return model 21 | 22 | 23 | -------------------------------------------------------------------------------- /modeling/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/baseline.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/__pycache__/baseline.cpython-35.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/baseline.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/__pycache__/baseline.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/baseline.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/__pycache__/baseline.cpython-37.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/gcn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/__pycache__/gcn.cpython-37.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/hpm.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/__pycache__/hpm.cpython-35.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/hpm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/__pycache__/hpm.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/hpm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/__pycache__/hpm.cpython-37.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/mgn.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/__pycache__/mgn.cpython-35.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/mgn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/__pycache__/mgn.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/mgn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/__pycache__/mgn.cpython-37.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/pcb.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/__pycache__/pcb.cpython-35.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/pcb.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/__pycache__/pcb.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/pcb.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/__pycache__/pcb.cpython-37.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/pcb_seg.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/__pycache__/pcb_seg.cpython-35.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/pcb_seg.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/__pycache__/pcb_seg.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/pcb_seg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/__pycache__/pcb_seg.cpython-37.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/pyramid.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/__pycache__/pyramid.cpython-35.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/pyramid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/__pycache__/pyramid.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/pyramid.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/__pycache__/pyramid.cpython-37.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/res_net.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/__pycache__/res_net.cpython-35.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/res_net.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/__pycache__/res_net.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/res_net.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/__pycache__/res_net.cpython-37.pyc -------------------------------------------------------------------------------- /modeling/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/backbones/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/backbones/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/backbones/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/cls_hrnet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/backbones/__pycache__/cls_hrnet.cpython-35.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/cls_hrnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/backbones/__pycache__/cls_hrnet.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/cls_hrnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/backbones/__pycache__/cls_hrnet.cpython-37.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/model.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/backbones/__pycache__/model.cpython-35.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/backbones/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/backbones/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/non_local.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/backbones/__pycache__/non_local.cpython-35.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/pcb.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/backbones/__pycache__/pcb.cpython-35.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/resnet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/backbones/__pycache__/resnet.cpython-35.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/backbones/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/backbones/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/resnet_bap.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/backbones/__pycache__/resnet_bap.cpython-35.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/resnet_bap.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/backbones/__pycache__/resnet_bap.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/resnet_bap.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/backbones/__pycache__/resnet_bap.cpython-37.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/senet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/backbones/__pycache__/senet.cpython-35.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/senet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/backbones/__pycache__/senet.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/senet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/backbones/__pycache__/senet.cpython-37.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/sga_resnet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/backbones/__pycache__/sga_resnet.cpython-35.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/sga_resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/backbones/__pycache__/sga_resnet.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/sga_resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/modeling/backbones/__pycache__/sga_resnet.cpython-37.pyc -------------------------------------------------------------------------------- /modeling/backbones/non_local.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class Non_local(nn.Module): 8 | def __init__(self, in_channels, reduc_ratio=2): 9 | super(Non_local, self).__init__() 10 | 11 | self.in_channels = in_channels 12 | self.inter_channels = reduc_ratio // reduc_ratio 13 | 14 | self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 15 | kernel_size=1, stride=1, padding=0) 16 | 17 | self.W = nn.Sequential( 18 | nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), 19 | nn.BatchNorm2d(self.in_channels), 20 | ) 21 | nn.init.constant_(self.W[1].weight, 0.0) 22 | nn.init.constant_(self.W[1].bias, 0.0) 23 | 24 | self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 25 | kernel_size=1, stride=1, padding=0) 26 | 27 | self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 28 | kernel_size=1, stride=1, padding=0) 29 | 30 | def forward(self, x): # [32, 512, 32, 16] 31 | ''' 32 | :param x: (b, t, h, w) 33 | :return x: (b, t, h, w) 34 | ''' 35 | batch_size = x.size(0) 36 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) # [32, 1, 512] 37 | g_x = g_x.permute(0, 2, 1) # [32, 512, 1] 38 | 39 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) # [32, 1, 512] 40 | theta_x = theta_x.permute(0, 2, 1) # [32, 512, 1] 41 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) # [32, 1, 512] 42 | f = torch.matmul(theta_x, phi_x) # [32, 512, 512] 43 | N = f.size(-1) 44 | f_div_C = f / N # [32, 512, 512] 45 | 46 | y = torch.matmul(f_div_C, g_x) # [32, 512, 1] 47 | y = y.permute(0, 2, 1).contiguous() # [32, 1, 512] 48 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) # [32, 1, 32, 16] 49 | W_y = self.W(y) 50 | z = W_y + x # [32, 512, 32, 16] 51 | return z 52 | -------------------------------------------------------------------------------- /modeling/mgn.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | from torchvision.models.resnet import resnet50, Bottleneck 8 | 9 | 10 | def make_model(args): 11 | return MGN(args) 12 | 13 | 14 | class MGN(nn.Module): 15 | def __init__(self, num_classes): 16 | super(MGN, self).__init__() 17 | resnet = resnet50(pretrained=True) 18 | 19 | self.backone = nn.Sequential( 20 | resnet.conv1, 21 | resnet.bn1, 22 | resnet.relu, 23 | resnet.maxpool, 24 | resnet.layer1, 25 | resnet.layer2, 26 | resnet.layer3[0], 27 | ) 28 | 29 | res_conv4 = nn.Sequential(*resnet.layer3[1:]) 30 | 31 | res_g_conv5 = resnet.layer4 32 | 33 | res_p_conv5 = nn.Sequential( 34 | Bottleneck(1024, 512, downsample=nn.Sequential(nn.Conv2d(1024, 2048, 1, bias=False), nn.BatchNorm2d(2048))), 35 | Bottleneck(2048, 512), 36 | Bottleneck(2048, 512)) 37 | res_p_conv5.load_state_dict(resnet.layer4.state_dict()) 38 | 39 | self.p1 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_g_conv5)) 40 | self.p2 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5)) 41 | self.p3 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5)) 42 | 43 | # pool2d = nn.MaxPool2d 44 | pool2d = nn.AvgPool2d 45 | 46 | self.maxpool_zg_p1 = nn.AdaptiveAvgPool2d(1) 47 | self.maxpool_zg_p2 = nn.AdaptiveAvgPool2d(1) 48 | self.maxpool_zg_p3 = nn.AdaptiveAvgPool2d(1) 49 | self.maxpool_zp2 = nn.AdaptiveAvgPool2d((2, 1)) 50 | self.maxpool_zp3 = nn.AdaptiveAvgPool2d((3, 1)) 51 | 52 | feats = 256 53 | reduction = nn.Sequential(nn.Conv2d(2048, feats, 1, bias=False), nn.BatchNorm2d(feats), nn.ReLU()) # 2048->256 54 | 55 | self._init_reduction(reduction) 56 | self.reduction_0 = copy.deepcopy(reduction) 57 | self.reduction_1 = copy.deepcopy(reduction) 58 | self.reduction_2 = copy.deepcopy(reduction) 59 | self.reduction_3 = copy.deepcopy(reduction) 60 | self.reduction_4 = copy.deepcopy(reduction) 61 | self.reduction_5 = copy.deepcopy(reduction) 62 | self.reduction_6 = copy.deepcopy(reduction) 63 | self.reduction_7 = copy.deepcopy(reduction) 64 | 65 | # self.fc_id_2048_0 = nn.Linear(2048, num_classes) 66 | self.fc_id_2048_0 = nn.Linear(feats, num_classes) # 256->751 67 | self.fc_id_2048_1 = nn.Linear(feats, num_classes) 68 | self.fc_id_2048_2 = nn.Linear(feats, num_classes) 69 | 70 | self.fc_id_256_1_0 = nn.Linear(feats, num_classes) 71 | self.fc_id_256_1_1 = nn.Linear(feats, num_classes) 72 | self.fc_id_256_2_0 = nn.Linear(feats, num_classes) 73 | self.fc_id_256_2_1 = nn.Linear(feats, num_classes) 74 | self.fc_id_256_2_2 = nn.Linear(feats, num_classes) 75 | 76 | self._init_fc(self.fc_id_2048_0) 77 | self._init_fc(self.fc_id_2048_1) 78 | self._init_fc(self.fc_id_2048_2) 79 | 80 | self._init_fc(self.fc_id_256_1_0) 81 | self._init_fc(self.fc_id_256_1_1) 82 | self._init_fc(self.fc_id_256_2_0) 83 | self._init_fc(self.fc_id_256_2_1) 84 | self._init_fc(self.fc_id_256_2_2) 85 | 86 | @staticmethod 87 | def _init_reduction(reduction): 88 | # conv 89 | nn.init.kaiming_normal_(reduction[0].weight, mode='fan_in') 90 | # nn.init.constant_(reduction[0].bias, 0.) 91 | 92 | # bn 93 | nn.init.normal_(reduction[1].weight, mean=1., std=0.02) 94 | nn.init.constant_(reduction[1].bias, 0.) 95 | 96 | @staticmethod 97 | def _init_fc(fc): 98 | nn.init.kaiming_normal_(fc.weight, mode='fan_out') 99 | # nn.init.normal_(fc.weight, std=0.001) 100 | nn.init.constant_(fc.bias, 0.) 101 | 102 | def forward(self, x): # [64, 3, 384, 128] 103 | 104 | x = self.backone(x) # [64, 1024, 24, 8] 105 | 106 | p1 = self.p1(x) # [64, 2048, 12, 4] 107 | p2 = self.p2(x) # [64, 2048, 24, 8] 108 | p3 = self.p3(x) # [64, 2048, 24, 8] 109 | 110 | zg_p1 = self.maxpool_zg_p1(p1) # [64, 2048, 1, 1] 111 | zg_p2 = self.maxpool_zg_p2(p2) # [64, 2048, 1, 1] 112 | zg_p3 = self.maxpool_zg_p3(p3) # [64, 2048, 1, 1] 113 | 114 | zp2 = self.maxpool_zp2(p2) # [64, 2048, 2, 1] 115 | z0_p2 = zp2[:, :, 0:1, :] # [64, 2048, 1, 1] 116 | z1_p2 = zp2[:, :, 1:2, :] # [64, 2048, 1, 1] 117 | 118 | zp3 = self.maxpool_zp3(p3) # [64, 2048, 3, 1] 119 | z0_p3 = zp3[:, :, 0:1, :] # [64, 2048, 1, 1] 120 | z1_p3 = zp3[:, :, 1:2, :] # [64, 2048, 1, 1] 121 | z2_p3 = zp3[:, :, 2:3, :] # [64, 2048, 1, 1] 122 | 123 | fg_p1 = self.reduction_0(zg_p1).squeeze(dim=3).squeeze(dim=2) # [64, 256] 124 | fg_p2 = self.reduction_1(zg_p2).squeeze(dim=3).squeeze(dim=2) # [64, 256] 125 | fg_p3 = self.reduction_2(zg_p3).squeeze(dim=3).squeeze(dim=2) # [64, 256] 126 | f0_p2 = self.reduction_3(z0_p2).squeeze(dim=3).squeeze(dim=2) # [64, 256] 127 | f1_p2 = self.reduction_4(z1_p2).squeeze(dim=3).squeeze(dim=2) # [64, 256] 128 | f0_p3 = self.reduction_5(z0_p3).squeeze(dim=3).squeeze(dim=2) # [64, 256] 129 | f1_p3 = self.reduction_6(z1_p3).squeeze(dim=3).squeeze(dim=2) # [64, 256] 130 | f2_p3 = self.reduction_7(z2_p3).squeeze(dim=3).squeeze(dim=2) # [64, 256] 131 | 132 | ''' 133 | l_p1 = self.fc_id_2048_0(zg_p1.squeeze(dim=3).squeeze(dim=2)) 134 | l_p2 = self.fc_id_2048_1(zg_p2.squeeze(dim=3).squeeze(dim=2)) 135 | l_p3 = self.fc_id_2048_2(zg_p3.squeeze(dim=3).squeeze(dim=2)) 136 | ''' 137 | l_p1 = self.fc_id_2048_0(fg_p1) # [64, 751] 138 | l_p2 = self.fc_id_2048_1(fg_p2) # [64, 751] 139 | l_p3 = self.fc_id_2048_2(fg_p3) # [64, 751] 140 | 141 | l0_p2 = self.fc_id_256_1_0(f0_p2) # [64, 751] 142 | l1_p2 = self.fc_id_256_1_1(f1_p2) # [64, 751] 143 | l0_p3 = self.fc_id_256_2_0(f0_p3) # [64, 751] 144 | l1_p3 = self.fc_id_256_2_1(f1_p3) # [64, 751] 145 | l2_p3 = self.fc_id_256_2_2(f2_p3) # [64, 751] 146 | 147 | predict = torch.cat([fg_p1, fg_p2, fg_p3, f0_p2, f1_p2, f0_p3, f1_p3, f2_p3], dim=1) # [64, 2048] 148 | 149 | if not self.training: 150 | return predict 151 | 152 | return predict, [fg_p1, fg_p2, fg_p3], [l_p1, l_p2, l_p3, l0_p2, l1_p2, l0_p3, l1_p3, l2_p3] 153 | -------------------------------------------------------------------------------- /solver/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | from .build import make_optimizer, make_optimizer_with_center, make_optimizer_with_pcb, make_optimizer_fine, make_optimizer_with_triplet 4 | from .build import make_optimizer_with_global 5 | from .lr_scheduler import WarmupMultiStepLR -------------------------------------------------------------------------------- /solver/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/solver/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /solver/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/solver/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /solver/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/solver/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /solver/__pycache__/build.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/solver/__pycache__/build.cpython-35.pyc -------------------------------------------------------------------------------- /solver/__pycache__/build.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/solver/__pycache__/build.cpython-36.pyc -------------------------------------------------------------------------------- /solver/__pycache__/build.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/solver/__pycache__/build.cpython-37.pyc -------------------------------------------------------------------------------- /solver/__pycache__/lr_scheduler.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/solver/__pycache__/lr_scheduler.cpython-35.pyc -------------------------------------------------------------------------------- /solver/__pycache__/lr_scheduler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/solver/__pycache__/lr_scheduler.cpython-36.pyc -------------------------------------------------------------------------------- /solver/__pycache__/lr_scheduler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/solver/__pycache__/lr_scheduler.cpython-37.pyc -------------------------------------------------------------------------------- /solver/__pycache__/ranger.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/solver/__pycache__/ranger.cpython-35.pyc -------------------------------------------------------------------------------- /solver/__pycache__/ranger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/solver/__pycache__/ranger.cpython-36.pyc -------------------------------------------------------------------------------- /solver/__pycache__/ranger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/solver/__pycache__/ranger.cpython-37.pyc -------------------------------------------------------------------------------- /solver/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import torch 3 | 4 | def make_optimizer(cfg, model): 5 | params = [] 6 | for key, value in model.named_parameters(): 7 | if not value.requires_grad: 8 | continue 9 | lr = cfg.lr 10 | weight_decay = cfg.weight_decay 11 | if "bias" in key: 12 | lr = cfg.lr * cfg.bias_lr_factor 13 | weight_decay = cfg.weight_decay_bias 14 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 15 | if cfg.optimizer_name == 'SGD': 16 | optimizer = getattr(torch.optim, cfg.optimizer_name)(params, momentum=cfg.momentum) 17 | else: 18 | optimizer = getattr(torch.optim, cfg.optimizer_name)(params) 19 | return optimizer 20 | 21 | 22 | 23 | def make_optimizer_with_triplet(cfg, model): 24 | params = [] 25 | for key, value in model.named_parameters(): 26 | if not value.requires_grad: 27 | continue 28 | lr = cfg.lr 29 | weight_decay = cfg.weight_decay 30 | if "bias" in key: 31 | lr = cfg.lr * cfg.bias_lr_factor 32 | weight_decay = cfg.weight_decay_bias 33 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 34 | if cfg.optimizer_name == 'SGD': 35 | optimizer = getattr(torch.optim, cfg.optimizer_name)(params, momentum=cfg.momentum) 36 | else: 37 | optimizer = getattr(torch.optim, cfg.optimizer_name)(params) 38 | return optimizer 39 | 40 | 41 | 42 | def make_optimizer_fine(cfg, model): 43 | params = [] 44 | for key, value in model.named_parameters(): 45 | if not value.requires_grad: 46 | continue 47 | lr = cfg.lr 48 | weight_decay = cfg.weight_decay 49 | if "bias" in key: 50 | lr = cfg.lr * cfg.bias_lr_factor 51 | weight_decay = cfg.weight_decay_bias 52 | # if 'classifier' in key: 53 | # lr = lr * 0.1 54 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 55 | if cfg.optimizer_name == 'SGD': 56 | optimizer = getattr(torch.optim, cfg.optimizer_name)(params, momentum=cfg.momentum) 57 | else: 58 | optimizer = getattr(torch.optim, cfg.optimizer_name)(params) 59 | 60 | return optimizer 61 | 62 | 63 | def make_optimizer_with_global(cfg, model): 64 | params = [] 65 | for key, value in model.named_parameters(): 66 | if not value.requires_grad: 67 | continue 68 | lr = cfg.lr 69 | weight_decay = cfg.weight_decay 70 | if "bias" in key: 71 | lr = cfg.lr * cfg.bias_lr_factor 72 | weight_decay = cfg.weight_decay_bias 73 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 74 | if cfg.optimizer_name == 'SGD': 75 | optimizer = getattr(torch.optim, cfg.optimizer_name)(params, momentum=cfg.momentum) 76 | else: 77 | optimizer = getattr(torch.optim, cfg.optimizer_name)(params) 78 | 79 | return optimizer 80 | 81 | 82 | 83 | 84 | def make_optimizer_with_pcb(cfg, model): 85 | params = [] 86 | for key, value in model.named_parameters(): 87 | if not value.requires_grad: 88 | continue 89 | lr = cfg.lr 90 | weight_decay = cfg.weight_decay 91 | if "bias" in key: 92 | lr = cfg.lr * cfg.bias_lr_factor 93 | weight_decay = cfg.weight_decay_bias 94 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 95 | if cfg.optimizer_name == 'SGD': 96 | optimizer = getattr(torch.optim, cfg.optimizer_name)(params, momentum=cfg.momentum) 97 | else: 98 | optimizer = getattr(torch.optim, cfg.optimizer_name)(params) 99 | 100 | return optimizer 101 | 102 | 103 | 104 | 105 | 106 | def make_optimizer_with_center(cfg, model, center_criterion): 107 | params = [] 108 | for key, value in model.named_parameters(): 109 | if not value.requires_grad: 110 | continue 111 | lr = cfg.lr 112 | weight_decay = cfg.weight_decay 113 | if "bias" in key: 114 | lr = cfg.lr * cfg.bias_lr_factor 115 | weight_decay = cfg.weight_decay_bias 116 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 117 | if cfg.optimizer_name == 'SGD': 118 | optimizer = getattr(torch.optim, cfg.optimizer_name)(params, momentum=cfg.momentum) 119 | else: 120 | optimizer = getattr(torch.optim, cfg.optimizer_name)(params) 121 | optimizer_center = torch.optim.SGD(center_criterion.parameters(), lr=cfg.lr_center) 122 | return optimizer, optimizer_center 123 | 124 | -------------------------------------------------------------------------------- /solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from bisect import bisect_right 7 | import torch 8 | 9 | 10 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 11 | # separating MultiStepLR with WarmupLR 12 | # but the current LRScheduler design doesn't allow it 13 | 14 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 15 | def __init__( 16 | self, 17 | optimizer, 18 | milestones, 19 | gamma=0.1, 20 | warmup_factor=1.0 / 3, 21 | warmup_iters=500, 22 | warmup_method="linear", 23 | last_epoch=-1, 24 | ): 25 | if not list(milestones) == sorted(milestones): 26 | raise ValueError( 27 | "Milestones should be a list of" " increasing integers. Got {}", 28 | milestones, 29 | ) 30 | 31 | if warmup_method not in ("constant", "linear"): 32 | raise ValueError( 33 | "Only 'constant' or 'linear' warmup_method accepted" 34 | "got {}".format(warmup_method) 35 | ) 36 | self.milestones = milestones # (40, 70) 37 | self.gamma = gamma # 0.1 38 | self.warmup_factor = warmup_factor # 0.01 39 | self.warmup_iters = warmup_iters # 0 40 | self.warmup_method = warmup_method # linear 41 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 42 | 43 | def get_lr(self): 44 | warmup_factor = 1 45 | if self.last_epoch < self.warmup_iters: 46 | if self.warmup_method == "constant": 47 | warmup_factor = self.warmup_factor 48 | elif self.warmup_method == "linear": 49 | alpha = self.last_epoch / self.warmup_iters 50 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 51 | return [ 52 | base_lr 53 | * warmup_factor 54 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 55 | for base_lr in self.base_lrs 56 | ] 57 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | -------------------------------------------------------------------------------- /tests/lr_scheduler_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import unittest 3 | 4 | import torch 5 | from torch import nn 6 | 7 | sys.path.append('.') 8 | from solver.lr_scheduler import WarmupMultiStepLR 9 | from solver.build import make_optimizer 10 | from config import cfg 11 | 12 | 13 | class MyTestCase(unittest.TestCase): 14 | def test_something(self): 15 | net = nn.Linear(10, 10) 16 | optimizer = make_optimizer(cfg, net) 17 | lr_scheduler = WarmupMultiStepLR(optimizer, [20, 40], warmup_iters=10) 18 | for i in range(50): 19 | lr_scheduler.step() 20 | for j in range(3): 21 | print(i, lr_scheduler.get_lr()[0]) 22 | optimizer.step() 23 | 24 | 25 | if __name__ == '__main__': 26 | unittest.main() 27 | -------------------------------------------------------------------------------- /train_prcc_base.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import argparse 4 | import os 5 | import sys 6 | import torch 7 | from torch.backends import cudnn 8 | 9 | sys.path.append('.') 10 | from data import make_data_loader_prcc_hrnet as make_data_loader 11 | from engine.trainer import do_train_prcc_base as do_train 12 | from modeling import build_model 13 | from layers import make_loss_with_triplet_entropy 14 | from solver import make_optimizer_with_triplet 15 | from solver import WarmupMultiStepLR 16 | from engine.inference import inference_prcc_global 17 | import datetime 18 | 19 | 20 | def load_network_pretrain(model, cfg): 21 | path = os.path.join(cfg.logs_dir, 'checkpoint_best.pth') 22 | if not os.path.exists(path): 23 | return model, 0, 0.0 24 | pre_dict = torch.load(path) 25 | model.load_state_dict(pre_dict['state_dict']) 26 | start_epoch = pre_dict['epoch'] 27 | best_acc = pre_dict['best_acc'] 28 | print('start_epoch:', start_epoch) 29 | print('best_acc:', best_acc) 30 | return model, start_epoch, best_acc 31 | 32 | 33 | def main(cfg): 34 | # prepare dataset 35 | train_loader, train_loader_ca, train_loader_cb, val_loader_c, val_loader_b, num_query_c, num_query_b, num_classes = make_data_loader(cfg, h=256, w=128) # num_query=3368, num_classes=751 36 | 37 | # prepare model 38 | model = build_model(num_classes, 'global') # num_classes=751 39 | model = torch.nn.DataParallel(model).cuda() if torch.cuda.is_available() else model 40 | 41 | loss_func = make_loss_with_triplet_entropy(cfg, num_classes) 42 | optimizer = make_optimizer_with_triplet(cfg, model) 43 | 44 | if cfg.lr_type == 'step': 45 | scheduler = WarmupMultiStepLR(optimizer, cfg.steps, cfg.gamma, cfg.warmup_factor, cfg.warmup_iters, cfg.warmup_method) 46 | else: 47 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=10, factor=0.5) 48 | 49 | if cfg.train == 'train': 50 | start_epoch = int(0) 51 | acc_best = 0.0 52 | if cfg.resume == 1: 53 | model, start_epoch, acc_best = load_network_pretrain(model, cfg) 54 | 55 | do_train(cfg, model, train_loader, val_loader_c, optimizer, scheduler, loss_func, num_query_c, start_epoch, acc_best, lr_type=cfg.lr_type) 56 | elif cfg.train == 'test': 57 | # Test 58 | last_model_wts = torch.load(os.path.join(cfg.logs_dir, 'checkpoint_best.pth')) 59 | model.load_state_dict(last_model_wts['state_dict']) 60 | 61 | mAP, cmc1 = inference_prcc_global(model, val_loader_c, num_query_c) 62 | start_time = datetime.datetime.now() 63 | start_time = '%4d:%d:%d-%2d:%2d:%2d' % (start_time.year, start_time.month, start_time.day, start_time.hour, start_time.minute, start_time.second) 64 | line = '{} - Test: cmc1: {:.1%}, mAP: {:.1%}\n'.format(start_time, cmc1, mAP) 65 | print(line) 66 | 67 | 68 | 69 | 70 | 71 | if __name__ == '__main__': 72 | gpu_id = 0 73 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id) 74 | cudnn.benchmark = True 75 | parser = argparse.ArgumentParser(description="ReID Baseline Training") 76 | 77 | # DATA 78 | parser.add_argument('--batch_size', type=int, default=64) 79 | parser.add_argument('--img_per_id', type=int, default=8) 80 | parser.add_argument('--batch_size_test', type=int, default=128) 81 | parser.add_argument('--workers', type=int, default=8) 82 | parser.add_argument('--height', type=int, default=256) 83 | parser.add_argument('--width', type=int, default=128) 84 | parser.add_argument('--height_mask', type=int, default=256) 85 | parser.add_argument('--width_mask', type=int, default=128) 86 | 87 | 88 | # MODEL 89 | parser.add_argument('--features', type=int, default=128) 90 | parser.add_argument('--dropout', type=float, default=0.0) 91 | parser.add_argument('--parts', type=int, default=14) 92 | 93 | # OPTIMIZER 94 | parser.add_argument('--seed', type=int, default=1) 95 | parser.add_argument('--lr', type=float, default=0.0035) 96 | parser.add_argument('--lr_center', type=float, default=0.5) 97 | parser.add_argument('--lr_type', type=str, default='step', help='step, plateau') 98 | parser.add_argument('--center_loss_weight', type=float, default=0.0005) 99 | parser.add_argument('--steps', type=list, default=[40, 80]) 100 | parser.add_argument('--gamma', type=float, default=0.1) 101 | parser.add_argument('--cluster_margin', type=float, default=0.3) 102 | parser.add_argument('--bias_lr_factor', type=float, default=1.0) 103 | parser.add_argument('--weight_decay', type=float, default=5e-4) 104 | parser.add_argument('--weight_decay_bias', type=float, default=5e-4) 105 | parser.add_argument('--range_k', type=float, default=2) 106 | parser.add_argument('--range_margin', type=float, default=0.3) 107 | parser.add_argument('--range_alpha', type=float, default=0) 108 | parser.add_argument('--range_beta', type=float, default=1) 109 | parser.add_argument('--range_loss_weight', type=float, default=1) 110 | parser.add_argument('--warmup_factor', type=float, default=0.01) 111 | parser.add_argument('--warmup_iters', type=float, default=10) 112 | parser.add_argument('--warmup_method', type=str, default='linear') 113 | parser.add_argument('--margin', type=float, default=0.3) 114 | parser.add_argument('--optimizer_name', type=str, default="SGD", help="SGD, Adam") 115 | parser.add_argument('--momentum', type=float, default=0.9) 116 | 117 | 118 | # TRAINER 119 | parser.add_argument('--max_epochs', type=int, default=60) 120 | parser.add_argument('--train', type=str, default='train', help='train, test') # change train or test mode 121 | parser.add_argument('--resume', type=int, default=0) 122 | parser.add_argument('--num_works', type=int, default=8) 123 | 124 | # misc 125 | working_dir = os.path.dirname(os.path.abspath(__file__)) 126 | parser.add_argument('--dataset', type=str, default='prcc_gcn') 127 | parser.add_argument('--data_dir', type=str, default='/data/shuxj/data/PReID/prcc/') 128 | parser.add_argument('--logs_dir', type=str, default=os.path.join(working_dir, 'logs/prcc_base')) 129 | 130 | cfg = parser.parse_args() 131 | if not os.path.exists(cfg.logs_dir): 132 | os.makedirs(cfg.logs_dir) 133 | 134 | 135 | main(cfg) 136 | -------------------------------------------------------------------------------- /train_prcc_hpm.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import argparse 3 | import os 4 | import sys 5 | import torch 6 | from torch.backends import cudnn 7 | 8 | sys.path.append('.') 9 | from data import make_data_loader_prcc_hrnet as make_data_loader 10 | from engine.trainer import do_train_prcc_hpm as do_train 11 | from modeling import build_model 12 | from layers import make_loss_with_triplet_entropy 13 | from solver import make_optimizer_with_global 14 | from solver import WarmupMultiStepLR 15 | from engine.inference import inference_prcc_global, inference_prcc_visual_rank 16 | import datetime 17 | 18 | 19 | def load_network_pretrain(model, cfg): 20 | path = os.path.join(cfg.logs_dir, 'checkpoint_best.pth') 21 | if not os.path.exists(path): 22 | return model, 0, 0.0 23 | pre_dict = torch.load(path) 24 | model.load_state_dict(pre_dict['state_dict']) 25 | start_epoch = pre_dict['epoch'] 26 | best_acc = pre_dict['best_acc'] 27 | print('start_epoch:', start_epoch) 28 | print('best_acc:', best_acc) 29 | return model, start_epoch, best_acc 30 | 31 | 32 | 33 | def main(cfg): 34 | # prepare dataset 35 | train_loader, train_loader_ca, train_loader_cb, val_loader_c, val_loader_b, num_query_c, num_query_b, num_classes = make_data_loader(cfg, h=256, w=128) # num_query=3368, num_classes=751 36 | 37 | # prepare model 38 | model = build_model(num_classes, 'hpm') # num_classes=751 39 | model = torch.nn.DataParallel(model).cuda() if torch.cuda.is_available() else model 40 | 41 | loss_func = make_loss_with_triplet_entropy(cfg, num_classes) # modified by gu 42 | optimizer = make_optimizer_with_global(cfg, model) 43 | 44 | if cfg.lr_type == 'step': 45 | scheduler = WarmupMultiStepLR(optimizer, cfg.steps, cfg.gamma, cfg.warmup_factor, cfg.warmup_iters, cfg.warmup_method) 46 | else: 47 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=10, factor=0.5) 48 | 49 | if cfg.train == 'train': 50 | start_epoch = int(0) 51 | acc_best = 0.0 52 | if cfg.resume == 1: 53 | model, start_epoch, acc_best = load_network_pretrain(model, cfg) 54 | 55 | do_train(cfg, model, train_loader, val_loader_c, optimizer, scheduler, loss_func, num_query_c, start_epoch, acc_best, lr_type=cfg.lr_type) 56 | elif cfg.train == 'test': 57 | # Test 58 | last_model_wts = torch.load(os.path.join(cfg.logs_dir, 'checkpoint_best.pth')) 59 | model.load_state_dict(last_model_wts['state_dict']) 60 | 61 | mAP, cmc1 = inference_prcc_global(model, val_loader_c, num_query_c) 62 | start_time = datetime.datetime.now() 63 | start_time = '%4d:%d:%d-%2d:%2d:%2d' % (start_time.year, start_time.month, start_time.day, start_time.hour, start_time.minute, start_time.second) 64 | line = '{} - Test: cmc1: {:.1%}, mAP: {:.1%}\n'.format(start_time, cmc1, mAP) 65 | print(line) 66 | elif cfg.train == 'rank': 67 | # Test 68 | last_model_wts = torch.load(os.path.join(cfg.logs_dir, 'checkpoint_best.pth')) 69 | model.load_state_dict(last_model_wts['state_dict']) 70 | 71 | home = os.path.join('logs', 'rank', os.path.basename(cfg.logs_dir)) 72 | inference_prcc_visual_rank(model, val_loader_c, num_query_c, home=home, show_rank=20, use_flip=True) 73 | print('finish') 74 | 75 | 76 | 77 | 78 | 79 | if __name__ == '__main__': 80 | gpu_id = 0 81 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id) 82 | cudnn.benchmark = True 83 | 84 | parser = argparse.ArgumentParser(description="ReID Baseline Training") 85 | 86 | # DATA 87 | parser.add_argument('--batch_size', type=int, default=64) 88 | parser.add_argument('--img_per_id', type=int, default=4) 89 | parser.add_argument('--batch_size_test', type=int, default=128) 90 | parser.add_argument('--workers', type=int, default=8) 91 | parser.add_argument('--height', type=int, default=256) 92 | parser.add_argument('--width', type=int, default=128) 93 | parser.add_argument('--height_mask', type=int, default=256) 94 | parser.add_argument('--width_mask', type=int, default=128) 95 | 96 | 97 | # MODEL 98 | parser.add_argument('--features', type=int, default=128) 99 | parser.add_argument('--dropout', type=float, default=0.0) 100 | parser.add_argument('--parts', type=int, default=6) 101 | 102 | # OPTIMIZER 103 | parser.add_argument('--seed', type=int, default=1) 104 | parser.add_argument('--lr', type=float, default=0.0035) 105 | parser.add_argument('--lr_center', type=float, default=0.5) 106 | parser.add_argument('--lr_type', type=str, default='step', help='step, plateau') 107 | parser.add_argument('--center_loss_weight', type=float, default=0.0005) 108 | parser.add_argument('--steps', type=list, default=[40, 80]) 109 | parser.add_argument('--gamma', type=float, default=0.1) 110 | parser.add_argument('--cluster_margin', type=float, default=0.3) 111 | parser.add_argument('--bias_lr_factor', type=float, default=1.0) 112 | parser.add_argument('--weight_decay', type=float, default=5e-4) 113 | parser.add_argument('--weight_decay_bias', type=float, default=5e-4) 114 | parser.add_argument('--range_k', type=float, default=2) 115 | parser.add_argument('--range_margin', type=float, default=0.3) 116 | parser.add_argument('--range_alpha', type=float, default=0) 117 | parser.add_argument('--range_beta', type=float, default=1) 118 | parser.add_argument('--range_loss_weight', type=float, default=1) 119 | parser.add_argument('--warmup_factor', type=float, default=0.01) 120 | parser.add_argument('--warmup_iters', type=float, default=10) 121 | parser.add_argument('--warmup_method', type=str, default='linear') 122 | parser.add_argument('--margin', type=float, default=0.3) 123 | parser.add_argument('--optimizer_name', type=str, default="SGD") 124 | parser.add_argument('--momentum', type=float, default=0.9) 125 | 126 | 127 | # TRAINER 128 | parser.add_argument('--max_epochs', type=int, default=60) 129 | parser.add_argument('--train', type=str, default='train', help='train, test, rank') # change train or test mode 130 | parser.add_argument('--resume', type=int, default=0) 131 | parser.add_argument('--num_works', type=int, default=0) 132 | 133 | # misc 134 | working_dir = os.path.dirname(os.path.abspath(__file__)) 135 | parser.add_argument('--dataset', type=str, default='prcc_gcn') 136 | parser.add_argument('--data_dir', type=str, default='/data/prcc/') 137 | parser.add_argument('--logs_dir', type=str, default=os.path.join(working_dir, 'logs/prcc_hpm/')) 138 | 139 | cfg = parser.parse_args() 140 | if not os.path.exists(cfg.logs_dir): 141 | os.makedirs(cfg.logs_dir) 142 | 143 | 144 | main(cfg) 145 | -------------------------------------------------------------------------------- /train_prcc_hpm_pix.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import argparse 3 | import os 4 | import sys 5 | import torch 6 | from torch.backends import cudnn 7 | 8 | sys.path.append('.') 9 | from data import make_data_loader_prcc_hrnet as make_data_loader 10 | from engine.trainer import do_train_prcc_hpm_pix as do_train 11 | from modeling import build_model 12 | from layers import make_loss_with_triplet_entropy 13 | from solver import make_optimizer_with_global 14 | from solver import WarmupMultiStepLR 15 | from engine.inference import inference_prcc_global, inference_prcc_visual_rank 16 | import datetime 17 | 18 | 19 | def load_network_pretrain(model, cfg): 20 | path = os.path.join(cfg.logs_dir, 'checkpoint_best.pth') 21 | if not os.path.exists(path): 22 | return model, 0, 0.0 23 | pre_dict = torch.load(path) 24 | model.load_state_dict(pre_dict['state_dict']) 25 | start_epoch = pre_dict['epoch'] 26 | best_acc = pre_dict['best_acc'] 27 | print('start_epoch:', start_epoch) 28 | print('best_acc:', best_acc) 29 | return model, start_epoch, best_acc 30 | 31 | 32 | 33 | def main(cfg): 34 | # prepare dataset 35 | train_loader, train_loader_ca, train_loader_cb, val_loader_c, val_loader_b, num_query_c, num_query_b, num_classes = make_data_loader(cfg, h=256, w=128) # num_query=3368, num_classes=751 36 | 37 | # prepare model 38 | model = build_model(num_classes, 'hpm') # num_classes=751 39 | model = torch.nn.DataParallel(model).cuda() if torch.cuda.is_available() else model 40 | 41 | loss_func = make_loss_with_triplet_entropy(cfg, num_classes) # modified by gu 42 | optimizer = make_optimizer_with_global(cfg, model) 43 | 44 | if cfg.lr_type == 'step': 45 | scheduler = WarmupMultiStepLR(optimizer, cfg.steps, cfg.gamma, cfg.warmup_factor, cfg.warmup_iters, cfg.warmup_method) 46 | else: 47 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=10, factor=0.5) 48 | 49 | if cfg.train == 'train': 50 | start_epoch = int(0) 51 | acc_best = 0.0 52 | if cfg.resume == 1: 53 | model, start_epoch, acc_best = load_network_pretrain(model, cfg) 54 | 55 | do_train(cfg, model, train_loader, val_loader_c, optimizer, scheduler, loss_func, num_query_c, start_epoch, acc_best, lr_type=cfg.lr_type) 56 | elif cfg.train == 'test': 57 | # Test 58 | last_model_wts = torch.load(os.path.join(cfg.logs_dir, 'checkpoint_best.pth')) 59 | model.load_state_dict(last_model_wts['state_dict']) 60 | 61 | mAP, cmc1 = inference_prcc_global(model, val_loader_c, num_query_c) 62 | start_time = datetime.datetime.now() 63 | start_time = '%4d:%d:%d-%2d:%2d:%2d' % (start_time.year, start_time.month, start_time.day, start_time.hour, start_time.minute, start_time.second) 64 | line = '{} - Test: cmc1: {:.1%}, mAP: {:.1%}\n'.format(start_time, cmc1, mAP) 65 | print(line) 66 | elif cfg.train == 'rank': 67 | # Test 68 | last_model_wts = torch.load(os.path.join(cfg.logs_dir, 'checkpoint_best.pth')) 69 | model.load_state_dict(last_model_wts['state_dict']) 70 | 71 | home = os.path.join('logs', 'rank', os.path.basename(cfg.logs_dir)) 72 | inference_prcc_visual_rank(model, val_loader_c, num_query_c, home=home, show_rank=20, use_flip=True) 73 | print('finish') 74 | 75 | 76 | 77 | 78 | if __name__ == '__main__': 79 | gpu_id = 0 80 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id) 81 | cudnn.benchmark = True 82 | 83 | parser = argparse.ArgumentParser(description="ReID Baseline Training") 84 | 85 | # DATA 86 | parser.add_argument('--batch_size', type=int, default=64) 87 | parser.add_argument('--img_per_id', type=int, default=4) 88 | parser.add_argument('--batch_size_test', type=int, default=128) 89 | parser.add_argument('--workers', type=int, default=8) 90 | parser.add_argument('--height', type=int, default=256) 91 | parser.add_argument('--width', type=int, default=128) 92 | parser.add_argument('--height_mask', type=int, default=256) 93 | parser.add_argument('--width_mask', type=int, default=128) 94 | 95 | 96 | # MODEL 97 | parser.add_argument('--features', type=int, default=128) 98 | parser.add_argument('--dropout', type=float, default=0.0) 99 | parser.add_argument('--parts', type=int, default=6) 100 | 101 | # OPTIMIZER 102 | parser.add_argument('--seed', type=int, default=1) 103 | parser.add_argument('--lr', type=float, default=0.0035) 104 | parser.add_argument('--lr_center', type=float, default=0.5) 105 | parser.add_argument('--lr_type', type=str, default='step', help='step, plateau') 106 | parser.add_argument('--center_loss_weight', type=float, default=0.0005) 107 | parser.add_argument('--steps', type=list, default=[40, 80]) 108 | parser.add_argument('--gamma', type=float, default=0.1) 109 | parser.add_argument('--cluster_margin', type=float, default=0.3) 110 | parser.add_argument('--bias_lr_factor', type=float, default=1.0) 111 | parser.add_argument('--weight_decay', type=float, default=5e-4) 112 | parser.add_argument('--weight_decay_bias', type=float, default=5e-4) 113 | parser.add_argument('--range_k', type=float, default=2) 114 | parser.add_argument('--range_margin', type=float, default=0.3) 115 | parser.add_argument('--range_alpha', type=float, default=0) 116 | parser.add_argument('--range_beta', type=float, default=1) 117 | parser.add_argument('--range_loss_weight', type=float, default=1) 118 | parser.add_argument('--warmup_factor', type=float, default=0.01) 119 | parser.add_argument('--warmup_iters', type=float, default=10) 120 | parser.add_argument('--warmup_method', type=str, default='linear') 121 | parser.add_argument('--margin', type=float, default=0.3) 122 | parser.add_argument('--optimizer_name', type=str, default="SGD") 123 | parser.add_argument('--momentum', type=float, default=0.9) 124 | 125 | 126 | # TRAINER 127 | parser.add_argument('--max_epochs', type=int, default=60) 128 | parser.add_argument('--train', type=str, default='train', help='train, test, rank') # change train or test mode 129 | parser.add_argument('--resume', type=int, default=0) 130 | parser.add_argument('--num_works', type=int, default=0) 131 | 132 | # misc 133 | working_dir = os.path.dirname(os.path.abspath(__file__)) 134 | parser.add_argument('--dataset', type=str, default='prcc_gcn') 135 | parser.add_argument('--data_dir', type=str, default='/data/prcc/') 136 | parser.add_argument('--logs_dir', type=str, default=os.path.join(working_dir, 'logs/prcc_hpm_pix/')) 137 | 138 | cfg = parser.parse_args() 139 | if not os.path.exists(cfg.logs_dir): 140 | os.makedirs(cfg.logs_dir) 141 | 142 | 143 | main(cfg) 144 | -------------------------------------------------------------------------------- /train_prcc_mgn.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import argparse 3 | import os 4 | import sys 5 | import torch 6 | from torch.backends import cudnn 7 | 8 | sys.path.append('.') 9 | from data import make_data_loader_prcc_hrnet as make_data_loader 10 | from engine.trainer import do_train_prcc_mgn as do_train 11 | from modeling import build_model 12 | from layers import make_loss_with_mgn 13 | from solver import make_optimizer_fine as make_optimizer 14 | from solver import WarmupMultiStepLR 15 | from engine.inference import inference_prcc_global 16 | import datetime 17 | 18 | 19 | def load_network_pretrain(model, cfg): 20 | path = os.path.join(cfg.logs_dir, 'checkpoint_best.pth') 21 | if not os.path.exists(path): 22 | return model, 0, 0.0 23 | pre_dict = torch.load(path) 24 | model.load_state_dict(pre_dict['state_dict']) 25 | start_epoch = pre_dict['epoch'] 26 | best_acc = pre_dict['best_acc'] 27 | print('start_epoch:', start_epoch) 28 | print('best_acc:', best_acc) 29 | return model, start_epoch, best_acc 30 | 31 | 32 | def main(cfg): 33 | # prepare dataset 34 | train_loader, train_loader_ca, train_loader_cb, val_loader_c, val_loader_b, num_query_c, num_query_b, num_classes = make_data_loader(cfg, h=256, w=128) # num_query=3368, num_classes=751 35 | 36 | # prepare model 37 | model = build_model(num_classes, 'mgn') # num_classes=751 38 | model = torch.nn.DataParallel(model).cuda() if torch.cuda.is_available() else model 39 | 40 | loss_func = make_loss_with_mgn(cfg, num_classes) # modified by gu 41 | optimizer = make_optimizer(cfg, model) 42 | 43 | 44 | if cfg.lr_type == 'step': 45 | scheduler = WarmupMultiStepLR(optimizer, cfg.steps, cfg.gamma, cfg.warmup_factor, cfg.warmup_iters, cfg.warmup_method) 46 | else: 47 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=10, factor=0.5) 48 | 49 | if cfg.train == 'train': 50 | start_epoch = int(0) 51 | acc_best = 0.0 52 | if cfg.resume == 1: 53 | model, start_epoch, acc_best = load_network_pretrain(model, cfg) 54 | 55 | do_train(cfg, model, train_loader, val_loader_c, optimizer, scheduler, loss_func, num_query_c, start_epoch, acc_best, lr_type=cfg.lr_type) 56 | elif cfg.train == 'test': 57 | # Test 58 | last_model_wts = torch.load(os.path.join(cfg.logs_dir, 'checkpoint_best.pth')) 59 | model.load_state_dict(last_model_wts['state_dict']) 60 | 61 | mAP, cmc1 = inference_prcc_global(model, val_loader_c, num_query_c) 62 | start_time = datetime.datetime.now() 63 | start_time = '%4d:%d:%d-%2d:%2d:%2d' % (start_time.year, start_time.month, start_time.day, start_time.hour, start_time.minute, start_time.second) 64 | line = '{} - Test: cmc1: {:.1%}, mAP: {:.1%}\n'.format(start_time, cmc1, mAP) 65 | print(line) 66 | 67 | 68 | 69 | 70 | if __name__ == '__main__': 71 | gpu_id = 0 72 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id) 73 | cudnn.benchmark = True 74 | 75 | parser = argparse.ArgumentParser(description="ReID Baseline Training") 76 | 77 | # DATA 78 | parser.add_argument('--batch_size', type=int, default=64) 79 | parser.add_argument('--img_per_id', type=int, default=4) 80 | parser.add_argument('--batch_size_test', type=int, default=128) 81 | parser.add_argument('--workers', type=int, default=8) 82 | parser.add_argument('--height', type=int, default=256) 83 | parser.add_argument('--width', type=int, default=128) 84 | parser.add_argument('--height_mask', type=int, default=256) 85 | parser.add_argument('--width_mask', type=int, default=128) 86 | 87 | 88 | # MODEL 89 | parser.add_argument('--features', type=int, default=128) 90 | parser.add_argument('--dropout', type=float, default=0.0) 91 | parser.add_argument('--parts', type=int, default=6) 92 | 93 | # OPTIMIZER 94 | parser.add_argument('--seed', type=int, default=1) 95 | parser.add_argument('--lr', type=float, default=0.0035) 96 | parser.add_argument('--lr_center', type=float, default=0.5) 97 | parser.add_argument('--lr_type', type=str, default='step', help='step, plateau') 98 | parser.add_argument('--center_loss_weight', type=float, default=0.0005) 99 | parser.add_argument('--steps', type=list, default=[40, 80]) 100 | parser.add_argument('--gamma', type=float, default=0.1) 101 | parser.add_argument('--cluster_margin', type=float, default=0.3) 102 | parser.add_argument('--bias_lr_factor', type=float, default=1.0) 103 | parser.add_argument('--weight_decay', type=float, default=5e-4) 104 | parser.add_argument('--weight_decay_bias', type=float, default=5e-4) 105 | parser.add_argument('--range_k', type=float, default=2) 106 | parser.add_argument('--range_margin', type=float, default=0.3) 107 | parser.add_argument('--range_alpha', type=float, default=0) 108 | parser.add_argument('--range_beta', type=float, default=1) 109 | parser.add_argument('--range_loss_weight', type=float, default=1) 110 | parser.add_argument('--warmup_factor', type=float, default=0.01) 111 | parser.add_argument('--warmup_iters', type=float, default=10) 112 | parser.add_argument('--warmup_method', type=str, default='linear') 113 | parser.add_argument('--margin', type=float, default=0.3) 114 | parser.add_argument('--optimizer_name', type=str, default="SGD") 115 | parser.add_argument('--momentum', type=float, default=0.9) 116 | 117 | 118 | # TRAINER 119 | parser.add_argument('--max_epochs', type=int, default=60) 120 | parser.add_argument('--train', type=str, default='train', help='train, test') # change train or test mode 121 | parser.add_argument('--resume', type=int, default=0) 122 | parser.add_argument('--num_works', type=int, default=8) 123 | 124 | # misc 125 | working_dir = os.path.dirname(os.path.abspath(__file__)) 126 | parser.add_argument('--dataset', type=str, default='prcc_gcn') 127 | parser.add_argument('--data_dir', type=str, default='/data/prcc/') 128 | parser.add_argument('--logs_dir', type=str, default=os.path.join(working_dir, 'logs/prcc_mgn')) 129 | 130 | cfg = parser.parse_args() 131 | if not os.path.exists(cfg.logs_dir): 132 | os.makedirs(cfg.logs_dir) 133 | 134 | main(cfg) 135 | -------------------------------------------------------------------------------- /train_prcc_mgn_pix.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import argparse 4 | import os 5 | import sys 6 | import torch 7 | from torch.backends import cudnn 8 | 9 | sys.path.append('.') 10 | from data import make_data_loader_prcc_hrnet as make_data_loader 11 | from engine.trainer import do_train_prcc_mgn_pix as do_train 12 | from modeling import build_model 13 | from layers import make_loss_with_mgn 14 | from solver import make_optimizer_fine as make_optimizer 15 | from solver import WarmupMultiStepLR 16 | from engine.inference import inference_prcc_global 17 | import datetime 18 | 19 | 20 | def load_network_pretrain(model, cfg): 21 | path = os.path.join(cfg.logs_dir, 'checkpoint_best.pth') 22 | if not os.path.exists(path): 23 | return model, 0, 0.0 24 | pre_dict = torch.load(path) 25 | model.load_state_dict(pre_dict['state_dict']) 26 | start_epoch = pre_dict['epoch'] 27 | best_acc = pre_dict['best_acc'] 28 | print('start_epoch:', start_epoch) 29 | print('best_acc:', best_acc) 30 | return model, start_epoch, best_acc 31 | 32 | 33 | 34 | def main(cfg): 35 | # prepare dataset 36 | train_loader, train_loader_ca, train_loader_cb, val_loader_c, val_loader_b, num_query_c, num_query_b, num_classes = make_data_loader(cfg, h=256, w=128) # num_query=3368, num_classes=751 37 | 38 | # prepare model 39 | model = build_model(num_classes, 'mgn') # num_classes=751 40 | model = torch.nn.DataParallel(model).cuda() if torch.cuda.is_available() else model 41 | 42 | loss_func = make_loss_with_mgn(cfg, num_classes) # modified by gu 43 | optimizer = make_optimizer(cfg, model) 44 | 45 | 46 | if cfg.lr_type == 'step': 47 | scheduler = WarmupMultiStepLR(optimizer, cfg.steps, cfg.gamma, cfg.warmup_factor, cfg.warmup_iters, cfg.warmup_method) 48 | else: 49 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=10, factor=0.5) 50 | 51 | if cfg.train == 'train': 52 | start_epoch = int(0) 53 | acc_best = 0.0 54 | if cfg.resume == 1: 55 | model, start_epoch, acc_best = load_network_pretrain(model, cfg) 56 | 57 | do_train(cfg, model, train_loader, val_loader_c, optimizer, scheduler, loss_func, num_query_c, start_epoch, acc_best, lr_type=cfg.lr_type) 58 | elif cfg.train == 'test': 59 | # Test 60 | last_model_wts = torch.load(os.path.join(cfg.logs_dir, 'checkpoint_best.pth')) 61 | model.load_state_dict(last_model_wts['state_dict']) 62 | 63 | mAP, cmc1 = inference_prcc_global(model, val_loader_c, num_query_c) 64 | start_time = datetime.datetime.now() 65 | start_time = '%4d:%d:%d-%2d:%2d:%2d' % (start_time.year, start_time.month, start_time.day, start_time.hour, start_time.minute, start_time.second) 66 | line = '{} - Test: cmc1: {:.1%}, mAP: {:.1%}\n'.format(start_time, cmc1, mAP) 67 | print(line) 68 | 69 | 70 | 71 | if __name__ == '__main__': 72 | gpu_id = 0 73 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id) 74 | cudnn.benchmark = True 75 | 76 | parser = argparse.ArgumentParser(description="ReID Baseline Training") 77 | 78 | # DATA 79 | parser.add_argument('--batch_size', type=int, default=64) 80 | parser.add_argument('--img_per_id', type=int, default=4) 81 | parser.add_argument('--batch_size_test', type=int, default=128) 82 | parser.add_argument('--workers', type=int, default=8) 83 | parser.add_argument('--height', type=int, default=256) 84 | parser.add_argument('--width', type=int, default=128) 85 | parser.add_argument('--height_mask', type=int, default=256) 86 | parser.add_argument('--width_mask', type=int, default=128) 87 | 88 | 89 | # MODEL 90 | parser.add_argument('--features', type=int, default=128) 91 | parser.add_argument('--dropout', type=float, default=0.0) 92 | parser.add_argument('--parts', type=int, default=6) 93 | 94 | # OPTIMIZER 95 | parser.add_argument('--seed', type=int, default=1) 96 | parser.add_argument('--lr', type=float, default=0.0035) 97 | parser.add_argument('--lr_center', type=float, default=0.5) 98 | parser.add_argument('--lr_type', type=str, default='step', help='step, plateau') 99 | parser.add_argument('--center_loss_weight', type=float, default=0.0005) 100 | parser.add_argument('--steps', type=list, default=[40, 80]) 101 | parser.add_argument('--gamma', type=float, default=0.1) 102 | parser.add_argument('--cluster_margin', type=float, default=0.3) 103 | parser.add_argument('--bias_lr_factor', type=float, default=1.0) 104 | parser.add_argument('--weight_decay', type=float, default=5e-4) 105 | parser.add_argument('--weight_decay_bias', type=float, default=5e-4) 106 | parser.add_argument('--range_k', type=float, default=2) 107 | parser.add_argument('--range_margin', type=float, default=0.3) 108 | parser.add_argument('--range_alpha', type=float, default=0) 109 | parser.add_argument('--range_beta', type=float, default=1) 110 | parser.add_argument('--range_loss_weight', type=float, default=1) 111 | parser.add_argument('--warmup_factor', type=float, default=0.01) 112 | parser.add_argument('--warmup_iters', type=float, default=10) 113 | parser.add_argument('--warmup_method', type=str, default='linear') 114 | parser.add_argument('--margin', type=float, default=0.3) 115 | parser.add_argument('--optimizer_name', type=str, default="SGD") 116 | parser.add_argument('--momentum', type=float, default=0.9) 117 | 118 | 119 | # TRAINER 120 | parser.add_argument('--max_epochs', type=int, default=60) 121 | parser.add_argument('--train', type=str, default='train', help='train, test, cam') # change train or test mode 122 | parser.add_argument('--resume', type=int, default=0) 123 | parser.add_argument('--num_works', type=int, default=8) 124 | 125 | # misc 126 | working_dir = os.path.dirname(os.path.abspath(__file__)) 127 | parser.add_argument('--dataset', type=str, default='prcc_gcn') 128 | parser.add_argument('--data_dir', type=str, default='/data/prcc/') 129 | parser.add_argument('--logs_dir', type=str, default=os.path.join(working_dir, 'logs/prcc_mgn_pix')) 130 | 131 | cfg = parser.parse_args() 132 | if not os.path.exists(cfg.logs_dir): 133 | os.makedirs(cfg.logs_dir) 134 | 135 | main(cfg) 136 | -------------------------------------------------------------------------------- /train_prcc_pcb.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import argparse 3 | import os 4 | import sys 5 | import torch 6 | from torch.backends import cudnn 7 | 8 | sys.path.append('.') 9 | from data import make_data_loader_prcc_hrnet as make_data_loader 10 | from engine.trainer import do_train_prcc_pcb as do_train 11 | from modeling import build_model 12 | from layers import make_loss_with_pcb 13 | from solver import make_optimizer_with_pcb 14 | from solver import WarmupMultiStepLR 15 | from engine.inference import inference_prcc_pcb 16 | import datetime 17 | 18 | 19 | def load_network_pretrain(model, cfg): 20 | path = os.path.join(cfg.logs_dir, 'checkpoint_best.pth') 21 | if not os.path.exists(path): 22 | return model, 0, 0.0 23 | pre_dict = torch.load(path) 24 | model.load_state_dict(pre_dict['state_dict']) 25 | start_epoch = pre_dict['epoch'] 26 | best_acc = pre_dict['best_acc'] 27 | print('start_epoch:', start_epoch) 28 | print('best_acc:', best_acc) 29 | return model, start_epoch, best_acc 30 | 31 | 32 | 33 | def main(cfg): 34 | # prepare dataset 35 | train_loader, train_loader_ca, train_loader_cb, val_loader_c, val_loader_b, num_query_c, num_query_b, num_classes = make_data_loader(cfg, h=256, w=128) # num_query=3368, num_classes=751 36 | 37 | # prepare model 38 | model = build_model(num_classes, 'pcb') # num_classes=751 39 | model = torch.nn.DataParallel(model).cuda() if torch.cuda.is_available() else model 40 | 41 | loss_func = make_loss_with_pcb(cfg, num_classes) 42 | optimizer = make_optimizer_with_pcb(cfg, model) 43 | 44 | if cfg.lr_type == 'step': 45 | scheduler = WarmupMultiStepLR(optimizer, cfg.steps, cfg.gamma, cfg.warmup_factor, cfg.warmup_iters, cfg.warmup_method) 46 | else: 47 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=10, factor=0.5) 48 | 49 | if cfg.train == 'train': 50 | start_epoch = int(0) 51 | acc_best = 0.0 52 | if cfg.resume == 1: 53 | model, start_epoch, acc_best = load_network_pretrain(model, cfg) 54 | 55 | do_train(cfg, model, train_loader, val_loader_c, optimizer, scheduler, loss_func, num_query_c, start_epoch, acc_best, lr_type=cfg.lr_type) 56 | elif cfg.train == 'test': 57 | # Test 58 | last_model_wts = torch.load(os.path.join(cfg.logs_dir, 'checkpoint_best.pth')) 59 | model.load_state_dict(last_model_wts['state_dict']) 60 | 61 | mAP, cmc1 = inference_prcc_pcb(model, val_loader_c, num_query_c) 62 | start_time = datetime.datetime.now() 63 | start_time = '%4d:%d:%d-%2d:%2d:%2d' % (start_time.year, start_time.month, start_time.day, start_time.hour, start_time.minute, start_time.second) 64 | line = '{} - Test: cmc1: {:.1%}, mAP: {:.1%}\n'.format(start_time, cmc1, mAP) 65 | print(line) 66 | 67 | 68 | 69 | 70 | 71 | if __name__ == '__main__': 72 | gpu_id = 0 73 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id) 74 | cudnn.benchmark = True 75 | 76 | parser = argparse.ArgumentParser(description="ReID Baseline Training") 77 | 78 | # DATA 79 | parser.add_argument('--batch_size', type=int, default=64) 80 | parser.add_argument('--img_per_id', type=int, default=4) 81 | parser.add_argument('--batch_size_test', type=int, default=128) 82 | parser.add_argument('--workers', type=int, default=8) 83 | parser.add_argument('--height', type=int, default=256) 84 | parser.add_argument('--width', type=int, default=128) 85 | parser.add_argument('--height_mask', type=int, default=256) 86 | parser.add_argument('--width_mask', type=int, default=128) 87 | 88 | 89 | # MODEL 90 | parser.add_argument('--features', type=int, default=128) 91 | parser.add_argument('--dropout', type=float, default=0.0) 92 | parser.add_argument('--parts', type=int, default=6) 93 | 94 | # OPTIMIZER 95 | parser.add_argument('--seed', type=int, default=1) 96 | parser.add_argument('--lr', type=float, default=0.0035) 97 | parser.add_argument('--lr_center', type=float, default=0.5) 98 | parser.add_argument('--lr_type', type=str, default='step', help='step, plateau') 99 | parser.add_argument('--center_loss_weight', type=float, default=0.0005) 100 | parser.add_argument('--steps', type=list, default=[40, 80]) 101 | parser.add_argument('--gamma', type=float, default=0.1) 102 | parser.add_argument('--cluster_margin', type=float, default=0.3) 103 | parser.add_argument('--bias_lr_factor', type=float, default=1.0) 104 | parser.add_argument('--weight_decay', type=float, default=5e-4) 105 | parser.add_argument('--weight_decay_bias', type=float, default=5e-4) 106 | parser.add_argument('--range_k', type=float, default=2) 107 | parser.add_argument('--range_margin', type=float, default=0.3) 108 | parser.add_argument('--range_alpha', type=float, default=0) 109 | parser.add_argument('--range_beta', type=float, default=1) 110 | parser.add_argument('--range_loss_weight', type=float, default=1) 111 | parser.add_argument('--warmup_factor', type=float, default=0.01) 112 | parser.add_argument('--warmup_iters', type=float, default=10) 113 | parser.add_argument('--warmup_method', type=str, default='linear') 114 | parser.add_argument('--margin', type=float, default=0.3) 115 | parser.add_argument('--optimizer_name', type=str, default="SGD") 116 | parser.add_argument('--momentum', type=float, default=0.9) 117 | 118 | 119 | # TRAINER 120 | parser.add_argument('--max_epochs', type=int, default=60) 121 | parser.add_argument('--train', type=str, default='train', help='train, test') # change train or test mode 122 | parser.add_argument('--resume', type=int, default=0) 123 | parser.add_argument('--num_works', type=int, default=8) 124 | 125 | # misc 126 | working_dir = os.path.dirname(os.path.abspath(__file__)) 127 | parser.add_argument('--dataset', type=str, default='prcc_gcn') 128 | parser.add_argument('--data_dir', type=str, default='/data/prcc/') 129 | parser.add_argument('--logs_dir', type=str, default=os.path.join(working_dir, 'logs/prcc_pcb/')) 130 | 131 | cfg = parser.parse_args() 132 | if not os.path.exists(cfg.logs_dir): 133 | os.makedirs(cfg.logs_dir) 134 | 135 | main(cfg) 136 | -------------------------------------------------------------------------------- /train_prcc_pcb_pix.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import argparse 8 | import os 9 | import sys 10 | import torch 11 | from torch.backends import cudnn 12 | 13 | sys.path.append('.') 14 | from data import make_data_loader_prcc_hrnet as make_data_loader 15 | from engine.trainer import do_train_prcc_pcb_pix as do_train 16 | from modeling import build_model 17 | from layers import make_loss_with_pcb 18 | from solver import make_optimizer_with_pcb 19 | from solver import WarmupMultiStepLR 20 | from engine.inference import inference_prcc_pcb 21 | import datetime 22 | 23 | 24 | def load_network_pretrain(model, cfg): 25 | path = os.path.join(cfg.logs_dir, 'checkpoint_best.pth') 26 | if not os.path.exists(path): 27 | return model, 0, 0.0 28 | pre_dict = torch.load(path) 29 | model.load_state_dict(pre_dict['state_dict']) 30 | start_epoch = pre_dict['epoch'] 31 | best_acc = pre_dict['best_acc'] 32 | print('start_epoch:', start_epoch) 33 | print('best_acc:', best_acc) 34 | return model, start_epoch, best_acc 35 | 36 | 37 | 38 | def main(cfg): 39 | # prepare dataset 40 | train_loader, train_loader_ca, train_loader_cb, val_loader_c, val_loader_b, num_query_c, num_query_b, num_classes = make_data_loader(cfg, h=256, w=128) # num_query=3368, num_classes=751 41 | 42 | # prepare model 43 | model = build_model(num_classes, 'pcb') # num_classes=751 44 | model = torch.nn.DataParallel(model).cuda() if torch.cuda.is_available() else model 45 | 46 | loss_func = make_loss_with_pcb(cfg, num_classes) 47 | optimizer = make_optimizer_with_pcb(cfg, model) 48 | 49 | if cfg.lr_type == 'step': 50 | scheduler = WarmupMultiStepLR(optimizer, cfg.steps, cfg.gamma, cfg.warmup_factor, cfg.warmup_iters, cfg.warmup_method) 51 | else: 52 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=10, factor=0.5) 53 | 54 | if cfg.train == 'train': 55 | start_epoch = int(0) 56 | acc_best = 0.0 57 | if cfg.resume == 1: 58 | model, start_epoch, acc_best = load_network_pretrain(model, cfg) 59 | 60 | do_train(cfg, model, train_loader, val_loader_c, optimizer, scheduler, loss_func, num_query_c, start_epoch, acc_best, lr_type=cfg.lr_type) 61 | elif cfg.train == 'test': 62 | # Test 63 | last_model_wts = torch.load(os.path.join(cfg.logs_dir, 'checkpoint_best.pth')) 64 | model.load_state_dict(last_model_wts['state_dict']) 65 | 66 | mAP, cmc1 = inference_prcc_pcb(model, val_loader_c, num_query_c) 67 | start_time = datetime.datetime.now() 68 | start_time = '%4d:%d:%d-%2d:%2d:%2d' % (start_time.year, start_time.month, start_time.day, start_time.hour, start_time.minute, start_time.second) 69 | line = '{} - Test: cmc1: {:.1%}, mAP: {:.1%}\n'.format(start_time, cmc1, mAP) 70 | print(line) 71 | 72 | 73 | 74 | 75 | if __name__ == '__main__': 76 | gpu_id = 0 77 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id) 78 | cudnn.benchmark = True 79 | 80 | parser = argparse.ArgumentParser(description="ReID Baseline Training") 81 | 82 | # DATA 83 | parser.add_argument('--batch_size', type=int, default=64) 84 | parser.add_argument('--img_per_id', type=int, default=4) 85 | parser.add_argument('--batch_size_test', type=int, default=128) 86 | parser.add_argument('--workers', type=int, default=8) 87 | parser.add_argument('--height', type=int, default=256) 88 | parser.add_argument('--width', type=int, default=128) 89 | parser.add_argument('--height_mask', type=int, default=256) 90 | parser.add_argument('--width_mask', type=int, default=128) 91 | 92 | 93 | # MODEL 94 | parser.add_argument('--features', type=int, default=128) 95 | parser.add_argument('--dropout', type=float, default=0.0) 96 | parser.add_argument('--parts', type=int, default=6) 97 | 98 | # OPTIMIZER 99 | parser.add_argument('--seed', type=int, default=1) 100 | parser.add_argument('--lr', type=float, default=0.0035) 101 | parser.add_argument('--lr_center', type=float, default=0.5) 102 | parser.add_argument('--lr_type', type=str, default='step', help='step, plateau') 103 | parser.add_argument('--center_loss_weight', type=float, default=0.0005) 104 | parser.add_argument('--steps', type=list, default=[40, 80]) 105 | parser.add_argument('--gamma', type=float, default=0.1) 106 | parser.add_argument('--cluster_margin', type=float, default=0.3) 107 | parser.add_argument('--bias_lr_factor', type=float, default=1.0) 108 | parser.add_argument('--weight_decay', type=float, default=5e-4) 109 | parser.add_argument('--weight_decay_bias', type=float, default=5e-4) 110 | parser.add_argument('--range_k', type=float, default=2) 111 | parser.add_argument('--range_margin', type=float, default=0.3) 112 | parser.add_argument('--range_alpha', type=float, default=0) 113 | parser.add_argument('--range_beta', type=float, default=1) 114 | parser.add_argument('--range_loss_weight', type=float, default=1) 115 | parser.add_argument('--warmup_factor', type=float, default=0.01) 116 | parser.add_argument('--warmup_iters', type=float, default=10) 117 | parser.add_argument('--warmup_method', type=str, default='linear') 118 | parser.add_argument('--margin', type=float, default=0.3) 119 | parser.add_argument('--optimizer_name', type=str, default="SGD") 120 | parser.add_argument('--momentum', type=float, default=0.9) 121 | 122 | 123 | # TRAINER 124 | parser.add_argument('--max_epochs', type=int, default=60) 125 | parser.add_argument('--train', type=str, default='train', help='train, test') # change train or test mode 126 | parser.add_argument('--resume', type=int, default=0) 127 | parser.add_argument('--num_works', type=int, default=8) 128 | 129 | # misc 130 | working_dir = os.path.dirname(os.path.abspath(__file__)) 131 | parser.add_argument('--dataset', type=str, default='prcc_gcn') 132 | parser.add_argument('--data_dir', type=str, default='/data/prcc/') 133 | parser.add_argument('--logs_dir', type=str, default=os.path.join(working_dir, 'logs/prcc_pcb_pix/')) 134 | 135 | cfg = parser.parse_args() 136 | if not os.path.exists(cfg.logs_dir): 137 | os.makedirs(cfg.logs_dir) 138 | 139 | main(cfg) 140 | -------------------------------------------------------------------------------- /train_prcc_pixel_sampling.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import argparse 4 | import os 5 | import sys 6 | import torch 7 | from torch.backends import cudnn 8 | 9 | sys.path.append('.') 10 | from data import make_data_loader_prcc_hrnet as make_data_loader 11 | from engine.trainer import do_train_prcc_pix_mse as do_train 12 | from modeling import build_model 13 | from layers import make_loss_with_triplet_entropy_mse 14 | from solver import make_optimizer_with_triplet 15 | from solver import WarmupMultiStepLR 16 | from engine.inference import inference_prcc_global, inference_prcc_visual_rank 17 | import datetime 18 | 19 | 20 | def load_network_pretrain(model, cfg): 21 | path = os.path.join(cfg.logs_dir, 'checkpoint_best.pth') 22 | if not os.path.exists(path): 23 | return model, 0, 0.0 24 | pre_dict = torch.load(path) 25 | model.load_state_dict(pre_dict['state_dict']) 26 | start_epoch = pre_dict['epoch'] 27 | best_acc = pre_dict['best_acc'] 28 | print('start_epoch:', start_epoch) 29 | print('best_acc:', best_acc) 30 | return model, start_epoch, best_acc 31 | 32 | 33 | 34 | def main(cfg): 35 | # prepare dataset 36 | train_loader, train_loader_ca, train_loader_cb, val_loader_c, val_loader_b, num_query_c, num_query_b, num_classes = make_data_loader(cfg, h=256, w=128) # num_query=3368, num_classes=751 37 | 38 | # prepare model 39 | model = build_model(num_classes, 'global') 40 | model = torch.nn.DataParallel(model).cuda() if torch.cuda.is_available() else model 41 | 42 | loss_func = make_loss_with_triplet_entropy_mse(cfg, num_classes) 43 | optimizer = make_optimizer_with_triplet(cfg, model) 44 | 45 | if cfg.lr_type == 'step': 46 | scheduler = WarmupMultiStepLR(optimizer, cfg.steps, cfg.gamma, cfg.warmup_factor, cfg.warmup_iters, cfg.warmup_method) 47 | else: 48 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=10, factor=0.5) 49 | 50 | if cfg.train == 'train': 51 | start_epoch = int(0) 52 | acc_best = 0.0 53 | if cfg.resume == 1: 54 | model, start_epoch, acc_best = load_network_pretrain(model, cfg) 55 | 56 | do_train(cfg, model, train_loader, val_loader_c, optimizer, scheduler, loss_func, num_query_c, start_epoch, acc_best, lr_type='step') 57 | elif cfg.train == 'test': 58 | # Test 59 | last_model_wts = torch.load(os.path.join(cfg.logs_dir, 'checkpoint_best.pth')) 60 | model.load_state_dict(last_model_wts['state_dict']) 61 | 62 | mAP, cmc1 = inference_prcc_global(model, val_loader_c, num_query_c) 63 | start_time = datetime.datetime.now() 64 | start_time = '%4d:%d:%d-%2d:%2d:%2d' % (start_time.year, start_time.month, start_time.day, start_time.hour, start_time.minute, start_time.second) 65 | line = '{} - Test: cmc1: {:.1%}, mAP: {:.1%}\n'.format(start_time, cmc1, mAP) 66 | print(line) 67 | elif cfg.train == 'rank': 68 | # Test 69 | last_model_wts = torch.load(os.path.join(cfg.logs_dir, 'checkpoint_best.pth')) 70 | model.load_state_dict(last_model_wts['state_dict']) 71 | 72 | home = os.path.join('logs', 'rank', os.path.basename(cfg.logs_dir)) 73 | inference_prcc_visual_rank(model, val_loader_c, num_query_c, home=home, show_rank=20, use_flip=True) 74 | print('finish') 75 | 76 | 77 | 78 | 79 | 80 | if __name__ == '__main__': 81 | gpu_id = 1 82 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id) 83 | cudnn.benchmark = True 84 | parser = argparse.ArgumentParser(description="ReID Baseline Training") 85 | 86 | # DATA 87 | parser.add_argument('--batch_size', type=int, default=64) 88 | parser.add_argument('--img_per_id', type=int, default=4) 89 | parser.add_argument('--batch_size_test', type=int, default=128) 90 | parser.add_argument('--workers', type=int, default=8) 91 | parser.add_argument('--height', type=int, default=256) 92 | parser.add_argument('--width', type=int, default=128) 93 | parser.add_argument('--height_mask', type=int, default=256) 94 | parser.add_argument('--width_mask', type=int, default=128) 95 | 96 | 97 | # MODEL 98 | parser.add_argument('--features', type=int, default=128) 99 | parser.add_argument('--dropout', type=float, default=0.0) 100 | parser.add_argument('--parts', type=int, default=14) 101 | 102 | # OPTIMIZER 103 | parser.add_argument('--seed', type=int, default=1) 104 | parser.add_argument('--lr', type=float, default=0.00035) 105 | parser.add_argument('--lr_center', type=float, default=0.5) 106 | parser.add_argument('--lr_type', type=str, default='step', help='step, plateau') 107 | parser.add_argument('--center_loss_weight', type=float, default=0.0005) 108 | parser.add_argument('--steps', type=list, default=[40, 80]) 109 | parser.add_argument('--gamma', type=float, default=0.1) 110 | parser.add_argument('--cluster_margin', type=float, default=0.3) 111 | parser.add_argument('--bias_lr_factor', type=float, default=1.0) 112 | parser.add_argument('--weight_decay', type=float, default=5e-4) 113 | parser.add_argument('--weight_decay_bias', type=float, default=5e-4) 114 | parser.add_argument('--range_k', type=float, default=2) 115 | parser.add_argument('--range_margin', type=float, default=0.3) 116 | parser.add_argument('--range_alpha', type=float, default=0) 117 | parser.add_argument('--range_beta', type=float, default=1) 118 | parser.add_argument('--range_loss_weight', type=float, default=1) 119 | parser.add_argument('--warmup_factor', type=float, default=0.01) 120 | parser.add_argument('--warmup_iters', type=float, default=10) 121 | parser.add_argument('--warmup_method', type=str, default='linear') 122 | parser.add_argument('--margin', type=float, default=0.3) 123 | parser.add_argument('--optimizer_name', type=str, default="SGD", help="SGD, Adam") 124 | parser.add_argument('--momentum', type=float, default=0.9) 125 | 126 | 127 | # TRAINER 128 | parser.add_argument('--max_epochs', type=int, default=60) 129 | parser.add_argument('--train', type=str, default='train', help='train, test, rank') # change train or test mode 130 | parser.add_argument('--resume', type=int, default=0) 131 | parser.add_argument('--num_works', type=int, default=0) 132 | 133 | # misc 134 | working_dir = os.path.dirname(os.path.abspath(__file__)) 135 | parser.add_argument('--dataset', type=str, default='prcc_gcn') 136 | parser.add_argument('--data_dir', type=str, default='/data/prcc/') 137 | parser.add_argument('--logs_dir', type=str, default=os.path.join(working_dir, 'logs/prcc_pix_mse')) 138 | 139 | cfg = parser.parse_args() 140 | if not os.path.exists(cfg.logs_dir): 141 | os.makedirs(cfg.logs_dir) 142 | 143 | main(cfg) 144 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/utils/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/distance.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/utils/__pycache__/distance.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/distance.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/utils/__pycache__/distance.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/distance.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/utils/__pycache__/distance.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/iotools.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/utils/__pycache__/iotools.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/iotools.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/utils/__pycache__/iotools.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/iotools.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/utils/__pycache__/iotools.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/utils/__pycache__/logger.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/utils/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/re_ranking.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/utils/__pycache__/re_ranking.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/re_ranking.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/utils/__pycache__/re_ranking.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/re_ranking.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/utils/__pycache__/re_ranking.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/reid_metric.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/utils/__pycache__/reid_metric.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/reid_metric.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/utils/__pycache__/reid_metric.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/reid_metric.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/utils/__pycache__/reid_metric.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/reid_tool.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/utils/__pycache__/reid_tool.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/rerank.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/utils/__pycache__/rerank.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/rerank.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/utils/__pycache__/rerank.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/rerank.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/utils/__pycache__/rerank.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/visual.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/utils/__pycache__/visual.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/visual.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/utils/__pycache__/visual.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/visual.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuxjweb/pixel_sampling/f686282a9609b8770c11647e7af04bd07e2f33ea/utils/__pycache__/visual.cpython-37.pyc -------------------------------------------------------------------------------- /utils/compute_map.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from data.datasets.eval_reid import eval_func 5 | 6 | 7 | def get_map(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): 8 | 9 | 10 | 11 | 12 | eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50) 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /utils/iotools.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import errno 8 | import json 9 | import os 10 | 11 | import os.path as osp 12 | 13 | 14 | def mkdir_if_missing(directory): 15 | if not osp.exists(directory): 16 | try: 17 | os.makedirs(directory) 18 | except OSError as e: 19 | if e.errno != errno.EEXIST: 20 | raise 21 | 22 | 23 | def check_isfile(path): 24 | isfile = osp.isfile(path) 25 | if not isfile: 26 | print("=> Warning: no file found at '{}' (ignored)".format(path)) 27 | return isfile 28 | 29 | 30 | def read_json(fpath): 31 | with open(fpath, 'r') as f: 32 | obj = json.load(f) 33 | return obj 34 | 35 | 36 | def write_json(obj, fpath): 37 | mkdir_if_missing(osp.dirname(fpath)) 38 | with open(fpath, 'w') as f: 39 | json.dump(obj, f, indent=4, separators=(',', ': ')) 40 | 41 | 42 | 43 | class AverageMeter(object): 44 | """Computes and stores the average and current value""" 45 | 46 | def __init__(self): 47 | self.val = 0 48 | self.avg = 0 49 | self.sum = 0 50 | self.count = 0 51 | 52 | def reset(self): 53 | self.val = 0 54 | self.avg = 0 55 | self.sum = 0 56 | self.count = 0 57 | 58 | def update(self, val, n=1): 59 | self.val = val 60 | self.sum += val * n 61 | self.count += n 62 | self.avg = self.sum / self.count 63 | 64 | 65 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import logging 8 | import os 9 | import sys 10 | 11 | 12 | def setup_logger(name, save_dir, distributed_rank): 13 | logger = logging.getLogger(name) 14 | logger.setLevel(logging.DEBUG) 15 | # don't log results for the non-master process 16 | if distributed_rank > 0: 17 | return logger 18 | ch = logging.StreamHandler(stream=sys.stdout) 19 | ch.setLevel(logging.DEBUG) 20 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 21 | ch.setFormatter(formatter) 22 | logger.addHandler(ch) 23 | 24 | if save_dir: 25 | fh = logging.FileHandler(os.path.join(save_dir, "log.txt"), mode='w') 26 | fh.setLevel(logging.DEBUG) 27 | fh.setFormatter(formatter) 28 | logger.addHandler(fh) 29 | 30 | return logger 31 | -------------------------------------------------------------------------------- /utils/re_ranking.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri, 25 May 2018 20:29:09 5 | 6 | @author: luohao 7 | """ 8 | 9 | """ 10 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 11 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 12 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 13 | """ 14 | 15 | """ 16 | API 17 | 18 | probFea: all feature vectors of the query set (torch tensor) 19 | probFea: all feature vectors of the gallery set (torch tensor) 20 | k1,k2,lambda: parameters, the original paper is (k1=20,k2=6,lambda=0.3) 21 | MemorySave: set to 'True' when using MemorySave mode 22 | Minibatch: avaliable when 'MemorySave' is 'True' 23 | """ 24 | 25 | import numpy as np 26 | import torch 27 | 28 | 29 | def re_ranking(probFea, galFea, k1, k2, lambda_value, local_distmat=None, only_local=False): 30 | # if feature vector is numpy, you should use 'torch.tensor' transform it to tensor 31 | query_num = probFea.size(0) 32 | all_num = query_num + galFea.size(0) 33 | if only_local: 34 | original_dist = local_distmat 35 | else: 36 | feat = torch.cat([probFea,galFea]) 37 | print('using GPU to compute original distance') 38 | distmat = torch.pow(feat,2).sum(dim=1, keepdim=True).expand(all_num,all_num) + \ 39 | torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num).t() 40 | distmat.addmm_(1,-2,feat,feat.t()) 41 | original_dist = distmat.cpu().numpy() 42 | del feat 43 | if not local_distmat is None: 44 | original_dist = original_dist + local_distmat 45 | gallery_num = original_dist.shape[0] 46 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0)) 47 | V = np.zeros_like(original_dist).astype(np.float16) 48 | initial_rank = np.argsort(original_dist).astype(np.int32) 49 | 50 | print('starting re_ranking') 51 | for i in range(all_num): 52 | # k-reciprocal neighbors 53 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 54 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 55 | fi = np.where(backward_k_neigh_index == i)[0] 56 | k_reciprocal_index = forward_k_neigh_index[fi] 57 | k_reciprocal_expansion_index = k_reciprocal_index 58 | for j in range(len(k_reciprocal_index)): 59 | candidate = k_reciprocal_index[j] 60 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1] 61 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, 62 | :int(np.around(k1 / 2)) + 1] 63 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 64 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 65 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len( 66 | candidate_k_reciprocal_index): 67 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index) 68 | 69 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 70 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) 71 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight) 72 | original_dist = original_dist[:query_num, ] 73 | if k2 != 1: 74 | V_qe = np.zeros_like(V, dtype=np.float16) 75 | for i in range(all_num): 76 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 77 | V = V_qe 78 | del V_qe 79 | del initial_rank 80 | invIndex = [] 81 | for i in range(gallery_num): 82 | invIndex.append(np.where(V[:, i] != 0)[0]) 83 | 84 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16) 85 | 86 | for i in range(query_num): 87 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16) 88 | indNonZero = np.where(V[i, :] != 0)[0] 89 | indImages = [invIndex[ind] for ind in indNonZero] 90 | for j in range(len(indNonZero)): 91 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]], 92 | V[indImages[j], indNonZero[j]]) 93 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 94 | 95 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value 96 | del original_dist 97 | del V 98 | del jaccard_dist 99 | final_dist = final_dist[:query_num, query_num:] 100 | return final_dist 101 | 102 | -------------------------------------------------------------------------------- /utils/reid_metric.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import numpy as np 4 | import torch 5 | from ignite.metrics import Metric 6 | from data.datasets.eval_reid import eval_func 7 | 8 | 9 | class R1_mAP(Metric): 10 | def __init__(self, num_query, max_rank=50, feat_norm='yes'): 11 | super(R1_mAP, self).__init__() 12 | self.num_query = num_query 13 | self.max_rank = max_rank 14 | self.feat_norm = feat_norm 15 | self.reset() 16 | 17 | def reset(self): 18 | self.feats = [] 19 | self.pids = [] 20 | self.camids = [] 21 | 22 | def update(self, output): 23 | feat, pid, camid = output 24 | self.feats.append(feat) 25 | self.pids.extend(np.asarray(pid)) 26 | self.camids.extend(np.asarray(camid)) 27 | 28 | def compute(self): 29 | feats = torch.cat(self.feats, dim=0) # [19281, 2048] 30 | if self.feat_norm == 'yes': 31 | feats = torch.nn.functional.normalize(feats, dim=1, p=2) 32 | # query 33 | qf = feats[:self.num_query] # [3368, 2048] 34 | q_pids = np.asarray(self.pids[:self.num_query]) # [3368,] 35 | q_camids = np.asarray(self.camids[:self.num_query]) # [3368,] 36 | # gallery 37 | gf = feats[self.num_query:] # [15913, 2048] 38 | g_pids = np.asarray(self.pids[self.num_query:]) # [15913,] 39 | g_camids = np.asarray(self.camids[self.num_query:]) # [15913,] 40 | m, n = qf.shape[0], gf.shape[0] 41 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 42 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 43 | # distmat.addmm_(1, -2, qf, gf.t()) # [3368, 15913] 44 | distmat = distmat - 2 * torch.matmul(qf, gf.t()) 45 | distmat = distmat.cpu().numpy() 46 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids) 47 | 48 | return cmc, mAP 49 | 50 | -------------------------------------------------------------------------------- /utils/reid_tool.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | import numpy as np 5 | import os 6 | import os.path as osp 7 | import shutil 8 | 9 | from .iotools import mkdir_if_missing 10 | 11 | 12 | def visualize_ranked_results(distmat, dataset, save_dir='log/ranked_results', topk=20): # [3368, 15913] 13 | """ 14 | Visualize ranked results 15 | 16 | Support both imgreid and vidreid 17 | 18 | Args: 19 | - distmat: distance matrix of shape (num_query, num_gallery). 20 | - dataset: a 2-tuple containing (query, gallery), each contains a list of (img_path, pid, camid); 21 | for imgreid, img_path is a string, while for vidreid, img_path is a tuple containing 22 | a sequence of strings. 23 | - save_dir: directory to save output images. 24 | - topk: int, denoting top-k images in the rank list to be visualized. 25 | """ 26 | num_q, num_g = distmat.shape # 3368, 15913 27 | 28 | print("Visualizing top-{} ranks".format(topk)) 29 | print("# query: {}\n# gallery {}".format(num_q, num_g)) 30 | print("Saving images to '{}'".format(save_dir)) 31 | 32 | query, gallery = dataset 33 | assert num_q == len(query) 34 | assert num_g == len(gallery) 35 | 36 | indices = np.argsort(distmat, axis=1) # [3368, 15913] 37 | mkdir_if_missing(save_dir) 38 | 39 | def _cp_img_to(src, dst, rank, prefix): 40 | """ 41 | - src: image path or tuple (for vidreid) 42 | - dst: target directory 43 | - rank: int, denoting ranked position, starting from 1 44 | - prefix: string 45 | """ 46 | if isinstance(src, tuple) or isinstance(src, list): 47 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3)) 48 | mkdir_if_missing(dst) 49 | for img_path in src: 50 | shutil.copy(img_path, dst) 51 | else: 52 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3) + '_name_' + osp.basename(src)) 53 | shutil.copy(src, dst) 54 | 55 | for q_idx in range(num_q): # 3368 56 | qimg_path, qpid, qcamid = query[q_idx] 57 | qdir = osp.join(save_dir, osp.basename(qimg_path)) 58 | mkdir_if_missing(qdir) 59 | _cp_img_to(qimg_path, qdir, rank=0, prefix='query') 60 | 61 | rank_idx = 1 62 | for g_idx in indices[q_idx, :]: 63 | gimg_path, gpid, gcamid = gallery[g_idx] 64 | invalid = (qpid == gpid) & (qcamid == gcamid) # True 65 | if not invalid: 66 | _cp_img_to(gimg_path, qdir, rank=rank_idx, prefix='gallery') 67 | rank_idx += 1 68 | if rank_idx > topk: 69 | break 70 | if q_idx % 100 == 0: 71 | print(num_q, q_idx + 1) 72 | 73 | print("Done") 74 | 75 | 76 | 77 | 78 | 79 | def visualize_ranked_results_all(distmat, dataset, save_dir='log/ranked_results', topk=20): # [3368, 15913] 80 | """ 81 | Visualize ranked results 82 | 83 | Support both imgreid and vidreid 84 | 85 | Args: 86 | - distmat: distance matrix of shape (num_query, num_gallery). 87 | - dataset: a 2-tuple containing (query, gallery), each contains a list of (img_path, pid, camid); 88 | for imgreid, img_path is a string, while for vidreid, img_path is a tuple containing 89 | a sequence of strings. 90 | - save_dir: directory to save output images. 91 | - topk: int, denoting top-k images in the rank list to be visualized. 92 | """ 93 | num_q, num_g = distmat.shape # 3368, 15913 94 | 95 | print("Visualizing top-{} ranks".format(topk)) 96 | print("# query: {}\n# gallery {}".format(num_q, num_g)) 97 | print("Saving images to '{}'".format(save_dir)) 98 | 99 | query, gallery = dataset 100 | assert num_q == len(query) 101 | assert num_g == len(gallery) 102 | 103 | indices = np.argsort(distmat, axis=1) # [3368, 15913] 104 | mkdir_if_missing(save_dir) 105 | 106 | def _cp_img_to(src, dst, rank, prefix): 107 | """ 108 | - src: image path or tuple (for vidreid) 109 | - dst: target directory 110 | - rank: int, denoting ranked position, starting from 1 111 | - prefix: string 112 | """ 113 | if isinstance(src, tuple) or isinstance(src, list): 114 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3)) 115 | mkdir_if_missing(dst) 116 | for img_path in src: 117 | shutil.copy(img_path, dst) 118 | else: 119 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3) + '_name_' + osp.basename(src)) 120 | shutil.copy(src, dst) 121 | 122 | for q_idx in range(num_q): # 3368 123 | qimg_path, qpid, qcamid = query[q_idx] 124 | qdir = osp.join(save_dir, osp.basename(qimg_path)) 125 | mkdir_if_missing(qdir) 126 | _cp_img_to(qimg_path, qdir, rank=0, prefix='query') 127 | 128 | rank_idx = 1 129 | for g_idx in indices[q_idx, :]: 130 | gimg_path, gpid, gcamid = gallery[g_idx] 131 | invalid = (qpid == gpid) & (qcamid == gcamid) # True 132 | if not invalid: 133 | if rank_idx > topk and int(qpid) != int(gpid): 134 | continue 135 | _cp_img_to(gimg_path, qdir, rank=rank_idx, prefix='gallery') 136 | rank_idx += 1 137 | 138 | if q_idx % 100 == 0: 139 | print(num_q, q_idx + 1) 140 | 141 | print("Done") 142 | 143 | 144 | -------------------------------------------------------------------------------- /utils/visual.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | import numpy as np 5 | import os 6 | import os.path as osp 7 | import shutil 8 | import errno 9 | 10 | 11 | 12 | def mkdir_if_missing(directory): 13 | if not osp.exists(directory): 14 | try: 15 | os.makedirs(directory) 16 | except OSError as e: 17 | if e.errno != errno.EEXIST: 18 | raise 19 | 20 | 21 | def visualize_ranked_results(distmat, query, gallery, save_dir='log/ranked_results', topk=20): # [3368, 15913] 22 | """ 23 | Visualize ranked results 24 | 25 | Support both imgreid and vidreid 26 | 27 | Args: 28 | - distmat: distance matrix of shape (num_query, num_gallery). 29 | - dataset: a 2-tuple containing (query, gallery), each contains a list of (img_path, pid, camid); 30 | for imgreid, img_path is a string, while for vidreid, img_path is a tuple containing 31 | a sequence of strings. 32 | - save_dir: directory to save output images. 33 | - topk: int, denoting top-k images in the rank list to be visualized. 34 | """ 35 | num_q, num_g = distmat.shape # 3368, 15913 36 | 37 | print("Visualizing top-{} ranks".format(topk)) 38 | print("# query: {}\n# gallery {}".format(num_q, num_g)) 39 | print("Saving images to '{}'".format(save_dir)) 40 | 41 | assert num_q == len(query) 42 | assert num_g == len(gallery) 43 | 44 | indices = np.argsort(distmat, axis=1) # [3368, 15913] 45 | mkdir_if_missing(save_dir) 46 | 47 | def _cp_img_to(src, dst, rank, prefix): 48 | """ 49 | - src: image path or tuple (for vidreid) 50 | - dst: target directory 51 | - rank: int, denoting ranked position, starting from 1 52 | - prefix: string 53 | """ 54 | if isinstance(src, tuple) or isinstance(src, list): 55 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3)) 56 | mkdir_if_missing(dst) 57 | for img_path in src: 58 | shutil.copy(img_path, dst) 59 | else: 60 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3) + '_name_' + osp.basename(src)) 61 | shutil.copy(src, dst) 62 | 63 | num_q = min(100, num_q) 64 | for q_idx in range(num_q): # 3368 65 | qimg_path, qpid, qcamid = query[q_idx] 66 | qdir = osp.join(save_dir, osp.basename(qimg_path)) 67 | mkdir_if_missing(qdir) 68 | _cp_img_to(qimg_path, qdir, rank=0, prefix='query') 69 | 70 | rank_idx = 1 71 | for g_idx in indices[q_idx, :]: 72 | gimg_path, gpid, gcamid = gallery[g_idx] 73 | invalid = (qpid == gpid) & (qcamid == gcamid) # True 74 | if not invalid: 75 | _cp_img_to(gimg_path, qdir, rank=rank_idx, prefix='gallery') 76 | rank_idx += 1 77 | if rank_idx > topk: 78 | break 79 | 80 | print("Done") 81 | --------------------------------------------------------------------------------