├── Figures ├── CVPR23.jpg └── Framework.png ├── README.md ├── configs ├── default_img.py ├── default_img_single.py └── res50_cels_cal.yaml ├── data ├── __init__.py ├── dataloader.py ├── dataset_loader.py ├── datasets │ ├── ltcc.py │ └── prcc.py ├── img_transforms.py └── samplers.py ├── demo.sh ├── demo_single_image.py ├── losses ├── __init__.py ├── arcface_loss.py ├── circle_loss.py ├── clothes_based_adversarial_loss.py ├── contrastive_loss.py ├── cosface_loss.py ├── cross_entropy_loss_with_label_smooth.py ├── gather.py └── triplet_loss.py ├── main.py ├── models ├── Fuse.py ├── Fusion.py ├── PM.py ├── ResNet.py ├── __init__.py ├── classifier.py ├── img_resnet.py ├── lr_scheduler.py └── utils │ ├── c3d_blocks.py │ ├── inflate.py │ ├── nonlocal_blocks.py │ └── pooling.py ├── test.py ├── test_AIM.sh ├── tools ├── eval_metrics.py └── utils.py └── train.py /Figures/CVPR23.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoomShakaY/AIM-CCReID/c3bda2b54a3c5d81eb65ea838ae3502aecd61b67/Figures/CVPR23.jpg -------------------------------------------------------------------------------- /Figures/Framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoomShakaY/AIM-CCReID/c3bda2b54a3c5d81eb65ea838ae3502aecd61b67/Figures/Framework.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AIM-CCReID 2 | An official implementation of our CVPR'23 paper: 3 | 4 | 《Good is Bad: Causality Inspired Cloth-Debiasing for Cloth-Changing Person Re-Identification》 5 | 6 | [\[Paper Link\]](https://openaccess.thecvf.com/content/CVPR2023/papers/Yang_Good_Is_Bad_Causality_Inspired_Cloth-Debiasing_for_Cloth-Changing_Person_Re-Identification_CVPR_2023_paper.pdf) 7 | [\[Repo\]](https://github.com/BoomShakaY/AIM-CCReID) 8 | [\[About Me\]](https://gavinyoung1.github.io/) 9 | 10 | 11 |
12 | motivation 13 | framework 14 |
15 | 16 | #### News 17 | 2023.10.24 The full codes are released! 18 | 19 | #### Requirements 20 | - Python 3.6 21 | - Pytorch 1.6.0 22 | - yacs 23 | - apex 24 | (remind: the apex is optional, not recommended if you have enough GPU memory; just comment all amp related codes) 25 | 26 | ## Performance of AIM 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 |
RRCCLTCC
Standard Cloth-Changing Standard Cloth-Changing
R@1mAPR@1mAPR@1mAPR@1mAP
Paper100.099.957.958.376.341.140.619.1
Repo100.099.858.258.075.941.740.819.2
74 | The indicators provided in this repo are broadly the same as those in the paper, and possibly even better (depending on what your focus is) 75 | 76 | ## Datasets 77 | PRCC is available at [Here](https://drive.google.com/file/d/1yTYawRm4ap3M-j0PjLQJ--xmZHseFDLz/view). 78 | 79 | LTCC is available at [Here](https://naiq.github.io/LTCC_Perosn_ReID.html). 80 | 81 | LaST is available at [Here](https://github.com/shuxjweb/last). 82 | 83 | ## Testing 84 | The trained models (weights) are available at [Baidu Disk](https://pan.baidu.com/s/1Du1XgoCim6I_bZtNRm3yPw?pwd=v4ly) or [Google Drive](https://drive.google.com/drive/folders/1xohg_OAHjNyy7LLq3Fq_KowcEP9IlY8k?usp=sharing). 85 | You will find the testing script for prcc and ltcc at `test_AIM.sh`, then modify the resume path to your own path where you placed the weights file. 86 | 87 | To be noticed, you need to modify the `DATA ROOT` and `OUTPUT` in the `configs/default_img.py` to your own path before testing. 88 | 89 | ## 📖 Citation 90 | 91 | If you find our work useful in your research, please consider citing: 92 | 93 | ```bibtex 94 | @inproceedings{yang2023good, 95 | title={Good is bad: Causality inspired cloth-debiasing for cloth-changing person re-identification}, 96 | author={Yang, Zhengwei and Lin, Meng and Zhong, Xian and Wu, Yu and Wang, Zheng}, 97 | booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition}, 98 | pages={1472--1481}, 99 | year={2023} 100 | } 101 | -------------------------------------------------------------------------------- /configs/default_img.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from yacs.config import CfgNode as CN 4 | import time 5 | 6 | 7 | _C = CN() 8 | # ----------------------------------------------------------------------------- 9 | # Data settings 10 | # ----------------------------------------------------------------------------- 11 | _C.DATA = CN() 12 | # Root path for dataset directory 13 | _C.DATA.ROOT = 'DATA_ROOT' 14 | # Dataset for evaluation 15 | _C.DATA.DATASET = 'ltcc' 16 | # Workers for dataloader 17 | _C.DATA.NUM_WORKERS = 8 18 | # Height of input image 19 | _C.DATA.HEIGHT = 384 20 | # Width of input image 21 | _C.DATA.WIDTH = 192 22 | # Batch size for training 23 | _C.DATA.TRAIN_BATCH = 32 24 | # Batch size for testing 25 | _C.DATA.TEST_BATCH = 128 26 | # The number of instances per identity for training sampler 27 | _C.DATA.NUM_INSTANCES = 8 28 | # ----------------------------------------------------------------------------- 29 | # Augmentation settings 30 | # ----------------------------------------------------------------------------- 31 | _C.AUG = CN() 32 | # Random crop prob 33 | _C.AUG.RC_PROB = 0.5 34 | # Random erase prob 35 | _C.AUG.RE_PROB = 0.5 36 | # Random flip prob 37 | _C.AUG.RF_PROB = 0.5 38 | # ----------------------------------------------------------------------------- 39 | # Model settings 40 | # ----------------------------------------------------------------------------- 41 | _C.MODEL = CN() 42 | # Model name 43 | _C.MODEL.NAME = 'resnet50' 44 | # The stride for laery4 in resnet 45 | _C.MODEL.RES4_STRIDE = 1 46 | # feature dim 47 | _C.MODEL.FEATURE_DIM = 4096 48 | # Model path for resuming 49 | _C.MODEL.RESUME = '' 50 | # Global pooling after the backbone 51 | _C.MODEL.POOLING = CN() 52 | # Choose in ['avg', 'max', 'gem', 'maxavg'] 53 | _C.MODEL.POOLING.NAME = 'maxavg' 54 | # Initialized power for GeM pooling 55 | _C.MODEL.POOLING.P = 3 56 | # ----------------------------------------------------------------------------- 57 | # Losses for training 58 | # ----------------------------------------------------------------------------- 59 | _C.LOSS = CN() 60 | # Classification loss 61 | _C.LOSS.CLA_LOSS = 'crossentropylabelsmooth' 62 | # Clothes classification loss 63 | _C.LOSS.CLOTHES_CLA_LOSS = 'cosface' 64 | # Scale for classification loss 65 | _C.LOSS.CLA_S = 16. 66 | # Margin for classification loss 67 | _C.LOSS.CLA_M = 0. 68 | # Clothes-based adversarial loss 69 | _C.LOSS.CAL = 'cal' 70 | # Epsilon for clothes-based adversarial loss 71 | _C.LOSS.EPSILON = 0.1 72 | # Momentum for clothes-based adversarial loss with memory bank 73 | _C.LOSS.MOMENTUM = 0. 74 | # ----------------------------------------------------------------------------- 75 | # Training settings 76 | # ----------------------------------------------------------------------------- 77 | _C.TRAIN = CN() 78 | _C.TRAIN.START_EPOCH = 0 79 | _C.TRAIN.MAX_EPOCH = 100 80 | # Start epoch for clothes classification 81 | _C.TRAIN.START_EPOCH_CC = 25 82 | # Start epoch for adversarial training 83 | _C.TRAIN.START_EPOCH_ADV = 25 84 | # Start epoch for debias 85 | _C.TRAIN.START_EPOCH_GENERAL = 25 86 | # Optimizer 87 | _C.TRAIN.OPTIMIZER = CN() 88 | _C.TRAIN.OPTIMIZER.NAME = 'adam' 89 | # Learning rate 90 | _C.TRAIN.OPTIMIZER.LR = 0.00035 91 | _C.TRAIN.OPTIMIZER.WEIGHT_DECAY = 5e-4 92 | # LR scheduler 93 | _C.TRAIN.LR_SCHEDULER = CN() 94 | # Stepsize to decay learning rate 95 | _C.TRAIN.LR_SCHEDULER.STEPSIZE = [20, 40] 96 | # LR decay rate, used in StepLRScheduler 97 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 98 | # ----------------------------------------------------------------------------- 99 | # Testing settings 100 | # ----------------------------------------------------------------------------- 101 | _C.TEST = CN() 102 | # Perform evaluation after every N epochs (set to -1 to test after training) 103 | _C.TEST.EVAL_STEP = 5 104 | # Start to evaluate after specific epoch 105 | _C.TEST.START_EVAL = 0 106 | # ----------------------------------------------------------------------------- 107 | # Misc 108 | # ----------------------------------------------------------------------------- 109 | # Fixed random seed 110 | _C.SEED = 1 111 | # Perform evaluation only 112 | _C.EVAL_MODE = False 113 | # GPU device ids for CUDA_VISIBLE_DEVICES 114 | _C.GPU = '0' 115 | # Path to output folder, overwritten by command line argument 116 | _C.OUTPUT = 'OUTPUT_PATH' 117 | # Tag of experiment, overwritten by command line argument 118 | _C.TAG = 'eval' 119 | # ----------------------------------------------------------------------------- 120 | # Hyperparameters 121 | _C.k_cal = 1.0 122 | _C.k_kl = 1.0 123 | # ----------------------------------------------------------------------------- 124 | 125 | def update_config(config, args): 126 | config.defrost() 127 | config.merge_from_file(args.cfg) 128 | 129 | # merge from specific arguments 130 | if args.root: 131 | config.DATA.ROOT = args.root 132 | if args.output: 133 | config.OUTPUT = args.output 134 | if args.resume: 135 | config.MODEL.RESUME = args.resume 136 | if args.eval: 137 | config.EVAL_MODE = True 138 | if args.tag: 139 | config.TAG = args.tag 140 | if args.dataset: 141 | config.DATA.DATASET = args.dataset 142 | if args.gpu: 143 | config.GPU = args.gpu 144 | 145 | # output folder 146 | config.OUTPUT = os.path.join(config.OUTPUT, config.DATA.DATASET, config.TAG) 147 | config.freeze() 148 | 149 | 150 | def get_img_config(args): 151 | """Get a yacs CfgNode object with default values.""" 152 | config = _C.clone() 153 | update_config(config, args) 154 | 155 | return config 156 | -------------------------------------------------------------------------------- /configs/default_img_single.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from yacs.config import CfgNode as CN 4 | 5 | 6 | _C = CN() 7 | # ----------------------------------------------------------------------------- 8 | # Data settings 9 | # ----------------------------------------------------------------------------- 10 | _C.DATA = CN() 11 | # Root path for dataset directory 12 | _C.DATA.ROOT = 'DATA_ROOT' 13 | # Dataset for evaluation 14 | _C.DATA.DATASET = 'ltcc' 15 | # Workers for dataloader 16 | _C.DATA.NUM_WORKERS = 4 17 | # Height of input image 18 | _C.DATA.HEIGHT = 384 19 | # Width of input image 20 | _C.DATA.WIDTH = 192 21 | # Batch size for training 22 | _C.DATA.TRAIN_BATCH = 32 # org:32 23 | # Batch size for testing 24 | _C.DATA.TEST_BATCH = 128 # org:128 25 | # The number of instances per identity for training sampler 26 | _C.DATA.NUM_INSTANCES = 8 27 | # ----------------------------------------------------------------------------- 28 | # Augmentation settings 29 | # ----------------------------------------------------------------------------- 30 | _C.AUG = CN() 31 | # Random crop prob 32 | _C.AUG.RC_PROB = 0.5 33 | # Random erase prob 34 | _C.AUG.RE_PROB = 0.5 35 | # Random flip prob 36 | _C.AUG.RF_PROB = 0.5 37 | # ----------------------------------------------------------------------------- 38 | # Model settings 39 | # ----------------------------------------------------------------------------- 40 | _C.MODEL = CN() 41 | # Model name 42 | _C.MODEL.NAME = 'resnet50' 43 | # The stride for laery4 in resnet 44 | _C.MODEL.RES4_STRIDE = 1 45 | # feature dim 46 | _C.MODEL.FEATURE_DIM = 4096 # orgin: 4096 47 | # Model path for resuming 48 | _C.MODEL.RESUME = '' 49 | # Global pooling after the backbone 50 | _C.MODEL.POOLING = CN() 51 | # Choose in ['avg', 'max', 'gem', 'maxavg'] 52 | _C.MODEL.POOLING.NAME = 'maxavg' # orgin: maxavg 53 | # Initialized power for GeM pooling 54 | _C.MODEL.POOLING.P = 3 55 | # ----------------------------------------------------------------------------- 56 | # Model2 settings 57 | # ----------------------------------------------------------------------------- 58 | _C.MODEL2 = CN() 59 | _C.MODEL2.NAME = 'hpm' 60 | # ----------------------------------------------------------------------------- 61 | # Losses for training 62 | # ----------------------------------------------------------------------------- 63 | _C.LOSS = CN() 64 | # Classification loss 65 | _C.LOSS.CLA_LOSS = 'crossentropy' 66 | # Clothes classification loss 67 | _C.LOSS.CLOTHES_CLA_LOSS = 'cosface' 68 | # Scale for classification loss 69 | _C.LOSS.CLA_S = 16. 70 | # Margin for classification loss 71 | _C.LOSS.CLA_M = 0. 72 | # Pairwise loss 73 | _C.LOSS.PAIR_LOSS = 'triplet' 74 | # The weight for pairwise loss 75 | _C.LOSS.PAIR_LOSS_WEIGHT = 0.0 76 | # Scale for pairwise loss 77 | _C.LOSS.PAIR_S = 16. 78 | # Margin for pairwise loss 79 | _C.LOSS.PAIR_M = 0.3 80 | # Clothes-based adversarial loss 81 | _C.LOSS.CAL = 'cal' 82 | # Epsilon for clothes-based adversarial loss 83 | _C.LOSS.EPSILON = 0.1 84 | # Momentum for clothes-based adversarial loss with memory bank 85 | _C.LOSS.MOMENTUM = 0. 86 | # ----------------------------------------------------------------------------- 87 | # Training settings 88 | # ----------------------------------------------------------------------------- 89 | _C.TRAIN = CN() 90 | _C.TRAIN.START_EPOCH = 0 91 | _C.TRAIN.MAX_EPOCH = 70 92 | # Start epoch for clothes classification 93 | _C.TRAIN.START_EPOCH_CC = 25 # org:25 94 | # Start epoch for adversarial training 95 | _C.TRAIN.START_EPOCH_ADV = 25 # org:25 96 | # Optimizer 97 | _C.TRAIN.OPTIMIZER = CN() 98 | _C.TRAIN.OPTIMIZER.NAME = 'adam' 99 | # Learning rate 100 | _C.TRAIN.OPTIMIZER.LR = 0.00035 101 | _C.TRAIN.OPTIMIZER.WEIGHT_DECAY = 5e-4 102 | # LR scheduler 103 | _C.TRAIN.LR_SCHEDULER = CN() 104 | # Stepsize to decay learning rate 105 | _C.TRAIN.LR_SCHEDULER.STEPSIZE = [20, 40] #, 60] 106 | # LR decay rate, used in StepLRScheduler 107 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 108 | # Using amp for training 109 | _C.TRAIN.AMP = False 110 | # ----------------------------------------------------------------------------- 111 | # Testing settings 112 | # ----------------------------------------------------------------------------- 113 | _C.TEST = CN() 114 | # Perform evaluation after every N epochs (set to -1 to test after training) 115 | _C.TEST.EVAL_STEP = 5 116 | # Start to evaluate after specific epoch 117 | _C.TEST.START_EVAL = 0 118 | # ----------------------------------------------------------------------------- 119 | # Misc 120 | # ----------------------------------------------------------------------------- 121 | # Fixed random seed 122 | _C.SEED = 1 123 | # Perform evaluation only 124 | _C.EVAL_MODE = False 125 | # GPU device ids for CUDA_VISIBLE_DEVICES 126 | _C.GPU = '0' 127 | # Path to output folder, overwritten by command line argument 128 | _C.OUTPUT = 'OUTPUT_PATH' 129 | # Tag of experiment, overwritten by command line argument 130 | _C.TAG = 'eval' 131 | # ----------------------------------------------------------------------------- 132 | # Hyperparameters 133 | _C.k_cal = 1 134 | _C.k_kl = 1 135 | 136 | 137 | def update_config(config, args): 138 | config.defrost() 139 | config.merge_from_file(args.cfg) 140 | 141 | # output folder 142 | config.OUTPUT = os.path.join(config.OUTPUT, config.DATA.DATASET) 143 | config.freeze() 144 | 145 | 146 | def get_img_config(args): 147 | """Get a yacs CfgNode object with default values.""" 148 | config = _C.clone() 149 | update_config(config, args) 150 | 151 | return config 152 | -------------------------------------------------------------------------------- /configs/res50_cels_cal.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: resnet50 3 | LOSS: 4 | CLA_LOSS: crossentropylabelsmooth 5 | CAL: cal -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import data.img_transforms as T 2 | from data.dataloader import DataLoaderX 3 | from data.dataset_loader import ImageDataset 4 | from data.samplers import DistributedRandomIdentitySampler, DistributedInferenceSampler 5 | from data.datasets.ltcc import LTCC 6 | from data.datasets.prcc import PRCC 7 | 8 | __factory = { 9 | 'ltcc': LTCC, 10 | 'prcc': PRCC, 11 | } 12 | 13 | def get_names(): 14 | return list(__factory.keys()) 15 | 16 | 17 | def build_dataset(config): 18 | if config.DATA.DATASET not in __factory.keys(): 19 | raise KeyError("Invalid dataset, got '{}', but expected to be one of {}".format(config.DATA.DATASET, __factory.keys())) 20 | 21 | dataset = __factory[config.DATA.DATASET](root=config.DATA.ROOT) 22 | 23 | return dataset 24 | 25 | 26 | def build_img_transforms(config): 27 | transform_train = T.Compose([ 28 | T.Resize((config.DATA.HEIGHT, config.DATA.WIDTH)), 29 | T.RandomCroping(p=config.AUG.RC_PROB), 30 | T.RandomHorizontalFlip(p=config.AUG.RF_PROB), 31 | T.ToTensor(), 32 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 33 | T.RandomErasing(probability=config.AUG.RE_PROB) 34 | ]) 35 | transform_test = T.Compose([ 36 | T.Resize((config.DATA.HEIGHT, config.DATA.WIDTH)), 37 | T.ToTensor(), 38 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 39 | ]) 40 | 41 | return transform_train, transform_test 42 | 43 | 44 | def build_dataloader(config): 45 | dataset = build_dataset(config) 46 | transform_train, transform_test = build_img_transforms(config) 47 | train_sampler = DistributedRandomIdentitySampler(dataset.train, 48 | num_instances=config.DATA.NUM_INSTANCES, 49 | seed=config.SEED) 50 | trainloader = DataLoaderX(dataset=ImageDataset(dataset.train, transform=transform_train), 51 | sampler=train_sampler, 52 | batch_size=config.DATA.TRAIN_BATCH, num_workers=config.DATA.NUM_WORKERS, 53 | pin_memory=True, drop_last=True) 54 | 55 | galleryloader = DataLoaderX(dataset=ImageDataset(dataset.gallery, transform=transform_test), 56 | sampler=DistributedInferenceSampler(dataset.gallery), 57 | batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS, 58 | pin_memory=True, drop_last=False, shuffle=False) 59 | 60 | if config.DATA.DATASET == 'prcc': 61 | queryloader_same = DataLoaderX(dataset=ImageDataset(dataset.query_same, transform=transform_test), 62 | sampler=DistributedInferenceSampler(dataset.query_same), 63 | batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS, 64 | pin_memory=True, drop_last=False, shuffle=False) 65 | queryloader_diff = DataLoaderX(dataset=ImageDataset(dataset.query_diff, transform=transform_test), 66 | sampler=DistributedInferenceSampler(dataset.query_diff), 67 | batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS, 68 | pin_memory=True, drop_last=False, shuffle=False) 69 | 70 | return trainloader, queryloader_same, queryloader_diff, galleryloader, dataset, train_sampler 71 | else: 72 | queryloader = DataLoaderX(dataset=ImageDataset(dataset.query, transform=transform_test), 73 | sampler=DistributedInferenceSampler(dataset.query), 74 | batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS, 75 | pin_memory=True, drop_last=False, shuffle=False) 76 | 77 | return trainloader, queryloader, galleryloader, dataset, train_sampler -------------------------------------------------------------------------------- /data/dataloader.py: -------------------------------------------------------------------------------- 1 | # refer to: https://github.com/JDAI-CV/fast-reid/blob/master/fastreid/data/data_utils.py 2 | 3 | import torch 4 | import threading 5 | import queue 6 | from torch.utils.data import DataLoader 7 | from torch import distributed as dist 8 | 9 | 10 | """ 11 | #based on http://stackoverflow.com/questions/7323664/python-generator-pre-fetch 12 | This is a single-function package that transforms arbitrary generator into a background-thead generator that 13 | prefetches several batches of data in a parallel background thead. 14 | 15 | This is useful if you have a computationally heavy process (CPU or GPU) that 16 | iteratively processes minibatches from the generator while the generator 17 | consumes some other resource (disk IO / loading from database / more CPU if you have unused cores). 18 | 19 | By default these two processes will constantly wait for one another to finish. If you make generator work in 20 | prefetch mode (see examples below), they will work in parallel, potentially saving you your GPU time. 21 | We personally use the prefetch generator when iterating minibatches of data for deep learning with PyTorch etc. 22 | 23 | Quick usage example (ipython notebook) - https://github.com/justheuristic/prefetch_generator/blob/master/example.ipynb 24 | This package contains this object 25 | - BackgroundGenerator(any_other_generator[,max_prefetch = something]) 26 | """ 27 | 28 | 29 | class BackgroundGenerator(threading.Thread): 30 | """ 31 | the usage is below 32 | >> for batch in BackgroundGenerator(my_minibatch_iterator): 33 | >> doit() 34 | More details are written in the BackgroundGenerator doc 35 | >> help(BackgroundGenerator) 36 | """ 37 | 38 | def __init__(self, generator, local_rank, max_prefetch=10): 39 | """ 40 | This function transforms generator into a background-thead generator. 41 | :param generator: generator or genexp or any 42 | It can be used with any minibatch generator. 43 | 44 | It is quite lightweight, but not entirely weightless. 45 | Using global variables inside generator is not recommended (may raise GIL and zero-out the 46 | benefit of having a background thread.) 47 | The ideal use case is when everything it requires is store inside it and everything it 48 | outputs is passed through queue. 49 | 50 | There's no restriction on doing weird stuff, reading/writing files, retrieving 51 | URLs [or whatever] wlilst iterating. 52 | 53 | :param max_prefetch: defines, how many iterations (at most) can background generator keep 54 | stored at any moment of time. 55 | Whenever there's already max_prefetch batches stored in queue, the background process will halt until 56 | one of these batches is dequeued. 57 | 58 | !Default max_prefetch=1 is okay unless you deal with some weird file IO in your generator! 59 | 60 | Setting max_prefetch to -1 lets it store as many batches as it can, which will work 61 | slightly (if any) faster, but will require storing 62 | all batches in memory. If you use infinite generator with max_prefetch=-1, it will exceed the RAM size 63 | unless dequeued quickly enough. 64 | """ 65 | super().__init__() 66 | self.queue = queue.Queue(max_prefetch) 67 | self.generator = generator 68 | self.local_rank = local_rank 69 | self.daemon = True 70 | self.exit_event = threading.Event() 71 | self.start() 72 | 73 | def run(self): 74 | torch.cuda.set_device(self.local_rank) 75 | for item in self.generator: 76 | if self.exit_event.is_set(): 77 | break 78 | self.queue.put(item) 79 | self.queue.put(None) 80 | 81 | def next(self): 82 | next_item = self.queue.get() 83 | if next_item is None: 84 | raise StopIteration 85 | return next_item 86 | 87 | # Python 3 compatibility 88 | def __next__(self): 89 | return self.next() 90 | 91 | def __iter__(self): 92 | return self 93 | 94 | 95 | class DataLoaderX(DataLoader): 96 | def __init__(self, **kwargs): 97 | super().__init__(**kwargs) 98 | local_rank = dist.get_rank() 99 | self.stream = torch.cuda.Stream(local_rank) # create a new cuda stream in each process 100 | self.local_rank = local_rank 101 | 102 | def __iter__(self): 103 | self.iter = super().__iter__() 104 | self.iter = BackgroundGenerator(self.iter, self.local_rank) 105 | self.preload() 106 | return self 107 | 108 | def _shutdown_background_thread(self): 109 | if not self.iter.is_alive(): 110 | # avoid re-entrance or ill-conditioned thread state 111 | return 112 | 113 | # Set exit event to True for background threading stopping 114 | self.iter.exit_event.set() 115 | 116 | # Exhaust all remaining elements, so that the queue becomes empty, 117 | # and the thread should quit 118 | for _ in self.iter: 119 | pass 120 | 121 | # Waiting for background thread to quit 122 | self.iter.join() 123 | 124 | def preload(self): 125 | self.batch = next(self.iter, None) 126 | if self.batch is None: 127 | return None 128 | with torch.cuda.stream(self.stream): 129 | # if isinstance(self.batch[0], torch.Tensor): 130 | # self.batch[0] = self.batch[0].to(device=self.local_rank, non_blocking=True) 131 | for k, v in enumerate(self.batch): 132 | if isinstance(self.batch[k], torch.Tensor): 133 | self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True) 134 | 135 | def __next__(self): 136 | torch.cuda.current_stream().wait_stream( 137 | self.stream 138 | ) # wait tensor to put on GPU 139 | batch = self.batch 140 | if batch is None: 141 | raise StopIteration 142 | self.preload() 143 | return batch 144 | 145 | # Signal for shutting down background thread 146 | def shutdown(self): 147 | # If the dataloader is to be freed, shutdown its BackgroundGenerator 148 | self._shutdown_background_thread() 149 | -------------------------------------------------------------------------------- /data/dataset_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import functools 3 | import os.path as osp 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | 8 | def read_image(img_path): 9 | """Keep reading image until succeed. 10 | This can avoid IOError incurred by heavy IO process.""" 11 | got_img = False 12 | if not osp.exists(img_path): 13 | raise IOError("{} does not exist".format(img_path)) 14 | while not got_img: 15 | try: 16 | img = Image.open(img_path).convert('RGB') 17 | got_img = True 18 | except IOError: 19 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 20 | pass 21 | return img 22 | 23 | 24 | class ImageDataset(Dataset): 25 | """Image Person ReID Dataset""" 26 | def __init__(self, dataset, transform=None): 27 | self.dataset = dataset 28 | self.transform = transform 29 | 30 | def __len__(self): 31 | return len(self.dataset) 32 | 33 | def __getitem__(self, index): 34 | img_path, pid, camid, clothes_id = self.dataset[index] 35 | img = read_image(img_path) 36 | if self.transform is not None: 37 | img = self.transform(img) 38 | return img, pid, camid, clothes_id, img_path 39 | 40 | 41 | def pil_loader(path): 42 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 43 | with open(path, 'rb') as f: 44 | with Image.open(f) as img: 45 | return img.convert('RGB') 46 | 47 | 48 | def accimage_loader(path): 49 | try: 50 | import accimage 51 | return accimage.Image(path) 52 | except IOError: 53 | # Potentially a decoding problem, fall back to PIL.Image 54 | return pil_loader(path) 55 | 56 | 57 | def get_default_image_loader(): 58 | from torchvision import get_image_backend 59 | if get_image_backend() == 'accimage': 60 | return accimage_loader 61 | else: 62 | return pil_loader 63 | 64 | 65 | def image_loader(path): 66 | from torchvision import get_image_backend 67 | if get_image_backend() == 'accimage': 68 | return accimage_loader(path) 69 | else: 70 | return pil_loader(path) 71 | -------------------------------------------------------------------------------- /data/datasets/ltcc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import glob 4 | import h5py 5 | import random 6 | import math 7 | import logging 8 | import numpy as np 9 | import os.path as osp 10 | from scipy.io import loadmat 11 | from tools.utils import mkdir_if_missing, write_json, read_json 12 | 13 | 14 | class LTCC(object): 15 | """ LTCC 16 | 17 | Reference: 18 | Qian et al. Long-Term Cloth-Changing Person Re-identification. arXiv:2005.12633, 2020. 19 | 20 | URL: https://naiq.github.io/LTCC_Perosn_ReID.html# 21 | """ 22 | dataset_dir = 'LTCC_ReID' 23 | def __init__(self, root='data', **kwargs): 24 | self.dataset_dir = osp.join(root, self.dataset_dir) 25 | self.train_dir = osp.join(self.dataset_dir, 'train') 26 | self.query_dir = osp.join(self.dataset_dir, 'query') 27 | self.gallery_dir = osp.join(self.dataset_dir, 'test') 28 | self._check_before_run() 29 | 30 | train, num_train_pids, num_train_imgs, num_train_clothes, pid2clothes = \ 31 | self._process_dir_train(self.train_dir) 32 | query, gallery, num_test_pids, num_query_imgs, num_gallery_imgs, num_test_clothes = \ 33 | self._process_dir_test(self.query_dir, self.gallery_dir) 34 | num_total_pids = num_train_pids + num_test_pids 35 | num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs 36 | num_test_imgs = num_query_imgs + num_gallery_imgs 37 | num_total_clothes = num_train_clothes + num_test_clothes 38 | 39 | logger = logging.getLogger('reid.dataset') 40 | logger.info("=> LTCC loaded") 41 | logger.info("Dataset statistics:") 42 | logger.info(" ----------------------------------------") 43 | logger.info(" subset | # ids | # images | # clothes") 44 | logger.info(" ----------------------------------------") 45 | logger.info(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_clothes)) 46 | logger.info(" test | {:5d} | {:8d} | {:9d}".format(num_test_pids, num_test_imgs, num_test_clothes)) 47 | logger.info(" query | {:5d} | {:8d} |".format(num_test_pids, num_query_imgs)) 48 | logger.info(" gallery | {:5d} | {:8d} |".format(num_test_pids, num_gallery_imgs)) 49 | logger.info(" ----------------------------------------") 50 | logger.info(" total | {:5d} | {:8d} | {:9d}".format(num_total_pids, num_total_imgs, num_total_clothes)) 51 | logger.info(" ----------------------------------------") 52 | 53 | self.train = train 54 | self.query = query 55 | self.gallery = gallery 56 | 57 | self.num_train_pids = num_train_pids 58 | self.num_train_clothes = num_train_clothes 59 | self.pid2clothes = pid2clothes 60 | 61 | def _check_before_run(self): 62 | """Check if all files are available before going deeper""" 63 | if not osp.exists(self.dataset_dir): 64 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 65 | if not osp.exists(self.train_dir): 66 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 67 | if not osp.exists(self.query_dir): 68 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 69 | if not osp.exists(self.gallery_dir): 70 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 71 | 72 | def _process_dir_train(self, dir_path): 73 | img_paths = glob.glob(osp.join(dir_path, '*.png')) 74 | img_paths.sort() 75 | pattern1 = re.compile(r'(\d+)_(\d+)_c(\d+)') 76 | pattern2 = re.compile(r'(\w+)_c') 77 | 78 | pid_container = set() 79 | clothes_container = set() 80 | for img_path in img_paths: 81 | pid, _, _ = map(int, pattern1.search(img_path).groups()) 82 | clothes_id = pattern2.search(img_path).group(1) 83 | pid_container.add(pid) 84 | clothes_container.add(clothes_id) 85 | pid_container = sorted(pid_container) 86 | clothes_container = sorted(clothes_container) 87 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 88 | clothes2label = {clothes_id:label for label, clothes_id in enumerate(clothes_container)} 89 | 90 | num_pids = len(pid_container) 91 | num_clothes = len(clothes_container) 92 | 93 | dataset = [] 94 | pid2clothes = np.zeros((num_pids, num_clothes)) 95 | for img_path in img_paths: 96 | pid, _, camid = map(int, pattern1.search(img_path).groups()) 97 | clothes = pattern2.search(img_path).group(1) 98 | camid -= 1 # index starts from 0 99 | pid = pid2label[pid] 100 | clothes_id = clothes2label[clothes] 101 | dataset.append((img_path, pid, camid, clothes_id)) 102 | pid2clothes[pid, clothes_id] = 1 103 | 104 | num_imgs = len(dataset) 105 | 106 | return dataset, num_pids, num_imgs, num_clothes, pid2clothes 107 | 108 | def _process_dir_test(self, query_path, gallery_path): 109 | query_img_paths = glob.glob(osp.join(query_path, '*.png')) 110 | gallery_img_paths = glob.glob(osp.join(gallery_path, '*.png')) 111 | query_img_paths.sort() 112 | gallery_img_paths.sort() 113 | pattern1 = re.compile(r'(\d+)_(\d+)_c(\d+)') 114 | pattern2 = re.compile(r'(\w+)_c') 115 | 116 | pid_container = set() 117 | clothes_container = set() 118 | for img_path in query_img_paths: 119 | pid, _, _ = map(int, pattern1.search(img_path).groups()) 120 | clothes_id = pattern2.search(img_path).group(1) 121 | pid_container.add(pid) 122 | clothes_container.add(clothes_id) 123 | for img_path in gallery_img_paths: 124 | pid, _, _ = map(int, pattern1.search(img_path).groups()) 125 | clothes_id = pattern2.search(img_path).group(1) 126 | pid_container.add(pid) 127 | clothes_container.add(clothes_id) 128 | pid_container = sorted(pid_container) 129 | clothes_container = sorted(clothes_container) 130 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 131 | clothes2label = {clothes_id:label for label, clothes_id in enumerate(clothes_container)} 132 | 133 | num_pids = len(pid_container) 134 | num_clothes = len(clothes_container) 135 | 136 | query_dataset = [] 137 | gallery_dataset = [] 138 | for img_path in query_img_paths: 139 | pid, _, camid = map(int, pattern1.search(img_path).groups()) 140 | clothes_id = pattern2.search(img_path).group(1) 141 | camid -= 1 # index starts from 0 142 | clothes_id = clothes2label[clothes_id] 143 | query_dataset.append((img_path, pid, camid, clothes_id)) 144 | 145 | for img_path in gallery_img_paths: 146 | pid, _, camid = map(int, pattern1.search(img_path).groups()) 147 | clothes_id = pattern2.search(img_path).group(1) 148 | camid -= 1 # index starts from 0 149 | clothes_id = clothes2label[clothes_id] 150 | gallery_dataset.append((img_path, pid, camid, clothes_id)) 151 | 152 | num_imgs_query = len(query_dataset) 153 | num_imgs_gallery = len(gallery_dataset) 154 | 155 | return query_dataset, gallery_dataset, num_pids, num_imgs_query, num_imgs_gallery, num_clothes 156 | 157 | -------------------------------------------------------------------------------- /data/datasets/prcc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import glob 4 | import shutil 5 | import h5py 6 | import random 7 | import math 8 | import logging 9 | import numpy as np 10 | import os.path as osp 11 | 12 | 13 | class PRCC(object): 14 | """ PRCC 15 | 16 | Reference: 17 | Yang et al. Person Re-identification by Contour Sketch under Moderate Clothing Change. TPAMI, 2019. 18 | 19 | URL: https://drive.google.com/file/d/1yTYawRm4ap3M-j0PjLQJ--xmZHseFDLz/view 20 | """ 21 | dataset_dir = 'prcc' 22 | def __init__(self, root='data', single_shot=False, seed_sel=0, **kwargs): 23 | self.dataset_dir = osp.join(root, self.dataset_dir) 24 | self.train_dir = osp.join(self.dataset_dir, 'rgb/train') 25 | self.val_dir = osp.join(self.dataset_dir, 'rgb/val') 26 | self.test_dir = osp.join(self.dataset_dir, 'rgb/test') 27 | 28 | self._check_before_run() 29 | train, num_train_pids, num_train_imgs, num_train_clothes, pid2clothes = \ 30 | self._process_dir_train(self.train_dir) 31 | val, num_val_pids, num_val_imgs, num_val_clothes, _ = \ 32 | self._process_dir_train(self.val_dir) 33 | query_same, query_diff, gallery, num_test_pids, \ 34 | num_query_imgs_same, num_query_imgs_diff, num_gallery_imgs, \ 35 | num_test_clothes, gallery_idx = self._process_dir_test(self.test_dir) 36 | 37 | num_total_pids = num_train_pids + num_test_pids 38 | num_test_imgs = num_query_imgs_same + num_query_imgs_diff + num_gallery_imgs 39 | num_total_imgs = num_train_imgs + num_val_imgs + num_test_imgs 40 | num_total_clothes = num_train_clothes + num_test_clothes 41 | 42 | logger = logging.getLogger('reid.dataset') 43 | logger.info("=> PRCC loaded") 44 | logger.info("Dataset statistics:") 45 | logger.info(" --------------------------------------------") 46 | logger.info(" subset | # ids | # images | # clothes") 47 | logger.info(" --------------------------------------------") 48 | logger.info(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_clothes)) 49 | logger.info(" val | {:5d} | {:8d} | {:9d}".format(num_val_pids, num_val_imgs, num_val_clothes)) 50 | logger.info(" test | {:5d} | {:8d} | {:9d}".format(num_test_pids, num_test_imgs, num_test_clothes)) 51 | logger.info(" query(same) | {:5d} | {:8d} |".format(num_test_pids, num_query_imgs_same)) 52 | logger.info(" query(diff) | {:5d} | {:8d} |".format(num_test_pids, num_query_imgs_diff)) 53 | logger.info(" gallery | {:5d} | {:8d} |".format(num_test_pids, num_gallery_imgs)) 54 | logger.info(" --------------------------------------------") 55 | logger.info(" total | {:5d} | {:8d} | {:9d}".format(num_total_pids, num_total_imgs, num_total_clothes)) 56 | logger.info(" --------------------------------------------") 57 | 58 | self.train = train 59 | self.val = val 60 | self.query_same = query_same 61 | self.query_diff = query_diff 62 | self.gallery = gallery 63 | 64 | self.num_train_pids = num_train_pids 65 | self.num_train_clothes = num_train_clothes 66 | self.pid2clothes = pid2clothes 67 | self.gallery_idx = gallery_idx 68 | 69 | def _check_before_run(self): 70 | """Check if all files are available before going deeper""" 71 | if not osp.exists(self.dataset_dir): 72 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 73 | if not osp.exists(self.train_dir): 74 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 75 | if not osp.exists(self.val_dir): 76 | raise RuntimeError("'{}' is not available".format(self.val_dir)) 77 | if not osp.exists(self.test_dir): 78 | raise RuntimeError("'{}' is not available".format(self.test_dir)) 79 | 80 | def _process_dir_train(self, dir_path): 81 | pdirs = glob.glob(osp.join(dir_path, '*')) 82 | pdirs.sort() 83 | 84 | pid_container = set() 85 | clothes_container = set() 86 | for pdir in pdirs: 87 | pid = int(osp.basename(pdir)) 88 | pid_container.add(pid) 89 | img_dirs = glob.glob(osp.join(pdir, '*.jpg')) 90 | for img_dir in img_dirs: 91 | cam = osp.basename(img_dir)[0] # 'A' or 'B' or 'C' 92 | if cam in ['A', 'B']: 93 | clothes_container.add(osp.basename(pdir)) 94 | else: 95 | clothes_container.add(osp.basename(pdir)+osp.basename(img_dir)[0]) 96 | pid_container = sorted(pid_container) 97 | clothes_container = sorted(clothes_container) 98 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 99 | clothes2label = {clothes_id:label for label, clothes_id in enumerate(clothes_container)} 100 | cam2label = {'A': 0, 'B': 1, 'C': 2} 101 | 102 | num_pids = len(pid_container) 103 | num_clothes = len(clothes_container) 104 | 105 | dataset = [] 106 | pid2clothes = np.zeros((num_pids, num_clothes)) 107 | for pdir in pdirs: 108 | pid = int(osp.basename(pdir)) 109 | img_dirs = glob.glob(osp.join(pdir, '*.jpg')) 110 | for img_dir in img_dirs: 111 | cam = osp.basename(img_dir)[0] # 'A' or 'B' or 'C' 112 | label = pid2label[pid] 113 | camid = cam2label[cam] 114 | if cam in ['A', 'B']: 115 | clothes_id = clothes2label[osp.basename(pdir)] 116 | else: 117 | clothes_id = clothes2label[osp.basename(pdir)+osp.basename(img_dir)[0]] 118 | dataset.append((img_dir, label, camid, clothes_id)) 119 | pid2clothes[label, clothes_id] = 1 120 | 121 | num_imgs = len(dataset) 122 | 123 | return dataset, num_pids, num_imgs, num_clothes, pid2clothes 124 | 125 | def _process_dir_test(self, test_path): 126 | pdirs = glob.glob(osp.join(test_path, '*')) 127 | pdirs.sort() 128 | 129 | pid_container = set() 130 | for pdir in glob.glob(osp.join(test_path, 'A', '*')): 131 | pid = int(osp.basename(pdir)) 132 | pid_container.add(pid) 133 | pid_container = sorted(pid_container) 134 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 135 | cam2label = {'A': 0, 'B': 1, 'C': 2} 136 | 137 | num_pids = len(pid_container) 138 | num_clothes = num_pids * 2 139 | 140 | query_dataset_same_clothes = [] 141 | query_dataset_diff_clothes = [] 142 | gallery_dataset = [] 143 | for cam in ['A', 'B', 'C']: 144 | pdirs = glob.glob(osp.join(test_path, cam, '*')) 145 | for pdir in pdirs: 146 | pid = int(osp.basename(pdir)) 147 | img_dirs = glob.glob(osp.join(pdir, '*.jpg')) 148 | for img_dir in img_dirs: 149 | # pid = pid2label[pid] 150 | camid = cam2label[cam] 151 | if cam == 'A': 152 | clothes_id = pid2label[pid] * 2 153 | gallery_dataset.append((img_dir, pid, camid, clothes_id)) 154 | elif cam == 'B': 155 | clothes_id = pid2label[pid] * 2 156 | query_dataset_same_clothes.append((img_dir, pid, camid, clothes_id)) 157 | else: 158 | clothes_id = pid2label[pid] * 2 + 1 159 | query_dataset_diff_clothes.append((img_dir, pid, camid, clothes_id)) 160 | 161 | pid2imgidx = {} 162 | for idx, (img_dir, pid, camid, clothes_id) in enumerate(gallery_dataset): 163 | if pid not in pid2imgidx: 164 | pid2imgidx[pid] = [] 165 | pid2imgidx[pid].append(idx) 166 | 167 | # get 10 gallery index to perform single-shot test 168 | gallery_idx = {} 169 | random.seed(3) 170 | for idx in range(0, 10): 171 | gallery_idx[idx] = [] 172 | for pid in pid2imgidx: 173 | gallery_idx[idx].append(random.choice(pid2imgidx[pid])) 174 | 175 | num_imgs_query_same = len(query_dataset_same_clothes) 176 | num_imgs_query_diff = len(query_dataset_diff_clothes) 177 | num_imgs_gallery = len(gallery_dataset) 178 | 179 | return query_dataset_same_clothes, query_dataset_diff_clothes, gallery_dataset, \ 180 | num_pids, num_imgs_query_same, num_imgs_query_diff, num_imgs_gallery, \ 181 | num_clothes, gallery_idx 182 | -------------------------------------------------------------------------------- /data/img_transforms.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import * 2 | from PIL import Image 3 | import random 4 | import math 5 | 6 | 7 | class ResizeWithEqualScale(object): 8 | """ 9 | Resize an image with equal scale as the original image. 10 | 11 | Args: 12 | height (int): resized height. 13 | width (int): resized width. 14 | interpolation: interpolation manner. 15 | fill_color (tuple): color for padding. 16 | """ 17 | def __init__(self, height, width, interpolation=Image.BILINEAR, fill_color=(0,0,0)): 18 | self.height = height 19 | self.width = width 20 | self.interpolation = interpolation 21 | self.fill_color = fill_color 22 | 23 | def __call__(self, img): 24 | width, height = img.size 25 | if self.height / self.width >= height / width: 26 | height = int(self.width * (height / width)) 27 | width = self.width 28 | else: 29 | width = int(self.height * (width / height)) 30 | height = self.height 31 | 32 | resized_img = img.resize((width, height), self.interpolation) 33 | new_img = Image.new('RGB', (self.width, self.height), self.fill_color) 34 | new_img.paste(resized_img, (int((self.width - width) / 2), int((self.height - height) / 2))) 35 | 36 | return new_img 37 | 38 | 39 | class RandomCroping(object): 40 | """ 41 | With a probability, first increase image size to (1 + 1/8), and then perform random crop. 42 | 43 | Args: 44 | p (float): probability of performing this transformation. Default: 0.5. 45 | """ 46 | def __init__(self, p=0.5, interpolation=Image.BILINEAR): 47 | self.p = p 48 | self.interpolation = interpolation 49 | 50 | def __call__(self, img): 51 | """ 52 | Args: 53 | img (PIL Image): Image to be cropped. 54 | 55 | Returns: 56 | PIL Image: Cropped image. 57 | """ 58 | width, height = img.size 59 | if random.uniform(0, 1) >= self.p: 60 | return img 61 | 62 | new_width, new_height = int(round(width * 1.125)), int(round(height * 1.125)) 63 | resized_img = img.resize((new_width, new_height), self.interpolation) 64 | x_maxrange = new_width - width 65 | y_maxrange = new_height - height 66 | x1 = int(round(random.uniform(0, x_maxrange))) 67 | y1 = int(round(random.uniform(0, y_maxrange))) 68 | croped_img = resized_img.crop((x1, y1, x1 + width, y1 + height)) 69 | 70 | return croped_img 71 | 72 | 73 | class RandomErasing(object): 74 | """ 75 | Randomly selects a rectangle region in an image and erases its pixels. 76 | 77 | Reference: 78 | Zhong et al. Random Erasing Data Augmentation. arxiv: 1708.04896, 2017. 79 | 80 | Args: 81 | probability: The probability that the Random Erasing operation will be performed. 82 | sl: Minimum proportion of erased area against input image. 83 | sh: Maximum proportion of erased area against input image. 84 | r1: Minimum aspect ratio of erased area. 85 | mean: Erasing value. 86 | """ 87 | 88 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]): 89 | self.probability = probability 90 | self.mean = mean 91 | self.sl = sl 92 | self.sh = sh 93 | self.r1 = r1 94 | 95 | def __call__(self, img): 96 | 97 | if random.uniform(0, 1) >= self.probability: 98 | return img 99 | 100 | for attempt in range(100): 101 | area = img.size()[1] * img.size()[2] 102 | 103 | target_area = random.uniform(self.sl, self.sh) * area 104 | aspect_ratio = random.uniform(self.r1, 1/self.r1) 105 | 106 | h = int(round(math.sqrt(target_area * aspect_ratio))) 107 | w = int(round(math.sqrt(target_area / aspect_ratio))) 108 | 109 | if w < img.size()[2] and h < img.size()[1]: 110 | x1 = random.randint(0, img.size()[1] - h) 111 | y1 = random.randint(0, img.size()[2] - w) 112 | if img.size()[0] == 3: 113 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 114 | img[1, x1:x1+h, y1:y1+w] = self.mean[1] 115 | img[2, x1:x1+h, y1:y1+w] = self.mean[2] 116 | else: 117 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 118 | return img 119 | 120 | return img -------------------------------------------------------------------------------- /data/samplers.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import random 4 | import numpy as np 5 | from torch import distributed as dist 6 | from collections import defaultdict 7 | from torch.utils.data.sampler import Sampler 8 | 9 | 10 | class RandomIdentitySampler(Sampler): 11 | """ 12 | Randomly sample N identities, then for each identity, 13 | randomly sample K instances, therefore batch size is N*K. 14 | 15 | Args: 16 | data_source (Dataset): dataset to sample from. 17 | num_instances (int): number of instances per identity. 18 | """ 19 | def __init__(self, data_source, num_instances=4): 20 | self.data_source = data_source 21 | self.num_instances = num_instances 22 | self.index_dic = defaultdict(list) 23 | for index, (_, pid, _, _) in enumerate(data_source): 24 | self.index_dic[pid].append(index) 25 | self.pids = list(self.index_dic.keys()) 26 | self.num_identities = len(self.pids) 27 | 28 | # compute number of examples in an epoch 29 | self.length = 0 30 | for pid in self.pids: 31 | idxs = self.index_dic[pid] 32 | num = len(idxs) 33 | if num < self.num_instances: 34 | num = self.num_instances 35 | self.length += num - num % self.num_instances 36 | 37 | def __iter__(self): 38 | list_container = [] 39 | 40 | for pid in self.pids: 41 | idxs = copy.deepcopy(self.index_dic[pid]) 42 | if len(idxs) < self.num_instances: 43 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 44 | random.shuffle(idxs) 45 | batch_idxs = [] 46 | for idx in idxs: 47 | batch_idxs.append(idx) 48 | if len(batch_idxs) == self.num_instances: 49 | list_container.append(batch_idxs) 50 | batch_idxs = [] 51 | 52 | random.shuffle(list_container) 53 | 54 | ret = [] 55 | for batch_idxs in list_container: 56 | ret.extend(batch_idxs) 57 | 58 | return iter(ret) 59 | 60 | def __len__(self): 61 | return self.length 62 | 63 | 64 | class DistributedRandomIdentitySampler(Sampler): 65 | """ 66 | Randomly sample N identities, then for each identity, 67 | randomly sample K instances, therefore batch size is N*K. 68 | 69 | Args: 70 | - data_source (Dataset): dataset to sample from. 71 | - num_instances (int): number of instances per identity. 72 | - num_replicas (int, optional): Number of processes participating in 73 | distributed training. By default, :attr:`world_size` is retrieved from the 74 | current distributed group. 75 | - rank (int, optional): Rank of the current process within :attr:`num_replicas`. 76 | By default, :attr:`rank` is retrieved from the current distributed group. 77 | - seed (int, optional): random seed used to shuffle the sampler. 78 | This number should be identical across all 79 | processes in the distributed group. Default: ``0``. 80 | """ 81 | def __init__(self, data_source, num_instances=4, 82 | num_replicas=None, rank=None, seed=0): 83 | if num_replicas is None: 84 | if not dist.is_available(): 85 | raise RuntimeError("Requires distributed package to be available") 86 | num_replicas = dist.get_world_size() 87 | if rank is None: 88 | if not dist.is_available(): 89 | raise RuntimeError("Requires distributed package to be available") 90 | rank = dist.get_rank() 91 | if rank >= num_replicas or rank < 0: 92 | raise ValueError( 93 | "Invalid rank {}, rank should be in the interval" 94 | " [0, {}]".format(rank, num_replicas - 1)) 95 | self.num_replicas = num_replicas 96 | self.rank = rank 97 | self.seed = seed 98 | self.epoch = 0 99 | 100 | self.data_source = data_source 101 | self.num_instances = num_instances 102 | self.index_dic = defaultdict(list) 103 | for index, (_, pid, _, _) in enumerate(data_source): 104 | self.index_dic[pid].append(index) 105 | self.pids = list(self.index_dic.keys()) 106 | self.num_identities = len(self.pids) 107 | 108 | # compute number of examples in an epoch 109 | self.length = 0 110 | for pid in self.pids: 111 | idxs = self.index_dic[pid] 112 | num = len(idxs) 113 | if num < self.num_instances: 114 | num = self.num_instances 115 | self.length += num - num % self.num_instances 116 | assert self.length % self.num_instances == 0 117 | 118 | if self.length // self.num_instances % self.num_replicas != 0: 119 | self.num_samples = math.ceil((self.length // self.num_instances - self.num_replicas) / self.num_replicas) * self.num_instances 120 | else: 121 | self.num_samples = math.ceil(self.length / self.num_replicas) 122 | self.total_size = self.num_samples * self.num_replicas 123 | 124 | def __iter__(self): 125 | # deterministically shuffle based on epoch and seed 126 | random.seed(self.seed + self.epoch) 127 | np.random.seed(self.seed + self.epoch) 128 | 129 | list_container = [] 130 | for pid in self.pids: 131 | idxs = copy.deepcopy(self.index_dic[pid]) 132 | if len(idxs) < self.num_instances: 133 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 134 | random.shuffle(idxs) 135 | batch_idxs = [] 136 | for idx in idxs: 137 | batch_idxs.append(idx) 138 | if len(batch_idxs) == self.num_instances: 139 | list_container.append(batch_idxs) 140 | batch_idxs = [] 141 | random.shuffle(list_container) 142 | 143 | # remove tail of data to make it evenly divisible. 144 | list_container = list_container[:self.total_size//self.num_instances] 145 | assert len(list_container) == self.total_size//self.num_instances 146 | 147 | # subsample 148 | list_container = list_container[self.rank:self.total_size//self.num_instances:self.num_replicas] 149 | assert len(list_container) == self.num_samples//self.num_instances 150 | 151 | ret = [] 152 | for batch_idxs in list_container: 153 | ret.extend(batch_idxs) 154 | 155 | return iter(ret) 156 | 157 | def __len__(self): 158 | return self.num_samples 159 | 160 | def set_epoch(self, epoch): 161 | """ 162 | Sets the epoch for this sampler. This ensures all replicas 163 | use a different random ordering for each epoch. Otherwise, the next iteration of this 164 | sampler will yield the same ordering. 165 | 166 | Args: 167 | epoch (int): Epoch number. 168 | """ 169 | self.epoch = epoch 170 | 171 | 172 | class DistributedInferenceSampler(Sampler): 173 | """ 174 | refer to: https://github.com/huggingface/transformers/blob/447808c85f0e6d6b0aeeb07214942bf1e578f9d2/src/transformers/trainer_pt_utils.py 175 | 176 | Distributed Sampler that subsamples indicies sequentially, 177 | making it easier to collate all results at the end. 178 | Even though we only use this sampler for eval and predict (no training), 179 | which means that the model params won't have to be synced (i.e. will not hang 180 | for synchronization even if varied number of forward passes), we still add extra 181 | samples to the sampler to make it evenly divisible (like in `DistributedSampler`) 182 | to make it easy to `gather` or `reduce` resulting tensors at the end of the loop. 183 | """ 184 | def __init__(self, dataset, rank=None, num_replicas=None): 185 | if num_replicas is None: 186 | if not dist.is_available(): 187 | raise RuntimeError("Requires distributed package to be available") 188 | num_replicas = dist.get_world_size() 189 | if rank is None: 190 | if not dist.is_available(): 191 | raise RuntimeError("Requires distributed package to be available") 192 | rank = dist.get_rank() 193 | self.dataset = dataset 194 | self.num_replicas = num_replicas 195 | self.rank = rank 196 | 197 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 198 | self.total_size = self.num_samples * self.num_replicas 199 | 200 | def __iter__(self): 201 | indices = list(range(len(self.dataset))) 202 | # add extra samples to make it evenly divisible 203 | indices += [indices[-1]] * (self.total_size - len(indices)) 204 | # subsample 205 | indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples] 206 | return iter(indices) 207 | 208 | def __len__(self): 209 | return self.num_samples -------------------------------------------------------------------------------- /demo.sh: -------------------------------------------------------------------------------- 1 | # Example 2 | # python demo_single_image.py --cfg configs/res50_cels_cal.yaml --img_path /root/github/datasets/LTCC_ReID/query/001_1_c4_015861.png --weights /root/github/ReID_demo/CC-ReID/weights/ltcc.pth.tar --gpu 7 3 | 4 | 5 | # 6 | python demo_single_image.py --cfg configs/res50_cels_cal.yaml --img_path YOUR_IMAGE_PATH --weights YOUR_WEIGHT_PATH --gpu YOUR_GPU_ID 7 | 8 | -------------------------------------------------------------------------------- /demo_single_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import datetime 4 | import logging 5 | import argparse 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import distributed as dist 10 | import torchvision 11 | from torchvision import datasets, models, transforms 12 | from configs.default_img_single import get_img_config 13 | from models.img_resnet import ResNet50 14 | from PIL import Image 15 | 16 | def parse_option(): 17 | parser = argparse.ArgumentParser( 18 | description='Train clothes-changing re-id model with clothes-based adversarial loss') 19 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file') 20 | # Datasets 21 | parser.add_argument('--root', type=str, help="your root path to data directory") 22 | # Miscs 23 | parser.add_argument('--img_path', type=str, help='path to the image') 24 | parser.add_argument('--weights', type=str, help='path to the weights') 25 | parser.add_argument('--gpu', type=str, default='0', help='gpu id') 26 | 27 | args, unparsed = parser.parse_known_args() 28 | config = get_img_config(args) 29 | return config, args 30 | 31 | @torch.no_grad() 32 | def extract_img_feature(model, img): 33 | flip_img = torch.flip(img, [3]) 34 | img, flip_img = img.cuda(), flip_img.cuda() 35 | _, batch_features = model(img) 36 | _, batch_features_flip = model(flip_img) 37 | batch_features += batch_features_flip 38 | batch_features = F.normalize(batch_features, p=2, dim=1) 39 | features = batch_features.cpu() 40 | 41 | return features 42 | 43 | config, args = parse_option() 44 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 45 | 46 | dict = torch.load(args.weights, weights_only=True) 47 | model = ResNet50(config) 48 | model.load_state_dict(dict['model_state_dict']) 49 | model = model.cuda() 50 | model.eval() 51 | 52 | # IMAGENET_MEAN = [0.485, 0.456, 0.406] 53 | # IMAGENET_STD = [0.229, 0.224, 0.225] 54 | # GRID_SPACING = 10 55 | 56 | data_transforms = transforms.Compose([ 57 | transforms.Resize((config.DATA.HEIGHT, config.DATA.WIDTH)), 58 | transforms.ToTensor(), 59 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 60 | ]) 61 | 62 | image = Image.open(args.img_path) 63 | image_tensor = data_transforms(image) 64 | input_batch = image_tensor.unsqueeze(0) # Add a batch dimension 65 | 66 | feature = extract_img_feature(model, input_batch) 67 | 68 | print("Input Image:", args.img_path, " Output Feautre:", feature) -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from losses.cross_entropy_loss_with_label_smooth import CrossEntropyWithLabelSmooth 3 | from losses.triplet_loss import TripletLoss 4 | from losses.contrastive_loss import ContrastiveLoss 5 | from losses.arcface_loss import ArcFaceLoss 6 | from losses.cosface_loss import CosFaceLoss, PairwiseCosFaceLoss 7 | from losses.circle_loss import CircleLoss, PairwiseCircleLoss 8 | from losses.clothes_based_adversarial_loss import ClothesBasedAdversarialLoss 9 | 10 | def build_losses(config, num_train_clothes): 11 | # Build identity classification loss 12 | if config.LOSS.CLA_LOSS == 'crossentropy': 13 | criterion_cla = nn.CrossEntropyLoss() 14 | elif config.LOSS.CLA_LOSS == 'crossentropylabelsmooth': 15 | criterion_cla = CrossEntropyWithLabelSmooth() 16 | elif config.LOSS.CLA_LOSS == 'arcface': 17 | criterion_cla = ArcFaceLoss(scale=config.LOSS.CLA_S, margin=config.LOSS.CLA_M) 18 | elif config.LOSS.CLA_LOSS == 'cosface': 19 | criterion_cla = CosFaceLoss(scale=config.LOSS.CLA_S, margin=config.LOSS.CLA_M) 20 | elif config.LOSS.CLA_LOSS == 'circle': 21 | criterion_cla = CircleLoss(scale=config.LOSS.CLA_S, margin=config.LOSS.CLA_M) 22 | else: 23 | raise KeyError("Invalid classification loss: '{}'".format(config.LOSS.CLA_LOSS)) 24 | 25 | # Build pairwise loss 26 | if config.LOSS.PAIR_LOSS == 'triplet': 27 | criterion_pair = TripletLoss(margin=config.LOSS.PAIR_M) 28 | elif config.LOSS.PAIR_LOSS == 'contrastive': 29 | criterion_pair = ContrastiveLoss(scale=config.LOSS.PAIR_S) 30 | elif config.LOSS.PAIR_LOSS == 'cosface': 31 | criterion_pair = PairwiseCosFaceLoss(scale=config.LOSS.PAIR_S, margin=config.LOSS.PAIR_M) 32 | elif config.LOSS.PAIR_LOSS == 'circle': 33 | criterion_pair = PairwiseCircleLoss(scale=config.LOSS.PAIR_S, margin=config.LOSS.PAIR_M) 34 | else: 35 | raise KeyError("Invalid pairwise loss: '{}'".format(config.LOSS.PAIR_LOSS)) 36 | 37 | 38 | # Build clothes classification loss 39 | if config.LOSS.CLOTHES_CLA_LOSS == 'crossentropy': 40 | criterion_clothes = nn.CrossEntropyLoss() 41 | elif config.LOSS.CLOTHES_CLA_LOSS == 'cosface': 42 | criterion_clothes = CosFaceLoss(scale=config.LOSS.CLA_S, margin=0) 43 | else: 44 | raise KeyError("Invalid clothes classification loss: '{}'".format(config.LOSS.CLOTHES_CLA_LOSS)) 45 | 46 | # Build clothes-based adversarial loss 47 | if config.LOSS.CAL == 'cal': 48 | criterion_cal = ClothesBasedAdversarialLoss(scale=config.LOSS.CLA_S, epsilon=config.LOSS.EPSILON) 49 | else: 50 | raise KeyError("Invalid clothing classification loss: '{}'".format(config.LOSS.CAL)) 51 | 52 | kl = nn.functional.kl_div 53 | 54 | return criterion_cla, criterion_pair, criterion_clothes, criterion_cal, kl 55 | -------------------------------------------------------------------------------- /losses/arcface_loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | 7 | class ArcFaceLoss(nn.Module): 8 | """ ArcFace loss. 9 | 10 | Reference: 11 | Deng et al. ArcFace: Additive Angular Margin Loss for Deep Face Recognition. In CVPR, 2019. 12 | 13 | Args: 14 | scale (float): scaling factor. 15 | margin (float): pre-defined margin. 16 | """ 17 | def __init__(self, scale=16, margin=0.1): 18 | super().__init__() 19 | self.s = scale 20 | self.m = margin 21 | 22 | def forward(self, inputs, targets): 23 | """ 24 | Args: 25 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 26 | targets: ground truth labels with shape (batch_size) 27 | """ 28 | # get a one-hot index 29 | index = inputs.data * 0.0 30 | index.scatter_(1, targets.data.view(-1, 1), 1) 31 | index = index.bool() 32 | 33 | cos_m = math.cos(self.m) 34 | sin_m = math.sin(self.m) 35 | cos_t = inputs[index] 36 | sin_t = torch.sqrt(1.0 - cos_t * cos_t) 37 | cos_t_add_m = cos_t * cos_m - sin_t * sin_m 38 | 39 | cond_v = cos_t - math.cos(math.pi - self.m) 40 | cond = F.relu(cond_v) 41 | keep = cos_t - math.sin(math.pi - self.m) * self.m 42 | 43 | cos_t_add_m = torch.where(cond.bool(), cos_t_add_m, keep) 44 | 45 | output = inputs * 1.0 46 | output[index] = cos_t_add_m 47 | output = self.s * output 48 | 49 | return F.cross_entropy(output, targets) 50 | -------------------------------------------------------------------------------- /losses/circle_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch import distributed as dist 5 | from losses.gather import GatherLayer 6 | 7 | 8 | class CircleLoss(nn.Module): 9 | """ Circle Loss based on the predictions of classifier. 10 | 11 | Reference: 12 | Sun et al. Circle Loss: A Unified Perspective of Pair Similarity Optimization. In CVPR, 2020. 13 | 14 | Args: 15 | scale (float): scaling factor. 16 | margin (float): pre-defined margin. 17 | """ 18 | def __init__(self, scale=96, margin=0.3, **kwargs): 19 | super().__init__() 20 | self.s = scale 21 | self.m = margin 22 | 23 | def forward(self, inputs, targets): 24 | """ 25 | Args: 26 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 27 | targets: ground truth labels with shape (batch_size) 28 | """ 29 | mask = torch.zeros_like(inputs).cuda() 30 | mask.scatter_(1, targets.view(-1, 1), 1.0) 31 | 32 | pos_scale = self.s * F.relu(1 + self.m - inputs.detach()) 33 | neg_scale = self.s * F.relu(inputs.detach() + self.m) 34 | scale_matrix = pos_scale * mask + neg_scale * (1 - mask) 35 | 36 | scores = (inputs - (1 - self.m) * mask - self.m * (1 - mask)) * scale_matrix 37 | 38 | loss = F.cross_entropy(scores, targets) 39 | 40 | return loss 41 | 42 | 43 | class PairwiseCircleLoss(nn.Module): 44 | """ Circle Loss among sample pairs. 45 | 46 | Reference: 47 | Sun et al. Circle Loss: A Unified Perspective of Pair Similarity Optimization. In CVPR, 2020. 48 | 49 | Args: 50 | scale (float): scaling factor. 51 | margin (float): pre-defined margin. 52 | """ 53 | def __init__(self, scale=48, margin=0.35, **kwargs): 54 | super().__init__() 55 | self.s = scale 56 | self.m = margin 57 | 58 | def forward(self, inputs, targets): 59 | """ 60 | Args: 61 | inputs: sample features (before classifier) with shape (batch_size, feat_dim) 62 | targets: ground truth labels with shape (batch_size) 63 | """ 64 | # l2-normalize 65 | inputs = F.normalize(inputs, p=2, dim=1) 66 | 67 | # gather all samples from different GPUs as gallery to compute pairwise loss. 68 | gallery_inputs = torch.cat(GatherLayer.apply(inputs), dim=0) 69 | gallery_targets = torch.cat(GatherLayer.apply(targets), dim=0) 70 | m, n = targets.size(0), gallery_targets.size(0) 71 | 72 | # compute cosine similarity 73 | similarities = torch.matmul(inputs, gallery_inputs.t()) 74 | 75 | # get mask for pos/neg pairs 76 | targets, gallery_targets = targets.view(-1, 1), gallery_targets.view(-1, 1) 77 | mask = torch.eq(targets, gallery_targets.T).float().cuda() 78 | mask_self = torch.zeros_like(mask) 79 | rank = dist.get_rank() 80 | mask_self[:, rank * m:(rank + 1) * m] += torch.eye(m).float().cuda() 81 | mask_pos = mask - mask_self 82 | mask_neg = 1 - mask 83 | 84 | pos_scale = self.s * F.relu(1 + self.m - similarities.detach()) 85 | neg_scale = self.s * F.relu(similarities.detach() + self.m) 86 | scale_matrix = pos_scale * mask_pos + neg_scale * mask_neg 87 | 88 | scores = (similarities - self.m) * mask_neg + (1 - self.m - similarities) * mask_pos 89 | scores = scores * scale_matrix 90 | 91 | neg_scores_LSE = torch.logsumexp(scores * mask_neg - 99999999 * (1 - mask_neg), dim=1) 92 | pos_scores_LSE = torch.logsumexp(scores * mask_pos - 99999999 * (1 - mask_pos), dim=1) 93 | 94 | loss = F.softplus(neg_scores_LSE + pos_scores_LSE).mean() 95 | 96 | return loss 97 | -------------------------------------------------------------------------------- /losses/clothes_based_adversarial_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from losses.gather import GatherLayer 5 | 6 | 7 | class ClothesBasedAdversarialLoss(nn.Module): 8 | """ Clothes-based Adversarial Loss. 9 | 10 | Reference: 11 | Gu et al. Clothes-Changing Person Re-identification with RGB Modality Only. In CVPR, 2022. 12 | 13 | Args: 14 | scale (float): scaling factor. 15 | epsilon (float): a trade-off hyper-parameter. 16 | """ 17 | def __init__(self, scale=16, epsilon=0.1): 18 | super().__init__() 19 | self.scale = scale 20 | self.epsilon = epsilon 21 | 22 | def forward(self, inputs, targets, positive_mask): 23 | """ 24 | Args: 25 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 26 | targets: ground truth labels with shape (batch_size) 27 | positive_mask: positive mask matrix with shape (batch_size, num_classes). The clothes classes with 28 | the same identity as the anchor sample are defined as positive clothes classes and their mask 29 | values are 1. The clothes classes with different identities from the anchor sample are defined 30 | as negative clothes classes and their mask values in positive_mask are 0. 31 | """ 32 | inputs = self.scale * inputs 33 | negtive_mask = 1 - positive_mask 34 | identity_mask = torch.zeros(inputs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1).cuda() 35 | 36 | exp_logits = torch.exp(inputs) 37 | log_sum_exp_pos_and_all_neg = torch.log((exp_logits * negtive_mask).sum(1, keepdim=True) + exp_logits) 38 | log_prob = inputs - log_sum_exp_pos_and_all_neg 39 | 40 | mask = (1 - self.epsilon) * identity_mask + self.epsilon / positive_mask.sum(1, keepdim=True) * positive_mask 41 | loss = (- mask * log_prob).sum(1).mean() 42 | 43 | return loss 44 | -------------------------------------------------------------------------------- /losses/contrastive_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch import distributed as dist 5 | from losses.gather import GatherLayer 6 | 7 | 8 | class ContrastiveLoss(nn.Module): 9 | """ Supervised Contrastive Learning Loss among sample pairs. 10 | 11 | Args: 12 | scale (float): scaling factor. 13 | """ 14 | def __init__(self, scale=16, **kwargs): 15 | super().__init__() 16 | self.s = scale 17 | 18 | def forward(self, inputs, targets): 19 | """ 20 | Args: 21 | inputs: sample features (before classifier) with shape (batch_size, feat_dim) 22 | targets: ground truth labels with shape (batch_size) 23 | """ 24 | # l2-normalize 25 | inputs = F.normalize(inputs, p=2, dim=1) 26 | 27 | # gather all samples from different GPUs as gallery to compute pairwise loss. 28 | gallery_inputs = torch.cat(GatherLayer.apply(inputs), dim=0) 29 | gallery_targets = torch.cat(GatherLayer.apply(targets), dim=0) 30 | m, n = targets.size(0), gallery_targets.size(0) 31 | 32 | # compute cosine similarity 33 | similarities = torch.matmul(inputs, gallery_inputs.t()) * self.s 34 | 35 | # get mask for pos/neg pairs 36 | targets, gallery_targets = targets.view(-1, 1), gallery_targets.view(-1, 1) 37 | mask = torch.eq(targets, gallery_targets.T).float().cuda() 38 | mask_self = torch.zeros_like(mask) 39 | rank = dist.get_rank() 40 | mask_self[:, rank * m:(rank + 1) * m] += torch.eye(m).float().cuda() 41 | mask_pos = mask - mask_self 42 | mask_neg = 1 - mask 43 | 44 | # compute log_prob 45 | exp_logits = torch.exp(similarities) * (1 - mask_self) 46 | # log_prob = similarities - torch.log(exp_logits.sum(1, keepdim=True)) 47 | log_sum_exp_pos_and_all_neg = torch.log((exp_logits * mask_neg).sum(1, keepdim=True) + exp_logits) 48 | log_prob = similarities - log_sum_exp_pos_and_all_neg 49 | 50 | # compute mean of log-likelihood over positive 51 | loss = (mask_pos * log_prob).sum(1) / mask_pos.sum(1) 52 | 53 | loss = - loss.mean() 54 | 55 | return loss -------------------------------------------------------------------------------- /losses/cosface_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch import distributed as dist 5 | from losses.gather import GatherLayer 6 | 7 | 8 | class CosFaceLoss(nn.Module): 9 | """ CosFace Loss based on the predictions of classifier. 10 | 11 | Reference: 12 | Wang et al. CosFace: Large Margin Cosine Loss for Deep Face Recognition. In CVPR, 2018. 13 | 14 | Args: 15 | scale (float): scaling factor. 16 | margin (float): pre-defined margin. 17 | """ 18 | def __init__(self, scale=16, margin=0.1, **kwargs): 19 | super().__init__() 20 | self.s = scale 21 | self.m = margin 22 | 23 | def forward(self, inputs, targets): 24 | """ 25 | Args: 26 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 27 | targets: ground truth labels with shape (batch_size) 28 | """ 29 | one_hot = torch.zeros_like(inputs) 30 | one_hot.scatter_(1, targets.view(-1, 1), 1.0) 31 | 32 | output = self.s * (inputs - one_hot * self.m) 33 | 34 | return F.cross_entropy(output, targets) 35 | 36 | 37 | class PairwiseCosFaceLoss(nn.Module): 38 | """ CosFace Loss among sample pairs. 39 | 40 | Reference: 41 | Sun et al. Circle Loss: A Unified Perspective of Pair Similarity Optimization. In CVPR, 2020. 42 | 43 | Args: 44 | scale (float): scaling factor. 45 | margin (float): pre-defined margin. 46 | """ 47 | def __init__(self, scale=16, margin=0): 48 | super().__init__() 49 | self.s = scale 50 | self.m = margin 51 | 52 | def forward(self, inputs, targets): 53 | """ 54 | Args: 55 | inputs: sample features (before classifier) with shape (batch_size, feat_dim) 56 | targets: ground truth labels with shape (batch_size) 57 | """ 58 | # l2-normalize 59 | inputs = F.normalize(inputs, p=2, dim=1) 60 | 61 | # gather all samples from different GPUs as gallery to compute pairwise loss. 62 | gallery_inputs = torch.cat(GatherLayer.apply(inputs), dim=0) 63 | gallery_targets = torch.cat(GatherLayer.apply(targets), dim=0) 64 | m, n = targets.size(0), gallery_targets.size(0) 65 | 66 | # compute cosine similarity 67 | similarities = torch.matmul(inputs, gallery_inputs.t()) 68 | 69 | # get mask for pos/neg pairs 70 | targets, gallery_targets = targets.view(-1, 1), gallery_targets.view(-1, 1) 71 | mask = torch.eq(targets, gallery_targets.T).float().cuda() 72 | mask_self = torch.zeros_like(mask) 73 | rank = dist.get_rank() 74 | mask_self[:, rank * m:(rank + 1) * m] += torch.eye(m).float().cuda() 75 | mask_pos = mask - mask_self 76 | mask_neg = 1 - mask 77 | 78 | scores = (similarities + self.m) * mask_neg - similarities * mask_pos 79 | scores = scores * self.s 80 | 81 | neg_scores_LSE = torch.logsumexp(scores * mask_neg - 99999999 * (1 - mask_neg), dim=1) 82 | pos_scores_LSE = torch.logsumexp(scores * mask_pos - 99999999 * (1 - mask_pos), dim=1) 83 | 84 | loss = F.softplus(neg_scores_LSE + pos_scores_LSE).mean() 85 | 86 | return loss -------------------------------------------------------------------------------- /losses/cross_entropy_loss_with_label_smooth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class CrossEntropyWithLabelSmooth(nn.Module): 6 | """ Cross entropy loss with label smoothing regularization. 7 | 8 | Reference: 9 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. In CVPR, 2016. 10 | Equation: 11 | y = (1 - epsilon) * y + epsilon / K. 12 | 13 | Args: 14 | epsilon (float): a hyper-parameter in the above equation. 15 | """ 16 | def __init__(self, epsilon=0.1): 17 | super().__init__() 18 | self.epsilon = epsilon 19 | self.logsoftmax = nn.LogSoftmax(dim=1) 20 | 21 | def forward(self, inputs, targets): 22 | """ 23 | Args: 24 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 25 | targets: ground truth labels with shape (batch_size) 26 | """ 27 | _, num_classes = inputs.size() 28 | log_probs = self.logsoftmax(inputs) 29 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1).cuda() 30 | targets = (1 - self.epsilon) * targets + self.epsilon / num_classes 31 | loss = (- targets * log_probs).mean(0).sum() 32 | 33 | return loss 34 | -------------------------------------------------------------------------------- /losses/gather.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | 5 | class GatherLayer(torch.autograd.Function): 6 | """Gather tensors from all process, supporting backward propagation.""" 7 | 8 | @staticmethod 9 | def forward(ctx, input): 10 | ctx.save_for_backward(input) 11 | output = [torch.zeros_like(input) for _ in range(dist.get_world_size())] 12 | dist.all_gather(output, input) 13 | 14 | return tuple(output) 15 | 16 | @staticmethod 17 | def backward(ctx, *grads): 18 | (input,) = ctx.saved_tensors 19 | grad_out = torch.zeros_like(input) 20 | 21 | # dist.reduce_scatter(grad_out, list(grads)) 22 | # grad_out.div_(dist.get_world_size()) 23 | 24 | grad_out[:] = grads[dist.get_rank()] 25 | 26 | return grad_out -------------------------------------------------------------------------------- /losses/triplet_loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from losses.gather import GatherLayer 6 | 7 | 8 | class TripletLoss(nn.Module): 9 | """ Triplet loss with hard example mining. 10 | 11 | Reference: 12 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 13 | 14 | Args: 15 | margin (float): pre-defined margin. 16 | 17 | Note that we use cosine similarity, rather than Euclidean distance in the original paper. 18 | """ 19 | def __init__(self, margin=0.3): 20 | super().__init__() 21 | self.m = margin 22 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 23 | 24 | def forward(self, inputs, targets): 25 | """ 26 | Args: 27 | inputs: sample features (before classifier) with shape (batch_size, feat_dim) 28 | targets: ground truth labels with shape (batch_size) 29 | """ 30 | # l2-normlize 31 | inputs = F.normalize(inputs, p=2, dim=1) 32 | 33 | # gather all samples from different GPUs as gallery to compute pairwise loss. 34 | gallery_inputs = torch.cat(GatherLayer.apply(inputs), dim=0) 35 | gallery_targets = torch.cat(GatherLayer.apply(targets), dim=0) 36 | 37 | # compute distance 38 | dist = 1 - torch.matmul(inputs, gallery_inputs.t()) # values in [0, 2] 39 | 40 | # get positive and negative masks 41 | targets, gallery_targets = targets.view(-1,1), gallery_targets.view(-1,1) 42 | mask_pos = torch.eq(targets, gallery_targets.T).float().cuda() 43 | mask_neg = 1 - mask_pos 44 | 45 | # For each anchor, find the hardest positive and negative pairs 46 | dist_ap, _ = torch.max((dist - mask_neg * 99999999.), dim=1) 47 | dist_an, _ = torch.min((dist + mask_pos * 99999999.), dim=1) 48 | 49 | # Compute ranking hinge loss 50 | y = torch.ones_like(dist_an) 51 | loss = self.ranking_loss(dist_an, dist_ap, y) 52 | 53 | return loss -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from threading import local 4 | import time 5 | import datetime 6 | import argparse 7 | import logging 8 | import os.path as osp 9 | import numpy as np 10 | import gc 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | from torch.optim import lr_scheduler 16 | from torch import distributed as dist 17 | from apex import amp 18 | # from models.lr_scheduler import WarmupMultiStepLR 19 | 20 | from configs.default_img import get_img_config 21 | from data import build_dataloader 22 | from models import build_model 23 | from losses import build_losses 24 | from tools.utils import save_checkpoint, set_seed, get_logger 25 | from train import train_aim 26 | from test import test, test_prcc 27 | 28 | 29 | def parse_option(): 30 | parser = argparse.ArgumentParser(description='Train clothes-changing re-id model with clothes-based adversarial loss') 31 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file') 32 | # Datasets 33 | parser.add_argument('--root', type=str, help="your root path to data directory") 34 | parser.add_argument('--dataset', type=str, default='ltcc', help="ltcc, prcc, vcclothes, last, deepchange") 35 | # Miscs 36 | parser.add_argument('--output', type=str, help="your output path to save model and logs") 37 | parser.add_argument('--resume', type=str, metavar='PATH') 38 | parser.add_argument('--amp', action='store_true', help="automatic mixed precision") 39 | parser.add_argument('--eval', action='store_true', help="evaluation only") 40 | parser.add_argument('--tag', type=str, help='tag for log file') 41 | parser.add_argument('--name', type=str, help='your model name for record') 42 | parser.add_argument('--gpu', default='0', type=str, help='gpu device ids for CUDA_VISIBLE_DEVICES') 43 | # Options and Hyper-parameters 44 | parser.add_argument('--seed', type=str, help='seed for single-shot') 45 | parser.add_argument('--single_shot', action='store_true', help='single-shot option') 46 | parser.add_argument('--k_cal', type=str) 47 | parser.add_argument('--k_kl', type=str) 48 | 49 | args, unparsed = parser.parse_known_args() 50 | config = get_img_config(args) 51 | 52 | return config 53 | 54 | def main(config): 55 | # Build dataloader 56 | if config.DATA.DATASET == 'prcc': 57 | trainloader, queryloader_same, queryloader_diff, galleryloader, dataset, train_sampler = build_dataloader(config) 58 | else: 59 | trainloader, queryloader, galleryloader, dataset, train_sampler = build_dataloader(config) 60 | 61 | # Define a matrix pid2clothes with shape (num_pids, num_clothes). 62 | # pid2clothes[i, j] = 1 when j-th clothes belongs to i-th identity. Otherwise, pid2clothes[i, j] = 0. 63 | pid2clothes = torch.from_numpy(dataset.pid2clothes) 64 | 65 | # Build model 66 | model, model2, fuse, classifier, clothes_classifier, clothes_classifier2 = build_model(config, dataset.num_train_pids, dataset.num_train_clothes) 67 | print("model loaded") 68 | # Build identity classification loss, pairwise loss, clothes classificaiton loss, and adversarial loss. 69 | criterion_cla, criterion_pair, criterion_clothes, criterion_adv, kl = build_losses(config, dataset.num_train_clothes) 70 | print("loss built") 71 | # Build optimizer 72 | parameters = list(model.parameters()) + list(fuse.parameters()) + list(classifier.parameters()) 73 | parameters2 = list(model2.parameters()) + list(clothes_classifier2.parameters()) 74 | 75 | 76 | if config.TRAIN.OPTIMIZER.NAME == 'adam': 77 | optimizer = optim.Adam(parameters, lr=config.TRAIN.OPTIMIZER.LR, 78 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY) 79 | optimizer2 = optim.Adam(parameters2, lr=config.TRAIN.OPTIMIZER.LR, 80 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY) 81 | optimizer_cc = optim.Adam(clothes_classifier.parameters(), lr=config.TRAIN.OPTIMIZER.LR, 82 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY) 83 | elif config.TRAIN.OPTIMIZER.NAME == 'adamw': 84 | optimizer = optim.AdamW(parameters, lr=config.TRAIN.OPTIMIZER.LR, 85 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY) 86 | optimizer2 = optim.AdamW(parameters2, lr=config.TRAIN.OPTIMIZER.LR, 87 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY) 88 | optimizer_cc = optim.AdamW(clothes_classifier.parameters(), lr=config.TRAIN.OPTIMIZER.LR, 89 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY) 90 | elif config.TRAIN.OPTIMIZER.NAME == 'sgd': 91 | optimizer = optim.SGD(parameters, lr=config.TRAIN.OPTIMIZER.LR, momentum=0.9, 92 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY, nesterov=True) 93 | optimizer2 = optim.SGD(parameters2, lr=config.TRAIN.OPTIMIZER.LR, momentum=0.9, 94 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY, nesterov=True) 95 | optimizer_cc = optim.SGD(clothes_classifier.parameters(), lr=config.TRAIN.OPTIMIZER.LR, momentum=0.9, 96 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY, nesterov=True) 97 | else: 98 | raise KeyError("Unknown optimizer: {}".format(config.TRAIN.OPTIMIZER.NAME)) 99 | 100 | # Build lr_scheduler 101 | scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=config.TRAIN.LR_SCHEDULER.STEPSIZE, 102 | gamma=config.TRAIN.LR_SCHEDULER.DECAY_RATE) 103 | scheduler2 = lr_scheduler.MultiStepLR(optimizer2, milestones=config.TRAIN.LR_SCHEDULER.STEPSIZE, 104 | gamma=config.TRAIN.LR_SCHEDULER.DECAY_RATE) 105 | 106 | start_epoch = config.TRAIN.START_EPOCH 107 | 108 | if config.MODEL.RESUME: 109 | logger.info("Loading checkpoint from '{}'".format(config.MODEL.RESUME)) 110 | checkpoint = torch.load(config.MODEL.RESUME) 111 | model.load_state_dict(checkpoint['model_state_dict']) 112 | classifier.load_state_dict(checkpoint['classifier_state_dict']) 113 | fuse.load_state_dict(checkpoint['fuse_state_dict']) 114 | model2.load_state_dict(checkpoint['model2_state_dict']) 115 | clothes_classifier.load_state_dict(checkpoint['clothes_classifier_state_dict']) 116 | clothes_classifier2.load_state_dict(checkpoint['clothes_classifier2_state_dict']) 117 | start_epoch = checkpoint['epoch'] 118 | 119 | local_rank = dist.get_rank() 120 | model = model.cuda(local_rank) 121 | model2 = model2.cuda(local_rank) 122 | classifier = classifier.cuda(local_rank) 123 | clothes_classifier2 = clothes_classifier2.cuda(local_rank) 124 | fuse = fuse.cuda(local_rank) 125 | clothes_classifier = clothes_classifier.cuda(local_rank) 126 | torch.cuda.set_device(local_rank) 127 | 128 | if config.TRAIN.AMP: 129 | [model, fuse, classifier], optimizer = amp.initialize([model, fuse, classifier], optimizer, opt_level="O1") 130 | [model2, clothes_classifier2], optimizer2 = amp.initialize([model2, clothes_classifier2], optimizer2, opt_level="O1") 131 | 132 | model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) 133 | fuse = nn.parallel.DistributedDataParallel(fuse, device_ids=[local_rank], output_device=local_rank) 134 | model2 = nn.parallel.DistributedDataParallel(model2, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) 135 | classifier = nn.parallel.DistributedDataParallel(classifier, device_ids=[local_rank], output_device=local_rank) 136 | clothes_classifier2 = nn.parallel.DistributedDataParallel(clothes_classifier2, device_ids=[local_rank], output_device=local_rank) 137 | 138 | if config.EVAL_MODE: 139 | logger.info("Evaluate only") 140 | with torch.no_grad(): 141 | if config.DATA.DATASET == 'prcc': 142 | test_prcc(model, queryloader_same, queryloader_diff, galleryloader, dataset) 143 | else: 144 | test(config, model, queryloader, galleryloader, dataset) 145 | return 146 | 147 | start_time = time.time() 148 | train_time = 0 149 | best_rank1 = -np.inf 150 | best_epoch = 0 151 | logger.info("==> Start training") 152 | for epoch in range(start_epoch, config.TRAIN.MAX_EPOCH): 153 | train_sampler.set_epoch(epoch) 154 | start_train_time = time.time() 155 | 156 | train_aim(config, epoch, model, model2, classifier, clothes_classifier, clothes_classifier2, fuse, criterion_cla, criterion_pair, 157 | criterion_clothes, criterion_adv, optimizer, optimizer2, optimizer_cc, trainloader, pid2clothes, kl) 158 | 159 | train_time += round(time.time() - start_train_time) 160 | 161 | 162 | if (epoch+1) > config.TEST.START_EVAL and config.TEST.EVAL_STEP > 0 and \ 163 | (epoch+1) % config.TEST.EVAL_STEP == 0 or (epoch+1) == config.TRAIN.MAX_EPOCH: 164 | logger.info("==> Test") 165 | torch.cuda.empty_cache() 166 | if config.DATA.DATASET == 'prcc': 167 | rank1 = test_prcc(model, queryloader_same, queryloader_diff, galleryloader, dataset) 168 | else: 169 | rank1 = test(config, model, queryloader, galleryloader, dataset) 170 | torch.cuda.empty_cache() 171 | is_best = rank1 > best_rank1 172 | if is_best: 173 | best_rank1 = rank1 174 | best_epoch = epoch + 1 175 | 176 | model_state_dict = model.module.state_dict() 177 | model2_state_dict = model2.module.state_dict() 178 | fuse_state_dict = fuse.module.state_dict() 179 | classifier_state_dict = classifier.module.state_dict() 180 | clothes_classifier_state_dict = clothes_classifier.module.state_dict() 181 | clothes_classifier2_state_dict = clothes_classifier2.module.state_dict() 182 | 183 | if local_rank == 0: 184 | save_checkpoint({ 185 | 'model_state_dict': model_state_dict, 186 | 'model2_state_dict': model2_state_dict, 187 | 'fuse_state_dict': fuse_state_dict, 188 | 'classifier_state_dict': classifier_state_dict, 189 | 'clothes_classifier_state_dict': clothes_classifier_state_dict, 190 | 'clothes_classifier2_state_dict': clothes_classifier2_state_dict, 191 | 'rank1': rank1, 192 | 'epoch': epoch, 193 | }, is_best, osp.join(config.OUTPUT, 'checkpoint_ep' + str(epoch+1) + '.pth.tar')) 194 | scheduler.step() 195 | scheduler2.step() 196 | 197 | logger.info("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(best_rank1, best_epoch)) 198 | 199 | elapsed = round(time.time() - start_time) 200 | elapsed = str(datetime.timedelta(seconds=elapsed)) 201 | train_time = str(datetime.timedelta(seconds=train_time)) 202 | logger.info("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time)) 203 | 204 | 205 | if __name__ == '__main__': 206 | gc.collect() 207 | torch.cuda.empty_cache() 208 | 209 | config = parse_option() 210 | # Set GPU 211 | os.environ['CUDA_VISIBLE_DEVICES'] = config.GPU 212 | # Init dist 213 | dist.init_process_group(backend="nccl", init_method='env://') 214 | local_rank = dist.get_rank() 215 | # Set random seed 216 | set_seed(config.SEED + local_rank) 217 | # get logger 218 | if not config.EVAL_MODE: 219 | output_file = osp.join(config.OUTPUT, 'log_train_.log') 220 | else: 221 | output_file = osp.join(config.OUTPUT, 'log_test.log') 222 | logger = get_logger(output_file, local_rank, 'reid') 223 | logger.info("Config:\n-----------------------------------------") 224 | logger.info(config) 225 | logger.info("-----------------------------------------") 226 | 227 | main(config) 228 | -------------------------------------------------------------------------------- /models/Fuse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | 6 | EPSILON = 1e-12 7 | 8 | class BasicConv2d(nn.Module): 9 | 10 | def __init__(self, in_channels, out_channels, **kwargs): 11 | super(BasicConv2d, self).__init__() 12 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 13 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 14 | 15 | def forward(self, x): 16 | x = self.conv(x) 17 | x = self.bn(x) 18 | return F.relu(x, inplace=True) 19 | 20 | class Fuse(nn.Module): 21 | def __init__(self, feature_dim): 22 | super(Fuse, self).__init__() 23 | 24 | self.trans = nn.Linear(8*2048, feature_dim, bias=False) 25 | self.bn = nn.BatchNorm1d(feature_dim) 26 | self.pool = nn.AdaptiveMaxPool2d(1) 27 | self.M = 8 28 | self.attentions = BasicConv2d(2048, self.M, kernel_size=1) 29 | self.trans.weight.data.normal_(0, 0.001) 30 | init.normal_(self.bn.weight.data, 1.0, 0.02) 31 | init.constant_(self.bn.bias.data, 0.0) 32 | 33 | def forward(self, feat, counter_feat_in): 34 | 35 | counter_feat = self.attentions(counter_feat_in) 36 | 37 | B, C, H, W = feat.size() 38 | _, M, AH, AW = counter_feat.size() 39 | 40 | x = (torch.einsum('imjk,injk->imn', (counter_feat, feat)) / float(H * W)).view(B, -1) 41 | x = torch.sign(x) * torch.sqrt(torch.abs(x) + EPSILON) 42 | x = F.normalize(x, dim=-1) 43 | x = self.trans(x) 44 | x = self.bn(x) 45 | 46 | return x -------------------------------------------------------------------------------- /models/Fusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | 6 | EPSILON = 1e-12 7 | 8 | class BasicConv2d(nn.Module): 9 | 10 | def __init__(self, in_channels, out_channels, **kwargs): 11 | super(BasicConv2d, self).__init__() 12 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 13 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 14 | 15 | def forward(self, x): 16 | x = self.conv(x) 17 | x = self.bn(x) 18 | return F.relu(x, inplace=True) 19 | 20 | class Fusion(nn.Module): 21 | def __init__(self, feature_dim): 22 | super(Fusion, self).__init__() 23 | 24 | self.linear = nn.Linear(8*2048, feature_dim, bias=False) 25 | self.bn = nn.BatchNorm1d(feature_dim) 26 | self.pool = nn.AdaptiveMaxPool2d(1) 27 | self.M = 8 28 | self.attentions = BasicConv2d(2048, self.M, kernel_size=1) 29 | self.linear.weight.data.normal_(0, 0.001) 30 | init.normal_(self.bn.weight.data, 1.0, 0.02) 31 | init.constant_(self.bn.bias.data, 0.0) 32 | 33 | def forward(self, feat, feat2): 34 | feat2_att = self.attentions(feat2) 35 | 36 | B, C, H, W = feat.size() 37 | _, M, AH, AW = feat2_att.size() 38 | 39 | x = (torch.einsum('imjk,injk->imn', (feat2_att, feat)) / float(H * W)).view(B, -1) 40 | x = torch.sign(x) * torch.sqrt(torch.abs(x) + EPSILON) 41 | x = F.normalize(x, dim=-1) 42 | x = self.linear(x) 43 | x = self.bn(x) 44 | 45 | return x -------------------------------------------------------------------------------- /models/PM.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import init 5 | from torchvision import models 6 | from torch.autograd import Variable 7 | import torch.nn.functional as F 8 | from models.ResNet import resnet50 9 | 10 | 11 | 12 | ###################################################################### 13 | def weight_init(m): 14 | if isinstance(m, nn.Conv2d): 15 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 16 | m.weight.data.normal_(0, math.sqrt(2. / n)) 17 | elif isinstance(m, nn.BatchNorm2d): 18 | m.weight.data.fill_(1) 19 | m.bias.data.zero_() 20 | elif isinstance(m, nn.Linear): 21 | m.weight.data.normal_(0, 0.001) 22 | 23 | 24 | def pcb_block(num_ftrs, num_stripes, local_conv_out_channels, feature_dim, avg=False): 25 | if avg: 26 | pooling_list = nn.ModuleList([nn.AdaptiveAvgPool2d(1) for _ in range(num_stripes)]) 27 | else: 28 | pooling_list = nn.ModuleList([nn.AdaptiveMaxPool2d(1) for _ in range(num_stripes)]) 29 | conv_list = nn.ModuleList([nn.Conv2d(num_ftrs, local_conv_out_channels, 1, bias=False) for _ in range(num_stripes)]) 30 | batchnorm_list = nn.ModuleList([nn.BatchNorm2d(local_conv_out_channels) for _ in range(num_stripes)]) 31 | relu_list = nn.ModuleList([nn.ReLU(inplace=True) for _ in range(num_stripes)]) 32 | for m in conv_list: 33 | weight_init(m) 34 | for m in batchnorm_list: 35 | weight_init(m) 36 | return pooling_list, conv_list, batchnorm_list, relu_list 37 | 38 | 39 | def spp_vertical(feats, pool_list, conv_list, bn_list, relu_list, num_strides, feat_list=[]): 40 | for i in range(num_strides): 41 | pcb_feat = pool_list[i](feats[:, :, i * int(feats.size(2) / num_strides): (i + 1) * int(feats.size(2) / num_strides), :]) 42 | pcb_feat = conv_list[i](pcb_feat) 43 | pcb_feat = bn_list[i](pcb_feat) 44 | pcb_feat = relu_list[i](pcb_feat) 45 | pcb_feat = pcb_feat.view(pcb_feat.size(0), -1) 46 | feat_list.append(pcb_feat) 47 | return feat_list 48 | 49 | def global_pcb(feats, pool, conv, bn, relu, feat_list=[]): 50 | global_feat = pool(feats) 51 | global_feat = conv(global_feat) 52 | global_feat = bn(global_feat) 53 | global_feat = relu(global_feat) 54 | global_feat = global_feat.view(feats.size(0), -1) 55 | feat_list.append(global_feat) 56 | return feat_list 57 | 58 | class PM(nn.Module): 59 | def __init__(self, feature_dim, blocks=15, num_stripes=6, local_conv_out_channels=256, erase=0, loss={'htri'}, avg=False, **kwargs): 60 | super(PM, self).__init__() 61 | self.num_stripes = num_stripes 62 | 63 | model_ft = resnet50(pretrained=True, last_conv_stride=1) 64 | self.num_ftrs = list(model_ft.layer4)[-1].conv1.in_channels 65 | self.features = model_ft 66 | 67 | self.global_pooling = nn.AdaptiveMaxPool2d(1) 68 | self.global_conv = nn.Conv2d(self.num_ftrs, local_conv_out_channels, 1, bias=False) 69 | self.global_bn = nn.BatchNorm2d(local_conv_out_channels) 70 | self.global_relu = nn.ReLU(inplace=True) 71 | 72 | self.trans = nn.Linear(256*blocks, feature_dim, bias=False) 73 | self.bn = nn.BatchNorm1d(feature_dim) 74 | 75 | weight_init(self.global_conv) 76 | weight_init(self.global_bn) 77 | weight_init(self.trans) 78 | init.normal_(self.bn.weight.data, 1.0, 0.02) 79 | init.constant_(self.bn.bias.data, 0.0) 80 | 81 | self.pcb2_pool_list, self.pcb2_conv_list, self.pcb2_batchnorm_list, self.pcb2_relu_list = pcb_block(self.num_ftrs, 2, local_conv_out_channels, feature_dim, avg) 82 | self.pcb4_pool_list, self.pcb4_conv_list, self.pcb4_batchnorm_list, self.pcb4_relu_list = pcb_block(self.num_ftrs, 4, local_conv_out_channels, feature_dim, avg) 83 | self.pcb8_pool_list, self.pcb8_conv_list, self.pcb8_batchnorm_list, self.pcb8_relu_list = pcb_block(self.num_ftrs, 8, local_conv_out_channels, feature_dim, avg) 84 | 85 | def forward(self, x): 86 | feats = self.features(x) 87 | 88 | feat_list = global_pcb(feats, self.global_pooling, self.global_conv, self.global_bn, self.global_relu, []) 89 | feat_list = spp_vertical(feats, self.pcb2_pool_list, self.pcb2_conv_list, self.pcb2_batchnorm_list, self.pcb2_relu_list, 2, feat_list) 90 | feat_list = spp_vertical(feats, self.pcb4_pool_list, self.pcb4_conv_list, self.pcb4_batchnorm_list, self.pcb4_relu_list, 4, feat_list) 91 | feat_list = spp_vertical(feats, self.pcb8_pool_list, self.pcb8_conv_list, self.pcb8_batchnorm_list, self.pcb8_relu_list, 8, feat_list) 92 | 93 | ret = torch.cat(feat_list, dim=1) 94 | ret = self.trans(ret) 95 | ret = self.bn(ret) 96 | return feats, ret 97 | -------------------------------------------------------------------------------- /models/ResNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | from models.utils import pooling 5 | 6 | 7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 8 | 'resnet152'] 9 | 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | } 18 | 19 | 20 | def conv3x3(in_planes, out_planes, stride=1): 21 | """3x3 convolution with padding""" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 23 | padding=1, bias=False) 24 | 25 | 26 | class BasicBlock(nn.Module): 27 | expansion = 1 28 | 29 | def __init__(self, inplanes, planes, stride=1, downsample=None): 30 | super(BasicBlock, self).__init__() 31 | self.conv1 = conv3x3(inplanes, planes, stride) 32 | self.bn1 = nn.BatchNorm2d(planes) 33 | self.relu = nn.ReLU(inplace=True) 34 | self.conv2 = conv3x3(planes, planes) 35 | self.bn2 = nn.BatchNorm2d(planes) 36 | self.downsample = downsample 37 | self.stride = stride 38 | 39 | def forward(self, x): 40 | residual = x 41 | 42 | out = self.conv1(x) 43 | out = self.bn1(out) 44 | out = self.relu(out) 45 | 46 | out = self.conv2(out) 47 | out = self.bn2(out) 48 | 49 | if self.downsample is not None: 50 | residual = self.downsample(x) 51 | 52 | out += residual 53 | out = self.relu(out) 54 | 55 | return out 56 | 57 | 58 | class Bottleneck(nn.Module): 59 | expansion = 4 60 | 61 | def __init__(self, inplanes, planes, stride=1, downsample=None): 62 | super(Bottleneck, self).__init__() 63 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 64 | self.bn1 = nn.BatchNorm2d(planes) 65 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 66 | padding=1, bias=False) 67 | self.bn2 = nn.BatchNorm2d(planes) 68 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 69 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.downsample = downsample 72 | self.stride = stride 73 | 74 | def forward(self, x): 75 | residual = x 76 | 77 | out = self.conv1(x) 78 | out = self.bn1(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv2(out) 82 | out = self.bn2(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv3(out) 86 | out = self.bn3(out) 87 | 88 | if self.downsample is not None: 89 | residual = self.downsample(x) 90 | 91 | out += residual 92 | out = self.relu(out) 93 | 94 | return out 95 | 96 | 97 | class ResNet(nn.Module): 98 | 99 | def __init__(self, block, layers, last_conv_stride=2, num_classes=1000): 100 | self.inplanes = 64 101 | super(ResNet, self).__init__() 102 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 103 | self.bn1 = nn.BatchNorm2d(self.inplanes) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | self.layer1 = self._make_layer(block, 64, layers[0]) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_conv_stride) 110 | self.avgpool = nn.AvgPool2d(7, stride=1) 111 | self.fc = nn.Linear(512 * block.expansion, num_classes) 112 | self.globalpool = pooling.MaxAvgPooling() 113 | 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 117 | elif isinstance(m, nn.BatchNorm2d): 118 | nn.init.constant_(m.weight, 1) 119 | nn.init.constant_(m.bias, 0) 120 | 121 | def _make_layer(self, block, planes, blocks, stride=1): 122 | downsample = None 123 | if stride != 1 or self.inplanes != planes * block.expansion: 124 | downsample = nn.Sequential( 125 | nn.Conv2d(self.inplanes, planes * block.expansion, 126 | kernel_size=1, stride=stride, bias=False), 127 | nn.BatchNorm2d(planes * block.expansion), 128 | ) 129 | 130 | layers = [] 131 | layers.append(block(self.inplanes, planes, stride, downsample)) 132 | self.inplanes = planes * block.expansion 133 | for i in range(1, blocks): 134 | layers.append(block(self.inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | x = self.conv1(x) 140 | x = self.bn1(x) 141 | x = self.relu(x) 142 | x = self.maxpool(x) 143 | 144 | x = self.layer1(x) 145 | x = self.layer2(x) 146 | x = self.layer3(x) 147 | x = self.layer4(x) 148 | 149 | # x = self.avgpool(x) 150 | # x = self.globalpool(x) 151 | # x = x.view(x.size(0), -1) 152 | # x = self.fc(x) 153 | 154 | return x 155 | 156 | 157 | def resnet18(pretrained=False, **kwargs): 158 | """Constructs a ResNet-18 model. 159 | 160 | Args: 161 | pretrained (bool): If True, returns a model pre-trained on ImageNet 162 | """ 163 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 164 | if pretrained: 165 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 166 | return model 167 | 168 | 169 | def resnet34(pretrained=False, **kwargs): 170 | """Constructs a ResNet-34 model. 171 | 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | """ 175 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 176 | if pretrained: 177 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 178 | return model 179 | 180 | 181 | def resnet50(pretrained=False, **kwargs): 182 | """Constructs a ResNet-50 model. 183 | 184 | Args: 185 | pretrained (bool): If True, returns a model pre-trained on ImageNet 186 | """ 187 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 188 | if pretrained: 189 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 190 | return model 191 | 192 | 193 | def resnet101(pretrained=False, **kwargs): 194 | """Constructs a ResNet-101 model. 195 | 196 | Args: 197 | pretrained (bool): If True, returns a model pre-trained on ImageNet 198 | """ 199 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 200 | if pretrained: 201 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 202 | return model 203 | 204 | 205 | def resnet152(pretrained=False, **kwargs): 206 | """Constructs a ResNet-152 model. 207 | 208 | Args: 209 | pretrained (bool): If True, returns a model pre-trained on ImageNet 210 | """ 211 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 212 | if pretrained: 213 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 214 | return model -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from models.classifier import Classifier, NormalizedClassifier 3 | from models.img_resnet import ResNet50 4 | from models.PM import PM 5 | from models.Fusion import Fusion 6 | 7 | 8 | def build_model(config, num_identities, num_clothes): 9 | logger = logging.getLogger('reid.model') 10 | # Build backbone 11 | logger.info("Initializing model: {}".format(config.MODEL.NAME)) 12 | 13 | 14 | logger.info("Init model: '{}'".format(config.MODEL.NAME)) 15 | model = ResNet50(config) 16 | 17 | model2 = PM(feature_dim=config.MODEL.FEATURE_DIM) 18 | fusion = Fusion(feature_dim=config.MODEL.FEATURE_DIM) 19 | 20 | logger.info("Model size: {:.5f}M".format(sum(p.numel() for p in model.parameters())/1000000.0)) 21 | logger.info("Model2 size: {:.5f}M".format(sum(p.numel() for p in model2.parameters())/1000000.0)) 22 | 23 | # Build classifier 24 | if config.LOSS.CLA_LOSS in ['crossentropy', 'crossentropylabelsmooth']: 25 | identity_classifier = Classifier(feature_dim=config.MODEL.FEATURE_DIM, num_classes=num_identities) 26 | else: 27 | identity_classifier = NormalizedClassifier(feature_dim=config.MODEL.FEATURE_DIM, num_classes=num_identities) 28 | 29 | clothes_classifier = NormalizedClassifier(feature_dim=config.MODEL.FEATURE_DIM, num_classes=num_clothes) 30 | 31 | #classifier of new model 32 | clothes_classifier2 = NormalizedClassifier(feature_dim=config.MODEL.FEATURE_DIM, num_classes=num_clothes) 33 | 34 | return model, model2, fusion, identity_classifier, clothes_classifier, clothes_classifier2 -------------------------------------------------------------------------------- /models/classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import init 4 | from torch.nn import functional as F 5 | from torch.nn import Parameter 6 | 7 | 8 | __all__ = ['Classifier', 'NormalizedClassifier'] 9 | 10 | 11 | class Classifier(nn.Module): 12 | def __init__(self, feature_dim, num_classes): 13 | super().__init__() 14 | self.classifier = nn.Linear(feature_dim, num_classes) 15 | init.normal_(self.classifier.weight.data, std=0.001) 16 | init.constant_(self.classifier.bias.data, 0.0) 17 | 18 | def forward(self, x): 19 | y = self.classifier(x) 20 | 21 | return y 22 | 23 | 24 | class NormalizedClassifier(nn.Module): 25 | def __init__(self, feature_dim, num_classes): 26 | super().__init__() 27 | self.weight = Parameter(torch.Tensor(num_classes, feature_dim)) 28 | self.weight.data.uniform_(-1, 1).renorm_(2,0,1e-5).mul_(1e5) 29 | 30 | def forward(self, x): 31 | w = self.weight 32 | 33 | x = F.normalize(x, p=2, dim=1) 34 | w = F.normalize(w, p=2, dim=1) 35 | 36 | return F.linear(x, w) 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /models/img_resnet.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torch import nn 3 | from torch.nn import init 4 | from models.utils import pooling 5 | 6 | 7 | class ResNet50(nn.Module): 8 | def __init__(self, config, **kwargs): 9 | super().__init__() 10 | 11 | resnet50 = torchvision.models.resnet50(pretrained=True) 12 | # or passing weights=ResNet50_Weights.IMAGENET1K_V1 instead of pretrained=True to haddle the warning 13 | if config.MODEL.RES4_STRIDE == 1: 14 | resnet50.layer4[0].conv2.stride=(1, 1) 15 | resnet50.layer4[0].downsample[0].stride=(1, 1) 16 | # self.base = nn.Sequential(*list(resnet50.children())[:-2]) 17 | self.conv1 = resnet50.conv1 18 | self.bn1 = resnet50.bn1 19 | self.relu = resnet50.relu 20 | self.maxpool = resnet50.maxpool 21 | 22 | self.layer1 = resnet50.layer1 23 | self.layer2 = resnet50.layer2 24 | self.layer3 = resnet50.layer3 25 | self.layer4 = resnet50.layer4 26 | 27 | if config.MODEL.POOLING.NAME == 'avg': 28 | self.globalpooling = nn.AdaptiveAvgPool2d(1) 29 | elif config.MODEL.POOLING.NAME == 'max': 30 | self.globalpooling = nn.AdaptiveMaxPool2d(1) 31 | elif config.MODEL.POOLING.NAME == 'gem': 32 | self.globalpooling = pooling.GeMPooling(p=config.MODEL.POOLING.P) 33 | elif config.MODEL.POOLING.NAME == 'maxavg': 34 | self.globalpooling = pooling.MaxAvgPooling() 35 | else: 36 | raise KeyError("Invalid pooling: '{}'".format(config.MODEL.POOLING.NAME)) 37 | 38 | self.bn = nn.BatchNorm1d(config.MODEL.FEATURE_DIM) 39 | init.normal_(self.bn.weight.data, 1.0, 0.02) 40 | init.constant_(self.bn.bias.data, 0.0) 41 | 42 | def forward(self, tmp): 43 | tmp = self.conv1(tmp) 44 | tmp = self.bn1(tmp) 45 | tmp = self.relu(tmp) 46 | tmp = self.maxpool(tmp) 47 | 48 | tmp = self.layer1(tmp) 49 | tmp = self.layer2(tmp) 50 | tmp = self.layer3(tmp) 51 | old_x = self.layer4(tmp) # torch.Size([32, 2048, 24, 12]) 52 | 53 | # old_x = self.base(tmp) 54 | 55 | x = self.globalpooling(old_x) # torch.Size([32, 4096, 1, 1]) 56 | x = x.view(x.size(0), -1) 57 | f = self.bn(x) 58 | 59 | return old_x, f -------------------------------------------------------------------------------- /models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from bisect import bisect_right 7 | import torch 8 | 9 | 10 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 11 | # separating MultiStepLR with WarmupLR 12 | # but the current LRScheduler design doesn't allow it 13 | 14 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 15 | def __init__( 16 | self, 17 | optimizer, 18 | milestones, 19 | gamma=0.1, 20 | warmup_factor=1.0 / 3, 21 | warmup_iters=500, 22 | warmup_method="linear", 23 | last_epoch=-1, 24 | ): 25 | if not list(milestones) == sorted(milestones): 26 | raise ValueError( 27 | "Milestones should be a list of" " increasing integers. Got {}", 28 | milestones, 29 | ) 30 | 31 | if warmup_method not in ("constant", "linear"): 32 | raise ValueError( 33 | "Only 'constant' or 'linear' warmup_method accepted" 34 | "got {}".format(warmup_method) 35 | ) 36 | self.milestones = milestones # (40, 70) 37 | self.gamma = gamma # 0.1 38 | self.warmup_factor = warmup_factor # 0.01 39 | self.warmup_iters = warmup_iters # 0 40 | self.warmup_method = warmup_method # linear 41 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 42 | 43 | def get_lr(self): 44 | warmup_factor = 1 45 | if self.last_epoch < self.warmup_iters: 46 | if self.warmup_method == "constant": 47 | warmup_factor = self.warmup_factor 48 | elif self.warmup_method == "linear": 49 | alpha = self.last_epoch / self.warmup_iters 50 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 51 | return [ 52 | base_lr 53 | * warmup_factor 54 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 55 | for base_lr in self.base_lrs 56 | ] 57 | -------------------------------------------------------------------------------- /models/utils/c3d_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class APM(nn.Module): 7 | def __init__(self, in_channels, out_channels, time_dim=3, temperature=4, contrastive_att=True): 8 | super(APM, self).__init__() 9 | 10 | self.time_dim = time_dim 11 | self.temperature = temperature 12 | self.contrastive_att = contrastive_att 13 | 14 | padding = (0, 0, 0, 0, (time_dim-1)//2, (time_dim-1)//2) 15 | self.padding = nn.ConstantPad3d(padding, value=0) 16 | 17 | self.semantic_mapping = nn.Conv3d(in_channels, out_channels, \ 18 | kernel_size=1, bias=False) 19 | if self.contrastive_att: 20 | self.x_mapping = nn.Conv3d(in_channels, out_channels, \ 21 | kernel_size=1, bias=False) 22 | self.n_mapping = nn.Conv3d(in_channels, out_channels, \ 23 | kernel_size=1, bias=False) 24 | self.contrastive_att_net = nn.Sequential(nn.Conv3d(out_channels, 1, \ 25 | kernel_size=1, bias=False), nn.Sigmoid()) 26 | 27 | def forward(self, x): 28 | b, c, t, h, w = x.size() 29 | N = self.time_dim 30 | 31 | neighbor_time_index = torch.cat([(torch.arange(0,t)+i).unsqueeze(0) for i in range(N) if i!=N//2], dim=0).t().flatten().long() 32 | 33 | # feature map registration 34 | semantic = self.semantic_mapping(x) # (b, c/16, t, h, w) 35 | x_norm = F.normalize(semantic, p=2, dim=1) # (b, c/16, t, h, w) 36 | x_norm_padding = self.padding(x_norm) # (b, c/16, t+2, h, w) 37 | x_norm_expand = x_norm.unsqueeze(3).expand(-1, -1, -1, N-1, -1, -1).permute(0, 2, 3, 4, 5, 1).contiguous().view(-1, h*w, c//16) # (b*t*2, h*w, c/16) 38 | neighbor_norm = x_norm_padding[:, :, neighbor_time_index, :, :].permute(0, 2, 1, 3, 4).contiguous().view(-1, c//16, h*w) # (b*t*2, c/16, h*w) 39 | 40 | similarity = torch.matmul(x_norm_expand, neighbor_norm) * self.temperature # (b*t*2, h*w, h*w) 41 | similarity = F.softmax(similarity, dim=-1) # (b*t*2, h*w, h*w) 42 | 43 | x_padding = self.padding(x) 44 | neighbor = x_padding[:, :, neighbor_time_index, :, :].permute(0, 2, 3, 4, 1).contiguous().view(-1, h*w, c) 45 | neighbor_new = torch.matmul(similarity, neighbor).view(b, t*(N-1), h, w, c).permute(0, 4, 1, 2, 3) # (b, c, t*2, h, w) 46 | 47 | # contrastive attention 48 | if self.contrastive_att: 49 | x_att = self.x_mapping(x.unsqueeze(3).expand(-1, -1, -1, N-1, -1, -1).contiguous().view(b, c, (N-1)*t, h, w).detach()) 50 | n_att = self.n_mapping(neighbor_new.detach()) 51 | contrastive_att = self.contrastive_att_net(x_att * n_att) 52 | neighbor_new = neighbor_new * contrastive_att 53 | 54 | # integrating feature maps 55 | x_offset = torch.zeros([b, c, N*t, h, w], dtype=x.data.dtype, device=x.device.type) 56 | x_index = torch.tensor([i for i in range(t*N) if i%N==N//2]) 57 | neighbor_index = torch.tensor([i for i in range(t*N) if i%N!=N//2]) 58 | x_offset[:, :, x_index, :, :] += x 59 | x_offset[:, :, neighbor_index, :, :] += neighbor_new 60 | 61 | return x_offset 62 | 63 | 64 | class C2D(nn.Module): 65 | def __init__(self, conv2d, **kwargs): 66 | super(C2D, self).__init__() 67 | 68 | # conv3d kernel 69 | kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1]) 70 | stride = (1, conv2d.stride[0], conv2d.stride[0]) 71 | padding = (0, conv2d.padding[0], conv2d.padding[1]) 72 | self.conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \ 73 | kernel_size=kernel_dim, padding=padding, \ 74 | stride=stride, bias=conv2d.bias) 75 | 76 | # init the parameters of conv3d 77 | weight_2d = conv2d.weight.data 78 | weight_3d = torch.zeros(*weight_2d.shape) 79 | weight_3d = weight_3d.unsqueeze(2) 80 | weight_3d[:, :, 0, :, :] = weight_2d 81 | self.conv3d.weight = nn.Parameter(weight_3d) 82 | self.conv3d.bias = conv2d.bias 83 | 84 | def forward(self, x): 85 | out = self.conv3d(x) 86 | 87 | return out 88 | 89 | 90 | class I3D(nn.Module): 91 | def __init__(self, conv2d, time_dim=3, time_stride=1, **kwargs): 92 | super(I3D, self).__init__() 93 | 94 | # conv3d kernel 95 | kernel_dim = (time_dim, conv2d.kernel_size[0], conv2d.kernel_size[1]) 96 | stride = (time_stride, conv2d.stride[0], conv2d.stride[0]) 97 | padding = (time_dim//2, conv2d.padding[0], conv2d.padding[1]) 98 | self.conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \ 99 | kernel_size=kernel_dim, padding=padding, \ 100 | stride=stride, bias=conv2d.bias) 101 | 102 | # init the parameters of conv3d 103 | weight_2d = conv2d.weight.data 104 | weight_3d = torch.zeros(*weight_2d.shape) 105 | weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) 106 | middle_idx = time_dim // 2 107 | weight_3d[:, :, middle_idx, :, :] = weight_2d 108 | self.conv3d.weight = nn.Parameter(weight_3d) 109 | self.conv3d.bias = conv2d.bias 110 | 111 | def forward(self, x): 112 | out = self.conv3d(x) 113 | 114 | return out 115 | 116 | 117 | class API3D(nn.Module): 118 | def __init__(self, conv2d, time_dim=3, time_stride=1, temperature=4, contrastive_att=True): 119 | super(API3D, self).__init__() 120 | 121 | self.APM = APM(conv2d.in_channels, conv2d.in_channels//16, \ 122 | time_dim=time_dim, temperature=temperature, contrastive_att=contrastive_att) 123 | 124 | # conv3d kernel 125 | kernel_dim = (time_dim, conv2d.kernel_size[0], conv2d.kernel_size[1]) 126 | stride = (time_stride*time_dim, conv2d.stride[0], conv2d.stride[0]) 127 | padding = (0, conv2d.padding[0], conv2d.padding[1]) 128 | self.conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \ 129 | kernel_size=kernel_dim, padding=padding, \ 130 | stride=stride, bias=conv2d.bias) 131 | 132 | # init the parameters of conv3d 133 | weight_2d = conv2d.weight.data 134 | weight_3d = torch.zeros(*weight_2d.shape) 135 | weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) 136 | middle_idx = time_dim // 2 137 | weight_3d[:, :, middle_idx, :, :] = weight_2d 138 | self.conv3d.weight = nn.Parameter(weight_3d) 139 | self.conv3d.bias = conv2d.bias 140 | 141 | def forward(self, x): 142 | x_offset = self.APM(x) 143 | out = self.conv3d(x_offset) 144 | 145 | return out 146 | 147 | 148 | class P3DA(nn.Module): 149 | def __init__(self, conv2d, time_dim=3, time_stride=1, **kwargs): 150 | super(P3DA, self).__init__() 151 | 152 | # spatial conv3d kernel 153 | kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1]) 154 | stride = (1, conv2d.stride[0], conv2d.stride[0]) 155 | padding = (0, conv2d.padding[0], conv2d.padding[1]) 156 | self.spatial_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \ 157 | kernel_size=kernel_dim, padding=padding, \ 158 | stride=stride, bias=conv2d.bias) 159 | 160 | # init the parameters of spatial_conv3d 161 | weight_2d = conv2d.weight.data 162 | weight_3d = torch.zeros(*weight_2d.shape) 163 | weight_3d = weight_3d.unsqueeze(2) 164 | weight_3d[:, :, 0, :, :] = weight_2d 165 | self.spatial_conv3d.weight = nn.Parameter(weight_3d) 166 | self.spatial_conv3d.bias = conv2d.bias 167 | 168 | 169 | # temporal conv3d kernel 170 | kernel_dim = (time_dim, 1, 1) 171 | stride = (time_stride, 1, 1) 172 | padding = (time_dim//2, 0, 0) 173 | self.temporal_conv3d = nn.Conv3d(conv2d.out_channels, conv2d.out_channels, \ 174 | kernel_size=kernel_dim, padding=padding, \ 175 | stride=stride, bias=False) 176 | 177 | # init the parameters of temporal_conv3d 178 | weight_2d = torch.eye(conv2d.out_channels).unsqueeze(2).unsqueeze(2) 179 | weight_3d = torch.zeros(*weight_2d.shape) 180 | weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) 181 | middle_idx = time_dim // 2 182 | weight_3d[:, :, middle_idx, :, :] = weight_2d 183 | self.temporal_conv3d.weight = nn.Parameter(weight_3d) 184 | 185 | 186 | def forward(self, x): 187 | x = self.spatial_conv3d(x) 188 | out = self.temporal_conv3d(x) 189 | 190 | return out 191 | 192 | 193 | class P3DB(nn.Module): 194 | def __init__(self, conv2d, time_dim=3, time_stride=1, **kwargs): 195 | super(P3DB, self).__init__() 196 | 197 | # spatial conv3d kernel 198 | kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1]) 199 | stride = (1, conv2d.stride[0], conv2d.stride[0]) 200 | padding = (0, conv2d.padding[0], conv2d.padding[1]) 201 | self.spatial_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \ 202 | kernel_size=kernel_dim, padding=padding, \ 203 | stride=stride, bias=conv2d.bias) 204 | 205 | # init the parameters of spatial_conv3d 206 | weight_2d = conv2d.weight.data 207 | weight_3d = torch.zeros(*weight_2d.shape) 208 | weight_3d = weight_3d.unsqueeze(2) 209 | weight_3d[:, :, 0, :, :] = weight_2d 210 | self.spatial_conv3d.weight = nn.Parameter(weight_3d) 211 | self.spatial_conv3d.bias = conv2d.bias 212 | 213 | 214 | # temporal conv3d kernel 215 | kernel_dim = (time_dim, 1, 1) 216 | stride = (time_stride, conv2d.stride[0], conv2d.stride[0]) 217 | padding = (time_dim//2, 0, 0) 218 | self.temporal_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \ 219 | kernel_size=kernel_dim, padding=padding, \ 220 | stride=stride, bias=False) 221 | 222 | # init the parameters of temporal_conv3d 223 | nn.init.constant_(self.temporal_conv3d.weight, 0) 224 | 225 | 226 | def forward(self, x): 227 | # print(x.shape) 228 | out1 = self.spatial_conv3d(x) 229 | # print(out1.shape) 230 | out2 = self.temporal_conv3d(x) 231 | # print(out2.shape) 232 | out = out1 + out2 233 | 234 | return out 235 | 236 | 237 | class P3DC(nn.Module): 238 | def __init__(self, conv2d, time_dim=3, time_stride=1, **kwargs): 239 | super(P3DC, self).__init__() 240 | 241 | # spatial conv3d kernel 242 | kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1]) 243 | stride = (1, conv2d.stride[0], conv2d.stride[0]) 244 | padding = (0, conv2d.padding[0], conv2d.padding[1]) 245 | self.spatial_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \ 246 | kernel_size=kernel_dim, padding=padding, \ 247 | stride=stride, bias=conv2d.bias) 248 | 249 | # init the parameters of spatial_conv3d 250 | weight_2d = conv2d.weight.data 251 | weight_3d = torch.zeros(*weight_2d.shape) 252 | weight_3d = weight_3d.unsqueeze(2) 253 | weight_3d[:, :, 0, :, :] = weight_2d 254 | self.spatial_conv3d.weight = nn.Parameter(weight_3d) 255 | self.spatial_conv3d.bias = conv2d.bias 256 | 257 | 258 | # temporal conv3d kernel 259 | kernel_dim = (time_dim, 1, 1) 260 | stride = (time_stride, 1, 1) 261 | padding = (time_dim//2, 0, 0) 262 | self.temporal_conv3d = nn.Conv3d(conv2d.out_channels, conv2d.out_channels, \ 263 | kernel_size=kernel_dim, padding=padding, \ 264 | stride=stride, bias=False) 265 | 266 | # init the parameters of temporal_conv3d 267 | nn.init.constant_(self.temporal_conv3d.weight, 0) 268 | 269 | 270 | def forward(self, x): 271 | out = self.spatial_conv3d(x) 272 | residual = self.temporal_conv3d(out) 273 | out = out + residual 274 | 275 | return out 276 | 277 | 278 | class APP3DA(nn.Module): 279 | def __init__(self, conv2d, time_dim=3, time_stride=1, temperature=4, contrastive_att=True): 280 | super(APP3DA, self).__init__() 281 | 282 | self.APM = APM(conv2d.out_channels, conv2d.out_channels//16, \ 283 | time_dim=time_dim, temperature=temperature, contrastive_att=contrastive_att) 284 | 285 | # spatial conv3d kernel 286 | kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1]) 287 | stride = (1, conv2d.stride[0], conv2d.stride[0]) 288 | padding = (0, conv2d.padding[0], conv2d.padding[1]) 289 | self.spatial_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \ 290 | kernel_size=kernel_dim, padding=padding, \ 291 | stride=stride, bias=conv2d.bias) 292 | 293 | # init the parameters of spatial_conv3d 294 | weight_2d = conv2d.weight.data 295 | weight_3d = torch.zeros(*weight_2d.shape) 296 | weight_3d = weight_3d.unsqueeze(2) 297 | weight_3d[:, :, 0, :, :] = weight_2d 298 | self.spatial_conv3d.weight = nn.Parameter(weight_3d) 299 | self.spatial_conv3d.bias = conv2d.bias 300 | 301 | 302 | # temporal conv3d kernel 303 | kernel_dim = (time_dim, 1, 1) 304 | stride = (time_stride*time_dim, 1, 1) 305 | padding = (0, 0, 0) 306 | self.temporal_conv3d = nn.Conv3d(conv2d.out_channels, conv2d.out_channels, \ 307 | kernel_size=kernel_dim, padding=padding, \ 308 | stride=stride, bias=False) 309 | 310 | # init the parameters of temporal_conv3d 311 | weight_2d = torch.eye(conv2d.out_channels).unsqueeze(2).unsqueeze(2) 312 | weight_3d = torch.zeros(*weight_2d.shape) 313 | weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) 314 | middle_idx = time_dim // 2 315 | weight_3d[:, :, middle_idx, :, :] = weight_2d 316 | self.temporal_conv3d.weight = nn.Parameter(weight_3d) 317 | 318 | 319 | def forward(self, x): 320 | x = self.spatial_conv3d(x) 321 | out = self.temporal_conv3d(self.APM(x)) 322 | 323 | return out 324 | 325 | 326 | class APP3DB(nn.Module): 327 | def __init__(self, conv2d, time_dim=3, time_stride=1, temperature=4, contrastive_att=True): 328 | super(APP3DB, self).__init__() 329 | 330 | self.APM = APM(conv2d.in_channels, conv2d.in_channels//16, \ 331 | time_dim=time_dim, temperature=temperature, contrastive_att=contrastive_att) 332 | 333 | # spatial conv3d kernel 334 | kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1]) 335 | stride = (1, conv2d.stride[0], conv2d.stride[0]) 336 | padding = (0, conv2d.padding[0], conv2d.padding[1]) 337 | self.spatial_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \ 338 | kernel_size=kernel_dim, padding=padding, \ 339 | stride=stride, bias=conv2d.bias) 340 | 341 | # init the parameters of spatial_conv3d 342 | weight_2d = conv2d.weight.data 343 | weight_3d = torch.zeros(*weight_2d.shape) 344 | weight_3d = weight_3d.unsqueeze(2) 345 | weight_3d[:, :, 0, :, :] = weight_2d 346 | self.spatial_conv3d.weight = nn.Parameter(weight_3d) 347 | self.spatial_conv3d.bias = conv2d.bias 348 | 349 | 350 | # temporal conv3d kernel 351 | kernel_dim = (time_dim, 1, 1) 352 | stride = (time_stride*time_dim, conv2d.stride[0], conv2d.stride[0]) 353 | padding = (0, 0, 0) 354 | self.temporal_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \ 355 | kernel_size=kernel_dim, padding=padding, \ 356 | stride=stride, bias=False) 357 | 358 | # init the parameters of temporal_conv3d 359 | nn.init.constant_(self.temporal_conv3d.weight, 0) 360 | 361 | 362 | def forward(self, x): 363 | out1 = self.spatial_conv3d(x) 364 | out2 = self.temporal_conv3d(self.APM(x)) 365 | out = out1 + out2 366 | 367 | return out 368 | 369 | 370 | class APP3DC(nn.Module): 371 | def __init__(self, conv2d, time_dim=3, time_stride=1, temperature=4, contrastive_att=True): 372 | super(APP3DC, self).__init__() 373 | 374 | self.APM = APM(conv2d.out_channels, conv2d.out_channels//16, \ 375 | time_dim=time_dim, temperature=temperature, contrastive_att=contrastive_att) 376 | 377 | # spatial conv3d kernel 378 | kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1]) 379 | stride = (1, conv2d.stride[0], conv2d.stride[0]) 380 | padding = (0, conv2d.padding[0], conv2d.padding[1]) 381 | self.spatial_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \ 382 | kernel_size=kernel_dim, padding=padding, \ 383 | stride=stride, bias=conv2d.bias) 384 | 385 | # init the parameters of spatial_conv3d 386 | weight_2d = conv2d.weight.data 387 | weight_3d = torch.zeros(*weight_2d.shape) 388 | weight_3d = weight_3d.unsqueeze(2) 389 | weight_3d[:, :, 0, :, :] = weight_2d 390 | self.spatial_conv3d.weight = nn.Parameter(weight_3d) 391 | self.spatial_conv3d.bias = conv2d.bias 392 | 393 | 394 | # temporal conv3d kernel 395 | kernel_dim = (time_dim, 1, 1) 396 | stride = (time_stride*time_dim, 1, 1) 397 | padding = (0, 0, 0) 398 | self.temporal_conv3d = nn.Conv3d(conv2d.out_channels, conv2d.out_channels, \ 399 | kernel_size=kernel_dim, padding=padding, \ 400 | stride=stride, bias=False) 401 | 402 | # init the parameters of temporal_conv3d 403 | nn.init.constant_(self.temporal_conv3d.weight, 0) 404 | 405 | 406 | def forward(self, x): 407 | out = self.spatial_conv3d(x) 408 | residual = self.temporal_conv3d(self.APM(out)) 409 | out = out + residual 410 | 411 | return out 412 | -------------------------------------------------------------------------------- /models/utils/inflate.py: -------------------------------------------------------------------------------- 1 | # inflate 2D modules to 3D modules 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | 6 | 7 | def inflate_conv(conv2d, 8 | time_dim=1, 9 | time_padding=0, 10 | time_stride=1, 11 | time_dilation=1, 12 | center=False): 13 | # To preserve activations, padding should be by continuity and not zero 14 | # or no padding in time dimension 15 | kernel_dim = (time_dim, conv2d.kernel_size[0], conv2d.kernel_size[1]) 16 | padding = (time_padding, conv2d.padding[0], conv2d.padding[1]) 17 | stride = (time_stride, conv2d.stride[0], conv2d.stride[0]) 18 | dilation = (time_dilation, conv2d.dilation[0], conv2d.dilation[1]) 19 | conv3d = nn.Conv3d( 20 | conv2d.in_channels, 21 | conv2d.out_channels, 22 | kernel_dim, 23 | padding=padding, 24 | dilation=dilation, 25 | stride=stride) 26 | # Repeat filter time_dim times along time dimension 27 | weight_2d = conv2d.weight.data 28 | if center: 29 | weight_3d = torch.zeros(*weight_2d.shape) 30 | weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) 31 | middle_idx = time_dim // 2 32 | weight_3d[:, :, middle_idx, :, :] = weight_2d 33 | else: 34 | weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) 35 | weight_3d = weight_3d / time_dim 36 | 37 | # Assign new params 38 | conv3d.weight = nn.Parameter(weight_3d) 39 | conv3d.bias = conv2d.bias 40 | return conv3d 41 | 42 | 43 | def inflate_linear(linear2d, time_dim): 44 | """ 45 | Args: 46 | time_dim: final time dimension of the features 47 | """ 48 | linear3d = nn.Linear(linear2d.in_features * time_dim, 49 | linear2d.out_features) 50 | weight3d = linear2d.weight.data.repeat(1, time_dim) 51 | weight3d = weight3d / time_dim 52 | 53 | linear3d.weight = nn.Parameter(weight3d) 54 | linear3d.bias = linear2d.bias 55 | return linear3d 56 | 57 | 58 | def inflate_batch_norm(batch2d): 59 | # In pytorch 0.2.0 the 2d and 3d versions of batch norm 60 | # work identically except for the check that verifies the 61 | # input dimensions 62 | 63 | batch3d = nn.BatchNorm3d(batch2d.num_features) 64 | # retrieve 3d _check_input_dim function 65 | batch2d._check_input_dim = batch3d._check_input_dim 66 | return batch2d 67 | 68 | 69 | def inflate_pool(pool2d, 70 | time_dim=1, 71 | time_padding=0, 72 | time_stride=None, 73 | time_dilation=1): 74 | kernel_dim = (time_dim, pool2d.kernel_size, pool2d.kernel_size) 75 | padding = (time_padding, pool2d.padding, pool2d.padding) 76 | if time_stride is None: 77 | time_stride = time_dim 78 | stride = (time_stride, pool2d.stride, pool2d.stride) 79 | if isinstance(pool2d, nn.MaxPool2d): 80 | dilation = (time_dilation, pool2d.dilation, pool2d.dilation) 81 | pool3d = nn.MaxPool3d( 82 | kernel_dim, 83 | padding=padding, 84 | dilation=dilation, 85 | stride=stride, 86 | ceil_mode=pool2d.ceil_mode) 87 | elif isinstance(pool2d, nn.AvgPool2d): 88 | pool3d = nn.AvgPool3d(kernel_dim, stride=stride) 89 | else: 90 | raise ValueError( 91 | '{} is not among known pooling classes'.format(type(pool2d))) 92 | return pool3d 93 | 94 | 95 | class MaxPool2dFor3dInput(nn.Module): 96 | """ 97 | Since nn.MaxPool3d is nondeterministic operation, using fixed random seeds can't get consistent results. 98 | So we attempt to use max_pool2d to implement MaxPool3d with kernelsize (1, kernel_size, kernel_size). 99 | """ 100 | def __init__(self, kernel_size, stride=None, padding=0, dilation=1): 101 | super().__init__() 102 | self.maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation) 103 | def forward(self, x): 104 | b, c, t, h, w = x.size() 105 | x = x.permute(0, 2, 1, 3, 4).contiguous() # b, t, c, h, w 106 | x = x.view(b*t, c, h, w) 107 | # max pooling 108 | x = self.maxpool(x) 109 | _, _, h, w = x.size() 110 | x = x.view(b, t, c, h, w).permute(0, 2, 1, 3, 4).contiguous() 111 | 112 | return x -------------------------------------------------------------------------------- /models/utils/nonlocal_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from models.utils import inflate 6 | 7 | 8 | class NonLocalBlockND(nn.Module): 9 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 10 | super(NonLocalBlockND, self).__init__() 11 | 12 | assert dimension in [1, 2, 3] 13 | 14 | self.dimension = dimension 15 | self.sub_sample = sub_sample 16 | self.in_channels = in_channels 17 | self.inter_channels = inter_channels 18 | 19 | if self.inter_channels is None: 20 | self.inter_channels = in_channels // 2 21 | if self.inter_channels == 0: 22 | self.inter_channels = 1 23 | 24 | if dimension == 3: 25 | conv_nd = nn.Conv3d 26 | # max_pool = inflate.MaxPool2dFor3dInput 27 | max_pool = nn.MaxPool3d 28 | bn = nn.BatchNorm3d 29 | elif dimension == 2: 30 | conv_nd = nn.Conv2d 31 | max_pool = nn.MaxPool2d 32 | bn = nn.BatchNorm2d 33 | else: 34 | conv_nd = nn.Conv1d 35 | max_pool = nn.MaxPool1d 36 | bn = nn.BatchNorm1d 37 | 38 | self.g = conv_nd(self.in_channels, self.inter_channels, 39 | kernel_size=1, stride=1, padding=0, bias=True) 40 | self.theta = conv_nd(self.in_channels, self.inter_channels, 41 | kernel_size=1, stride=1, padding=0, bias=True) 42 | self.phi = conv_nd(self.in_channels, self.inter_channels, 43 | kernel_size=1, stride=1, padding=0, bias=True) 44 | # if sub_sample: 45 | # self.g = nn.Sequential(self.g, max_pool(kernel_size=2)) 46 | # self.phi = nn.Sequential(self.phi, max_pool(kernel_size=2)) 47 | if sub_sample: 48 | if dimension == 3: 49 | self.g = nn.Sequential(self.g, max_pool((1, 2, 2))) 50 | self.phi = nn.Sequential(self.phi, max_pool((1, 2, 2))) 51 | else: 52 | self.g = nn.Sequential(self.g, max_pool(kernel_size=2)) 53 | self.phi = nn.Sequential(self.phi, max_pool(kernel_size=2)) 54 | 55 | if bn_layer: 56 | self.W = nn.Sequential( 57 | conv_nd(self.inter_channels, self.in_channels, 58 | kernel_size=1, stride=1, padding=0, bias=True), 59 | bn(self.in_channels) 60 | ) 61 | else: 62 | self.W = conv_nd(self.inter_channels, self.in_channels, 63 | kernel_size=1, stride=1, padding=0, bias=True) 64 | 65 | # init 66 | for m in self.modules(): 67 | if isinstance(m, conv_nd): 68 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 69 | m.weight.data.normal_(0, math.sqrt(2. / n)) 70 | elif isinstance(m, bn): 71 | m.weight.data.fill_(1) 72 | m.bias.data.zero_() 73 | 74 | if bn_layer: 75 | nn.init.constant_(self.W[1].weight.data, 0.0) 76 | nn.init.constant_(self.W[1].bias.data, 0.0) 77 | else: 78 | nn.init.constant_(self.W.weight.data, 0.0) 79 | nn.init.constant_(self.W.bias.data, 0.0) 80 | 81 | 82 | def forward(self, x): 83 | ''' 84 | :param x: (b, c, t, h, w) 85 | :return: 86 | ''' 87 | batch_size = x.size(0) 88 | 89 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 90 | g_x = g_x.permute(0, 2, 1) 91 | 92 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 93 | theta_x = theta_x.permute(0, 2, 1) 94 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 95 | f = torch.matmul(theta_x, phi_x) 96 | f = F.softmax(f, dim=-1) 97 | 98 | y = torch.matmul(f, g_x) 99 | y = y.permute(0, 2, 1).contiguous() 100 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 101 | y = self.W(y) 102 | z = y + x 103 | 104 | return z 105 | 106 | 107 | class NonLocalBlock1D(NonLocalBlockND): 108 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 109 | super(NonLocalBlock1D, self).__init__(in_channels, 110 | inter_channels=inter_channels, 111 | dimension=1, sub_sample=sub_sample, 112 | bn_layer=bn_layer) 113 | 114 | 115 | class NonLocalBlock2D(NonLocalBlockND): 116 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 117 | super(NonLocalBlock2D, self).__init__(in_channels, 118 | inter_channels=inter_channels, 119 | dimension=2, sub_sample=sub_sample, 120 | bn_layer=bn_layer) 121 | 122 | 123 | class NonLocalBlock3D(NonLocalBlockND): 124 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 125 | super(NonLocalBlock3D, self).__init__(in_channels, 126 | inter_channels=inter_channels, 127 | dimension=3, sub_sample=sub_sample, 128 | bn_layer=bn_layer) 129 | -------------------------------------------------------------------------------- /models/utils/pooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class GeMPooling(nn.Module): 7 | def __init__(self, p=3, eps=1e-6): 8 | super().__init__() 9 | self.p = nn.Parameter(torch.ones(1) * p) 10 | self.eps = eps 11 | 12 | def forward(self, x): 13 | return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), x.size()[2:]).pow(1./self.p) 14 | 15 | 16 | class MaxAvgPooling(nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | self.maxpooling = nn.AdaptiveMaxPool2d(1) 20 | self.avgpooling = nn.AdaptiveAvgPool2d(1) 21 | 22 | def forward(self, x): 23 | max_f = self.maxpooling(x) 24 | avg_f = self.avgpooling(x) 25 | 26 | return torch.cat((max_f, avg_f), 1) 27 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import distributed as dist 6 | from tools.eval_metrics import evaluate, evaluate_with_clothes 7 | 8 | 9 | def concat_all_gather(tensors, num_total_examples): 10 | ''' 11 | Performs all_gather operation on the provided tensor list. 12 | ''' 13 | outputs = [] 14 | for tensor in tensors: 15 | tensor = tensor.cuda() 16 | tensors_gather = [tensor.clone() for _ in range(dist.get_world_size())] 17 | dist.all_gather(tensors_gather, tensor) 18 | output = torch.cat(tensors_gather, dim=0).cpu() 19 | # truncate the dummy elements added by DistributedInferenceSampler 20 | outputs.append(output[:num_total_examples]) 21 | return outputs 22 | 23 | 24 | @torch.no_grad() 25 | def extract_img_feature(model, dataloader): 26 | features, pids, camids, clothes_ids = [], torch.tensor([]), torch.tensor([]), torch.tensor([]) 27 | for batch_idx, (imgs, batch_pids, batch_camids, batch_clothes_ids, batch_img_path) in enumerate(dataloader): 28 | flip_imgs = torch.flip(imgs, [3]) 29 | imgs, flip_imgs = imgs.cuda(), flip_imgs.cuda() 30 | _, batch_features = model(imgs) 31 | _, batch_features_flip = model(flip_imgs) 32 | batch_features += batch_features_flip 33 | batch_features = F.normalize(batch_features, p=2, dim=1) 34 | 35 | features.append(batch_features.cpu()) 36 | pids = torch.cat((pids, batch_pids.cpu()), dim=0) 37 | camids = torch.cat((camids, batch_camids.cpu()), dim=0) 38 | clothes_ids = torch.cat((clothes_ids, batch_clothes_ids.cpu()), dim=0) 39 | features = torch.cat(features, 0) 40 | 41 | return features, pids, camids, clothes_ids 42 | 43 | 44 | def test(config, model, queryloader, galleryloader, dataset): 45 | logger = logging.getLogger('reid.test') 46 | since = time.time() 47 | model.eval() 48 | local_rank = dist.get_rank() 49 | # Extract features 50 | qf, q_pids, q_camids, q_clothes_ids = extract_img_feature(model, queryloader) 51 | gf, g_pids, g_camids, g_clothes_ids = extract_img_feature(model, galleryloader) 52 | # Gather samples from different GPUs 53 | torch.cuda.empty_cache() 54 | qf, q_pids, q_camids, q_clothes_ids = concat_all_gather([qf, q_pids, q_camids, q_clothes_ids], len(dataset.query)) 55 | gf, g_pids, g_camids, g_clothes_ids = concat_all_gather([gf, g_pids, g_camids, g_clothes_ids], len(dataset.gallery)) 56 | torch.cuda.empty_cache() 57 | time_elapsed = time.time() - since 58 | 59 | logger.info("Extracted features for query set, obtained {} matrix".format(qf.shape)) 60 | logger.info("Extracted features for gallery set, obtained {} matrix".format(gf.shape)) 61 | logger.info('Extracting features complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 62 | # Compute distance matrix between query and gallery 63 | since = time.time() 64 | m, n = qf.size(0), gf.size(0) 65 | distmat = torch.zeros((m,n)) 66 | qf, gf = qf.cuda(), gf.cuda() 67 | # Cosine similarity 68 | for i in range(m): 69 | distmat[i] = (- torch.mm(qf[i:i+1], gf.t())).cpu() 70 | distmat = distmat.numpy() 71 | q_pids, q_camids, q_clothes_ids = q_pids.numpy(), q_camids.numpy(), q_clothes_ids.numpy() 72 | g_pids, g_camids, g_clothes_ids = g_pids.numpy(), g_camids.numpy(), g_clothes_ids.numpy() 73 | time_elapsed = time.time() - since 74 | logger.info('Distance computing in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 75 | 76 | since = time.time() 77 | logger.info("Computing CMC and mAP") 78 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids) 79 | logger.info("Results ---------------------------------------------------") 80 | logger.info('top1:{:.1%} top5:{:.1%} top10:{:.1%} top20:{:.1%} mAP:{:.1%}'.format(cmc[0], cmc[4], cmc[9], cmc[19], mAP)) 81 | logger.info("-----------------------------------------------------------") 82 | time_elapsed = time.time() - since 83 | logger.info('Using {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 84 | 85 | logger.info("Computing CMC and mAP only for the same clothes setting") 86 | cmc, mAP = evaluate_with_clothes(distmat, q_pids, g_pids, q_camids, g_camids, q_clothes_ids, g_clothes_ids, mode='SC') 87 | logger.info("Results ---------------------------------------------------") 88 | logger.info('top1:{:.1%} top5:{:.1%} top10:{:.1%} top20:{:.1%} mAP:{:.1%}'.format(cmc[0], cmc[4], cmc[9], cmc[19], mAP)) 89 | logger.info("-----------------------------------------------------------") 90 | 91 | logger.info("Computing CMC and mAP only for clothes-changing") 92 | cmc, mAP = evaluate_with_clothes(distmat, q_pids, g_pids, q_camids, g_camids, q_clothes_ids, g_clothes_ids, mode='CC') 93 | logger.info("Results ---------------------------------------------------") 94 | logger.info('top1:{:.1%} top5:{:.1%} top10:{:.1%} top20:{:.1%} mAP:{:.1%}'.format(cmc[0], cmc[4], cmc[9], cmc[19], mAP)) 95 | logger.info("-----------------------------------------------------------") 96 | 97 | return cmc[0] 98 | 99 | 100 | def test_prcc(model, queryloader_same, queryloader_diff, galleryloader, dataset): 101 | logger = logging.getLogger('reid.test') 102 | since = time.time() 103 | model.eval() 104 | local_rank = dist.get_rank() 105 | # Extract features for query set 106 | qsf, qs_pids, qs_camids, qs_clothes_ids = extract_img_feature(model, queryloader_same) 107 | qdf, qd_pids, qd_camids, qd_clothes_ids = extract_img_feature(model, queryloader_diff) 108 | # Extract features for gallery set 109 | gf, g_pids, g_camids, g_clothes_ids = extract_img_feature(model, galleryloader) 110 | # Gather samples from different GPUs 111 | torch.cuda.empty_cache() 112 | qsf, qs_pids, qs_camids, qs_clothes_ids = concat_all_gather([qsf, qs_pids, qs_camids, qs_clothes_ids], len(dataset.query_same)) 113 | qdf, qd_pids, qd_camids, qd_clothes_ids = concat_all_gather([qdf, qd_pids, qd_camids, qd_clothes_ids], len(dataset.query_diff)) 114 | gf, g_pids, g_camids, g_clothes_ids = concat_all_gather([gf, g_pids, g_camids, g_clothes_ids], len(dataset.gallery)) 115 | time_elapsed = time.time() - since 116 | 117 | logger.info("Extracted features for query set (with same clothes), obtained {} matrix".format(qsf.shape)) 118 | logger.info("Extracted features for query set (with different clothes), obtained {} matrix".format(qdf.shape)) 119 | logger.info("Extracted features for gallery set, obtained {} matrix".format(gf.shape)) 120 | logger.info('Extracting features complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 121 | # Compute distance matrix between query and gallery 122 | m, n, k = qsf.size(0), qdf.size(0), gf.size(0) 123 | distmat_same = torch.zeros((m, k)) 124 | distmat_diff = torch.zeros((n, k)) 125 | qsf, qdf, gf = qsf.cuda(), qdf.cuda(), gf.cuda() 126 | # Cosine similarity 127 | for i in range(m): 128 | distmat_same[i] = (- torch.mm(qsf[i:i+1], gf.t())).cpu() 129 | for i in range(n): 130 | distmat_diff[i] = (- torch.mm(qdf[i:i+1], gf.t())).cpu() 131 | distmat_same = distmat_same.numpy() 132 | distmat_diff = distmat_diff.numpy() 133 | qs_pids, qs_camids, qs_clothes_ids = qs_pids.numpy(), qs_camids.numpy(), qs_clothes_ids.numpy() 134 | qd_pids, qd_camids, qd_clothes_ids = qd_pids.numpy(), qd_camids.numpy(), qd_clothes_ids.numpy() 135 | g_pids, g_camids, g_clothes_ids = g_pids.numpy(), g_camids.numpy(), g_clothes_ids.numpy() 136 | 137 | logger.info("Computing CMC and mAP for the same clothes setting") 138 | cmc, mAP = evaluate(distmat_same, qs_pids, g_pids, qs_camids, g_camids) 139 | logger.info("Results ---------------------------------------------------") 140 | logger.info('top1:{:.1%} top5:{:.1%} top10:{:.1%} top20:{:.1%} mAP:{:.1%}'.format(cmc[0], cmc[4], cmc[9], cmc[19], mAP)) 141 | logger.info("-----------------------------------------------------------") 142 | 143 | logger.info("Computing CMC and mAP only for clothes changing") 144 | cmc, mAP = evaluate(distmat_diff, qd_pids, g_pids, qd_camids, g_camids) 145 | logger.info("Results ---------------------------------------------------") 146 | logger.info('top1:{:.1%} top5:{:.1%} top10:{:.1%} top20:{:.1%} mAP:{:.1%}'.format(cmc[0], cmc[4], cmc[9], cmc[19], mAP)) 147 | logger.info("-----------------------------------------------------------") 148 | 149 | return cmc[0] -------------------------------------------------------------------------------- /test_AIM.sh: -------------------------------------------------------------------------------- 1 | # For LTCC dataset 2 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset ltcc --cfg configs/res50_cels_cal.yaml --gpu 0,1 --eval --resume ./ltcc.pth.tar # 3 | # For PRCC dataset 4 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset prcc --cfg configs/res50_cels_cal.yaml --gpu 0,1 --eval --resume ./prcc.pth.tar # -------------------------------------------------------------------------------- /tools/eval_metrics.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | 4 | 5 | def compute_ap_cmc(index, good_index, junk_index): 6 | """ Compute AP and CMC for each sample 7 | """ 8 | ap = 0 9 | cmc = np.zeros(len(index)) 10 | 11 | # remove junk_index 12 | mask = np.in1d(index, junk_index, invert=True) 13 | index = index[mask] 14 | 15 | # find good_index index 16 | ngood = len(good_index) 17 | mask = np.in1d(index, good_index) 18 | rows_good = np.argwhere(mask==True) 19 | rows_good = rows_good.flatten() 20 | 21 | cmc[rows_good[0]:] = 1.0 22 | for i in range(ngood): 23 | d_recall = 1.0/ngood 24 | precision = (i+1)*1.0/(rows_good[i]+1) 25 | ap = ap + d_recall*precision 26 | 27 | return ap, cmc 28 | 29 | 30 | def evaluate(distmat, q_pids, g_pids, q_camids, g_camids): 31 | """ Compute CMC and mAP 32 | 33 | Args: 34 | distmat (numpy ndarray): distance matrix with shape (num_query, num_gallery). 35 | q_pids (numpy array): person IDs for query samples. 36 | g_pids (numpy array): person IDs for gallery samples. 37 | q_camids (numpy array): camera IDs for query samples. 38 | g_camids (numpy array): camera IDs for gallery samples. 39 | """ 40 | num_q, num_g = distmat.shape 41 | index = np.argsort(distmat, axis=1) # from small to large 42 | 43 | num_no_gt = 0 # num of query imgs without groundtruth 44 | num_r1 = 0 45 | CMC = np.zeros(len(g_pids)) 46 | AP = 0 47 | 48 | for i in range(num_q): 49 | # groundtruth index 50 | query_index = np.argwhere(g_pids==q_pids[i]) 51 | camera_index = np.argwhere(g_camids==q_camids[i]) 52 | good_index = np.setdiff1d(query_index, camera_index, assume_unique=True) 53 | if good_index.size == 0: 54 | num_no_gt += 1 55 | continue 56 | # remove gallery samples that have the same pid and camid with query 57 | junk_index = np.intersect1d(query_index, camera_index) 58 | 59 | ap_tmp, CMC_tmp = compute_ap_cmc(index[i], good_index, junk_index) 60 | if CMC_tmp[0]==1: 61 | num_r1 += 1 62 | CMC = CMC + CMC_tmp 63 | AP += ap_tmp 64 | 65 | if num_no_gt > 0: 66 | logger = logging.getLogger('reid.evaluate') 67 | logger.info("{} query samples do not have groundtruth.".format(num_no_gt)) 68 | 69 | CMC = CMC / (num_q - num_no_gt) 70 | mAP = AP / (num_q - num_no_gt) 71 | 72 | return CMC, mAP 73 | 74 | 75 | def evaluate_with_clothes(distmat, q_pids, g_pids, q_camids, g_camids, q_clothids, g_clothids, mode='CC'): 76 | """ Compute CMC and mAP with clothes 77 | 78 | Args: 79 | distmat (numpy ndarray): distance matrix with shape (num_query, num_gallery). 80 | q_pids (numpy array): person IDs for query samples. 81 | g_pids (numpy array): person IDs for gallery samples. 82 | q_camids (numpy array): camera IDs for query samples. 83 | g_camids (numpy array): camera IDs for gallery samples. 84 | q_clothids (numpy array): clothes IDs for query samples. 85 | g_clothids (numpy array): clothes IDs for gallery samples. 86 | mode: 'CC' for clothes-changing; 'SC' for the same clothes. 87 | """ 88 | assert mode in ['CC', 'SC'] 89 | 90 | num_q, num_g = distmat.shape 91 | index = np.argsort(distmat, axis=1) # from small to large 92 | 93 | num_no_gt = 0 # num of query imgs without groundtruth 94 | num_r1 = 0 95 | CMC = np.zeros(len(g_pids)) 96 | AP = 0 97 | 98 | for i in range(num_q): 99 | # groundtruth index 100 | query_index = np.argwhere(g_pids==q_pids[i]) 101 | camera_index = np.argwhere(g_camids==q_camids[i]) 102 | cloth_index = np.argwhere(g_clothids==q_clothids[i]) 103 | good_index = np.setdiff1d(query_index, camera_index, assume_unique=True) 104 | if mode == 'CC': 105 | good_index = np.setdiff1d(good_index, cloth_index, assume_unique=True) 106 | # remove gallery samples that have the same (pid, camid) or (pid, clothid) with query 107 | junk_index1 = np.intersect1d(query_index, camera_index) 108 | junk_index2 = np.intersect1d(query_index, cloth_index) 109 | junk_index = np.union1d(junk_index1, junk_index2) 110 | else: 111 | good_index = np.intersect1d(good_index, cloth_index) 112 | # remove gallery samples that have the same (pid, camid) or 113 | # (the same pid and different clothid) with query 114 | junk_index1 = np.intersect1d(query_index, camera_index) 115 | junk_index2 = np.setdiff1d(query_index, cloth_index) 116 | junk_index = np.union1d(junk_index1, junk_index2) 117 | 118 | if good_index.size == 0: 119 | num_no_gt += 1 120 | continue 121 | 122 | ap_tmp, CMC_tmp = compute_ap_cmc(index[i], good_index, junk_index) 123 | if CMC_tmp[0]==1: 124 | num_r1 += 1 125 | CMC = CMC + CMC_tmp 126 | AP += ap_tmp 127 | 128 | if num_no_gt > 0: 129 | logger = logging.getLogger('reid.evaluate') 130 | logger.info("{} query samples do not have groundtruth.".format(num_no_gt)) 131 | 132 | if (num_q - num_no_gt) != 0: 133 | CMC = CMC / (num_q - num_no_gt) 134 | mAP = AP / (num_q - num_no_gt) 135 | else: 136 | mAP = 0 137 | 138 | return CMC, mAP -------------------------------------------------------------------------------- /tools/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import errno 5 | import json 6 | import os.path as osp 7 | import torch 8 | import random 9 | import logging 10 | import numpy as np 11 | 12 | 13 | def set_seed(seed=None): 14 | if seed is None: 15 | return 16 | random.seed(seed) 17 | os.environ['PYTHONHASHSEED'] = ("%s" % seed) 18 | np.random.seed(seed) 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | torch.backends.cudnn.benchmark = False 23 | torch.backends.cudnn.deterministic = True 24 | 25 | 26 | def mkdir_if_missing(directory): 27 | if not osp.exists(directory): 28 | try: 29 | os.makedirs(directory) 30 | except OSError as e: 31 | if e.errno != errno.EEXIST: 32 | raise 33 | 34 | 35 | def read_json(fpath): 36 | with open(fpath, 'r') as f: 37 | obj = json.load(f) 38 | return obj 39 | 40 | 41 | def write_json(obj, fpath): 42 | mkdir_if_missing(osp.dirname(fpath)) 43 | with open(fpath, 'w') as f: 44 | json.dump(obj, f, indent=4, separators=(',', ': ')) 45 | 46 | 47 | class AverageMeter(object): 48 | """Computes and stores the average and current value. 49 | 50 | Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 51 | """ 52 | def __init__(self): 53 | self.reset() 54 | 55 | def reset(self): 56 | self.val = 0 57 | self.avg = 0 58 | self.sum = 0 59 | self.count = 0 60 | 61 | def update(self, val, n=1): 62 | self.val = val 63 | self.sum += val * n 64 | self.count += n 65 | self.avg = self.sum / self.count 66 | 67 | 68 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 69 | mkdir_if_missing(osp.dirname(fpath)) 70 | torch.save(state, fpath) 71 | if is_best: 72 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar')) 73 | 74 | ''' 75 | class Logger(object): 76 | """ 77 | Write console output to external text file. 78 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 79 | """ 80 | def __init__(self, fpath=None): 81 | self.console = sys.stdout 82 | self.file = None 83 | if fpath is not None: 84 | mkdir_if_missing(os.path.dirname(fpath)) 85 | self.file = open(fpath, 'w') 86 | 87 | def __del__(self): 88 | self.close() 89 | 90 | def __enter__(self): 91 | pass 92 | 93 | def __exit__(self, *args): 94 | self.close() 95 | 96 | def write(self, msg): 97 | self.console.write(msg) 98 | if self.file is not None: 99 | self.file.write(msg) 100 | 101 | def flush(self): 102 | self.console.flush() 103 | if self.file is not None: 104 | self.file.flush() 105 | os.fsync(self.file.fileno()) 106 | 107 | def close(self): 108 | self.console.close() 109 | if self.file is not None: 110 | self.file.close() 111 | ''' 112 | 113 | 114 | def get_logger(fpath, local_rank=0, name=''): 115 | # Creat logger 116 | logger = logging.getLogger(name) 117 | level = logging.INFO if local_rank in [-1, 0] else logging.WARN 118 | logger.setLevel(level=level) 119 | 120 | # Output to console 121 | console_handler = logging.StreamHandler(sys.stdout) 122 | console_handler.setLevel(level=level) 123 | console_handler.setFormatter(logging.Formatter('%(message)s')) 124 | logger.addHandler(console_handler) 125 | 126 | # Output to file 127 | if fpath is not None: 128 | mkdir_if_missing(os.path.dirname(fpath)) 129 | file_handler = logging.FileHandler(fpath, mode='w') 130 | file_handler.setLevel(level=level) 131 | file_handler.setFormatter(logging.Formatter('%(message)s')) 132 | logger.addHandler(file_handler) 133 | 134 | return logger -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import reduction 2 | import time 3 | import datetime 4 | import logging 5 | import torch 6 | from apex import amp 7 | from tools.utils import AverageMeter 8 | 9 | 10 | def train_aim(config, epoch, model, model2, classifier, clothes_classifier, clothes_classifier2, fuse, criterion_cla, criterion_pair, 11 | criterion_clothes, criterion_adv, optimizer, optimizer2, optimizer_cc, trainloader, pid2clothes, kl): 12 | logger = logging.getLogger('reid.train') 13 | batch_cla_loss = AverageMeter() 14 | batch_pair_loss = AverageMeter() 15 | batch_clo_loss = AverageMeter() 16 | batch_adv_loss = AverageMeter() 17 | batch_clothes_loss2 = AverageMeter() 18 | batch_loss2 = AverageMeter() 19 | batch_kl_loss = AverageMeter() 20 | corrects = AverageMeter() 21 | corrects2 = AverageMeter() 22 | corrects3 = AverageMeter() 23 | clothes_corrects = AverageMeter() 24 | clothes_corrects2 = AverageMeter() 25 | batch_time = AverageMeter() 26 | data_time = AverageMeter() 27 | 28 | model.train() 29 | model2.train() 30 | fuse.train() 31 | classifier.train() 32 | clothes_classifier.train() 33 | clothes_classifier2.train() 34 | 35 | end = time.time() 36 | for batch_idx, (imgs, pids, camids, clothes_ids, img_path) in enumerate(trainloader): 37 | # Get all positive clothes classes (belonging to the same identity) for each sample 38 | pos_mask = pid2clothes[pids] 39 | imgs, pids, clothes_ids, pos_mask = imgs.cuda(), pids.cuda(), clothes_ids.cuda(), pos_mask.float().cuda() 40 | # Measure data loading timeq 41 | data_time.update(time.time() - end) 42 | # Forward 43 | pri_feat, features = model(imgs) # torch.size([32,4096]) 44 | pri_feat2, features2 = model2(imgs) 45 | 46 | pri_feat2 = pri_feat2.clone().detach() 47 | features_fuse = fuse(pri_feat, pri_feat2) 48 | 49 | outputs = classifier(features) 50 | outputs2 = clothes_classifier2(features2) 51 | outputs3 = classifier(features_fuse) # clothes score on id classifier 52 | 53 | # new_pred_clothes2 = clothes_classifier2(features2) 54 | # loss2 = criterion_adv(new_pred_clothes2, clothes_ids, pos_mask) 55 | 56 | pred_clothes = clothes_classifier(features.detach()) # no grad 57 | 58 | _, preds = torch.max(outputs.data, 1) # return (max_value, index), 1 indicates dim=1 59 | _, preds3 = torch.max(outputs3.data, 1) 60 | 61 | # Update the clothes discriminator 62 | clothes_loss = criterion_clothes(pred_clothes, clothes_ids) 63 | if epoch >= config.TRAIN.START_EPOCH_CC: 64 | optimizer_cc.zero_grad() 65 | if config.TRAIN.AMP: 66 | with amp.scale_loss(clothes_loss, optimizer_cc) as scaled_loss: 67 | scaled_loss.backward() 68 | else: 69 | clothes_loss.backward() 70 | optimizer_cc.step() 71 | 72 | # Update the backbone 73 | new_pred_clothes = clothes_classifier(features) 74 | _, clothes_preds = torch.max(new_pred_clothes.data, 1) 75 | 76 | _, pred_clothes2 = torch.max(outputs2.data, 1) 77 | # outputs2_no_grad = clothes_classifier2(features2.detach()) 78 | 79 | Q = new_pred_clothes.clone().detach() 80 | P = outputs2.clone() 81 | Q = torch.nn.functional.softmax(Q, dim=-1) 82 | P = torch.nn.functional.softmax(P, dim=-1) 83 | 84 | # Update the clothes discriminator 2 85 | 86 | clothes_loss2 = criterion_clothes(outputs2, clothes_ids) 87 | 88 | kl_loss = kl(torch.log(Q), P, reduction='sum') + kl(torch.log(P), Q, reduction='sum') 89 | 90 | if epoch >= config.TRAIN.START_EPOCH_CC: 91 | loss2 = clothes_loss2 + config.k_kl * kl_loss 92 | else: 93 | loss2 = clothes_loss2 94 | 95 | optimizer2.zero_grad() 96 | if config.TRAIN.AMP: 97 | with amp.scale_loss(loss2, optimizer2) as scaled_loss2: 98 | scaled_loss2.backward() 99 | else: 100 | loss2.backward() 101 | optimizer2.step() 102 | 103 | GENERAL_EPOCH = config.TRAIN.START_EPOCH_ADV 104 | 105 | # Compute loss 106 | if epoch >= GENERAL_EPOCH: 107 | cla_loss = criterion_cla(outputs, pids) + config.k_cal * criterion_cla(outputs - outputs3, pids) 108 | else: 109 | cla_loss = criterion_cla(outputs, pids) 110 | pair_loss = criterion_pair(features, pids) 111 | adv_loss = criterion_adv(new_pred_clothes, clothes_ids, pos_mask) 112 | 113 | if epoch >= config.TRAIN.START_EPOCH_ADV: 114 | loss = cla_loss + adv_loss + config.LOSS.PAIR_LOSS_WEIGHT * pair_loss 115 | else: 116 | loss = cla_loss + config.LOSS.PAIR_LOSS_WEIGHT * pair_loss 117 | 118 | optimizer.zero_grad() 119 | if config.TRAIN.AMP: 120 | with amp.scale_loss(loss, optimizer) as scaled_loss: 121 | scaled_loss.backward() 122 | else: 123 | loss.backward() 124 | optimizer.step() 125 | 126 | # statistics 127 | corrects.update(torch.sum(preds == pids.data).float()/pids.size(0), pids.size(0)) 128 | corrects2.update(torch.sum(pred_clothes2 == clothes_ids.data).float()/clothes_ids.size(0), clothes_ids.size(0)) 129 | corrects3.update(torch.sum(preds3 == pids.data).float()/pids.size(0), pids.size(0)) 130 | clothes_corrects.update(torch.sum(clothes_preds == clothes_ids.data).float()/clothes_ids.size(0), clothes_ids.size(0)) 131 | clothes_corrects2.update(torch.sum(pred_clothes2 == clothes_ids.data).float()/clothes_ids.size(0), clothes_ids.size(0)) 132 | batch_cla_loss.update(cla_loss.item(), pids.size(0)) 133 | batch_pair_loss.update(pair_loss.item(), pids.size(0)) 134 | batch_clo_loss.update(clothes_loss.item(), clothes_ids.size(0)) 135 | batch_adv_loss.update(adv_loss.item(), clothes_ids.size(0)) 136 | batch_loss2.update(loss2.item(), clothes_ids.size(0)) 137 | batch_clothes_loss2.update(clothes_loss2.item(), clothes_ids.size(0)) 138 | batch_kl_loss.update(kl_loss.item(), clothes_ids.size(0)) 139 | 140 | # measure elapsed time 141 | batch_time.update(time.time() - end) 142 | end = time.time() 143 | 144 | logger.info('Epoch{0} ' 145 | 'Time:{batch_time.sum:.1f}s ' 146 | 'Data:{data_time.sum:.1f}s ' 147 | 'ClaLoss:{cla_loss.avg:.4f} ' 148 | 'PairLoss:{pair_loss.avg:.4f} ' 149 | 'CloLoss:{clo_loss.avg:.4f} ' 150 | 'AdvLoss:{adv_loss.avg:.4f} ' 151 | 'clothes_loss2:{clothes_loss2.avg:.4f} ' 152 | 'loss2:{loss2.avg:.4f} ' 153 | 'kl_loss:{kl_loss.avg:.4f} ' 154 | 'Acc:{acc.avg:.2%} ' 155 | 'Acc2:{acc2.avg:.2%} ' 156 | 'Acc3:{acc3.avg:.2%} ' 157 | 'CloAcc:{clo_acc.avg:.2%} ' 158 | 'Clo2Acc:{clo2_acc.avg:.2%} '.format( 159 | epoch+1, batch_time=batch_time, data_time=data_time, 160 | cla_loss=batch_cla_loss, pair_loss=batch_pair_loss, 161 | clo_loss=batch_clo_loss, adv_loss=batch_adv_loss, 162 | clothes_loss2=batch_clothes_loss2, 163 | loss2=batch_loss2, kl_loss=batch_kl_loss, 164 | acc=corrects, acc2=corrects2, acc3=corrects3, 165 | clo_acc=clothes_corrects, clo2_acc=clothes_corrects2)) --------------------------------------------------------------------------------