├── README.md ├── network.pdf ├── network.png ├── reid ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── evaluators.cpython-36.pyc │ ├── evaluators.cpython-37.pyc │ ├── evaluators_dis.cpython-37.pyc │ ├── evaluators_l.cpython-37.pyc │ ├── evaluators_new.cpython-37.pyc │ ├── evaluators_notrans.cpython-37.pyc │ ├── evaluators_total.cpython-37.pyc │ ├── evaluators_transfer.cpython-36.pyc │ ├── evaluators_transfer.cpython-37.pyc │ ├── evaluators_transfer2.cpython-36.pyc │ ├── evaluators_transfer2.cpython-37.pyc │ ├── trainer_tsne.cpython-37.pyc │ ├── trainers.cpython-36.pyc │ ├── trainers.cpython-37.pyc │ ├── trainers_alldis.cpython-37.pyc │ ├── trainers_dis.cpython-37.pyc │ ├── trainers_gan.cpython-37.pyc │ ├── trainers_kd.cpython-36.pyc │ ├── trainers_kd.cpython-37.pyc │ ├── trainers_mem.cpython-36.pyc │ ├── trainers_mem.cpython-37.pyc │ ├── trainers_mmd.cpython-37.pyc │ ├── trainers_newtransfer.cpython-36.pyc │ ├── trainers_newtransfer.cpython-37.pyc │ ├── trainers_resnet.cpython-37.pyc │ ├── trainers_total.cpython-37.pyc │ ├── trainers_transfer.cpython-37.pyc │ ├── trainers_transfer2.cpython-36.pyc │ ├── trainers_transfer2.cpython-37.pyc │ ├── trainers_transmem.cpython-36.pyc │ └── trainers_transmem.cpython-37.pyc ├── datasets │ ├── TestData.py │ ├── TotalData.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── TestData.cpython-36.pyc │ │ ├── TestData.cpython-37.pyc │ │ ├── TotalData.cpython-36.pyc │ │ ├── TotalData.cpython-37.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── domain_adaptation.cpython-36.pyc │ │ ├── domain_adaptation.cpython-37.pyc │ │ ├── duke.cpython-36.pyc │ │ ├── duke.cpython-37.pyc │ │ ├── market.cpython-36.pyc │ │ └── market.cpython-37.pyc │ └── domain_adaptation.py ├── evaluation_metrics │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── classification.cpython-36.pyc │ │ ├── classification.cpython-37.pyc │ │ ├── ranking.cpython-36.pyc │ │ └── ranking.cpython-37.pyc │ ├── classification.py │ └── ranking.py ├── evaluators.py ├── loss │ ├── .tripletloss.py.swp │ ├── __init__.py │ ├── __pycache__ │ │ ├── MLS.cpython-37.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── dual.cpython-36.pyc │ │ ├── dual.cpython-37.pyc │ │ ├── gyneightboor.cpython-36.pyc │ │ ├── gyneightboor.cpython-37.pyc │ │ ├── identityloss.cpython-37.pyc │ │ ├── lsr.cpython-37.pyc │ │ ├── mmd.cpython-36.pyc │ │ ├── mmd.cpython-37.pyc │ │ ├── neightboor.cpython-36.pyc │ │ ├── neightboor.cpython-37.pyc │ │ ├── newneightboor.cpython-36.pyc │ │ ├── newneightboor.cpython-37.pyc │ │ ├── triplet.cpython-37.pyc │ │ ├── tripletloss.cpython-36.pyc │ │ ├── tripletloss.cpython-37.pyc │ │ ├── wassdistance.cpython-36.pyc │ │ └── wassdistance.cpython-37.pyc │ ├── identityloss.py │ ├── tripletloss.py │ └── wassdistance.py ├── models │ ├── DDAN.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── DDAN.cpython-37.pyc │ │ ├── PFE.cpython-37.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── basenet.cpython-36.pyc │ │ ├── basenet.cpython-37.pyc │ │ ├── disnet.cpython-36.pyc │ │ ├── disnet.cpython-37.pyc │ │ ├── gannet.cpython-36.pyc │ │ ├── gannet.cpython-37.pyc │ │ ├── kdnet.cpython-36.pyc │ │ ├── kdnet.cpython-37.pyc │ │ ├── mobilenet.cpython-36.pyc │ │ ├── mobilenet.cpython-37.pyc │ │ ├── mobilenet_o.cpython-36.pyc │ │ ├── mobilenet_o.cpython-37.pyc │ │ ├── resnet.cpython-36.pyc │ │ ├── resnet.cpython-37.pyc │ │ ├── totalnet.cpython-36.pyc │ │ ├── totalnet.cpython-37.pyc │ │ ├── transfernet.cpython-36.pyc │ │ ├── transfernet.cpython-37.pyc │ │ ├── transfernet2.cpython-36.pyc │ │ └── transfernet2.cpython-37.pyc │ └── mobilenet.py ├── trainers.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── logging.cpython-36.pyc │ ├── logging.cpython-37.pyc │ ├── meters.cpython-36.pyc │ ├── meters.cpython-37.pyc │ ├── osutils.cpython-36.pyc │ ├── osutils.cpython-37.pyc │ ├── serialization.cpython-36.pyc │ └── serialization.cpython-37.pyc │ ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── preprocessor.cpython-36.pyc │ │ ├── preprocessor.cpython-37.pyc │ │ ├── sampler.cpython-36.pyc │ │ ├── sampler.cpython-37.pyc │ │ ├── transforms.cpython-36.pyc │ │ └── transforms.cpython-37.pyc │ ├── preprocessor.py │ ├── sampler.py │ └── transforms.py │ ├── logging.py │ ├── meters.py │ ├── osutils.py │ └── serialization.py ├── run.sh └── transmem.py /README.md: -------------------------------------------------------------------------------- 1 | # Dual Distribution Alignment Network for Generalizable Person Re-Identification 2 | 3 | This resposity is the official implementation of our AAAI2021 [Paper](https://arxiv.org/abs/2007.13249). 4 | ![network](./network.png) 5 | 6 | #### Dependent 7 | * Python 3.7.5 8 | * PyTorch == 1.3.1 9 | #### Datasets 10 | You maybe need the [CUHK-SYSU](https://drive.google.com/file/d/1yoQOTp--ULGPct6erCsAQ_hd46hENE5G/view?usp=sharing) and the [iLIDS](https://drive.google.com/file/d/1_2bYbnH0GIDE6BfjZdtWVQE2nK134ZLi/view?usp=sharing). 11 | Other datasets are available from the original author's open source website. 12 | #### Usage 13 | * train/test 14 | ``` 15 | bash run.sh 16 | ``` 17 | #### Model 18 | [Download](https://drive.google.com/file/d/1ece571WcZ3ietIfDqA31yZ1loe0GfBka/view?usp=sharing) the DDAN model. 19 | 20 | #### Citations 21 | If our paper helps your research, please cite it in your publications: 22 | ``` 23 | @inproceedings{AAAI: 2021, 24 | Author={Peixian Chen, Pingyang Dai, Jianzhuang Liu, Feng Zheng, Minglinag Xu, Qi Tian, Rongrong Ji.}, 25 | Title={Dual Distribution Alignment Network for Generalizable Person Re-Identification}, 26 | Booktitle={Thirty-Fifth AAAI Conference on Artificial Intelligence}, 27 | Year={2021}, Accept} 28 | ``` 29 | 30 | -------------------------------------------------------------------------------- /network.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/network.pdf -------------------------------------------------------------------------------- /network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/network.png -------------------------------------------------------------------------------- /reid/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import datasets 4 | from . import evaluation_metrics 5 | from . import loss 6 | from . import models 7 | from . import utils 8 | from . import evaluators 9 | from . import trainers 10 | -------------------------------------------------------------------------------- /reid/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /reid/__pycache__/evaluators.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/evaluators.cpython-36.pyc -------------------------------------------------------------------------------- /reid/__pycache__/evaluators.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/evaluators.cpython-37.pyc -------------------------------------------------------------------------------- /reid/__pycache__/evaluators_dis.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/evaluators_dis.cpython-37.pyc -------------------------------------------------------------------------------- /reid/__pycache__/evaluators_l.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/evaluators_l.cpython-37.pyc -------------------------------------------------------------------------------- /reid/__pycache__/evaluators_new.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/evaluators_new.cpython-37.pyc -------------------------------------------------------------------------------- /reid/__pycache__/evaluators_notrans.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/evaluators_notrans.cpython-37.pyc -------------------------------------------------------------------------------- /reid/__pycache__/evaluators_total.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/evaluators_total.cpython-37.pyc -------------------------------------------------------------------------------- /reid/__pycache__/evaluators_transfer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/evaluators_transfer.cpython-36.pyc -------------------------------------------------------------------------------- /reid/__pycache__/evaluators_transfer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/evaluators_transfer.cpython-37.pyc -------------------------------------------------------------------------------- /reid/__pycache__/evaluators_transfer2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/evaluators_transfer2.cpython-36.pyc -------------------------------------------------------------------------------- /reid/__pycache__/evaluators_transfer2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/evaluators_transfer2.cpython-37.pyc -------------------------------------------------------------------------------- /reid/__pycache__/trainer_tsne.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/trainer_tsne.cpython-37.pyc -------------------------------------------------------------------------------- /reid/__pycache__/trainers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/trainers.cpython-36.pyc -------------------------------------------------------------------------------- /reid/__pycache__/trainers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/trainers.cpython-37.pyc -------------------------------------------------------------------------------- /reid/__pycache__/trainers_alldis.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/trainers_alldis.cpython-37.pyc -------------------------------------------------------------------------------- /reid/__pycache__/trainers_dis.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/trainers_dis.cpython-37.pyc -------------------------------------------------------------------------------- /reid/__pycache__/trainers_gan.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/trainers_gan.cpython-37.pyc -------------------------------------------------------------------------------- /reid/__pycache__/trainers_kd.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/trainers_kd.cpython-36.pyc -------------------------------------------------------------------------------- /reid/__pycache__/trainers_kd.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/trainers_kd.cpython-37.pyc -------------------------------------------------------------------------------- /reid/__pycache__/trainers_mem.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/trainers_mem.cpython-36.pyc -------------------------------------------------------------------------------- /reid/__pycache__/trainers_mem.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/trainers_mem.cpython-37.pyc -------------------------------------------------------------------------------- /reid/__pycache__/trainers_mmd.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/trainers_mmd.cpython-37.pyc -------------------------------------------------------------------------------- /reid/__pycache__/trainers_newtransfer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/trainers_newtransfer.cpython-36.pyc -------------------------------------------------------------------------------- /reid/__pycache__/trainers_newtransfer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/trainers_newtransfer.cpython-37.pyc -------------------------------------------------------------------------------- /reid/__pycache__/trainers_resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/trainers_resnet.cpython-37.pyc -------------------------------------------------------------------------------- /reid/__pycache__/trainers_total.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/trainers_total.cpython-37.pyc -------------------------------------------------------------------------------- /reid/__pycache__/trainers_transfer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/trainers_transfer.cpython-37.pyc -------------------------------------------------------------------------------- /reid/__pycache__/trainers_transfer2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/trainers_transfer2.cpython-36.pyc -------------------------------------------------------------------------------- /reid/__pycache__/trainers_transfer2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/trainers_transfer2.cpython-37.pyc -------------------------------------------------------------------------------- /reid/__pycache__/trainers_transmem.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/trainers_transmem.cpython-36.pyc -------------------------------------------------------------------------------- /reid/__pycache__/trainers_transmem.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/__pycache__/trainers_transmem.cpython-37.pyc -------------------------------------------------------------------------------- /reid/datasets/TestData.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import numpy as np 4 | import pdb 5 | from glob import glob 6 | import re 7 | import random 8 | 9 | 10 | class TestData(object): 11 | 12 | def __init__(self, data_dir): 13 | 14 | self.prid_images_dir = osp.join(data_dir, "DG/PRID") 15 | self.grid_images_dir = osp.join(data_dir, "DG/GRID") 16 | self.viper_images_dir = osp.join(data_dir, "DG/VIPeR") 17 | self.ilid_images_dir = osp.join(data_dir, "DG/iLIDS") 18 | 19 | self.PRID_ID = random.sample([i for i in range(1, 201)], 100) 20 | self.GRID_ID = random.sample([i for i in range(1, 251)], 125) 21 | self.VIPeR_ID = random.sample([i for i in range(1, 633)], 316) 22 | self.ILID_ID = random.sample([i for i in range(1,120)], 60) 23 | 24 | 25 | self.PRID_Pair = [i for i in range(1, 201)] 26 | self.GRID_Pair = [i for i in range(1, 251)] 27 | self.VIPeR_Pair = [i for i in range(1, 633)] 28 | 29 | 30 | # training image dir 31 | self.gallery_path = 'gallery_test' 32 | self.query_path = 'query_test' 33 | self.load() 34 | 35 | def preprocess(self, images_dir, path, queryID, IDpair, relabel=True): 36 | domain = 5 37 | pattern = re.compile(r'([-\d]+)_c?(\d)') 38 | all_pids = {} 39 | ret = [] 40 | fpaths = sorted(glob(osp.join(images_dir, path, '*.jpg'))) 41 | type = ['*jpeg', '*.png', '*bmp'] 42 | t = 0 43 | while fpaths == []: 44 | fpaths = sorted(glob(osp.join(images_dir, path, type[t]))) 45 | t += 1 46 | for fpath in fpaths: 47 | fname = osp.basename(fpath) 48 | pid, cam = map(int, pattern.search(fname).groups()) 49 | # if path == self.query_path and pid not in queryID: 50 | # continue 51 | # if path == self.gallery_path and (pid not in queryID and pid in IDpair): 52 | # continue 53 | if pid == -1: continue 54 | if relabel: 55 | if pid not in all_pids: 56 | all_pids[pid] = len(all_pids) 57 | else: 58 | if pid not in all_pids: 59 | all_pids[pid] = pid 60 | pid = all_pids[pid] 61 | cam -= 1 62 | ret.append((fname, pid, cam, cam)) 63 | return ret, int(len(all_pids)) 64 | 65 | def ILIDS_preprocess(self, images_dir, path, queryID, relabel=True): 66 | domain = 5 67 | ilid_query, ilid_gallery = [], [] 68 | queryid, galleryid = [], [] 69 | 70 | pattern = re.compile(r'([-\d]+)_c?(\d)') 71 | all_pids = {} 72 | ret = [] 73 | fpaths = glob(osp.join(images_dir, path, '*.jpg')) 74 | type = ['*jpeg', '*.png', '*bmp'] 75 | t = 0 76 | while fpaths == []: 77 | fpaths = glob(osp.join(images_dir, path, type[t])) 78 | t += 1 79 | for fpath in fpaths: 80 | fname = osp.basename(fpath) 81 | pid, cam = map(int, pattern.search(fname).groups()) 82 | # if (pid not in queryID) or ((pid in queryid) and (pid in galleryid)): 83 | # continue 84 | if pid == -1: continue 85 | if relabel: 86 | if pid not in all_pids: 87 | all_pids[pid] = len(all_pids) 88 | else: 89 | if pid not in all_pids: 90 | all_pids[pid] = pid 91 | pid = all_pids[pid] 92 | cam -= 1 93 | if pid not in queryid: 94 | ilid_query.append((fname, pid, cam, cam)) 95 | queryid.append(pid) 96 | else: 97 | ilid_gallery.append((fname, pid, cam,cam)) 98 | galleryid.append(pid) 99 | return ilid_query, ilid_gallery, int(len(all_pids)) 100 | 101 | def load(self): 102 | # DG TEST 103 | self.prid_gallery, self.num_gallery_ids = self.preprocess(self.prid_images_dir, self.gallery_path, self.PRID_ID, self.PRID_Pair, False) 104 | self.prid_query, self.num_query_ids = self.preprocess(self.prid_images_dir, self.query_path, self.PRID_ID, self.PRID_Pair, False) 105 | self.grid_gallery, self.num_gallery_ids = self.preprocess(self.grid_images_dir, self.gallery_path, self.GRID_ID, self.GRID_Pair, False) 106 | self.grid_query, self.num_query_ids = self.preprocess(self.grid_images_dir, self.query_path, self.GRID_ID, self.GRID_Pair, False) 107 | self.viper_gallery, self.num_gallery_ids = self.preprocess(self.viper_images_dir, self.gallery_path, self.VIPeR_ID, self.VIPeR_Pair, False) 108 | self.viper_query, self.num_query_ids = self.preprocess(self.viper_images_dir, self.query_path, self.VIPeR_ID, self.VIPeR_Pair, False) 109 | 110 | self.ilid_query, self.ilid_gallery, self.num_query_ids = self.ILIDS_preprocess(self.ilid_images_dir, "images", self.ILID_ID, False) 111 | 112 | print(self.__class__.__name__, "dataset loaded") 113 | print(" subset | # query | # gallery") 114 | print(" ------------------------------------") 115 | print(" prid train | {:8d} | {:8d} " 116 | .format(len(self.prid_query), len(self.prid_gallery))) 117 | print(" grid train | {:8d} | {:8d} " 118 | .format(len(self.grid_query), len(self.grid_gallery))) 119 | print(" viper train | {:8d} | {:8d} " 120 | .format(len(self.viper_query), len(self.viper_gallery))) 121 | print(" ilid train | {:8d} | {:8d} " 122 | .format(len(self.ilid_query), len(self.ilid_gallery))) 123 | -------------------------------------------------------------------------------- /reid/datasets/TotalData.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import numpy as np 4 | import pdb 5 | from glob import glob 6 | import re 7 | import random 8 | 9 | 10 | class TotalData(object): 11 | 12 | def __init__(self, data_dir): 13 | 14 | self.prid_images_dir = osp.join(data_dir, "DG/PRID") 15 | self.grid_images_dir = osp.join(data_dir, "DG/GRID") 16 | self.viper_images_dir = osp.join(data_dir, "DG/VIPeR") 17 | self.ilid_images_dir = osp.join(data_dir, "DG/iLIDS") 18 | 19 | # training image dir 20 | self.gallery_path = 'gallery_test' 21 | self.query_path = 'query_test' 22 | self.load() 23 | 24 | def preprocess(self, images_dir, path, relabel=True): 25 | domain = 5 26 | pattern = re.compile(r'([-\d]+)_c?(\d)') 27 | all_pids = {} 28 | ret = [] 29 | fpaths = sorted(glob(osp.join(images_dir, path, '*.jpg'))) 30 | type = ['*jpeg', '*.png', '*bmp'] 31 | t = 0 32 | while fpaths == []: 33 | fpaths = sorted(glob(osp.join(images_dir, path, type[t]))) 34 | t += 1 35 | for fpath in fpaths: 36 | fname = osp.basename(fpath) 37 | pid, cam = map(int, pattern.search(fname).groups()) 38 | if pid == -1: continue 39 | if relabel: 40 | if pid not in all_pids: 41 | all_pids[pid] = len(all_pids) 42 | else: 43 | if pid not in all_pids: 44 | all_pids[pid] = pid 45 | pid = all_pids[pid] 46 | cam -= 1 47 | ret.append((fname, pid, cam, cam)) 48 | return ret, int(len(all_pids)) 49 | 50 | def ILIDS_preprocess(self, images_dir, path, relabel=True): 51 | domain = 5 52 | ilid_query, ilid_gallery = [], [] 53 | 54 | pattern = re.compile(r'([-\d]+)_c?(\d)') 55 | all_pids = {} 56 | ret = [] 57 | fpaths = glob(osp.join(images_dir, path, '*.jpg')) 58 | type = ['*jpeg', '*.png', '*bmp'] 59 | t = 0 60 | while fpaths == []: 61 | fpaths = sorted(glob(osp.join(images_dir, path, type[t]))) 62 | t += 1 63 | for fpath in fpaths: 64 | fname = osp.basename(fpath) 65 | pid, cam = map(int, pattern.search(fname).groups()) 66 | if pid == -1: continue 67 | if relabel: 68 | if pid not in all_pids: 69 | all_pids[pid] = len(all_pids) 70 | else: 71 | if pid not in all_pids: 72 | all_pids[pid] = pid 73 | pid = all_pids[pid] 74 | cam -= 1 75 | ilid_query.append((fname, pid, cam, cam)) 76 | ilid_gallery.append((fname, pid, cam,cam)) 77 | return ilid_query, ilid_gallery, int(len(all_pids)) 78 | 79 | def load(self): 80 | # DG TEST 81 | self.prid_gallery, self.num_gallery_ids = self.preprocess(self.prid_images_dir, self.gallery_path, False) 82 | self.prid_query, self.num_query_ids = self.preprocess(self.prid_images_dir, self.query_path, False) 83 | self.grid_gallery, self.num_gallery_ids = self.preprocess(self.grid_images_dir, self.gallery_path, False) 84 | self.grid_query, self.num_query_ids = self.preprocess(self.grid_images_dir, self.query_path, False) 85 | self.viper_gallery, self.num_gallery_ids = self.preprocess(self.viper_images_dir, self.gallery_path, False) 86 | self.viper_query, self.num_query_ids = self.preprocess(self.viper_images_dir, self.query_path, False) 87 | 88 | self.ilid_query, self.ilid_gallery, self.num_query_ids = self.ILIDS_preprocess(self.ilid_images_dir, "images", False) 89 | 90 | print(self.__class__.__name__, "dataset loaded") 91 | print(" subset | # query | # gallery") 92 | print(" ------------------------------------") 93 | print(" prid train | {:8d} | {:8d} " 94 | .format(len(self.prid_query), len(self.prid_gallery))) 95 | print(" grid train | {:8d} | {:8d} " 96 | .format(len(self.grid_query), len(self.grid_gallery))) 97 | print(" viper train | {:8d} | {:8d} " 98 | .format(len(self.viper_query), len(self.viper_gallery))) 99 | print(" ilid train | {:8d} | {:8d} " 100 | .format(len(self.ilid_query), len(self.ilid_gallery))) 101 | -------------------------------------------------------------------------------- /reid/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | def names(): 4 | return sorted(__factory.keys()) 5 | 6 | 7 | def create(name, root, *args, **kwargs): 8 | """ 9 | Create a dataset instance. 10 | 11 | Parameters 12 | ---------- 13 | name : str 14 | The dataset name. Can be one of 'market', 'duke'. 15 | root : str 16 | The path to the dataset directory. 17 | """ 18 | if name not in __factory: 19 | raise KeyError("Unknown dataset:", name) 20 | return __factory[name](root, *args, **kwargs) 21 | 22 | -------------------------------------------------------------------------------- /reid/datasets/__pycache__/TestData.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/datasets/__pycache__/TestData.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/TestData.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/datasets/__pycache__/TestData.cpython-37.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/TotalData.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/datasets/__pycache__/TotalData.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/TotalData.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/datasets/__pycache__/TotalData.cpython-37.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/domain_adaptation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/datasets/__pycache__/domain_adaptation.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/domain_adaptation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/datasets/__pycache__/domain_adaptation.cpython-37.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/duke.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/datasets/__pycache__/duke.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/duke.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/datasets/__pycache__/duke.cpython-37.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/market.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/datasets/__pycache__/market.cpython-36.pyc -------------------------------------------------------------------------------- /reid/datasets/__pycache__/market.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/datasets/__pycache__/market.cpython-37.pyc -------------------------------------------------------------------------------- /reid/datasets/domain_adaptation.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import numpy as np 4 | import pdb 5 | from glob import glob 6 | import re 7 | 8 | 9 | class DA(object): 10 | 11 | def __init__(self, data_dir): 12 | 13 | # training image dir 14 | self.market_images_dir = osp.join(data_dir, "DG/market1501") 15 | self.duke_images_dir = osp.join(data_dir, "DG/DukeMTMC-reID") 16 | self.sysu_images_dir = osp.join(data_dir, "DG/CUHK-SYSU") 17 | self.cuhk03_images_dir = osp.join(data_dir, "DG/cuhk03") 18 | self.cuhk02_images_dir = osp.join(data_dir, "DG/cuhk02") 19 | 20 | self.source_train_path = 'bounding_box_train' 21 | self.load() 22 | 23 | def preprocess(self, images_dir, path, relabel=True): 24 | add_pid = 0 25 | add_cam = 0 26 | domain_all = 0 27 | domain = 0 28 | 29 | if images_dir == self.duke_images_dir: 30 | add_pid = 1501 31 | add_cam = 10 32 | domain_all = 1 33 | domain = 0 34 | if images_dir == self.cuhk03_images_dir: 35 | add_pid = 1501+1812 36 | add_cam = 10+2 37 | domain_all = 2 38 | domain = 0 39 | if images_dir == self.cuhk02_images_dir: 40 | add_pid = 1501+1812+1467 41 | add_cam = 10+2+8 42 | domain_all = 3 43 | domain = 0 44 | if images_dir == self.sysu_images_dir: 45 | add_pid = 1501+1816+1467+1812 46 | add_cam = 10+2+8+6 47 | domain_all = 4 48 | domain = 1 49 | 50 | pattern = re.compile(r'([-\d]+)_c?(\d)') 51 | all_pids = {} 52 | ret = [] 53 | fpaths = sorted(glob(osp.join(images_dir, path, '*.jpg'))) 54 | type = ['*jpeg', '*.png', '*bmp'] 55 | t = 0 56 | while fpaths == []: 57 | fpaths = sorted(glob(osp.join(images_dir, path, type[t]))) 58 | t += 1 59 | for fpath in fpaths: 60 | # fname = osp.basename(fpath) 61 | fname = fpath 62 | pid, cam = map(int, pattern.search(fname).groups()) 63 | if pid == -1: continue 64 | if relabel: 65 | if pid not in all_pids: 66 | all_pids[pid] = len(all_pids) 67 | else: 68 | if pid not in all_pids: 69 | all_pids[pid] = pid 70 | pid = all_pids[pid] 71 | cam -= 1 72 | ret.append((fname, pid+add_pid, domain, domain_all)) 73 | return ret, int(len(all_pids)) 74 | 75 | def load(self): 76 | self.market_train, self.num_market_ids = self.preprocess(self.market_images_dir, self.source_train_path) 77 | self.duke_train, self.num_duke_ids = self.preprocess(self.duke_images_dir, self.source_train_path) 78 | self.sysu_train, self.num_sysu_ids = self.preprocess(self.sysu_images_dir, self.source_train_path) 79 | self.cuhk03_train, self.num_cuhk03_ids = self.preprocess(self.cuhk03_images_dir, self.source_train_path) 80 | self.cuhk02_train, self.num_cuhk02_ids = self.preprocess(self.cuhk02_images_dir, self.source_train_path) 81 | 82 | unique_ids = set() 83 | all_dataset = [self.market_train, self.duke_train, self.cuhk03_train, self.cuhk02_train, self.sysu_train] 84 | for dataset in all_dataset: 85 | ids = set(i for _, i, _, _ in dataset) 86 | assert not unique_ids & ids 87 | unique_ids |= ids 88 | 89 | self.source_train = self.market_train + self.duke_train + self.cuhk03_train + self.cuhk02_train + self.sysu_train 90 | self.num_source_ids = self.num_market_ids + self.num_duke_ids + self.num_cuhk03_ids + self.num_cuhk02_ids + self.num_sysu_ids 91 | 92 | print(self.__class__.__name__, "dataset loaded") 93 | print(" subset | # ids | # images") 94 | print(" ------------------------------------") 95 | print(" market train | {:5d} | {:8d}" 96 | .format(self.num_market_ids, len(self.market_train))) 97 | print(" duke train | {:5d} | {:8d}" 98 | .format(self.num_duke_ids, len(self.duke_train))) 99 | print(" cuhk03 train | {:5d} | {:8d}" 100 | .format(self.num_cuhk03_ids, len(self.cuhk03_train))) 101 | print(" sysu train | {:5d} | {:8d}" 102 | .format(self.num_sysu_ids, len(self.sysu_train))) 103 | print(" cuhk02 train | {:5d} | {:8d}" 104 | .format(self.num_cuhk02_ids, len(self.cuhk02_train))) 105 | print(" total train | {:5d} | {:8d}" 106 | .format(self.num_source_ids, len(self.source_train))) 107 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .classification import accuracy 4 | from .ranking import cmc, mean_ap 5 | 6 | __all__ = [ 7 | 'accuracy', 8 | 'cmc', 9 | 'mean_ap', 10 | ] 11 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/evaluation_metrics/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/evaluation_metrics/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/evaluation_metrics/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /reid/evaluation_metrics/__pycache__/classification.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/evaluation_metrics/__pycache__/classification.cpython-36.pyc -------------------------------------------------------------------------------- /reid/evaluation_metrics/__pycache__/classification.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/evaluation_metrics/__pycache__/classification.cpython-37.pyc -------------------------------------------------------------------------------- /reid/evaluation_metrics/__pycache__/ranking.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/evaluation_metrics/__pycache__/ranking.cpython-36.pyc -------------------------------------------------------------------------------- /reid/evaluation_metrics/__pycache__/ranking.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/evaluation_metrics/__pycache__/ranking.cpython-37.pyc -------------------------------------------------------------------------------- /reid/evaluation_metrics/classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from ..utils import to_torch 4 | 5 | 6 | def accuracy(output, target, topk=(1,)): 7 | output, target = to_torch(output), to_torch(target) 8 | maxk = max(topk) 9 | batch_size = target.size(0) 10 | 11 | _, pred = output.topk(maxk, 1, True, True) 12 | pred = pred.t() 13 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 14 | 15 | ret = [] 16 | for k in topk: 17 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) 18 | ret.append(correct_k.mul_(1. / batch_size)) 19 | return ret 20 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/ranking.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | from sklearn.metrics.base import _average_binary_score 6 | from sklearn.metrics import precision_recall_curve, auc 7 | # from sklearn.metrics import average_precision_score 8 | 9 | 10 | from ..utils import to_numpy 11 | 12 | 13 | def _unique_sample(ids_dict, num): 14 | mask = np.zeros(num, dtype=np.bool) 15 | for _, indices in ids_dict.items(): 16 | i = np.random.choice(indices) 17 | mask[i] = True 18 | return mask 19 | 20 | 21 | def average_precision_score(y_true, y_score, average="macro", 22 | sample_weight=None): 23 | def _binary_average_precision(y_true, y_score, sample_weight=None): 24 | precision, recall, thresholds = precision_recall_curve( 25 | y_true, y_score, sample_weight=sample_weight) 26 | return auc(recall, precision) 27 | 28 | return _average_binary_score(_binary_average_precision, y_true, y_score, 29 | average, sample_weight=sample_weight) 30 | 31 | 32 | def cmc(distmat, query_ids=None, gallery_ids=None, 33 | query_cams=None, gallery_cams=None, topk=100, 34 | separate_camera_set=False, 35 | single_gallery_shot=False, 36 | first_match_break=False): 37 | distmat = to_numpy(distmat) 38 | m, n = distmat.shape 39 | # Fill up default values 40 | if query_ids is None: 41 | query_ids = np.arange(m) 42 | if gallery_ids is None: 43 | gallery_ids = np.arange(n) 44 | if query_cams is None: 45 | query_cams = np.zeros(m).astype(np.int32) 46 | if gallery_cams is None: 47 | gallery_cams = np.ones(n).astype(np.int32) 48 | # Ensure numpy array 49 | query_ids = np.asarray(query_ids) 50 | gallery_ids = np.asarray(gallery_ids) 51 | query_cams = np.asarray(query_cams) 52 | gallery_cams = np.asarray(gallery_cams) 53 | # Sort and find correct matches 54 | indices = np.argsort(distmat, axis=1) 55 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 56 | # Compute CMC for each query 57 | ret = np.zeros(topk) 58 | num_valid_queries = 0 59 | for i in range(m): 60 | # Filter out the same id and same camera 61 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 62 | (gallery_cams[indices[i]] != query_cams[i])) 63 | if separate_camera_set: 64 | # Filter out samples from same camera 65 | valid &= (gallery_cams[indices[i]] != query_cams[i]) 66 | if not np.any(matches[i, valid]): continue 67 | if single_gallery_shot: 68 | repeat = 10 69 | gids = gallery_ids[indices[i][valid]] 70 | inds = np.where(valid)[0] 71 | ids_dict = defaultdict(list) 72 | for j, x in zip(inds, gids): 73 | ids_dict[x].append(j) 74 | else: 75 | repeat = 1 76 | for _ in range(repeat): 77 | if single_gallery_shot: 78 | # Randomly choose one instance for each id 79 | sampled = (valid & _unique_sample(ids_dict, len(valid))) 80 | index = np.nonzero(matches[i, sampled])[0] 81 | else: 82 | index = np.nonzero(matches[i, valid])[0] 83 | delta = 1. / (len(index) * repeat) 84 | for j, k in enumerate(index): 85 | if k - j >= topk: break 86 | if first_match_break: 87 | ret[k - j] += 1 88 | break 89 | ret[k - j] += delta 90 | num_valid_queries += 1 91 | if num_valid_queries == 0: 92 | raise RuntimeError("No valid query") 93 | return ret.cumsum() / num_valid_queries 94 | 95 | 96 | def mean_ap(distmat, query_ids=None, gallery_ids=None, 97 | query_cams=None, gallery_cams=None): 98 | distmat = to_numpy(distmat) 99 | m, n = distmat.shape 100 | # Fill up default values 101 | if query_ids is None: 102 | query_ids = np.arange(m) 103 | if gallery_ids is None: 104 | gallery_ids = np.arange(n) 105 | if query_cams is None: 106 | query_cams = np.zeros(m).astype(np.int32) 107 | if gallery_cams is None: 108 | gallery_cams = np.ones(n).astype(np.int32) 109 | # Ensure numpy array 110 | query_ids = np.asarray(query_ids) 111 | gallery_ids = np.asarray(gallery_ids) 112 | query_cams = np.asarray(query_cams) 113 | gallery_cams = np.asarray(gallery_cams) 114 | # Sort and find correct matches 115 | indices = np.argsort(distmat, axis=1) 116 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 117 | # Compute AP for each query 118 | aps = [] 119 | for i in range(m): 120 | # Filter out the same id and same camera 121 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 122 | (gallery_cams[indices[i]] != query_cams[i])) 123 | 124 | y_true = matches[i, valid] 125 | y_score = -distmat[i][indices[i]][valid] 126 | if not np.any(y_true): continue 127 | aps.append(average_precision_score(y_true, y_score)) 128 | if len(aps) == 0: 129 | raise RuntimeError("No valid query") 130 | return np.mean(aps) 131 | 132 | -------------------------------------------------------------------------------- /reid/evaluators.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | from collections import OrderedDict 4 | import pdb 5 | from collections import defaultdict 6 | 7 | import torch 8 | import numpy as np 9 | 10 | from .evaluation_metrics import cmc, mean_ap 11 | from .utils.meters import AverageMeter 12 | 13 | from torch.autograd import Variable 14 | from .utils import to_torch 15 | from .utils import to_numpy 16 | import pdb 17 | 18 | def fliplr(img): 19 | '''flip horizontal''' 20 | inv_idx = torch.arange(img.size(3)-1,-1,-1).long() # N x C x H x W N:TensornSamples in minibatch, i.e., batchsize x nChannels x Height x Width 21 | img_flip = img.index_select(3,inv_idx) 22 | return img_flip 23 | 24 | def extract_cnn_feature(model, inputs, output_feature=None): 25 | # encoder, transfer, _, pfeNet = model 26 | encoder, transfer, _ = model 27 | encoder.eval() 28 | transfer.eval() 29 | # pfeNet.eval() 30 | # inputs = to_torch(inputs) 31 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 32 | # inputs = inputs.to(device) 33 | with torch.no_grad(): 34 | feature = encoder(inputs.cuda()) 35 | outputs = transfer(feature, output_feature=output_feature) 36 | # outputs = pfeNet(outputs) 37 | # _, outputs = model[0](inputs, output_feature=output_feature) 38 | # outputs = model.module.base(inputs) 39 | 40 | # outputs = torch.FloatTensor(inputs.size(0),1280).zero_() 41 | # for i in range(2): 42 | # if(i==1): 43 | # inputs = fliplr(inputs) 44 | # # inputs = 45 | # # Variable(inputs.cuda()) 46 | # o = model(inputs, output_feature='pool5') 47 | # # o = model.module.base(inputs.cuda()) 48 | # f = o.data.cpu() 49 | # outputs = outputs+f 50 | 51 | fnorm = torch.norm(outputs, p=2, dim=1, keepdim=True) 52 | outputs = outputs.div(fnorm.expand_as(outputs)) 53 | outputs = outputs.data.cpu() 54 | return outputs 55 | 56 | 57 | def extract_features(model, data_loader, print_freq=1, output_feature=None): 58 | # encoder, transfer, _, pfeNet = model 59 | encoder, transfer, _ = model 60 | encoder.eval() 61 | transfer.eval() 62 | # pfeNet.eval() 63 | 64 | batch_time = AverageMeter() 65 | data_time = AverageMeter() 66 | 67 | features = OrderedDict() 68 | labels = OrderedDict() 69 | 70 | end = time.time() 71 | for i, (imgs, fnames, pids, _, _) in enumerate(data_loader): 72 | data_time.update(time.time() - end) 73 | 74 | outputs = extract_cnn_feature(model, imgs, output_feature) 75 | for fname, output, pid in zip(fnames, outputs, pids): 76 | features[fname] = output 77 | labels[fname] = pid 78 | 79 | batch_time.update(time.time() - end) 80 | end = time.time() 81 | 82 | #if (i + 1) % print_freq == 0: 83 | # print('Extract Features: [{}/{}]\t' 84 | # 'Time {:.3f} ({:.3f})\t' 85 | # 'Data {:.3f} ({:.3f})\t' 86 | # .format(i + 1, len(data_loader), 87 | # batch_time.val, batch_time.avg, 88 | # data_time.val, data_time.avg)) 89 | 90 | return features, labels 91 | 92 | def pairwise_distance(query_features, gallery_features, query=None, gallery=None): 93 | x = torch.cat([query_features[f].unsqueeze(0) for f, _, _,_ in query], 0) 94 | y = torch.cat([gallery_features[f].unsqueeze(0) for f, _, _,_ in gallery], 0) 95 | m, n = x.size(0), y.size(0) 96 | x = x.view(m, -1) 97 | y = y.view(n, -1) 98 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 99 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 100 | dist.addmm_(1, -2, x, y.t()) 101 | return dist 102 | 103 | def evaluate_all(distmat, query=None, gallery=None, 104 | query_ids=None, gallery_ids=None, 105 | query_cams=None, gallery_cams=None, 106 | cmc_topk=(1, 5, 10, 20)): 107 | if query is not None and gallery is not None: 108 | query_ids = [pid for _, pid, _,_ in query] 109 | gallery_ids = [pid for _, pid, _,_ in gallery] 110 | query_cams = [cam for _, _, cam,_ in query] 111 | gallery_cams = [cam for _, _, cam,_ in gallery] 112 | else: 113 | assert (query_ids is not None and gallery_ids is not None 114 | and query_cams is not None and gallery_cams is not None) 115 | 116 | # Compute mean AP 117 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 118 | print('Mean AP: {:4.1%}'.format(mAP)) 119 | 120 | # Compute CMC scores 121 | cmc_configs = { 122 | 'market1501': dict(separate_camera_set=False, 123 | single_gallery_shot=False, 124 | first_match_break=False)} 125 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 126 | query_cams, gallery_cams, **params) 127 | for name, params in cmc_configs.items()} 128 | 129 | print('CMC Scores') 130 | for k in cmc_topk: 131 | print(' top-{:<4}{:12.1%}' 132 | .format(k, cmc_scores['market1501'][k - 1])) 133 | 134 | return cmc_scores['market1501'][0] 135 | 136 | class Evaluator(object): 137 | def __init__(self, model): 138 | super(Evaluator, self).__init__() 139 | self.model = model 140 | 141 | def evaluate(self, query_loader, gallery_loader, query, gallery, output_feature=None, rerank=False): 142 | query_features, _ = extract_features(self.model, query_loader, 1, output_feature) 143 | gallery_features, _ = extract_features(self.model, gallery_loader, 1, output_feature) 144 | if rerank: 145 | distmat = reranking(query_features, gallery_features, query, gallery) 146 | else: 147 | distmat = pairwise_distance(query_features, gallery_features, query, gallery) 148 | return evaluate_all(distmat, query=query, gallery=gallery) 149 | 150 | class Evaluator(object): 151 | def __init__(self, model): 152 | super(Evaluator, self).__init__() 153 | self.model = model 154 | 155 | def evaluate(self, query_loader, gallery_loader, query, gallery, output_feature=None, rerank=False): 156 | query_features, _ = extract_features(self.model, query_loader, 1, output_feature) 157 | gallery_features, _ = extract_features(self.model, gallery_loader, 1, output_feature) 158 | if rerank: 159 | distmat = reranking(query_features, gallery_features, query, gallery) 160 | else: 161 | distmat = pairwise_distance(query_features, gallery_features, query, gallery) 162 | return evaluate_all(distmat, query=query, gallery=gallery) 163 | def __init__(self, model): 164 | super(Evaluator, self).__init__() 165 | self.model = model 166 | 167 | self.num_folds = 10 168 | self.test_all = False 169 | print(f'[evaluate] num_folds = {self.num_folds}, test_all = {self.test_all}') 170 | 171 | 172 | def eval_viper(self, query_loader, gallery_loader, query_data, gallery_data, output_feature='pool5', seed=0): 173 | rs = np.random.RandomState(seed) 174 | # rs = np.random.RandomState() 175 | 176 | # 提取所有样本的 feature 177 | query_features, _ = extract_features(self.model, query_loader, 1, output_feature) 178 | all_query_features = list(query_features.values()) 179 | 180 | gallery_features, _ = extract_features(self.model, gallery_loader, 1, output_feature) 181 | all_gallery_features = list(gallery_features.values()) 182 | 183 | # 基本数据 184 | _, all_query_pids, all_query_cids, _ = map(np.array, zip(*query_data)) 185 | _, all_gallery_pids, all_gallery_cids, _ = map(np.array, zip(*gallery_data)) 186 | 187 | q_pid_to_idx = defaultdict(list) 188 | for i, pid in enumerate(all_query_pids): 189 | q_pid_to_idx[pid].append(i) 190 | g_pid_to_idx = defaultdict(list) 191 | for i, pid in enumerate(all_gallery_pids): 192 | g_pid_to_idx[pid].append(i) 193 | 194 | # 多次测试 195 | num_tests = 632 if self.test_all else 316 196 | q_num_unique_pids = 632 197 | all_mAP, all_CMC, all_CMC5, all_CMC10 = [], [], [], [] 198 | 199 | for j in range(self.num_folds): 200 | if j % 2 == 0: 201 | selected_pids = rs.choice(q_num_unique_pids, num_tests, replace=False) + 1 202 | else: 203 | new_selected_pids = [i for i in range(1, q_num_unique_pids+1) if i not in selected_pids] 204 | selected_pids = new_selected_pids 205 | gallery_idx, query_idx = [], [] 206 | for pid in selected_pids: 207 | q_idx = q_pid_to_idx[pid][0] 208 | query_idx.append(q_idx) 209 | for pid in selected_pids: 210 | g_idx = g_pid_to_idx[pid][0] 211 | gallery_idx.append(g_idx) 212 | # 随机选取 num_tests 个样本测试 213 | selected_pids = rs.choice(q_num_unique_pids, num_tests, replace=False) + 1 214 | #selected_pids = sorted(selected_pids) 215 | 216 | # # 划分 gallery 和 query,即每个 pid 选 2 个样本 217 | # if j % 2 ==0: 218 | # gallery_idx, query_idx = [], [] 219 | # for pid in selected_pids: 220 | # q_idx = q_pid_to_idx[pid][0] 221 | # query_idx.append(q_idx) 222 | # for pid in selected_pids: 223 | # g_idx = g_pid_to_idx[pid][0] 224 | # gallery_idx.append(g_idx) 225 | # else: 226 | # gallery_idx, query_idx = query_idx, gallery_idx 227 | 228 | def get(x, idx): 229 | return [x[i] for i in idx] 230 | 231 | # 获取 gallery 232 | gallery_features = torch.stack(get(all_gallery_features, gallery_idx)) 233 | gallery_pids = get(all_gallery_pids, gallery_idx) 234 | gallery_cids = get(all_gallery_cids, gallery_idx) 235 | 236 | # 获取 query 237 | query_features = torch.stack(get(all_query_features, query_idx)) 238 | query_pids = get(all_query_pids, query_idx) 239 | query_cids = get(all_query_cids, query_idx) 240 | 241 | # 测试 242 | dist = self._pdist(query_features, gallery_features) 243 | #dist = self._pdist(gallery_features, query_features) 244 | mAP, CMC, CMC5, CMC10 = self._eval_dist(dist, gallery_pids, gallery_cids, query_pids, query_cids) 245 | acc = self._simple_acc(dist, gallery_pids, query_pids) 246 | all_mAP.append(mAP) 247 | all_CMC.append(CMC) 248 | all_CMC5.append(CMC5) 249 | all_CMC10.append(CMC10) 250 | # all_acc.append(acc) 251 | 252 | print('VIPeR') 253 | print(f'[map] {np.mean(all_mAP):.2%} |', ' '.join(map('{:.2%}'.format, sorted(all_mAP)))) 254 | print(f'[cmc] {np.mean(all_CMC):.2%} |', ' '.join(map('{:.2%}'.format, sorted(all_CMC)))) 255 | print("5:", np.mean(all_CMC5)) 256 | print("10:", np.mean(all_CMC10)) 257 | #print(f'[acc] {np.mean(all_acc):.2%} |', ' '.join(map('{:.2%}'.format, sorted(all_acc)))) # acc == cmc(top-1) 258 | return np.mean(all_CMC) 259 | def eval_grid(self, query_loader, gallery_loader, query_data, gallery_data, output_feature='pool5', seed=0): 260 | rs = np.random.RandomState(seed) 261 | # 提取所有样本的 feature 262 | query_features, _ = extract_features(self.model, query_loader, 1, output_feature) 263 | all_query_features = list(query_features.values()) 264 | 265 | gallery_features, _ = extract_features(self.model, gallery_loader, 1, output_feature) 266 | all_gallery_features = list(gallery_features.values()) 267 | 268 | # 基本数据 269 | _, all_query_pids, all_query_cids, _ = map(np.array, zip(*query_data)) 270 | _, all_gallery_pids, all_gallery_cids, _ = map(np.array, zip(*gallery_data)) 271 | 272 | q_pid_to_idx = defaultdict(list) 273 | for i, pid in enumerate(all_query_pids): 274 | q_pid_to_idx[pid].append(i) 275 | g_pid_to_idx = defaultdict(list) 276 | for i, pid in enumerate(all_gallery_pids): 277 | g_pid_to_idx[pid].append(i) 278 | 279 | # 多次测试 280 | num_tests = 250 if self.test_all else 125 281 | q_num_unique_pids = 250 282 | all_mAP, all_CMC, all_CMC5, all_CMC10 = [], [], [], [] 283 | 284 | for _ in range(self.num_folds): 285 | # 随机选取 num_tests 个样本测试 286 | selected_pids = rs.choice(q_num_unique_pids, num_tests, replace=False) + 1 287 | 288 | # 划分 gallery 和 query,即每个 pid 选 2 个样本 289 | gallery_idx, query_idx = [], [] 290 | for pid in selected_pids: 291 | q_idx = q_pid_to_idx[pid][0] 292 | query_idx.append(q_idx) 293 | gallery_idx = gallery_idx + g_pid_to_idx[0] 294 | for pid in selected_pids: 295 | g_idx = g_pid_to_idx[pid][0] 296 | gallery_idx.append(g_idx) 297 | 298 | def get(x, idx): 299 | return [x[i] for i in idx] 300 | 301 | # 获取 gallery 302 | gallery_features = torch.stack(get(all_gallery_features, gallery_idx)) 303 | gallery_pids = get(all_gallery_pids, gallery_idx) 304 | gallery_cids = get(all_gallery_cids, gallery_idx) 305 | 306 | # 获取 query 307 | query_features = torch.stack(get(all_query_features, query_idx)) 308 | query_pids = get(all_query_pids, query_idx) 309 | query_cids = get(all_query_cids, query_idx) 310 | 311 | # 测试 312 | dist = self._pdist(query_features, gallery_features) 313 | mAP, CMC, CMC5, CMC10 = self._eval_dist(dist, gallery_pids, gallery_cids, query_pids, query_cids) 314 | all_mAP.append(mAP) 315 | all_CMC.append(CMC) 316 | all_CMC5.append(CMC5) 317 | all_CMC10.append(CMC10) 318 | 319 | print('GRID') 320 | print(f'[map] {np.mean(all_mAP):.2%} |', ' '.join(map('{:.2%}'.format, sorted(all_mAP)))) 321 | print(f'[cmc] {np.mean(all_CMC):.2%} |', ' '.join(map('{:.2%}'.format, sorted(all_CMC)))) 322 | print("5:", np.mean(all_CMC5)) 323 | print("10:", np.mean(all_CMC10)) 324 | return np.mean(all_CMC) 325 | def eval_prid(self, query_loader, gallery_loader, query_data, gallery_data, output_feature='pool5', seed=0): 326 | rs = np.random.RandomState(seed) 327 | # 提取所有样本的 feature 328 | query_features, _ = extract_features(self.model, query_loader, 1, output_feature) 329 | all_query_features = list(query_features.values()) 330 | 331 | gallery_features, _ = extract_features(self.model, gallery_loader, 1, output_feature) 332 | all_gallery_features = list(gallery_features.values()) 333 | 334 | # 基本数据 335 | _, all_query_pids, all_query_cids, _ = map(np.array, zip(*query_data)) 336 | _, all_gallery_pids, all_gallery_cids, _ = map(np.array, zip(*gallery_data)) 337 | 338 | q_pid_to_idx = defaultdict(list) 339 | for i, pid in enumerate(all_query_pids): 340 | q_pid_to_idx[pid].append(i) 341 | g_pid_to_idx = defaultdict(list) 342 | for i, pid in enumerate(all_gallery_pids): 343 | g_pid_to_idx[pid].append(i) 344 | 345 | # 多次测试 346 | num_tests = 200 if self.test_all else 100 347 | q_num_unique_pids = 200 348 | all_mAP, all_CMC, all_CMC5, all_CMC10 = [], [], [], [] 349 | 350 | for _ in range(self.num_folds): 351 | # 随机选取 num_tests 个样本测试 352 | selected_pids = rs.choice(q_num_unique_pids, num_tests, replace=False) + 1 353 | 354 | # 划分 gallery 和 query,即每个 pid 选 2 个样本 355 | gallery_idx, query_idx = [], [] 356 | for pid in selected_pids: 357 | q_idx = q_pid_to_idx[pid][0] 358 | query_idx.append(q_idx) 359 | for pid in selected_pids.tolist() + [i for i in range(201, 750)]: 360 | g_idx = g_pid_to_idx[pid][0] 361 | gallery_idx.append(g_idx) 362 | 363 | def get(x, idx): 364 | return [x[i] for i in idx] 365 | 366 | # 获取 gallery 367 | gallery_features = torch.stack(get(all_gallery_features, gallery_idx)) 368 | gallery_pids = get(all_gallery_pids, gallery_idx) 369 | gallery_cids = get(all_gallery_cids, gallery_idx) 370 | 371 | # 获取 query 372 | query_features = torch.stack(get(all_query_features, query_idx)) 373 | query_pids = get(all_query_pids, query_idx) 374 | query_cids = get(all_query_cids, query_idx) 375 | 376 | # 测试 377 | dist = self._pdist(query_features, gallery_features) 378 | mAP, CMC, CMC5, CMC10 = self._eval_dist(dist, gallery_pids, gallery_cids, query_pids, query_cids) 379 | all_mAP.append(mAP) 380 | all_CMC.append(CMC) 381 | all_CMC5.append(CMC5) 382 | all_CMC10.append(CMC10) 383 | 384 | print('PRID') 385 | print(f'[map] {np.mean(all_mAP):.2%} |', ' '.join(map('{:.2%}'.format, sorted(all_mAP)))) 386 | print(f'[cmc] {np.mean(all_CMC):.2%} |', ' '.join(map('{:.2%}'.format, sorted(all_CMC)))) 387 | print("5:", np.mean(all_CMC5)) 388 | print("10:", np.mean(all_CMC10)) 389 | return np.mean(all_CMC) 390 | def eval_ilids(self, all_loader, all_data, output_feature=None, seed=0): 391 | """ 392 | Args: 393 | all_loader: loader of iLids inputs 394 | all_data: list of labels, (fname, person_id, cam_id, ?) 395 | output_feature: ? 396 | seed: random seed only for eval 397 | 398 | Note: 399 | loader 和 data 应该是按顺序的,[id1.cam1/2/3, id2.cam1/2/3, ....] 400 | 不按顺序应该也可以 401 | """ 402 | # 这个随机状态只在 eval 中使用,不影响外部的随机状态 403 | rs = np.random.RandomState(seed) 404 | 405 | # 提取所有样本的 feature 406 | features, _ = extract_features(self.model, all_loader, 1, output_feature) 407 | features = list(features.values()) 408 | 409 | # 基本数据 410 | _, all_pids, all_cids, _ = map(np.array, zip(*all_data)) 411 | num_unique_pids = len(set(all_pids)) 412 | 413 | pid_to_idx = defaultdict(list) 414 | for i, pid in enumerate(all_pids): 415 | pid_to_idx[pid].append(i) 416 | 417 | # 多次测试 418 | num_tests = 119 if self.test_all else 60 419 | all_mAP, all_CMC, all_CMC5, all_CMC10 = [], [], [], [] 420 | for _ in range(self.num_folds): 421 | # 随机选取 num_tests 个样本测试 422 | selected_pids = rs.choice(num_unique_pids, num_tests, replace=False) + 1 423 | 424 | # 划分 gallery 和 query,即每个 pid 选 2 个样本 425 | gallery_idx, query_idx = [], [] 426 | for pid in selected_pids: 427 | idx1, idx2 = rs.choice(pid_to_idx[pid], 2, replace=False) 428 | gallery_idx.append(idx1) 429 | query_idx.append(idx2) 430 | 431 | def get(x, idx): 432 | return [x[i] for i in idx] 433 | 434 | # 获取 gallery 435 | gallery_features = torch.stack(get(features, gallery_idx)) 436 | gallery_pids = get(all_pids, gallery_idx) 437 | gallery_cids = get(all_cids, gallery_idx) 438 | 439 | # 获取 query 440 | query_features = torch.stack(get(features, query_idx)) 441 | query_pids = get(all_pids, query_idx) 442 | query_cids = get(all_cids, query_idx) 443 | 444 | # 测试 445 | dist = self._pdist(query_features, gallery_features) 446 | mAP, CMC, CMC5, CMC10 = self._eval_dist(dist, gallery_pids, gallery_cids, query_pids, query_cids) 447 | all_mAP.append(mAP) 448 | all_CMC.append(CMC) 449 | all_CMC5.append(CMC5) 450 | all_CMC10.append(CMC10) 451 | 452 | print('iLIDs') 453 | print(f'[map] {np.mean(all_mAP):.2%} |', ' '.join(map('{:.2%}'.format, sorted(all_mAP)))) 454 | print(f'[cmc] {np.mean(all_CMC):.2%} |', ' '.join(map('{:.2%}'.format, sorted(all_CMC)))) 455 | print("5:", np.mean(all_CMC5)) 456 | print("10:", np.mean(all_CMC10)) 457 | return np.mean(all_CMC) 458 | @staticmethod 459 | # def _pdist(input1, input2): 460 | # dist = 1 - torch.mm(input1, input2.t()) 461 | # dist = 1 - torch.cosine_similarity(input1,input2, dim=1) 462 | # return dist 463 | def _pdist(x, y): 464 | m, n = x.size(0), y.size(0) 465 | x, y = x.view(m, -1), y.view(n, -1) 466 | xx = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) 467 | yy = torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 468 | xy = x @ y.t() 469 | dist = xx - 2 * xy + yy 470 | return dist 471 | 472 | @staticmethod 473 | def _eval_dist(dist, gallery_pids, gallery_cids, query_pids, query_cids): 474 | args = (dist, query_pids, gallery_pids, query_cids, gallery_cids) 475 | kwargs = dict(separate_camera_set=False, single_gallery_shot=False, first_match_break=False) 476 | mAP, CMC = mean_ap(*args), cmc(*args, **kwargs) 477 | return mAP, CMC[0], CMC[1], CMC[5] 478 | 479 | @staticmethod 480 | def _simple_acc(dist, gallery_pids, query_pids): 481 | pred = dist.argmin(1) 482 | total = len(query_pids) 483 | good = sum([gallery_pids[pred[i]] == query_pids[i] for i in range(pred.shape[0])]) 484 | return 1. * good / total 485 | 486 | 487 | -------------------------------------------------------------------------------- /reid/loss/.tripletloss.py.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/loss/.tripletloss.py.swp -------------------------------------------------------------------------------- /reid/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .tripletloss import TripletLoss 4 | from .identityloss import InvNet 5 | 6 | __all__ = [ 7 | 'TripletLoss', 8 | 'InvNet', 9 | ] 10 | -------------------------------------------------------------------------------- /reid/loss/__pycache__/MLS.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/loss/__pycache__/MLS.cpython-37.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/loss/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/loss/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/dual.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/loss/__pycache__/dual.cpython-36.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/dual.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/loss/__pycache__/dual.cpython-37.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/gyneightboor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/loss/__pycache__/gyneightboor.cpython-36.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/gyneightboor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/loss/__pycache__/gyneightboor.cpython-37.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/identityloss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/loss/__pycache__/identityloss.cpython-37.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/lsr.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/loss/__pycache__/lsr.cpython-37.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/mmd.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/loss/__pycache__/mmd.cpython-36.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/mmd.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/loss/__pycache__/mmd.cpython-37.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/neightboor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/loss/__pycache__/neightboor.cpython-36.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/neightboor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/loss/__pycache__/neightboor.cpython-37.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/newneightboor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/loss/__pycache__/newneightboor.cpython-36.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/newneightboor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/loss/__pycache__/newneightboor.cpython-37.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/triplet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/loss/__pycache__/triplet.cpython-37.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/tripletloss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/loss/__pycache__/tripletloss.cpython-36.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/tripletloss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/loss/__pycache__/tripletloss.cpython-37.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/wassdistance.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/loss/__pycache__/wassdistance.cpython-36.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/wassdistance.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/loss/__pycache__/wassdistance.cpython-37.pyc -------------------------------------------------------------------------------- /reid/loss/identityloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn, autograd 4 | from torch.autograd import Variable, Function 5 | import numpy as np 6 | import math 7 | 8 | from .wassdistance import SinkhornDistance 9 | 10 | class ExemplarMemory(Function): 11 | def __init__(self, em, em_cnt, alpha=0.01): 12 | super(ExemplarMemory, self).__init__() 13 | self.em = em 14 | self.em_cnt = em_cnt # 记录当前 em 每个 entry 是多少个样本的 mean 15 | self.alpha = alpha 16 | 17 | def forward(self, inputs, targets): 18 | self.save_for_backward(inputs, targets) 19 | if self.em.sum() == 0: 20 | outputs = inputs.mm(self.em.t()) 21 | else: 22 | em = self.em / self.em.norm(p=2, dim=1, keepdim=True) 23 | outputs = inputs.mm(em.t()) 24 | return outputs 25 | 26 | def backward(self, grad_outputs): 27 | inputs, targets = self.saved_tensors 28 | grad_inputs = None 29 | if self.needs_input_grad[0]: 30 | grad_inputs = grad_outputs.mm(self.em) 31 | # for x, y in zip(inputs, targets): 32 | # n = self.em_cnt[y] 33 | # n += 1 34 | # self.em[y] = self.em[y] * (n - 1) / n + x / n 35 | 36 | n = self.em_cnt[targets] 37 | n += 1 38 | self.em[y] = (self.em[y] * (n-1) + x) / n 39 | 40 | return grad_inputs, None 41 | 42 | class InvNet(nn.Module): 43 | def __init__(self, num_features, num_classes, batchsize, beta=0.05, knn=6, alpha=0.01): 44 | super(InvNet, self).__init__() 45 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 46 | self.num_features = num_features 47 | self.num_classes = num_classes 48 | self.alpha = alpha # Memory update rate 49 | self.beta = beta # Temperature fact 50 | self.knn = knn # Knn for neighborhood invariance 51 | 52 | # Exemplar memory 53 | self.em = nn.Parameter(torch.zeros(num_classes, num_features), requires_grad=False) 54 | self.em_cnt = torch.zeros(num_classes) 55 | self.untouched_targets = set(range(num_classes)) 56 | self.em_current = torch.zeros(num_classes, num_features).cuda() 57 | self.em_last = torch.zeros(num_classes, num_features).cuda() 58 | self.domain = torch.from_numpy(np.array([0 for _ in range(1816)] + [1 for _ in range(1467)] + [2 for _ in range(1812)] + [3 for _ in range(1501)] + [4 for i in range(11934)])).unsqueeze(0).cuda() 59 | self.domain = self.domain.repeat(batchsize, 1) 60 | 61 | self.id = [[] for _ in range(18530)] 62 | 63 | def forward(self, inputs, label, domain, epoch=None, step=None, fnames_target=None): 64 | ''' 65 | inputs: [128, 2048], each t's 2048-d feature 66 | label: [128], each t's label 67 | ''' 68 | if step == 0: 69 | print(self.em) 70 | # alpha = self.alpha * epoch 71 | alpha = self.alpha 72 | if epoch ==0 : 73 | alpha = 0 74 | if epoch > 0 and step==0: 75 | # self.em_cnt = self.em_cnt * 0 + 2 76 | 77 | self.em_last = self.em.clone() 78 | self.em_cnt = self.em_cnt * 0 79 | if epoch > 0: 80 | em = self.em / self.em.norm(p=2, dim=1, keepdim=True) 81 | else: 82 | em = self.em 83 | 84 | tgt_feature = inputs.mm(em.t()) 85 | tgt_feature /= 0.05 86 | 87 | 88 | loss = self.smooth_loss(self.em, inputs, tgt_feature, label, domain, epoch) 89 | 90 | for x, y in zip(inputs, label): 91 | n = self.em_cnt[y] 92 | n += 1 93 | # self.em.data[y] = self.em.data[y] * (n - 1) / n + x.data / n 94 | 95 | self.em_current[y] = self.em_current[y] * (n - 1) / n + x.data / n 96 | self.em.data[y] = self.em_last[y] * alpha + self.em_current[y] * (1-alpha) 97 | return loss 98 | 99 | def smooth_loss(self, em, inputs_feature, tgt_feature, label, domain, epoch): 100 | ''' 101 | tgt_feature: [128, 16522], similarity of batch & targets 102 | label: see forward 103 | ''' 104 | mask = self.smooth_hot(tgt_feature.detach().clone(), label.detach().clone(), self.knn, domain) 105 | 106 | # batchsize是64 107 | new_feature = [] 108 | for m in mask: 109 | index = m.nonzero() 110 | new_feature.append(em[index]) 111 | new_feature = torch.cat(new_feature,0).squeeze(1) 112 | # inputs_feature = torch.cat([inputs_feature for _ in range(self.knn*4)],1) 113 | # inputs_feature = torch.reshape(inputs_feature,(64*self.knn*4, 1280)) 114 | 115 | 116 | # ------------12.30------------------------ 117 | # 如果domain 是sysu, 则不和其他对比,为0 118 | # mask_sysu = domain.clone() 119 | # mask_sysu[domain != 4] = 1 120 | # mask_sysu[domain == 4] = 0 121 | 122 | # inputs_feature = inputs_feature * mask_sysu.unsqueeze(1) 123 | # 为了消除inputs下来的梯度 124 | # ------------12.30------------------------ 125 | 126 | inputs_feature = torch.cat([inputs_feature for _ in range(self.knn)],1) 127 | # batchsize=64 128 | inputs_feature = torch.reshape(inputs_feature,(64*self.knn, 1280)) 129 | if epoch > 0: 130 | inputs_feature = F.softmax(inputs_feature / self.beta, dim=1) 131 | new_feature = F.softmax(new_feature / self.beta, dim=1) 132 | # print(inputs_feature) 133 | # print(new_feature) 134 | loss = (nn.KLDivLoss()(torch.log(inputs_feature + 1e-8), new_feature) + nn.KLDivLoss()(torch.log(new_feature + 1e-8), inputs_feature)) 135 | # print(loss) 136 | # print(a) 137 | 138 | # outputs = F.log_softmax(tgt_feature, dim=1) 139 | # loss = - (mask * outputs) 140 | # loss = loss.sum(dim=1) 141 | # loss = loss.mean(dim=0) 142 | 143 | # wassdistance 144 | # sinkhorn = SinkhornDistance(eps=0.1, max_iter=100, reduction=None) 145 | # loss, _, _ = sinkhorn(inputs_feature, new_feature) 146 | 147 | return loss 148 | 149 | def smooth_hot(self, tgt_feature, targets, k=6, domain=None): 150 | ''' 151 | see smooth_loss 152 | ''' 153 | mask = torch.zeros(tgt_feature.size()).to(self.device) 154 | 155 | # 如果domain 是sysu, 则不和其他对比,为0 156 | # mask_sysu = domain.clone() 157 | # mask_sysu[domain != 4] = 1 158 | # mask_sysu[domain == 4] = 0 159 | 160 | # 得到的mask_d:自己的domain为0, 与自己不同的为1 161 | domain = domain.unsqueeze(1).repeat(1,18530) 162 | mask_d = domain-self.domain 163 | mask_d[mask_d!=0] = 1 164 | # print(mask_d) 165 | # print(mask_d[mask_d == 0].shape) 166 | # mask = 1 - mask 167 | 168 | #----------------12.30--------------- 169 | # SYSU找自己域内的拉近 170 | mask_d[domain == 4] = 1 - mask_d[domain == 4] 171 | #----------------12.30--------------- 172 | 173 | # # 在memorybank中,对于每个domain都找到k个 174 | # _feature = tgt_feature * mask_d 175 | # # 第1个domain,ID数量1816 176 | # _, topk = tgt_feature[:,:1816].topk(k, dim=1) 177 | # mask.scatter_(1, topk, 1) 178 | 179 | # # 第2个domain,1467 180 | # _, topk = tgt_feature[:,1816:1816+1467].topk(k, dim=1) 181 | # mask.scatter_(1, topk+1816, 1) 182 | 183 | # # 第3个domain,ID数量1812 184 | # _, topk = tgt_feature[:,1816+1467:1816+1467+1812].topk(k, dim=1) 185 | # mask.scatter_(1, topk+1816+1467, 1) 186 | 187 | # # 第4个domain,ID数量1501 188 | # _, topk = tgt_feature[:,1816+1467+1812:1816+1467+1812+1501].topk(k, dim=1) 189 | # mask.scatter_(1, topk+1816+1467+1812, 1) 190 | 191 | # # 第5个domain,ID数量11934 192 | # _, topk = tgt_feature[:,1816+1467+1812+1501:].topk(k, dim=1) 193 | # mask.scatter_(1, topk+1816+1467+1812+1501, 1) 194 | # # 这里得到的mask:共有4*k个1,(自身domain内的ID都为0) 195 | # mask = mask * mask_d.float() 196 | 197 | _feature = tgt_feature * mask_d #* mask_sysu.unsqueeze(1) 198 | _, topk = _feature.topk(k, dim=1) 199 | # print("different----------") 200 | # for i, k in enumerate(topk): 201 | # if domain[i][0] == 2: 202 | # # if targets[i] == 4459: 203 | # print(domain[i][0].item(), targets[i], k[0].item(), tgt_feature[i][8983]) 204 | # # # print(a) 205 | # print("different----------") 206 | # # 找同domain里相似度相似度比较低,但是像的: 207 | # mask_d = 1 - mask_d 208 | # _, topk = (tgt_feature * mask_d).topk(80, dim=1) 209 | # print("same----------") 210 | # for i, k in enumerate(topk): 211 | # if domain[i][0] == 2: 212 | # print(tgt_feature[i][k[0].item()]) 213 | # print(domain[i][0].item(), targets[i], k[-1].item(), tgt_feature[i][3667]) 214 | 215 | # print("same----------") 216 | 217 | mask.scatter_(1, topk, 1) 218 | 219 | # print(mask.sum()) 220 | # # 自己 221 | # index_2d = targets[..., None] 222 | # mask.scatter_(1, index_2d, 3) 223 | 224 | return mask 225 | -------------------------------------------------------------------------------- /reid/loss/tripletloss.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | 11 | def normalize(x, axis=-1): 12 | """Normalizing to unit length along the specified dimension. 13 | Args: 14 | x: pytorch Variable 15 | Returns: 16 | x: pytorch Variable, same shape as input 17 | """ 18 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 19 | return x 20 | 21 | 22 | def euclidean_dist(x, y): 23 | """ 24 | Args: 25 | x: pytorch Variable, with shape [m, d] 26 | y: pytorch Variable, with shape [n, d] 27 | Returns: 28 | dist: pytorch Variable, with shape [m, n] 29 | """ 30 | m, n = x.size(0), y.size(0) 31 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 32 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 33 | dist = xx + yy 34 | dist.addmm_(1, -2, x, y.t()) 35 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 36 | return dist 37 | 38 | 39 | def hard_example_mining(dist_mat, labels, return_inds=False): 40 | """For each anchor, find the hardest positive and negative sample. 41 | Args: 42 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 43 | labels: pytorch LongTensor, with shape [N] 44 | return_inds: whether to return the indices. Save time if `False`(?) 45 | Returns: 46 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 47 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 48 | p_inds: pytorch LongTensor, with shape [N]; 49 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 50 | n_inds: pytorch LongTensor, with shape [N]; 51 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 52 | NOTE: Only consider the case in which all labels have same num of samples, 53 | thus we can cope with all anchors in parallel. 54 | """ 55 | 56 | assert len(dist_mat.size()) == 2 57 | assert dist_mat.size(0) == dist_mat.size(1) 58 | N = dist_mat.size(0) 59 | 60 | # shape [N, N] 61 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 62 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) 63 | 64 | # `dist_ap` means distance(anchor, positive) 65 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 66 | dist_ap, relative_p_inds = torch.max( 67 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) 68 | # `dist_an` means distance(anchor, negative) 69 | # both `dist_an` and `relative_n_inds` with shape [N, 1] 70 | dist_an, relative_n_inds = torch.min( 71 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) 72 | # shape [N] 73 | dist_ap = dist_ap.squeeze(1) 74 | dist_an = dist_an.squeeze(1) 75 | 76 | if return_inds: 77 | # shape [N, N] 78 | ind = (labels.new().resize_as_(labels) 79 | .copy_(torch.arange(0, N).long()) 80 | .unsqueeze(0).expand(N, N)) 81 | # shape [N, 1] 82 | p_inds = torch.gather( 83 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 84 | n_inds = torch.gather( 85 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 86 | # shape [N] 87 | p_inds = p_inds.squeeze(1) 88 | n_inds = n_inds.squeeze(1) 89 | return dist_ap, dist_an, p_inds, n_inds 90 | 91 | return dist_ap, dist_an 92 | 93 | 94 | 95 | class TripletLoss(object): 96 | """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid). 97 | Related Triplet Loss theory can be found in paper 'In Defense of the Triplet 98 | Loss for Person Re-Identification'.""" 99 | 100 | def __init__(self, margin=None): 101 | self.margin = margin 102 | if margin is not None: 103 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 104 | else: 105 | self.ranking_loss = nn.SoftMarginLoss() 106 | 107 | def __call__(self, global_feat, labels, normalize_feature=False): 108 | # if normalize_feature: 109 | global_feat = normalize(global_feat, axis=-1) 110 | dist_mat = euclidean_dist(global_feat, global_feat) 111 | dist_ap, dist_an = hard_example_mining( 112 | dist_mat, labels) 113 | y = dist_an.new().resize_as_(dist_an).fill_(1) 114 | if self.margin is not None: 115 | # loss = self.ranking_loss(dist_an, dist_ap, y) 116 | # 只拉近同类,不推远异类,则y=1, margin调小 117 | # y *= -1 118 | # dist_ap = torch.zeros(dist_an.shape).cuda() 119 | loss = self.ranking_loss(dist_an, dist_ap, y) 120 | else: 121 | loss = self.ranking_loss(dist_an - dist_ap, y) 122 | return loss 123 | 124 | class CrossEntropyLabelSmooth(nn.Module): 125 | """Cross entropy loss with label smoothing regularizer. 126 | Reference: 127 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 128 | Equation: y = (1 - epsilon) * y + epsilon / K. 129 | Args: 130 | num_classes (int): number of classes. 131 | epsilon (float): weight. 132 | """ 133 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 134 | super(CrossEntropyLabelSmooth, self).__init__() 135 | self.num_classes = num_classes 136 | self.epsilon = epsilon 137 | self.use_gpu = use_gpu 138 | self.logsoftmax = nn.LogSoftmax(dim=1) 139 | 140 | def forward(self, inputs, targets): 141 | """ 142 | Args: 143 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 144 | targets: ground truth labels with shape (num_classes) 145 | """ 146 | log_probs = self.logsoftmax(inputs) 147 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 148 | if self.use_gpu: targets = targets.cuda() 149 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 150 | loss = (- targets * log_probs).mean(0).sum() 151 | return loss -------------------------------------------------------------------------------- /reid/loss/wassdistance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # Adapted from https://github.com/gpeyre/SinkhornAutoDiff 5 | class SinkhornDistance(nn.Module): 6 | r""" 7 | Given two empirical measures each with :math:`P_1` locations 8 | :math:`x\in\mathbb{R}^{D_1}` and :math:`P_2` locations :math:`y\in\mathbb{R}^{D_2}`, 9 | outputs an approximation of the regularized OT cost for point clouds. 10 | Args: 11 | eps (float): regularization coefficient 12 | max_iter (int): maximum number of Sinkhorn iterations 13 | reduction (string, optional): Specifies the reduction to apply to the output: 14 | 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 15 | 'mean': the sum of the output will be divided by the number of 16 | elements in the output, 'sum': the output will be summed. Default: 'none' 17 | Shape: 18 | - Input: :math:`(N, P_1, D_1)`, :math:`(N, P_2, D_2)` 19 | - Output: :math:`(N)` or :math:`()`, depending on `reduction` 20 | """ 21 | def __init__(self, eps, max_iter, reduction='none'): 22 | super(SinkhornDistance, self).__init__() 23 | self.eps = eps 24 | self.max_iter = max_iter 25 | self.reduction = reduction 26 | 27 | def forward(self, x, y): 28 | # The Sinkhorn algorithm takes as input three variables : 29 | C = self._cost_matrix(x, y) # Wasserstein cost function 30 | x_points = x.shape[-2] 31 | y_points = y.shape[-2] 32 | if x.dim() == 2: 33 | batch_size = 1 34 | else: 35 | batch_size = x.shape[0] 36 | 37 | # both marginals are fixed with equal weights 38 | mu = torch.empty(batch_size, x_points, dtype=torch.float, 39 | requires_grad=False).fill_(1.0 / x_points).squeeze()#.cuda() 40 | nu = torch.empty(batch_size, y_points, dtype=torch.float, 41 | requires_grad=False).fill_(1.0 / y_points).squeeze()#.cuda() 42 | 43 | u = torch.zeros_like(mu) 44 | v = torch.zeros_like(nu) 45 | # To check if algorithm terminates because of threshold 46 | # or max iterations reached 47 | actual_nits = 0 48 | # Stopping criterion 49 | thresh = 1e-1 50 | 51 | # Sinkhorn iterations 52 | for i in range(self.max_iter): 53 | u1 = u # useful to check the update 54 | u = self.eps * (torch.log(mu+1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u 55 | v = self.eps * (torch.log(nu+1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v 56 | err = (u - u1).abs().sum(-1).mean() 57 | 58 | actual_nits += 1 59 | if err.item() < thresh: 60 | break 61 | 62 | U, V = u, v 63 | # Transport plan pi = diag(a)*K*diag(b) 64 | pi = torch.exp(self.M(C, U, V)) 65 | # Sinkhorn distance 66 | cost = torch.sum(pi * C, dim=(-2, -1)) 67 | if self.reduction == 'mean': 68 | cost = cost.mean() 69 | elif self.reduction == 'sum': 70 | cost = cost.sum() 71 | 72 | return cost, pi, C 73 | 74 | def M(self, C, u, v): 75 | "Modified cost for logarithmic updates" 76 | "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$" 77 | return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps 78 | 79 | @staticmethod 80 | def _cost_matrix(x, y, p=2): 81 | "Returns the matrix of $|x_i-y_j|^p$." 82 | x_col = x.unsqueeze(-2) 83 | y_lin = y.unsqueeze(-3) 84 | C = torch.sum((torch.abs(x_col - y_lin)) ** p, -1) 85 | return C 86 | 87 | @staticmethod 88 | def ave(u, u1, tau): 89 | "Barycenter subroutine, used by kinetic acceleration through extrapolation." 90 | return tau * u + (1 - tau) * u1 -------------------------------------------------------------------------------- /reid/models/DDAN.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | import torchvision 7 | import pdb 8 | from . import mobilenet as mobilenet 9 | 10 | 11 | class Encoder(nn.Module): 12 | def __init__(self, pretrained=True, cut_at_pooling=False, 13 | num_features=0, norm=False, dropout=0, num_classes=0, triplet_features=0): 14 | super(Encoder, self).__init__() 15 | self.base = mobilenet.mobilenet_v2(pretrained=True) 16 | 17 | def forward(self, x, output_feature=None): 18 | x = self.base(x,'encoder') 19 | 20 | return x 21 | 22 | class TransferNet(nn.Module): 23 | def __init__(self, pretrained=True, cut_at_pooling=False, 24 | num_features=0, norm=False, dropout=0, num_classes=0, triplet_features=0): 25 | super(TransferNet, self).__init__() 26 | 27 | self.base = mobilenet.mobilenet_v2(pretrained=True) 28 | 29 | self.dropout = dropout 30 | self.num_classes = num_classes 31 | out_planes = 1280 32 | self.bn = nn.BatchNorm1d(out_planes) 33 | init.constant_(self.bn.weight, 1) 34 | init.constant_(self.bn.bias, 0) 35 | 36 | self.drop = nn.Dropout(self.dropout) 37 | self.classifier = nn.Linear(out_planes, self.num_classes) 38 | 39 | 40 | def forward(self, x, output_feature=None): 41 | x = self.base(x,'task') 42 | x = F.avg_pool2d(x, x.size()[2:]) 43 | x = x.view(x.size(0), -1) 44 | 45 | x_feature = x 46 | x = self.bn(x) 47 | if output_feature == 'pool5': 48 | return x 49 | x = self.drop(x) 50 | x_class = self.classifier(x) 51 | return x_class, x_feature 52 | 53 | class CamDiscriminator(nn.Module): 54 | def __init__(self, channels=1280): 55 | super(CamDiscriminator, self).__init__() 56 | 57 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 58 | out_planes = channels 59 | self.feat = nn.Linear(out_planes, 128) 60 | self.bn = nn.BatchNorm1d(128) 61 | init.constant_(self.bn.weight, 1) 62 | init.constant_(self.bn.bias, 0) 63 | self.num_classes = 2 64 | self.classifier = nn.Linear(128, self.num_classes) 65 | 66 | def forward(self, x): 67 | x = self.feat(x) 68 | x = self.bn(x) 69 | self.drop = nn.Dropout(0.5) 70 | x = self.classifier(x) 71 | return x 72 | 73 | def DDAN(**kwargs): 74 | return Encoder(50, **kwargs), TransferNet(50, **kwargs), CamDiscriminator(channels=1280) 75 | -------------------------------------------------------------------------------- /reid/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .DDAN import DDAN 4 | 5 | __factory = { 6 | 'DDAN':DDAN, 7 | } 8 | 9 | 10 | def names(): 11 | return sorted(__factory.keys()) 12 | 13 | 14 | def create(name, *args, **kwargs): 15 | """ 16 | Create a model instance. 17 | 18 | Parameters 19 | ---------- 20 | name : str 21 | Model name. Can be one of 'inception', 'resnet18', 'resnet34', 22 | 'resnet50', 'resnet101', and 'resnet152'. 23 | pretrained : bool, optional 24 | Only applied for 'resnet*' models. If True, will use ImageNet pretrained 25 | model. Default: True 26 | cut_at_pooling : bool, optional 27 | If True, will cut the model before the last global pooling layer and 28 | ignore the remaining kwargs. Default: False 29 | num_features : int, optional 30 | If positive, will append a Linear layer after the global pooling layer, 31 | with this number of output units, followed by a BatchNorm layer. 32 | Otherwise these layers will not be appended. Default: 256 for 33 | 'inception', 0 for 'resnet*' 34 | norm : bool, optional 35 | If True, will normalize the feature to be unit L2-norm for each sample. 36 | Otherwise will append a ReLU layer after the above Linear layer if 37 | num_features > 0. Default: False 38 | dropout : float, optional 39 | If positive, will append a Dropout layer with this dropout rate. 40 | Default: 0 41 | num_classes : int, optional 42 | If positive, will append a Linear layer at the end as the classifier 43 | with this number of output units. Default: 0 44 | """ 45 | if name not in __factory: 46 | raise KeyError("Unknown model:", name) 47 | return __factory[name](*args, **kwargs) 48 | -------------------------------------------------------------------------------- /reid/models/__pycache__/DDAN.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/models/__pycache__/DDAN.cpython-37.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/PFE.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/models/__pycache__/PFE.cpython-37.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/basenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/models/__pycache__/basenet.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/basenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/models/__pycache__/basenet.cpython-37.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/disnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/models/__pycache__/disnet.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/disnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/models/__pycache__/disnet.cpython-37.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/gannet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/models/__pycache__/gannet.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/gannet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/models/__pycache__/gannet.cpython-37.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/kdnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/models/__pycache__/kdnet.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/kdnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/models/__pycache__/kdnet.cpython-37.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/mobilenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/models/__pycache__/mobilenet.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/mobilenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/models/__pycache__/mobilenet.cpython-37.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/mobilenet_o.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/models/__pycache__/mobilenet_o.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/mobilenet_o.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/models/__pycache__/mobilenet_o.cpython-37.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/models/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/models/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/totalnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/models/__pycache__/totalnet.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/totalnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/models/__pycache__/totalnet.cpython-37.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/transfernet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/models/__pycache__/transfernet.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/transfernet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/models/__pycache__/transfernet.cpython-37.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/transfernet2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/models/__pycache__/transfernet2.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/transfernet2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/models/__pycache__/transfernet2.cpython-37.pyc -------------------------------------------------------------------------------- /reid/models/mobilenet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torchvision.models.utils import load_state_dict_from_url 3 | 4 | 5 | __all__ = ['MobileNetV2', 'mobilenet_v2'] 6 | 7 | 8 | model_urls = { 9 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 10 | } 11 | 12 | 13 | def _make_divisible(v, divisor, min_value=None): 14 | """ 15 | This function is taken from the original tf repo. 16 | It ensures that all layers have a channel number that is divisible by 8 17 | It can be seen here: 18 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 19 | :param v: 20 | :param divisor: 21 | :param min_value: 22 | :return: 23 | """ 24 | if min_value is None: 25 | min_value = divisor 26 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 27 | # Make sure that round down does not go down by more than 10%. 28 | if new_v < 0.9 * v: 29 | new_v += divisor 30 | return new_v 31 | 32 | 33 | class ConvBNReLU(nn.Sequential): 34 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 35 | padding = (kernel_size - 1) // 2 36 | super(ConvBNReLU, self).__init__( 37 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 38 | nn.BatchNorm2d(out_planes), 39 | nn.ReLU6(inplace=True) 40 | ) 41 | 42 | 43 | class InvertedResidual(nn.Module): 44 | def __init__(self, inp, oup, stride, expand_ratio): 45 | super(InvertedResidual, self).__init__() 46 | self.stride = stride 47 | assert stride in [1, 2] 48 | 49 | hidden_dim = int(round(inp * expand_ratio)) 50 | self.use_res_connect = self.stride == 1 and inp == oup 51 | 52 | layers = [] 53 | if expand_ratio != 1: 54 | # pw 55 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 56 | layers.extend([ 57 | # dw 58 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), 59 | # pw-linear 60 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 61 | nn.BatchNorm2d(oup), 62 | ]) 63 | self.conv = nn.Sequential(*layers) 64 | 65 | def forward(self, x): 66 | if self.use_res_connect: 67 | return x + self.conv(x) 68 | else: 69 | return self.conv(x) 70 | 71 | 72 | class MobileNetV2(nn.Module): 73 | def __init__(self, 74 | num_classes=1000, 75 | width_mult=1.0, 76 | inverted_residual_setting=None, 77 | round_nearest=8, 78 | block=None): 79 | 80 | """ 81 | MobileNet V2 main class 82 | Args: 83 | num_classes (int): Number of classes 84 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 85 | inverted_residual_setting: Network structure 86 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 87 | Set to 1 to turn off rounding 88 | block: Module specifying inverted residual building block for mobilenet 89 | """ 90 | super(MobileNetV2, self).__init__() 91 | 92 | if block is None: 93 | block = InvertedResidual 94 | input_channel = 32 95 | last_channel = 1280 96 | 97 | if inverted_residual_setting is None: 98 | inverted_residual_setting = [ 99 | # t, c, n, s 100 | [1, 16, 1, 1], 101 | [6, 24, 2, 2], 102 | [6, 32, 3, 2], 103 | [6, 64, 4, 2], 104 | [6, 96, 3, 1], 105 | [6, 160, 3, 2], 106 | [6, 320, 1, 1], 107 | ] 108 | 109 | # only check the first element, assuming user knows t,c,n,s are required 110 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 111 | raise ValueError("inverted_residual_setting should be non-empty " 112 | "or a 4-element list, got {}".format(inverted_residual_setting)) 113 | 114 | # building first layer 115 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 116 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 117 | features = [ConvBNReLU(3, input_channel, stride=2)] 118 | # building inverted residual blocks 119 | for t, c, n, s in inverted_residual_setting: 120 | output_channel = _make_divisible(c * width_mult, round_nearest) 121 | for i in range(n): 122 | stride = s if i == 0 else 1 123 | features.append(block(input_channel, output_channel, stride, expand_ratio=t)) 124 | input_channel = output_channel 125 | # building last several layers 126 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 127 | # make it nn.Sequential 128 | self.features = nn.Sequential(*features) 129 | 130 | # building classifier 131 | self.classifier = nn.Sequential( 132 | nn.Dropout(0.2), 133 | nn.Linear(self.last_channel, num_classes), 134 | ) 135 | 136 | # weight initialization 137 | for m in self.modules(): 138 | if isinstance(m, nn.Conv2d): 139 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 140 | if m.bias is not None: 141 | nn.init.zeros_(m.bias) 142 | elif isinstance(m, nn.BatchNorm2d): 143 | nn.init.ones_(m.weight) 144 | nn.init.zeros_(m.bias) 145 | elif isinstance(m, nn.Linear): 146 | nn.init.normal_(m.weight, 0, 0.01) 147 | nn.init.zeros_(m.bias) 148 | 149 | def forward(self, x, p): 150 | if p == 'teacher': 151 | for i, module in enumerate(self.features): 152 | x = module(x) 153 | return x 154 | # x = self.features(x) 155 | if p == 'encoder': 156 | for i, module in enumerate(self.features): 157 | # print(i,module) 158 | if i == 18: 159 | break 160 | x = module(x) 161 | if p == 'task': 162 | for i, module in enumerate(self.features): 163 | # print(module) 164 | if i < 18: 165 | continue 166 | x = module(x) 167 | 168 | # x = x.mean([2, 3]) 169 | # x = self.classifier(x) 170 | return x 171 | 172 | # Allow for accessing forward method in a inherited class 173 | # forward = _forward 174 | 175 | 176 | def mobilenet_v2(pretrained=False, progress=True, **kwargs): 177 | """ 178 | Constructs a MobileNetV2 architecture from 179 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 180 | Args: 181 | pretrained (bool): If True, returns a model pre-trained on ImageNet 182 | progress (bool): If True, displays a progress bar of the download to stderr 183 | """ 184 | model = MobileNetV2(**kwargs) 185 | if pretrained: 186 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], 187 | progress=progress) 188 | model.load_state_dict(state_dict) 189 | return model 190 | -------------------------------------------------------------------------------- /reid/trainers.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | 4 | import torch 5 | from torch.autograd import Variable 6 | 7 | from .evaluation_metrics import accuracy 8 | from .loss import TripletLoss 9 | from .utils.meters import AverageMeter 10 | import torch.nn.functional as F 11 | 12 | 13 | class BaseTrainer(object): 14 | def __init__(self, model, criterion, InvNet=None): 15 | super(BaseTrainer, self).__init__() 16 | self.model = model 17 | self.criterion = criterion 18 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 19 | self.InvNet = InvNet 20 | 21 | def train(self, epoch, data_loader, optimizer, tri_weight, adv_weight, mem_weight, print_freq=1): 22 | optimizer_Encoder, optimizer_Transfer, optimizer_Cam = optimizer 23 | self.Encoder, self.Transfer, self.CamDis = self.model 24 | 25 | self.Encoder.train() 26 | self.Transfer.train() 27 | self.CamDis.train() 28 | 29 | batch_time = AverageMeter() 30 | data_time = AverageMeter() 31 | losses_id = AverageMeter() 32 | losses_tri = AverageMeter() 33 | losses_cam = AverageMeter() 34 | losses_s_cam = AverageMeter() 35 | losses_mem = AverageMeter() 36 | precisions = AverageMeter() 37 | 38 | end = time.time() 39 | for i, inputs in enumerate(data_loader): 40 | data_time.update(time.time() - end) 41 | 42 | inputs, _, pids, cams, domain = self._parse_data(inputs) 43 | loss_id, loss_tri, loss_cam, loss_s_cam, loss_mem, prec1 = self._forward(inputs, pids, cams, domain, epoch, step=i) 44 | 45 | losses_id.update(loss_id.item(), pids.size(0)) 46 | losses_tri.update(loss_tri.item(), pids.size(0)) 47 | losses_cam.update(loss_cam.item(), pids.size(0)) 48 | losses_s_cam.update(loss_s_cam.item(), pids.size(0)) 49 | losses_mem.update(loss_mem.item(), pids.size(0)) 50 | precisions.update(prec1, pids.size(0)) 51 | 52 | # if epoch > 10: 53 | # loss = loss_id + tri_weight * loss_tri 54 | if epoch < 3: 55 | loss = loss_id + tri_weight * loss_tri + adv_weight * loss_s_cam + 0 * loss_mem 56 | else: 57 | loss = loss_id + tri_weight * loss_tri + adv_weight * loss_s_cam + mem_weight * loss_mem 58 | optimizer_Transfer.zero_grad() 59 | optimizer_Encoder.zero_grad() 60 | loss.backward() 61 | optimizer_Transfer.step() 62 | optimizer_Encoder.step() 63 | 64 | loss = loss_cam 65 | optimizer_Cam.zero_grad() 66 | loss.backward() 67 | optimizer_Cam.step() 68 | 69 | batch_time.update(time.time() - end) 70 | end = time.time() 71 | 72 | if (i + 1) % print_freq == 0: 73 | print('Epoch: [{}][{}/{}]\t' 74 | 'Time {:.3f} ({:.3f})\t' 75 | # 'Data {:.3f} ({:.3f})\t' 76 | 'ID {:.3f} ({:.3f})\t' 77 | 'Tri {:.3f} ({:.3f})\t' 78 | 'cam {:.3f} ({:.3f})\t' 79 | 'advcam {:.3f} ({:.3f})\t' 80 | 'mem {:.5f} ({:.5f})\t' 81 | 'Prec {:.2%} ({:.2%})\t' 82 | .format(epoch, i + 1, len(data_loader), 83 | batch_time.val, batch_time.avg, 84 | # data_time.val, data_time.avg, 85 | losses_id.val, losses_id.avg, 86 | losses_tri.val, losses_tri.avg, 87 | losses_cam.val, losses_cam.avg, 88 | losses_s_cam.val, losses_s_cam.avg, 89 | losses_mem.val, losses_mem.avg, 90 | precisions.val, precisions.avg)) 91 | 92 | def _parse_data(self, inputs): 93 | raise NotImplementedError 94 | 95 | def _forward(self, inputs, pids): 96 | raise NotImplementedError 97 | 98 | 99 | class Trainer(BaseTrainer): 100 | def _parse_data(self, inputs): 101 | imgs, fnames, pids, cams, domain = inputs 102 | inputs = imgs.to(self.device) 103 | pids = pids.to(self.device) 104 | cams = cams.to(self.device) 105 | domain = domain.to(self.device) 106 | return inputs, fnames, pids, cams, domain 107 | 108 | def _forward(self, inputs, pids, cams, domain, epoch, step): 109 | x_feature = self.Encoder(inputs) 110 | _, trans_feature = self.Transfer(x_feature.clone().detach()) # transfer feature 111 | s_outputs, s_feature = self.Transfer(x_feature) 112 | 113 | # id 114 | loss_id = self.criterion[0](s_outputs, pids) 115 | loss_tri = self.criterion[1](s_feature, pids) 116 | prec, = accuracy(s_outputs.data, pids.data) 117 | prec = prec[0] 118 | 119 | # domain(cam) 120 | c_outputs = self.CamDis(trans_feature.clone().detach()) 121 | loss_cam = self.criterion[0](c_outputs, cams) 122 | # 信息熵 123 | outputs_cam_s = self.CamDis(trans_feature) 124 | loss_s_cam = -torch.mean(torch.log(F.softmax(outputs_cam_s + 1e-6))) 125 | 126 | # newcam = torch.ones(cams.shape).long().cuda() 127 | # loss_s_cam = self.criterion[0](outputs_cam_s, newcam) 128 | 129 | # mem 130 | loss_mem = self.InvNet(trans_feature, pids, domain, epoch=epoch, step=step) 131 | 132 | return loss_id, loss_tri, loss_cam, loss_s_cam, loss_mem, prec 133 | -------------------------------------------------------------------------------- /reid/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def to_numpy(tensor): 7 | if torch.is_tensor(tensor): 8 | return tensor.cpu().numpy() 9 | elif type(tensor).__module__ != 'numpy': 10 | raise ValueError("Cannot convert {} to numpy array" 11 | .format(type(tensor))) 12 | return tensor 13 | 14 | 15 | def to_torch(ndarray): 16 | if type(ndarray).__module__ == 'numpy': 17 | return torch.from_numpy(ndarray) 18 | elif not torch.is_tensor(ndarray): 19 | raise ValueError("Cannot convert {} to torch tensor" 20 | .format(type(ndarray))) 21 | return ndarray 22 | -------------------------------------------------------------------------------- /reid/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/logging.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/utils/__pycache__/logging.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/logging.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/utils/__pycache__/logging.cpython-37.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/meters.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/utils/__pycache__/meters.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/meters.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/utils/__pycache__/meters.cpython-37.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/osutils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/utils/__pycache__/osutils.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/osutils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/utils/__pycache__/osutils.cpython-37.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/serialization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/utils/__pycache__/serialization.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/__pycache__/serialization.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/utils/__pycache__/serialization.cpython-37.pyc -------------------------------------------------------------------------------- /reid/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .preprocessor import Preprocessor 4 | -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/utils/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/utils/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/preprocessor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/utils/data/__pycache__/preprocessor.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/preprocessor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/utils/data/__pycache__/preprocessor.cpython-37.pyc -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/utils/data/__pycache__/sampler.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/sampler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/utils/data/__pycache__/sampler.cpython-37.pyc -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/utils/data/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeixianChen/DDAN/7440576eb787e0afc4c45e9ac090a5e153c0cbe5/reid/utils/data/__pycache__/transforms.cpython-37.pyc -------------------------------------------------------------------------------- /reid/utils/data/preprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os.path as osp 3 | 4 | from PIL import Image 5 | import torch 6 | import numpy as np 7 | 8 | class IdentityPreprocessor(object): 9 | def __init__(self, dataset, root=None, transform=None): 10 | super(IdentityPreprocessor, self).__init__() 11 | self.dataset = dataset 12 | self.root = root 13 | self.transform = transform 14 | self.pindex = 0 15 | 16 | def __len__(self): 17 | return len(self.dataset) 18 | 19 | def __getitem__(self, indices): 20 | if isinstance(indices, (tuple, list)): 21 | return [self._get_single_item(index) for index in indices] 22 | return self._get_single_item(indices) 23 | 24 | def _get_single_item(self, index): 25 | fname, pid, camid, domainall = self.dataset[index] 26 | fpath = fname 27 | try: 28 | if self.root is not None: 29 | fpath = osp.join(self.root, fname) 30 | img = Image.open(fpath).convert('RGB') 31 | except: 32 | fpath = osp.join(self.root_, fname) 33 | img = Image.open(fpath).convert('RGB') 34 | if self.transform is not None: 35 | img = self.transform(img) 36 | return img, fname, pid, camid, domainall 37 | 38 | class Preprocessor(object): 39 | def __init__(self, dataset, root=None, transform=None): 40 | super(Preprocessor, self).__init__() 41 | self.dataset = dataset 42 | self.root = root 43 | self.transform = transform 44 | 45 | def __len__(self): 46 | return len(self.dataset) 47 | 48 | def __getitem__(self, indices): 49 | if isinstance(indices, (tuple, list)): 50 | return [self._get_single_item(index) for index in indices] 51 | return self._get_single_item(indices) 52 | 53 | def _get_single_item(self, index): 54 | fname, pid, camid,domainall = self.dataset[index] 55 | fpath = fname 56 | if self.root is not None: 57 | fpath = osp.join(self.root, fname) 58 | img = Image.open(fpath).convert('RGB') 59 | if self.transform is not None: 60 | img = self.transform(img) 61 | return img, fname, pid, camid,domainall 62 | 63 | -------------------------------------------------------------------------------- /reid/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data.sampler import ( 7 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 8 | WeightedRandomSampler) 9 | 10 | import copy 11 | import random 12 | 13 | import torch 14 | from collections import defaultdict 15 | 16 | import numpy as np 17 | 18 | 19 | class RandomIdentitySampler(Sampler): 20 | """ 21 | Randomly sample N identities, then for each identity, 22 | randomly sample K instances, therefore batch size is N*K. 23 | Args: 24 | - data_source (list): list of (img_path, pid, camid). 25 | - num_instances (int): number of instances per identity in a batch. 26 | - batch_size (int): number of examples in a batch. 27 | """ 28 | 29 | def __init__(self, data_source, batch_size, num_instances): 30 | self.data_source = data_source 31 | self.batch_size = batch_size 32 | self.num_instances = num_instances 33 | self.num_pids_per_batch = self.batch_size // self.num_instances 34 | self.index_dic = defaultdict(list) 35 | for index, (_, pid, _, _) in enumerate(self.data_source): 36 | self.index_dic[pid].append(index) 37 | self.pids = list(self.index_dic.keys()) 38 | 39 | # estimate number of examples in an epoch 40 | self.length = 0 41 | for pid in self.pids: 42 | idxs = self.index_dic[pid] 43 | num = len(idxs) 44 | if num < self.num_instances: 45 | num = self.num_instances 46 | self.length += num - num % self.num_instances 47 | 48 | def __iter__(self): 49 | batch_idxs_dict = defaultdict(list) 50 | 51 | for pid in self.pids: 52 | idxs = copy.deepcopy(self.index_dic[pid]) 53 | if len(idxs) < self.num_instances: 54 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 55 | random.shuffle(idxs) 56 | batch_idxs = [] 57 | for idx in idxs: 58 | batch_idxs.append(idx) 59 | if len(batch_idxs) == self.num_instances: 60 | batch_idxs_dict[pid].append(batch_idxs) 61 | batch_idxs = [] 62 | 63 | avai_pids = copy.deepcopy(self.pids) 64 | final_idxs = [] 65 | 66 | while len(avai_pids) >= self.num_pids_per_batch: 67 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 68 | for pid in selected_pids: 69 | batch_idxs = batch_idxs_dict[pid].pop(0) 70 | final_idxs.extend(batch_idxs) 71 | if len(batch_idxs_dict[pid]) == 0: 72 | avai_pids.remove(pid) 73 | 74 | self.length = len(final_idxs) 75 | return iter(final_idxs) 76 | 77 | def __len__(self): 78 | return self.length 79 | # class RandomIdentitySampler(Sampler): 80 | # def __init__(self, data_source, num_instances=1): 81 | # self.data_source = data_source 82 | # self.num_instances = num_instances 83 | # self.index_dic = defaultdict(list) 84 | # for index, (_, pid, _) in enumerate(data_source): 85 | # self.index_dic[pid].append(index) 86 | # self.pids = list(self.index_dic.keys()) 87 | # self.num_samples = len(self.pids) 88 | 89 | # def __len__(self): 90 | # return self.num_samples * self.num_instances 91 | 92 | # def __iter__(self): 93 | # indices = torch.randperm(self.num_samples) 94 | # ret = [] 95 | # for i in indices: 96 | # pid = self.pids[i] 97 | # t = self.index_dic[pid] 98 | # if len(t) >= self.num_instances: 99 | # t = np.random.choice(t, size=self.num_instances, replace=False) 100 | # else: 101 | # t = np.random.choice(t, size=self.num_instances, replace=True) 102 | # ret.extend(t) 103 | # return iter(ret) 104 | 105 | class IdentitySampler(Sampler): 106 | """Sample person identities evenly in each batch. 107 | Args: 108 | train_color_label, train_thermal_label: labels of two modalities 109 | color_pos, thermal_pos: positions of each identity 110 | batchSize: batch size 111 | """ 112 | 113 | def __init__(self, data_source, num_instances=1): 114 | self.index_dic = defaultdict(list) 115 | for index, (_, pid, _) in enumerate(data_source): 116 | self.index_dic[pid].append(index) 117 | self.pids = np.array(list(self.index_dic.keys())) 118 | # uni_label = np.unique(train_color_label) 119 | self.n_classes = len(self.pids) 120 | # sample_thermal = np.arange(batchSize) 121 | self.N = len(data_source) 122 | self.num_instances = num_instances 123 | 124 | self.ret = [] 125 | 126 | for pid, pid_indexes in self.index_dic.items(): 127 | if len(pid_indexes) % self.num_instances != 0: 128 | r = np.random.choice(pid_indexes, len(pid_indexes) % self.num_instances, replace=False) 129 | pid_indexes += r.tolist() 130 | random.shuffle(pid_indexes) 131 | for i in range(0, len(pid_indexes), self.num_instances): 132 | self.ret.append(pid_indexes[i:i+self.num_instances]) 133 | random.shuffle(self.ret) 134 | self.ret = [i for j in self.ret for i in j] 135 | 136 | 137 | def __iter__(self): 138 | self.ret = [] 139 | for pid, pid_indexes in self.index_dic.items(): 140 | if len(pid_indexes) % self.num_instances != 0: 141 | r = np.random.choice(pid_indexes, len(pid_indexes) % self.num_instances, replace=False) 142 | pid_indexes += r.tolist() 143 | random.shuffle(pid_indexes) 144 | for i in range(0, len(pid_indexes), self.num_instances): 145 | self.ret.append(pid_indexes[i:i+self.num_instances]) 146 | random.shuffle(self.ret) 147 | self.ret = [i for j in self.ret for i in j] 148 | return iter(self.ret) 149 | 150 | def __len__(self): 151 | return len(self.ret) 152 | -------------------------------------------------------------------------------- /reid/utils/data/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | from PIL import Image 5 | import random 6 | import math 7 | 8 | # random.seed(0) 9 | 10 | class RectScale(object): 11 | def __init__(self, height, width, interpolation=Image.BILINEAR): 12 | self.height = height 13 | self.width = width 14 | self.interpolation = interpolation 15 | 16 | def __call__(self, img): 17 | w, h = img.size 18 | if h == self.height and w == self.width: 19 | return img 20 | return img.resize((self.width, self.height), self.interpolation) 21 | 22 | 23 | class RandomSizedRectCrop(object): 24 | def __init__(self, height, width, interpolation=Image.BILINEAR): 25 | self.height = height 26 | self.width = width 27 | self.interpolation = interpolation 28 | 29 | def __call__(self, img): 30 | for attempt in range(10): 31 | area = img.size[0] * img.size[1] 32 | target_area = random.uniform(0.64, 1.0) * area 33 | aspect_ratio = random.uniform(2, 3) 34 | 35 | h = int(round(math.sqrt(target_area * aspect_ratio))) 36 | w = int(round(math.sqrt(target_area / aspect_ratio))) 37 | 38 | if w <= img.size[0] and h <= img.size[1]: 39 | x1 = random.randint(0, img.size[0] - w) 40 | y1 = random.randint(0, img.size[1] - h) 41 | 42 | img = img.crop((x1, y1, x1 + w, y1 + h)) 43 | assert(img.size == (w, h)) 44 | 45 | return img.resize((self.width, self.height), self.interpolation) 46 | 47 | # Fallback 48 | scale = RectScale(self.height, self.width, 49 | interpolation=self.interpolation) 50 | return scale(img) 51 | 52 | 53 | class RandomErasing(object): 54 | def __init__(self, EPSILON=0.5, mean=[0.485, 0.456, 0.406]): 55 | self.EPSILON = EPSILON 56 | self.mean = mean 57 | 58 | def __call__(self, img): 59 | 60 | if random.uniform(0, 1) > self.EPSILON: 61 | return img 62 | 63 | for attempt in range(100): 64 | area = img.size()[1] * img.size()[2] 65 | 66 | target_area = random.uniform(0.02, 0.2) * area 67 | aspect_ratio = random.uniform(0.3, 3) 68 | 69 | h = int(round(math.sqrt(target_area * aspect_ratio))) 70 | w = int(round(math.sqrt(target_area / aspect_ratio))) 71 | 72 | if w <= img.size()[2] and h <= img.size()[1]: 73 | x1 = random.randint(0, img.size()[1] - h) 74 | y1 = random.randint(0, img.size()[2] - w) 75 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 76 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 77 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 78 | 79 | return img 80 | 81 | return img 82 | -------------------------------------------------------------------------------- /reid/utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /reid/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /reid/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /reid/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os.path as osp 4 | import shutil 5 | 6 | import torch 7 | from torch.nn import Parameter 8 | 9 | from .osutils import mkdir_if_missing 10 | 11 | 12 | def save_checkpoint(state, fpath='checkpoint.pth.tar'): 13 | mkdir_if_missing(osp.dirname(fpath)) 14 | torch.save(state, fpath) 15 | 16 | 17 | def load_checkpoint(fpath): 18 | if osp.isfile(fpath): 19 | checkpoint = torch.load(fpath) 20 | print("=> Loaded checkpoint '{}'".format(fpath)) 21 | return checkpoint 22 | else: 23 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 24 | 25 | 26 | def copy_state_dict(state_dict, model, strip=None): 27 | tgt_state = model.state_dict() 28 | copied_names = set() 29 | for name, param in state_dict.items(): 30 | if strip is not None and name.startswith(strip): 31 | name = name[len(strip):] 32 | if name not in tgt_state: 33 | continue 34 | if isinstance(param, Parameter): 35 | param = param.data 36 | if param.size() != tgt_state[name].size(): 37 | print('mismatch:', name, param.size(), tgt_state[name].size()) 38 | continue 39 | tgt_state[name].copy_(param) 40 | copied_names.add(name) 41 | 42 | missing = set(tgt_state.keys()) - copied_names 43 | if len(missing) > 0: 44 | print("missing keys in state_dict:", missing) 45 | 46 | return model 47 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=3 \ 2 | python -u transmem.py \ 3 | --data-dir /home/chenpeixian/reid/dataset/ \ 4 | -a DDAN \ 5 | -b 64 \ 6 | --height 256 \ 7 | --width 128 \ 8 | --logs-dir ./logs/ \ 9 | --epoch 100 \ 10 | --workers=4 \ 11 | --lr 0.1 \ 12 | --num-instance 4 \ 13 | --tri-weight 1 \ 14 | --margin 0.3 \ 15 | --adv-weight 0.18 \ 16 | --mem-weight 0.05 \ 17 | --knn 8 \ 18 | --beta 0.002 \ 19 | --alpha 0.05 \ 20 | --features 1280 \ 21 | --dropout 0.5 \ 22 | --seed 0 \ 23 | # --resume ./logs/checkpoint-100.pth.tar \ 24 | # --evaluate \ 25 | 26 | -------------------------------------------------------------------------------- /transmem.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import argparse 3 | import os.path as osp 4 | 5 | import numpy as np 6 | import sys 7 | import torch 8 | 9 | from torch import nn 10 | from torch.backends import cudnn 11 | from torch.utils.data import DataLoader 12 | 13 | from reid.datasets.domain_adaptation import DA 14 | from reid.datasets.TotalData import TotalData 15 | from reid import models 16 | from reid.trainers import Trainer 17 | from reid.evaluators import Evaluator 18 | from reid.utils.data import transforms as T 19 | from reid.loss import TripletLoss, InvNet 20 | from reid.utils.data.preprocessor import Preprocessor 21 | from reid.utils.data.sampler import RandomIdentitySampler, IdentitySampler 22 | from reid.utils.logging import Logger 23 | from reid.utils.serialization import load_checkpoint, save_checkpoint 24 | 25 | def fix(s): 26 | import torch 27 | torch.manual_seed(s) 28 | torch.cuda.manual_seed(s) 29 | torch.cuda.manual_seed_all(s) 30 | torch.backends.cudnn.deterministic = True 31 | torch.backends.cudnn.benchmark = False 32 | import numpy as np 33 | np.random.seed(s) 34 | import random 35 | random.seed(s) 36 | 37 | def get_data(data_dir, height, width, batch_size, num_instances, re=0, workers=8): 38 | 39 | dataset = DA(data_dir) 40 | test_dataset = TotalData(data_dir) 41 | 42 | 43 | 44 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 45 | std=[0.229, 0.224, 0.225]) 46 | 47 | num_classes = dataset.num_source_ids 48 | 49 | train_transformer = T.Compose([ 50 | T.Resize((256, 128), interpolation=3), 51 | T.Pad(10), 52 | T.RandomCrop((256,128)), 53 | T.RandomHorizontalFlip(0.5), 54 | T.RandomRotation(5), 55 | T.ColorJitter(brightness=(0.5, 2.0), saturation=(0.5, 2.0), hue=(-0.1, 0.1)), 56 | T.ToTensor(), 57 | normalizer, 58 | # T.RandomErasing(EPSILON=re), 59 | ]) 60 | 61 | test_transformer = T.Compose([ 62 | T.Resize((256, 128), interpolation=3), 63 | T.ToTensor(), 64 | normalizer, 65 | ]) 66 | 67 | # Train 68 | source_train_loader = DataLoader( 69 | Preprocessor(dataset.source_train, 70 | transform=train_transformer), 71 | batch_size=batch_size, num_workers=workers, 72 | # shuffle=True, pin_memory=True, drop_last=True) 73 | sampler=RandomIdentitySampler(dataset.source_train, batch_size, num_instances), 74 | pin_memory=True, drop_last=True) 75 | 76 | # Test 77 | grid_query_loader = DataLoader( 78 | Preprocessor(test_dataset.grid_query, 79 | root=osp.join(test_dataset.grid_images_dir, test_dataset.query_path), transform=test_transformer), 80 | batch_size=64, num_workers=4, 81 | shuffle=False, pin_memory=True) 82 | grid_gallery_loader = DataLoader( 83 | Preprocessor(test_dataset.grid_gallery, 84 | root=osp.join(test_dataset.grid_images_dir, test_dataset.gallery_path), transform=test_transformer), 85 | batch_size=64, num_workers=4, 86 | shuffle=False, pin_memory=True) 87 | prid_query_loader = DataLoader( 88 | Preprocessor(test_dataset.prid_query, 89 | root=osp.join(test_dataset.prid_images_dir, test_dataset.query_path), transform=test_transformer), 90 | batch_size=64, num_workers=4, 91 | shuffle=False, pin_memory=True) 92 | prid_gallery_loader = DataLoader( 93 | Preprocessor(test_dataset.prid_gallery, 94 | root=osp.join(test_dataset.prid_images_dir, test_dataset.gallery_path), transform=test_transformer), 95 | batch_size=64, num_workers=4, 96 | shuffle=False, pin_memory=True) 97 | viper_query_loader = DataLoader( 98 | Preprocessor(test_dataset.viper_query, 99 | root=osp.join(test_dataset.viper_images_dir, test_dataset.query_path), transform=test_transformer), 100 | batch_size=64, num_workers=4, 101 | shuffle=False, pin_memory=True) 102 | viper_gallery_loader = DataLoader( 103 | Preprocessor(test_dataset.viper_gallery, 104 | root=osp.join(test_dataset.viper_images_dir, test_dataset.gallery_path), transform=test_transformer), 105 | batch_size=64, num_workers=4, 106 | shuffle=False, pin_memory=True) 107 | ilid_query_loader = DataLoader( 108 | Preprocessor(test_dataset.ilid_query, 109 | root=osp.join(test_dataset.ilid_images_dir, "images"), transform=test_transformer), 110 | batch_size=64, num_workers=4, 111 | shuffle=False, pin_memory=True) 112 | ilid_gallery_loader = DataLoader( 113 | Preprocessor(test_dataset.ilid_gallery, 114 | root=osp.join(test_dataset.ilid_images_dir, "images"), transform=test_transformer), 115 | batch_size=64, num_workers=4, 116 | shuffle=False, pin_memory=True) 117 | 118 | 119 | return dataset, test_dataset, num_classes, source_train_loader, grid_query_loader, grid_gallery_loader,prid_query_loader, prid_gallery_loader,viper_query_loader, viper_gallery_loader, ilid_query_loader, ilid_gallery_loader 120 | 121 | 122 | def main(args): 123 | fix(args.seed) 124 | # Redirect print to both console and log file 125 | if not args.evaluate: 126 | sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) 127 | 128 | print(args) 129 | # Create data loaders 130 | dataset, test_dataset, num_classes, source_train_loader, grid_query_loader, grid_gallery_loader,prid_query_loader, prid_gallery_loader,viper_query_loader, viper_gallery_loader, ilid_query_loader, ilid_gallery_loader = \ 131 | get_data(args.data_dir, args.height, 132 | args.width, args.batch_size, args.num_instance, args.re, args.workers) 133 | 134 | # Create model 135 | Encoder, Transfer, CamDis = models.create(args.arch, num_features=args.features, 136 | dropout=args.dropout, num_classes=num_classes) 137 | 138 | invNet = InvNet(args.features, num_classes, args.batch_size, beta=args.beta, knn=args.knn, alpha=args.alpha).cuda() 139 | 140 | # Load from checkpoint 141 | start_epoch = 0 142 | if args.resume: 143 | checkpoint = load_checkpoint(args.resume) 144 | Encoder.load_state_dict(checkpoint['Encoder']) 145 | Transfer.load_state_dict(checkpoint['Transfer']) 146 | CamDis.load_state_dict(checkpoint['CamDis']) 147 | invNet.load_state_dict(checkpoint['InvNet']) 148 | start_epoch = checkpoint['epoch'] 149 | 150 | Encoder = Encoder.cuda() 151 | Transfer = Transfer.cuda() 152 | CamDis = CamDis.cuda() 153 | 154 | model = [Encoder, Transfer, CamDis] 155 | # Evaluator 156 | evaluator = Evaluator(model) 157 | if args.evaluate: 158 | # ----------------------------- 159 | v = evaluator.eval_viper(viper_query_loader, viper_gallery_loader, test_dataset.viper_query, test_dataset.viper_gallery, args.output_feature, seed=57) 160 | p = evaluator.eval_prid(prid_query_loader, prid_gallery_loader, test_dataset.prid_query, test_dataset.prid_gallery, args.output_feature, seed=40) 161 | g = evaluator.eval_grid(grid_query_loader, grid_gallery_loader, test_dataset.grid_query, test_dataset.grid_gallery, args.output_feature, seed=35) 162 | l = evaluator.eval_ilids(ilid_query_loader, test_dataset.ilid_query, args.output_feature, seed=24) 163 | # ----------------------------- 164 | 165 | criterion = [] 166 | criterion.append(nn.CrossEntropyLoss().cuda()) 167 | criterion.append(TripletLoss(margin=args.margin)) 168 | 169 | 170 | # Optimizer 171 | base_param_ids = set(map(id, Encoder.base.parameters())) 172 | new_params = [p for p in Encoder.parameters() if 173 | id(p) not in base_param_ids] 174 | param_groups = [ 175 | {'params': Encoder.base.parameters(), 'lr_mult': 0.1}, 176 | {'params': new_params, 'lr_mult': 1.0}] 177 | 178 | optimizer_Encoder = torch.optim.SGD(param_groups, lr=args.lr, 179 | momentum=0.9, weight_decay=5e-4, nesterov=True) 180 | # ==== 181 | base_param_ids = set(map(id, Transfer.base.parameters())) 182 | new_params = [p for p in Transfer.parameters() if 183 | id(p) not in base_param_ids] 184 | param_groups = [ 185 | {'params': Transfer.base.parameters(), 'lr_mult': 0.1}, 186 | {'params': new_params, 'lr_mult': 1.0}] 187 | 188 | optimizer_Transfer = torch.optim.SGD(param_groups, lr=args.lr, 189 | momentum=0.9, weight_decay=5e-4, nesterov=True) 190 | # ==== 191 | param_groups = [ 192 | {'params':CamDis.parameters(), 'lr_mult':1.0}, 193 | ] 194 | optimizer_Cam = torch.optim.SGD(param_groups, lr=args.lr,momentum=0.9, weight_decay=5e-4, nesterov=True) 195 | 196 | optimizer = [optimizer_Encoder, optimizer_Transfer, optimizer_Cam] 197 | 198 | # Trainer 199 | trainer = Trainer(model, criterion, InvNet=invNet) 200 | 201 | # Schedule learning rate 202 | def adjust_lr(epoch): 203 | step_size = 40 204 | lr = args.lr * (0.1 ** ((epoch) // step_size)) 205 | for g in optimizer_Encoder.param_groups: 206 | g['lr'] = lr * g.get('lr_mult', 1) 207 | for g in optimizer_Transfer.param_groups: 208 | g['lr'] = lr * g.get('lr_mult', 1) 209 | for g in optimizer_Cam.param_groups: 210 | g['lr'] = lr * g.get('lr_mult', 1) 211 | 212 | # Start training 213 | for epoch in range(start_epoch, args.epochs): 214 | adjust_lr(epoch) 215 | trainer.train(epoch, source_train_loader, optimizer, args.tri_weight, args.adv_weight, args.mem_weight) 216 | 217 | save_checkpoint({ 218 | 'Encoder': Encoder.state_dict(), 219 | 'Transfer': Transfer.state_dict(), 220 | 'CamDis': CamDis.state_dict(), 221 | 'InvNet': invNet.state_dict(), 222 | 'epoch': epoch + 1, 223 | }, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar')) 224 | 225 | evaluator = Evaluator(model) 226 | print('\n * Finished epoch {:3d} \n'. 227 | format(epoch)) 228 | 229 | # Final test 230 | print('Test with best model:') 231 | evaluator = Evaluator(model) 232 | evaluator.evaluate(query_loader, gallery_loader, dataset.query, dataset.gallery, args.output_feature, args.rerank) 233 | 234 | 235 | if __name__ == '__main__': 236 | parser = argparse.ArgumentParser(description="baseline") 237 | # seed 238 | parser.add_argument('--seed', type=int, default=0) 239 | parser.add_argument('-b', '--batch-size', type=int, default=128, help="batch size for source") 240 | parser.add_argument('-j', '--workers', type=int, default=1) 241 | parser.add_argument('--height', type=int, default=256, 242 | help="input height, default: 256") 243 | parser.add_argument('--width', type=int, default=128, 244 | help="input width, default: 128") 245 | # model 246 | parser.add_argument('-a', '--arch', type=str, default='resnet50', 247 | choices=models.names()) 248 | parser.add_argument('--features', type=int, default=1024) 249 | parser.add_argument('--dropout', type=float, default=0.5) 250 | # optimizer 251 | parser.add_argument('--lr', type=float, default=0.1) 252 | parser.add_argument('--momentum', type=float, default=0.9) 253 | parser.add_argument('--weight-decay', type=float, default=5e-4) 254 | # training configs 255 | parser.add_argument('--resume', type=str, default='', metavar='PATH') 256 | parser.add_argument('--evaluate', action='store_true', 257 | help="evaluation only") 258 | parser.add_argument('--epochs', type=int, default=60) 259 | parser.add_argument('--print-freq', type=int, default=1) 260 | # metric learning 261 | parser.add_argument('--dist-metric', type=str, default='euclidean') 262 | parser.add_argument('--num-instance', type=int, default=4) 263 | parser.add_argument('--tri-weight', type=float, default=0.3) 264 | parser.add_argument('--margin',type=float,default=0.3) 265 | 266 | parser.add_argument('--adv-weight', type=float, default=0.5) 267 | parser.add_argument('--mem-weight', type=float, default=0.5) 268 | parser.add_argument('--knn', type=int, default=3) 269 | parser.add_argument('--beta', type=float, default=0.05) 270 | parser.add_argument('--alpha', type=float, default=0.01) 271 | 272 | # misc 273 | working_dir = osp.dirname(osp.abspath(__file__)) 274 | parser.add_argument('--data-dir', type=str, metavar='PATH', 275 | default=osp.join(working_dir, 'data')) 276 | parser.add_argument('--logs-dir', type=str, metavar='PATH', 277 | default=osp.join(working_dir, 'logs')) 278 | parser.add_argument('--output_feature', type=str, default='pool5') 279 | #random erasing 280 | parser.add_argument('--re', type=float, default=0) 281 | # perform re-ranking 282 | parser.add_argument('--rerank', action='store_true', help="perform re-ranking") 283 | 284 | main(parser.parse_args()) 285 | --------------------------------------------------------------------------------