├── Figures
├── CVPR23.jpg
└── Framework.png
├── README.md
├── configs
├── default_img.py
├── default_img_single.py
└── res50_cels_cal.yaml
├── data
├── __init__.py
├── dataloader.py
├── dataset_loader.py
├── datasets
│ ├── ltcc.py
│ └── prcc.py
├── img_transforms.py
└── samplers.py
├── demo.sh
├── demo_single_image.py
├── losses
├── __init__.py
├── arcface_loss.py
├── circle_loss.py
├── clothes_based_adversarial_loss.py
├── contrastive_loss.py
├── cosface_loss.py
├── cross_entropy_loss_with_label_smooth.py
├── gather.py
└── triplet_loss.py
├── main.py
├── models
├── Fuse.py
├── Fusion.py
├── PM.py
├── ResNet.py
├── __init__.py
├── classifier.py
├── img_resnet.py
├── lr_scheduler.py
└── utils
│ ├── c3d_blocks.py
│ ├── inflate.py
│ ├── nonlocal_blocks.py
│ └── pooling.py
├── test.py
├── test_AIM.sh
├── tools
├── eval_metrics.py
└── utils.py
└── train.py
/Figures/CVPR23.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BoomShakaY/AIM-CCReID/c3bda2b54a3c5d81eb65ea838ae3502aecd61b67/Figures/CVPR23.jpg
--------------------------------------------------------------------------------
/Figures/Framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BoomShakaY/AIM-CCReID/c3bda2b54a3c5d81eb65ea838ae3502aecd61b67/Figures/Framework.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AIM-CCReID
2 | An official implementation of our CVPR'23 paper:
3 |
4 | 《Good is Bad: Causality Inspired Cloth-Debiasing for Cloth-Changing Person Re-Identification》
5 |
6 | [\[Paper Link\]](https://openaccess.thecvf.com/content/CVPR2023/papers/Yang_Good_Is_Bad_Causality_Inspired_Cloth-Debiasing_for_Cloth-Changing_Person_Re-Identification_CVPR_2023_paper.pdf)
7 | [\[Repo\]](https://github.com/BoomShakaY/AIM-CCReID)
8 | [\[About Me\]](https://gavinyoung1.github.io/)
9 |
10 |
11 |
12 |

13 |

14 |
15 |
16 | #### News
17 | 2023.10.24 The full codes are released!
18 |
19 | #### Requirements
20 | - Python 3.6
21 | - Pytorch 1.6.0
22 | - yacs
23 | - apex
24 | (remind: the apex is optional, not recommended if you have enough GPU memory; just comment all amp related codes)
25 |
26 | ## Performance of AIM
27 |
28 |
29 | |
30 | RRCC |
31 | LTCC |
32 |
33 |
34 | |
35 | Standard |
36 | Cloth-Changing |
37 | Standard |
38 | Cloth-Changing |
39 |
40 |
41 | |
42 | R@1 |
43 | mAP |
44 | R@1 |
45 | mAP |
46 | R@1 |
47 | mAP |
48 | R@1 |
49 | mAP |
50 |
51 |
52 | Paper |
53 | 100.0 |
54 | 99.9 |
55 | 57.9 |
56 | 58.3 |
57 | 76.3 |
58 | 41.1 |
59 | 40.6 |
60 | 19.1 |
61 |
62 |
63 | Repo |
64 | 100.0 |
65 | 99.8 |
66 | 58.2 |
67 | 58.0 |
68 | 75.9 |
69 | 41.7 |
70 | 40.8 |
71 | 19.2 |
72 |
73 |
74 | The indicators provided in this repo are broadly the same as those in the paper, and possibly even better (depending on what your focus is)
75 |
76 | ## Datasets
77 | PRCC is available at [Here](https://drive.google.com/file/d/1yTYawRm4ap3M-j0PjLQJ--xmZHseFDLz/view).
78 |
79 | LTCC is available at [Here](https://naiq.github.io/LTCC_Perosn_ReID.html).
80 |
81 | LaST is available at [Here](https://github.com/shuxjweb/last).
82 |
83 | ## Testing
84 | The trained models (weights) are available at [Baidu Disk](https://pan.baidu.com/s/1Du1XgoCim6I_bZtNRm3yPw?pwd=v4ly) or [Google Drive](https://drive.google.com/drive/folders/1xohg_OAHjNyy7LLq3Fq_KowcEP9IlY8k?usp=sharing).
85 | You will find the testing script for prcc and ltcc at `test_AIM.sh`, then modify the resume path to your own path where you placed the weights file.
86 |
87 | To be noticed, you need to modify the `DATA ROOT` and `OUTPUT` in the `configs/default_img.py` to your own path before testing.
88 |
89 | ## 📖 Citation
90 |
91 | If you find our work useful in your research, please consider citing:
92 |
93 | ```bibtex
94 | @inproceedings{yang2023good,
95 | title={Good is bad: Causality inspired cloth-debiasing for cloth-changing person re-identification},
96 | author={Yang, Zhengwei and Lin, Meng and Zhong, Xian and Wu, Yu and Wang, Zheng},
97 | booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
98 | pages={1472--1481},
99 | year={2023}
100 | }
101 |
--------------------------------------------------------------------------------
/configs/default_img.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | from yacs.config import CfgNode as CN
4 | import time
5 |
6 |
7 | _C = CN()
8 | # -----------------------------------------------------------------------------
9 | # Data settings
10 | # -----------------------------------------------------------------------------
11 | _C.DATA = CN()
12 | # Root path for dataset directory
13 | _C.DATA.ROOT = 'DATA_ROOT'
14 | # Dataset for evaluation
15 | _C.DATA.DATASET = 'ltcc'
16 | # Workers for dataloader
17 | _C.DATA.NUM_WORKERS = 8
18 | # Height of input image
19 | _C.DATA.HEIGHT = 384
20 | # Width of input image
21 | _C.DATA.WIDTH = 192
22 | # Batch size for training
23 | _C.DATA.TRAIN_BATCH = 32
24 | # Batch size for testing
25 | _C.DATA.TEST_BATCH = 128
26 | # The number of instances per identity for training sampler
27 | _C.DATA.NUM_INSTANCES = 8
28 | # -----------------------------------------------------------------------------
29 | # Augmentation settings
30 | # -----------------------------------------------------------------------------
31 | _C.AUG = CN()
32 | # Random crop prob
33 | _C.AUG.RC_PROB = 0.5
34 | # Random erase prob
35 | _C.AUG.RE_PROB = 0.5
36 | # Random flip prob
37 | _C.AUG.RF_PROB = 0.5
38 | # -----------------------------------------------------------------------------
39 | # Model settings
40 | # -----------------------------------------------------------------------------
41 | _C.MODEL = CN()
42 | # Model name
43 | _C.MODEL.NAME = 'resnet50'
44 | # The stride for laery4 in resnet
45 | _C.MODEL.RES4_STRIDE = 1
46 | # feature dim
47 | _C.MODEL.FEATURE_DIM = 4096
48 | # Model path for resuming
49 | _C.MODEL.RESUME = ''
50 | # Global pooling after the backbone
51 | _C.MODEL.POOLING = CN()
52 | # Choose in ['avg', 'max', 'gem', 'maxavg']
53 | _C.MODEL.POOLING.NAME = 'maxavg'
54 | # Initialized power for GeM pooling
55 | _C.MODEL.POOLING.P = 3
56 | # -----------------------------------------------------------------------------
57 | # Losses for training
58 | # -----------------------------------------------------------------------------
59 | _C.LOSS = CN()
60 | # Classification loss
61 | _C.LOSS.CLA_LOSS = 'crossentropylabelsmooth'
62 | # Clothes classification loss
63 | _C.LOSS.CLOTHES_CLA_LOSS = 'cosface'
64 | # Scale for classification loss
65 | _C.LOSS.CLA_S = 16.
66 | # Margin for classification loss
67 | _C.LOSS.CLA_M = 0.
68 | # Clothes-based adversarial loss
69 | _C.LOSS.CAL = 'cal'
70 | # Epsilon for clothes-based adversarial loss
71 | _C.LOSS.EPSILON = 0.1
72 | # Momentum for clothes-based adversarial loss with memory bank
73 | _C.LOSS.MOMENTUM = 0.
74 | # -----------------------------------------------------------------------------
75 | # Training settings
76 | # -----------------------------------------------------------------------------
77 | _C.TRAIN = CN()
78 | _C.TRAIN.START_EPOCH = 0
79 | _C.TRAIN.MAX_EPOCH = 100
80 | # Start epoch for clothes classification
81 | _C.TRAIN.START_EPOCH_CC = 25
82 | # Start epoch for adversarial training
83 | _C.TRAIN.START_EPOCH_ADV = 25
84 | # Start epoch for debias
85 | _C.TRAIN.START_EPOCH_GENERAL = 25
86 | # Optimizer
87 | _C.TRAIN.OPTIMIZER = CN()
88 | _C.TRAIN.OPTIMIZER.NAME = 'adam'
89 | # Learning rate
90 | _C.TRAIN.OPTIMIZER.LR = 0.00035
91 | _C.TRAIN.OPTIMIZER.WEIGHT_DECAY = 5e-4
92 | # LR scheduler
93 | _C.TRAIN.LR_SCHEDULER = CN()
94 | # Stepsize to decay learning rate
95 | _C.TRAIN.LR_SCHEDULER.STEPSIZE = [20, 40]
96 | # LR decay rate, used in StepLRScheduler
97 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
98 | # -----------------------------------------------------------------------------
99 | # Testing settings
100 | # -----------------------------------------------------------------------------
101 | _C.TEST = CN()
102 | # Perform evaluation after every N epochs (set to -1 to test after training)
103 | _C.TEST.EVAL_STEP = 5
104 | # Start to evaluate after specific epoch
105 | _C.TEST.START_EVAL = 0
106 | # -----------------------------------------------------------------------------
107 | # Misc
108 | # -----------------------------------------------------------------------------
109 | # Fixed random seed
110 | _C.SEED = 1
111 | # Perform evaluation only
112 | _C.EVAL_MODE = False
113 | # GPU device ids for CUDA_VISIBLE_DEVICES
114 | _C.GPU = '0'
115 | # Path to output folder, overwritten by command line argument
116 | _C.OUTPUT = 'OUTPUT_PATH'
117 | # Tag of experiment, overwritten by command line argument
118 | _C.TAG = 'eval'
119 | # -----------------------------------------------------------------------------
120 | # Hyperparameters
121 | _C.k_cal = 1.0
122 | _C.k_kl = 1.0
123 | # -----------------------------------------------------------------------------
124 |
125 | def update_config(config, args):
126 | config.defrost()
127 | config.merge_from_file(args.cfg)
128 |
129 | # merge from specific arguments
130 | if args.root:
131 | config.DATA.ROOT = args.root
132 | if args.output:
133 | config.OUTPUT = args.output
134 | if args.resume:
135 | config.MODEL.RESUME = args.resume
136 | if args.eval:
137 | config.EVAL_MODE = True
138 | if args.tag:
139 | config.TAG = args.tag
140 | if args.dataset:
141 | config.DATA.DATASET = args.dataset
142 | if args.gpu:
143 | config.GPU = args.gpu
144 |
145 | # output folder
146 | config.OUTPUT = os.path.join(config.OUTPUT, config.DATA.DATASET, config.TAG)
147 | config.freeze()
148 |
149 |
150 | def get_img_config(args):
151 | """Get a yacs CfgNode object with default values."""
152 | config = _C.clone()
153 | update_config(config, args)
154 |
155 | return config
156 |
--------------------------------------------------------------------------------
/configs/default_img_single.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | from yacs.config import CfgNode as CN
4 |
5 |
6 | _C = CN()
7 | # -----------------------------------------------------------------------------
8 | # Data settings
9 | # -----------------------------------------------------------------------------
10 | _C.DATA = CN()
11 | # Root path for dataset directory
12 | _C.DATA.ROOT = 'DATA_ROOT'
13 | # Dataset for evaluation
14 | _C.DATA.DATASET = 'ltcc'
15 | # Workers for dataloader
16 | _C.DATA.NUM_WORKERS = 4
17 | # Height of input image
18 | _C.DATA.HEIGHT = 384
19 | # Width of input image
20 | _C.DATA.WIDTH = 192
21 | # Batch size for training
22 | _C.DATA.TRAIN_BATCH = 32 # org:32
23 | # Batch size for testing
24 | _C.DATA.TEST_BATCH = 128 # org:128
25 | # The number of instances per identity for training sampler
26 | _C.DATA.NUM_INSTANCES = 8
27 | # -----------------------------------------------------------------------------
28 | # Augmentation settings
29 | # -----------------------------------------------------------------------------
30 | _C.AUG = CN()
31 | # Random crop prob
32 | _C.AUG.RC_PROB = 0.5
33 | # Random erase prob
34 | _C.AUG.RE_PROB = 0.5
35 | # Random flip prob
36 | _C.AUG.RF_PROB = 0.5
37 | # -----------------------------------------------------------------------------
38 | # Model settings
39 | # -----------------------------------------------------------------------------
40 | _C.MODEL = CN()
41 | # Model name
42 | _C.MODEL.NAME = 'resnet50'
43 | # The stride for laery4 in resnet
44 | _C.MODEL.RES4_STRIDE = 1
45 | # feature dim
46 | _C.MODEL.FEATURE_DIM = 4096 # orgin: 4096
47 | # Model path for resuming
48 | _C.MODEL.RESUME = ''
49 | # Global pooling after the backbone
50 | _C.MODEL.POOLING = CN()
51 | # Choose in ['avg', 'max', 'gem', 'maxavg']
52 | _C.MODEL.POOLING.NAME = 'maxavg' # orgin: maxavg
53 | # Initialized power for GeM pooling
54 | _C.MODEL.POOLING.P = 3
55 | # -----------------------------------------------------------------------------
56 | # Model2 settings
57 | # -----------------------------------------------------------------------------
58 | _C.MODEL2 = CN()
59 | _C.MODEL2.NAME = 'hpm'
60 | # -----------------------------------------------------------------------------
61 | # Losses for training
62 | # -----------------------------------------------------------------------------
63 | _C.LOSS = CN()
64 | # Classification loss
65 | _C.LOSS.CLA_LOSS = 'crossentropy'
66 | # Clothes classification loss
67 | _C.LOSS.CLOTHES_CLA_LOSS = 'cosface'
68 | # Scale for classification loss
69 | _C.LOSS.CLA_S = 16.
70 | # Margin for classification loss
71 | _C.LOSS.CLA_M = 0.
72 | # Pairwise loss
73 | _C.LOSS.PAIR_LOSS = 'triplet'
74 | # The weight for pairwise loss
75 | _C.LOSS.PAIR_LOSS_WEIGHT = 0.0
76 | # Scale for pairwise loss
77 | _C.LOSS.PAIR_S = 16.
78 | # Margin for pairwise loss
79 | _C.LOSS.PAIR_M = 0.3
80 | # Clothes-based adversarial loss
81 | _C.LOSS.CAL = 'cal'
82 | # Epsilon for clothes-based adversarial loss
83 | _C.LOSS.EPSILON = 0.1
84 | # Momentum for clothes-based adversarial loss with memory bank
85 | _C.LOSS.MOMENTUM = 0.
86 | # -----------------------------------------------------------------------------
87 | # Training settings
88 | # -----------------------------------------------------------------------------
89 | _C.TRAIN = CN()
90 | _C.TRAIN.START_EPOCH = 0
91 | _C.TRAIN.MAX_EPOCH = 70
92 | # Start epoch for clothes classification
93 | _C.TRAIN.START_EPOCH_CC = 25 # org:25
94 | # Start epoch for adversarial training
95 | _C.TRAIN.START_EPOCH_ADV = 25 # org:25
96 | # Optimizer
97 | _C.TRAIN.OPTIMIZER = CN()
98 | _C.TRAIN.OPTIMIZER.NAME = 'adam'
99 | # Learning rate
100 | _C.TRAIN.OPTIMIZER.LR = 0.00035
101 | _C.TRAIN.OPTIMIZER.WEIGHT_DECAY = 5e-4
102 | # LR scheduler
103 | _C.TRAIN.LR_SCHEDULER = CN()
104 | # Stepsize to decay learning rate
105 | _C.TRAIN.LR_SCHEDULER.STEPSIZE = [20, 40] #, 60]
106 | # LR decay rate, used in StepLRScheduler
107 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
108 | # Using amp for training
109 | _C.TRAIN.AMP = False
110 | # -----------------------------------------------------------------------------
111 | # Testing settings
112 | # -----------------------------------------------------------------------------
113 | _C.TEST = CN()
114 | # Perform evaluation after every N epochs (set to -1 to test after training)
115 | _C.TEST.EVAL_STEP = 5
116 | # Start to evaluate after specific epoch
117 | _C.TEST.START_EVAL = 0
118 | # -----------------------------------------------------------------------------
119 | # Misc
120 | # -----------------------------------------------------------------------------
121 | # Fixed random seed
122 | _C.SEED = 1
123 | # Perform evaluation only
124 | _C.EVAL_MODE = False
125 | # GPU device ids for CUDA_VISIBLE_DEVICES
126 | _C.GPU = '0'
127 | # Path to output folder, overwritten by command line argument
128 | _C.OUTPUT = 'OUTPUT_PATH'
129 | # Tag of experiment, overwritten by command line argument
130 | _C.TAG = 'eval'
131 | # -----------------------------------------------------------------------------
132 | # Hyperparameters
133 | _C.k_cal = 1
134 | _C.k_kl = 1
135 |
136 |
137 | def update_config(config, args):
138 | config.defrost()
139 | config.merge_from_file(args.cfg)
140 |
141 | # output folder
142 | config.OUTPUT = os.path.join(config.OUTPUT, config.DATA.DATASET)
143 | config.freeze()
144 |
145 |
146 | def get_img_config(args):
147 | """Get a yacs CfgNode object with default values."""
148 | config = _C.clone()
149 | update_config(config, args)
150 |
151 | return config
152 |
--------------------------------------------------------------------------------
/configs/res50_cels_cal.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | NAME: resnet50
3 | LOSS:
4 | CLA_LOSS: crossentropylabelsmooth
5 | CAL: cal
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | import data.img_transforms as T
2 | from data.dataloader import DataLoaderX
3 | from data.dataset_loader import ImageDataset
4 | from data.samplers import DistributedRandomIdentitySampler, DistributedInferenceSampler
5 | from data.datasets.ltcc import LTCC
6 | from data.datasets.prcc import PRCC
7 |
8 | __factory = {
9 | 'ltcc': LTCC,
10 | 'prcc': PRCC,
11 | }
12 |
13 | def get_names():
14 | return list(__factory.keys())
15 |
16 |
17 | def build_dataset(config):
18 | if config.DATA.DATASET not in __factory.keys():
19 | raise KeyError("Invalid dataset, got '{}', but expected to be one of {}".format(config.DATA.DATASET, __factory.keys()))
20 |
21 | dataset = __factory[config.DATA.DATASET](root=config.DATA.ROOT)
22 |
23 | return dataset
24 |
25 |
26 | def build_img_transforms(config):
27 | transform_train = T.Compose([
28 | T.Resize((config.DATA.HEIGHT, config.DATA.WIDTH)),
29 | T.RandomCroping(p=config.AUG.RC_PROB),
30 | T.RandomHorizontalFlip(p=config.AUG.RF_PROB),
31 | T.ToTensor(),
32 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
33 | T.RandomErasing(probability=config.AUG.RE_PROB)
34 | ])
35 | transform_test = T.Compose([
36 | T.Resize((config.DATA.HEIGHT, config.DATA.WIDTH)),
37 | T.ToTensor(),
38 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
39 | ])
40 |
41 | return transform_train, transform_test
42 |
43 |
44 | def build_dataloader(config):
45 | dataset = build_dataset(config)
46 | transform_train, transform_test = build_img_transforms(config)
47 | train_sampler = DistributedRandomIdentitySampler(dataset.train,
48 | num_instances=config.DATA.NUM_INSTANCES,
49 | seed=config.SEED)
50 | trainloader = DataLoaderX(dataset=ImageDataset(dataset.train, transform=transform_train),
51 | sampler=train_sampler,
52 | batch_size=config.DATA.TRAIN_BATCH, num_workers=config.DATA.NUM_WORKERS,
53 | pin_memory=True, drop_last=True)
54 |
55 | galleryloader = DataLoaderX(dataset=ImageDataset(dataset.gallery, transform=transform_test),
56 | sampler=DistributedInferenceSampler(dataset.gallery),
57 | batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS,
58 | pin_memory=True, drop_last=False, shuffle=False)
59 |
60 | if config.DATA.DATASET == 'prcc':
61 | queryloader_same = DataLoaderX(dataset=ImageDataset(dataset.query_same, transform=transform_test),
62 | sampler=DistributedInferenceSampler(dataset.query_same),
63 | batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS,
64 | pin_memory=True, drop_last=False, shuffle=False)
65 | queryloader_diff = DataLoaderX(dataset=ImageDataset(dataset.query_diff, transform=transform_test),
66 | sampler=DistributedInferenceSampler(dataset.query_diff),
67 | batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS,
68 | pin_memory=True, drop_last=False, shuffle=False)
69 |
70 | return trainloader, queryloader_same, queryloader_diff, galleryloader, dataset, train_sampler
71 | else:
72 | queryloader = DataLoaderX(dataset=ImageDataset(dataset.query, transform=transform_test),
73 | sampler=DistributedInferenceSampler(dataset.query),
74 | batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS,
75 | pin_memory=True, drop_last=False, shuffle=False)
76 |
77 | return trainloader, queryloader, galleryloader, dataset, train_sampler
--------------------------------------------------------------------------------
/data/dataloader.py:
--------------------------------------------------------------------------------
1 | # refer to: https://github.com/JDAI-CV/fast-reid/blob/master/fastreid/data/data_utils.py
2 |
3 | import torch
4 | import threading
5 | import queue
6 | from torch.utils.data import DataLoader
7 | from torch import distributed as dist
8 |
9 |
10 | """
11 | #based on http://stackoverflow.com/questions/7323664/python-generator-pre-fetch
12 | This is a single-function package that transforms arbitrary generator into a background-thead generator that
13 | prefetches several batches of data in a parallel background thead.
14 |
15 | This is useful if you have a computationally heavy process (CPU or GPU) that
16 | iteratively processes minibatches from the generator while the generator
17 | consumes some other resource (disk IO / loading from database / more CPU if you have unused cores).
18 |
19 | By default these two processes will constantly wait for one another to finish. If you make generator work in
20 | prefetch mode (see examples below), they will work in parallel, potentially saving you your GPU time.
21 | We personally use the prefetch generator when iterating minibatches of data for deep learning with PyTorch etc.
22 |
23 | Quick usage example (ipython notebook) - https://github.com/justheuristic/prefetch_generator/blob/master/example.ipynb
24 | This package contains this object
25 | - BackgroundGenerator(any_other_generator[,max_prefetch = something])
26 | """
27 |
28 |
29 | class BackgroundGenerator(threading.Thread):
30 | """
31 | the usage is below
32 | >> for batch in BackgroundGenerator(my_minibatch_iterator):
33 | >> doit()
34 | More details are written in the BackgroundGenerator doc
35 | >> help(BackgroundGenerator)
36 | """
37 |
38 | def __init__(self, generator, local_rank, max_prefetch=10):
39 | """
40 | This function transforms generator into a background-thead generator.
41 | :param generator: generator or genexp or any
42 | It can be used with any minibatch generator.
43 |
44 | It is quite lightweight, but not entirely weightless.
45 | Using global variables inside generator is not recommended (may raise GIL and zero-out the
46 | benefit of having a background thread.)
47 | The ideal use case is when everything it requires is store inside it and everything it
48 | outputs is passed through queue.
49 |
50 | There's no restriction on doing weird stuff, reading/writing files, retrieving
51 | URLs [or whatever] wlilst iterating.
52 |
53 | :param max_prefetch: defines, how many iterations (at most) can background generator keep
54 | stored at any moment of time.
55 | Whenever there's already max_prefetch batches stored in queue, the background process will halt until
56 | one of these batches is dequeued.
57 |
58 | !Default max_prefetch=1 is okay unless you deal with some weird file IO in your generator!
59 |
60 | Setting max_prefetch to -1 lets it store as many batches as it can, which will work
61 | slightly (if any) faster, but will require storing
62 | all batches in memory. If you use infinite generator with max_prefetch=-1, it will exceed the RAM size
63 | unless dequeued quickly enough.
64 | """
65 | super().__init__()
66 | self.queue = queue.Queue(max_prefetch)
67 | self.generator = generator
68 | self.local_rank = local_rank
69 | self.daemon = True
70 | self.exit_event = threading.Event()
71 | self.start()
72 |
73 | def run(self):
74 | torch.cuda.set_device(self.local_rank)
75 | for item in self.generator:
76 | if self.exit_event.is_set():
77 | break
78 | self.queue.put(item)
79 | self.queue.put(None)
80 |
81 | def next(self):
82 | next_item = self.queue.get()
83 | if next_item is None:
84 | raise StopIteration
85 | return next_item
86 |
87 | # Python 3 compatibility
88 | def __next__(self):
89 | return self.next()
90 |
91 | def __iter__(self):
92 | return self
93 |
94 |
95 | class DataLoaderX(DataLoader):
96 | def __init__(self, **kwargs):
97 | super().__init__(**kwargs)
98 | local_rank = dist.get_rank()
99 | self.stream = torch.cuda.Stream(local_rank) # create a new cuda stream in each process
100 | self.local_rank = local_rank
101 |
102 | def __iter__(self):
103 | self.iter = super().__iter__()
104 | self.iter = BackgroundGenerator(self.iter, self.local_rank)
105 | self.preload()
106 | return self
107 |
108 | def _shutdown_background_thread(self):
109 | if not self.iter.is_alive():
110 | # avoid re-entrance or ill-conditioned thread state
111 | return
112 |
113 | # Set exit event to True for background threading stopping
114 | self.iter.exit_event.set()
115 |
116 | # Exhaust all remaining elements, so that the queue becomes empty,
117 | # and the thread should quit
118 | for _ in self.iter:
119 | pass
120 |
121 | # Waiting for background thread to quit
122 | self.iter.join()
123 |
124 | def preload(self):
125 | self.batch = next(self.iter, None)
126 | if self.batch is None:
127 | return None
128 | with torch.cuda.stream(self.stream):
129 | # if isinstance(self.batch[0], torch.Tensor):
130 | # self.batch[0] = self.batch[0].to(device=self.local_rank, non_blocking=True)
131 | for k, v in enumerate(self.batch):
132 | if isinstance(self.batch[k], torch.Tensor):
133 | self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True)
134 |
135 | def __next__(self):
136 | torch.cuda.current_stream().wait_stream(
137 | self.stream
138 | ) # wait tensor to put on GPU
139 | batch = self.batch
140 | if batch is None:
141 | raise StopIteration
142 | self.preload()
143 | return batch
144 |
145 | # Signal for shutting down background thread
146 | def shutdown(self):
147 | # If the dataloader is to be freed, shutdown its BackgroundGenerator
148 | self._shutdown_background_thread()
149 |
--------------------------------------------------------------------------------
/data/dataset_loader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import functools
3 | import os.path as osp
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 |
7 |
8 | def read_image(img_path):
9 | """Keep reading image until succeed.
10 | This can avoid IOError incurred by heavy IO process."""
11 | got_img = False
12 | if not osp.exists(img_path):
13 | raise IOError("{} does not exist".format(img_path))
14 | while not got_img:
15 | try:
16 | img = Image.open(img_path).convert('RGB')
17 | got_img = True
18 | except IOError:
19 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
20 | pass
21 | return img
22 |
23 |
24 | class ImageDataset(Dataset):
25 | """Image Person ReID Dataset"""
26 | def __init__(self, dataset, transform=None):
27 | self.dataset = dataset
28 | self.transform = transform
29 |
30 | def __len__(self):
31 | return len(self.dataset)
32 |
33 | def __getitem__(self, index):
34 | img_path, pid, camid, clothes_id = self.dataset[index]
35 | img = read_image(img_path)
36 | if self.transform is not None:
37 | img = self.transform(img)
38 | return img, pid, camid, clothes_id, img_path
39 |
40 |
41 | def pil_loader(path):
42 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
43 | with open(path, 'rb') as f:
44 | with Image.open(f) as img:
45 | return img.convert('RGB')
46 |
47 |
48 | def accimage_loader(path):
49 | try:
50 | import accimage
51 | return accimage.Image(path)
52 | except IOError:
53 | # Potentially a decoding problem, fall back to PIL.Image
54 | return pil_loader(path)
55 |
56 |
57 | def get_default_image_loader():
58 | from torchvision import get_image_backend
59 | if get_image_backend() == 'accimage':
60 | return accimage_loader
61 | else:
62 | return pil_loader
63 |
64 |
65 | def image_loader(path):
66 | from torchvision import get_image_backend
67 | if get_image_backend() == 'accimage':
68 | return accimage_loader(path)
69 | else:
70 | return pil_loader(path)
71 |
--------------------------------------------------------------------------------
/data/datasets/ltcc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import glob
4 | import h5py
5 | import random
6 | import math
7 | import logging
8 | import numpy as np
9 | import os.path as osp
10 | from scipy.io import loadmat
11 | from tools.utils import mkdir_if_missing, write_json, read_json
12 |
13 |
14 | class LTCC(object):
15 | """ LTCC
16 |
17 | Reference:
18 | Qian et al. Long-Term Cloth-Changing Person Re-identification. arXiv:2005.12633, 2020.
19 |
20 | URL: https://naiq.github.io/LTCC_Perosn_ReID.html#
21 | """
22 | dataset_dir = 'LTCC_ReID'
23 | def __init__(self, root='data', **kwargs):
24 | self.dataset_dir = osp.join(root, self.dataset_dir)
25 | self.train_dir = osp.join(self.dataset_dir, 'train')
26 | self.query_dir = osp.join(self.dataset_dir, 'query')
27 | self.gallery_dir = osp.join(self.dataset_dir, 'test')
28 | self._check_before_run()
29 |
30 | train, num_train_pids, num_train_imgs, num_train_clothes, pid2clothes = \
31 | self._process_dir_train(self.train_dir)
32 | query, gallery, num_test_pids, num_query_imgs, num_gallery_imgs, num_test_clothes = \
33 | self._process_dir_test(self.query_dir, self.gallery_dir)
34 | num_total_pids = num_train_pids + num_test_pids
35 | num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs
36 | num_test_imgs = num_query_imgs + num_gallery_imgs
37 | num_total_clothes = num_train_clothes + num_test_clothes
38 |
39 | logger = logging.getLogger('reid.dataset')
40 | logger.info("=> LTCC loaded")
41 | logger.info("Dataset statistics:")
42 | logger.info(" ----------------------------------------")
43 | logger.info(" subset | # ids | # images | # clothes")
44 | logger.info(" ----------------------------------------")
45 | logger.info(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_clothes))
46 | logger.info(" test | {:5d} | {:8d} | {:9d}".format(num_test_pids, num_test_imgs, num_test_clothes))
47 | logger.info(" query | {:5d} | {:8d} |".format(num_test_pids, num_query_imgs))
48 | logger.info(" gallery | {:5d} | {:8d} |".format(num_test_pids, num_gallery_imgs))
49 | logger.info(" ----------------------------------------")
50 | logger.info(" total | {:5d} | {:8d} | {:9d}".format(num_total_pids, num_total_imgs, num_total_clothes))
51 | logger.info(" ----------------------------------------")
52 |
53 | self.train = train
54 | self.query = query
55 | self.gallery = gallery
56 |
57 | self.num_train_pids = num_train_pids
58 | self.num_train_clothes = num_train_clothes
59 | self.pid2clothes = pid2clothes
60 |
61 | def _check_before_run(self):
62 | """Check if all files are available before going deeper"""
63 | if not osp.exists(self.dataset_dir):
64 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
65 | if not osp.exists(self.train_dir):
66 | raise RuntimeError("'{}' is not available".format(self.train_dir))
67 | if not osp.exists(self.query_dir):
68 | raise RuntimeError("'{}' is not available".format(self.query_dir))
69 | if not osp.exists(self.gallery_dir):
70 | raise RuntimeError("'{}' is not available".format(self.gallery_dir))
71 |
72 | def _process_dir_train(self, dir_path):
73 | img_paths = glob.glob(osp.join(dir_path, '*.png'))
74 | img_paths.sort()
75 | pattern1 = re.compile(r'(\d+)_(\d+)_c(\d+)')
76 | pattern2 = re.compile(r'(\w+)_c')
77 |
78 | pid_container = set()
79 | clothes_container = set()
80 | for img_path in img_paths:
81 | pid, _, _ = map(int, pattern1.search(img_path).groups())
82 | clothes_id = pattern2.search(img_path).group(1)
83 | pid_container.add(pid)
84 | clothes_container.add(clothes_id)
85 | pid_container = sorted(pid_container)
86 | clothes_container = sorted(clothes_container)
87 | pid2label = {pid:label for label, pid in enumerate(pid_container)}
88 | clothes2label = {clothes_id:label for label, clothes_id in enumerate(clothes_container)}
89 |
90 | num_pids = len(pid_container)
91 | num_clothes = len(clothes_container)
92 |
93 | dataset = []
94 | pid2clothes = np.zeros((num_pids, num_clothes))
95 | for img_path in img_paths:
96 | pid, _, camid = map(int, pattern1.search(img_path).groups())
97 | clothes = pattern2.search(img_path).group(1)
98 | camid -= 1 # index starts from 0
99 | pid = pid2label[pid]
100 | clothes_id = clothes2label[clothes]
101 | dataset.append((img_path, pid, camid, clothes_id))
102 | pid2clothes[pid, clothes_id] = 1
103 |
104 | num_imgs = len(dataset)
105 |
106 | return dataset, num_pids, num_imgs, num_clothes, pid2clothes
107 |
108 | def _process_dir_test(self, query_path, gallery_path):
109 | query_img_paths = glob.glob(osp.join(query_path, '*.png'))
110 | gallery_img_paths = glob.glob(osp.join(gallery_path, '*.png'))
111 | query_img_paths.sort()
112 | gallery_img_paths.sort()
113 | pattern1 = re.compile(r'(\d+)_(\d+)_c(\d+)')
114 | pattern2 = re.compile(r'(\w+)_c')
115 |
116 | pid_container = set()
117 | clothes_container = set()
118 | for img_path in query_img_paths:
119 | pid, _, _ = map(int, pattern1.search(img_path).groups())
120 | clothes_id = pattern2.search(img_path).group(1)
121 | pid_container.add(pid)
122 | clothes_container.add(clothes_id)
123 | for img_path in gallery_img_paths:
124 | pid, _, _ = map(int, pattern1.search(img_path).groups())
125 | clothes_id = pattern2.search(img_path).group(1)
126 | pid_container.add(pid)
127 | clothes_container.add(clothes_id)
128 | pid_container = sorted(pid_container)
129 | clothes_container = sorted(clothes_container)
130 | pid2label = {pid:label for label, pid in enumerate(pid_container)}
131 | clothes2label = {clothes_id:label for label, clothes_id in enumerate(clothes_container)}
132 |
133 | num_pids = len(pid_container)
134 | num_clothes = len(clothes_container)
135 |
136 | query_dataset = []
137 | gallery_dataset = []
138 | for img_path in query_img_paths:
139 | pid, _, camid = map(int, pattern1.search(img_path).groups())
140 | clothes_id = pattern2.search(img_path).group(1)
141 | camid -= 1 # index starts from 0
142 | clothes_id = clothes2label[clothes_id]
143 | query_dataset.append((img_path, pid, camid, clothes_id))
144 |
145 | for img_path in gallery_img_paths:
146 | pid, _, camid = map(int, pattern1.search(img_path).groups())
147 | clothes_id = pattern2.search(img_path).group(1)
148 | camid -= 1 # index starts from 0
149 | clothes_id = clothes2label[clothes_id]
150 | gallery_dataset.append((img_path, pid, camid, clothes_id))
151 |
152 | num_imgs_query = len(query_dataset)
153 | num_imgs_gallery = len(gallery_dataset)
154 |
155 | return query_dataset, gallery_dataset, num_pids, num_imgs_query, num_imgs_gallery, num_clothes
156 |
157 |
--------------------------------------------------------------------------------
/data/datasets/prcc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import glob
4 | import shutil
5 | import h5py
6 | import random
7 | import math
8 | import logging
9 | import numpy as np
10 | import os.path as osp
11 |
12 |
13 | class PRCC(object):
14 | """ PRCC
15 |
16 | Reference:
17 | Yang et al. Person Re-identification by Contour Sketch under Moderate Clothing Change. TPAMI, 2019.
18 |
19 | URL: https://drive.google.com/file/d/1yTYawRm4ap3M-j0PjLQJ--xmZHseFDLz/view
20 | """
21 | dataset_dir = 'prcc'
22 | def __init__(self, root='data', single_shot=False, seed_sel=0, **kwargs):
23 | self.dataset_dir = osp.join(root, self.dataset_dir)
24 | self.train_dir = osp.join(self.dataset_dir, 'rgb/train')
25 | self.val_dir = osp.join(self.dataset_dir, 'rgb/val')
26 | self.test_dir = osp.join(self.dataset_dir, 'rgb/test')
27 |
28 | self._check_before_run()
29 | train, num_train_pids, num_train_imgs, num_train_clothes, pid2clothes = \
30 | self._process_dir_train(self.train_dir)
31 | val, num_val_pids, num_val_imgs, num_val_clothes, _ = \
32 | self._process_dir_train(self.val_dir)
33 | query_same, query_diff, gallery, num_test_pids, \
34 | num_query_imgs_same, num_query_imgs_diff, num_gallery_imgs, \
35 | num_test_clothes, gallery_idx = self._process_dir_test(self.test_dir)
36 |
37 | num_total_pids = num_train_pids + num_test_pids
38 | num_test_imgs = num_query_imgs_same + num_query_imgs_diff + num_gallery_imgs
39 | num_total_imgs = num_train_imgs + num_val_imgs + num_test_imgs
40 | num_total_clothes = num_train_clothes + num_test_clothes
41 |
42 | logger = logging.getLogger('reid.dataset')
43 | logger.info("=> PRCC loaded")
44 | logger.info("Dataset statistics:")
45 | logger.info(" --------------------------------------------")
46 | logger.info(" subset | # ids | # images | # clothes")
47 | logger.info(" --------------------------------------------")
48 | logger.info(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_clothes))
49 | logger.info(" val | {:5d} | {:8d} | {:9d}".format(num_val_pids, num_val_imgs, num_val_clothes))
50 | logger.info(" test | {:5d} | {:8d} | {:9d}".format(num_test_pids, num_test_imgs, num_test_clothes))
51 | logger.info(" query(same) | {:5d} | {:8d} |".format(num_test_pids, num_query_imgs_same))
52 | logger.info(" query(diff) | {:5d} | {:8d} |".format(num_test_pids, num_query_imgs_diff))
53 | logger.info(" gallery | {:5d} | {:8d} |".format(num_test_pids, num_gallery_imgs))
54 | logger.info(" --------------------------------------------")
55 | logger.info(" total | {:5d} | {:8d} | {:9d}".format(num_total_pids, num_total_imgs, num_total_clothes))
56 | logger.info(" --------------------------------------------")
57 |
58 | self.train = train
59 | self.val = val
60 | self.query_same = query_same
61 | self.query_diff = query_diff
62 | self.gallery = gallery
63 |
64 | self.num_train_pids = num_train_pids
65 | self.num_train_clothes = num_train_clothes
66 | self.pid2clothes = pid2clothes
67 | self.gallery_idx = gallery_idx
68 |
69 | def _check_before_run(self):
70 | """Check if all files are available before going deeper"""
71 | if not osp.exists(self.dataset_dir):
72 | raise RuntimeError("'{}' is not available".format(self.dataset_dir))
73 | if not osp.exists(self.train_dir):
74 | raise RuntimeError("'{}' is not available".format(self.train_dir))
75 | if not osp.exists(self.val_dir):
76 | raise RuntimeError("'{}' is not available".format(self.val_dir))
77 | if not osp.exists(self.test_dir):
78 | raise RuntimeError("'{}' is not available".format(self.test_dir))
79 |
80 | def _process_dir_train(self, dir_path):
81 | pdirs = glob.glob(osp.join(dir_path, '*'))
82 | pdirs.sort()
83 |
84 | pid_container = set()
85 | clothes_container = set()
86 | for pdir in pdirs:
87 | pid = int(osp.basename(pdir))
88 | pid_container.add(pid)
89 | img_dirs = glob.glob(osp.join(pdir, '*.jpg'))
90 | for img_dir in img_dirs:
91 | cam = osp.basename(img_dir)[0] # 'A' or 'B' or 'C'
92 | if cam in ['A', 'B']:
93 | clothes_container.add(osp.basename(pdir))
94 | else:
95 | clothes_container.add(osp.basename(pdir)+osp.basename(img_dir)[0])
96 | pid_container = sorted(pid_container)
97 | clothes_container = sorted(clothes_container)
98 | pid2label = {pid:label for label, pid in enumerate(pid_container)}
99 | clothes2label = {clothes_id:label for label, clothes_id in enumerate(clothes_container)}
100 | cam2label = {'A': 0, 'B': 1, 'C': 2}
101 |
102 | num_pids = len(pid_container)
103 | num_clothes = len(clothes_container)
104 |
105 | dataset = []
106 | pid2clothes = np.zeros((num_pids, num_clothes))
107 | for pdir in pdirs:
108 | pid = int(osp.basename(pdir))
109 | img_dirs = glob.glob(osp.join(pdir, '*.jpg'))
110 | for img_dir in img_dirs:
111 | cam = osp.basename(img_dir)[0] # 'A' or 'B' or 'C'
112 | label = pid2label[pid]
113 | camid = cam2label[cam]
114 | if cam in ['A', 'B']:
115 | clothes_id = clothes2label[osp.basename(pdir)]
116 | else:
117 | clothes_id = clothes2label[osp.basename(pdir)+osp.basename(img_dir)[0]]
118 | dataset.append((img_dir, label, camid, clothes_id))
119 | pid2clothes[label, clothes_id] = 1
120 |
121 | num_imgs = len(dataset)
122 |
123 | return dataset, num_pids, num_imgs, num_clothes, pid2clothes
124 |
125 | def _process_dir_test(self, test_path):
126 | pdirs = glob.glob(osp.join(test_path, '*'))
127 | pdirs.sort()
128 |
129 | pid_container = set()
130 | for pdir in glob.glob(osp.join(test_path, 'A', '*')):
131 | pid = int(osp.basename(pdir))
132 | pid_container.add(pid)
133 | pid_container = sorted(pid_container)
134 | pid2label = {pid:label for label, pid in enumerate(pid_container)}
135 | cam2label = {'A': 0, 'B': 1, 'C': 2}
136 |
137 | num_pids = len(pid_container)
138 | num_clothes = num_pids * 2
139 |
140 | query_dataset_same_clothes = []
141 | query_dataset_diff_clothes = []
142 | gallery_dataset = []
143 | for cam in ['A', 'B', 'C']:
144 | pdirs = glob.glob(osp.join(test_path, cam, '*'))
145 | for pdir in pdirs:
146 | pid = int(osp.basename(pdir))
147 | img_dirs = glob.glob(osp.join(pdir, '*.jpg'))
148 | for img_dir in img_dirs:
149 | # pid = pid2label[pid]
150 | camid = cam2label[cam]
151 | if cam == 'A':
152 | clothes_id = pid2label[pid] * 2
153 | gallery_dataset.append((img_dir, pid, camid, clothes_id))
154 | elif cam == 'B':
155 | clothes_id = pid2label[pid] * 2
156 | query_dataset_same_clothes.append((img_dir, pid, camid, clothes_id))
157 | else:
158 | clothes_id = pid2label[pid] * 2 + 1
159 | query_dataset_diff_clothes.append((img_dir, pid, camid, clothes_id))
160 |
161 | pid2imgidx = {}
162 | for idx, (img_dir, pid, camid, clothes_id) in enumerate(gallery_dataset):
163 | if pid not in pid2imgidx:
164 | pid2imgidx[pid] = []
165 | pid2imgidx[pid].append(idx)
166 |
167 | # get 10 gallery index to perform single-shot test
168 | gallery_idx = {}
169 | random.seed(3)
170 | for idx in range(0, 10):
171 | gallery_idx[idx] = []
172 | for pid in pid2imgidx:
173 | gallery_idx[idx].append(random.choice(pid2imgidx[pid]))
174 |
175 | num_imgs_query_same = len(query_dataset_same_clothes)
176 | num_imgs_query_diff = len(query_dataset_diff_clothes)
177 | num_imgs_gallery = len(gallery_dataset)
178 |
179 | return query_dataset_same_clothes, query_dataset_diff_clothes, gallery_dataset, \
180 | num_pids, num_imgs_query_same, num_imgs_query_diff, num_imgs_gallery, \
181 | num_clothes, gallery_idx
182 |
--------------------------------------------------------------------------------
/data/img_transforms.py:
--------------------------------------------------------------------------------
1 | from torchvision.transforms import *
2 | from PIL import Image
3 | import random
4 | import math
5 |
6 |
7 | class ResizeWithEqualScale(object):
8 | """
9 | Resize an image with equal scale as the original image.
10 |
11 | Args:
12 | height (int): resized height.
13 | width (int): resized width.
14 | interpolation: interpolation manner.
15 | fill_color (tuple): color for padding.
16 | """
17 | def __init__(self, height, width, interpolation=Image.BILINEAR, fill_color=(0,0,0)):
18 | self.height = height
19 | self.width = width
20 | self.interpolation = interpolation
21 | self.fill_color = fill_color
22 |
23 | def __call__(self, img):
24 | width, height = img.size
25 | if self.height / self.width >= height / width:
26 | height = int(self.width * (height / width))
27 | width = self.width
28 | else:
29 | width = int(self.height * (width / height))
30 | height = self.height
31 |
32 | resized_img = img.resize((width, height), self.interpolation)
33 | new_img = Image.new('RGB', (self.width, self.height), self.fill_color)
34 | new_img.paste(resized_img, (int((self.width - width) / 2), int((self.height - height) / 2)))
35 |
36 | return new_img
37 |
38 |
39 | class RandomCroping(object):
40 | """
41 | With a probability, first increase image size to (1 + 1/8), and then perform random crop.
42 |
43 | Args:
44 | p (float): probability of performing this transformation. Default: 0.5.
45 | """
46 | def __init__(self, p=0.5, interpolation=Image.BILINEAR):
47 | self.p = p
48 | self.interpolation = interpolation
49 |
50 | def __call__(self, img):
51 | """
52 | Args:
53 | img (PIL Image): Image to be cropped.
54 |
55 | Returns:
56 | PIL Image: Cropped image.
57 | """
58 | width, height = img.size
59 | if random.uniform(0, 1) >= self.p:
60 | return img
61 |
62 | new_width, new_height = int(round(width * 1.125)), int(round(height * 1.125))
63 | resized_img = img.resize((new_width, new_height), self.interpolation)
64 | x_maxrange = new_width - width
65 | y_maxrange = new_height - height
66 | x1 = int(round(random.uniform(0, x_maxrange)))
67 | y1 = int(round(random.uniform(0, y_maxrange)))
68 | croped_img = resized_img.crop((x1, y1, x1 + width, y1 + height))
69 |
70 | return croped_img
71 |
72 |
73 | class RandomErasing(object):
74 | """
75 | Randomly selects a rectangle region in an image and erases its pixels.
76 |
77 | Reference:
78 | Zhong et al. Random Erasing Data Augmentation. arxiv: 1708.04896, 2017.
79 |
80 | Args:
81 | probability: The probability that the Random Erasing operation will be performed.
82 | sl: Minimum proportion of erased area against input image.
83 | sh: Maximum proportion of erased area against input image.
84 | r1: Minimum aspect ratio of erased area.
85 | mean: Erasing value.
86 | """
87 |
88 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]):
89 | self.probability = probability
90 | self.mean = mean
91 | self.sl = sl
92 | self.sh = sh
93 | self.r1 = r1
94 |
95 | def __call__(self, img):
96 |
97 | if random.uniform(0, 1) >= self.probability:
98 | return img
99 |
100 | for attempt in range(100):
101 | area = img.size()[1] * img.size()[2]
102 |
103 | target_area = random.uniform(self.sl, self.sh) * area
104 | aspect_ratio = random.uniform(self.r1, 1/self.r1)
105 |
106 | h = int(round(math.sqrt(target_area * aspect_ratio)))
107 | w = int(round(math.sqrt(target_area / aspect_ratio)))
108 |
109 | if w < img.size()[2] and h < img.size()[1]:
110 | x1 = random.randint(0, img.size()[1] - h)
111 | y1 = random.randint(0, img.size()[2] - w)
112 | if img.size()[0] == 3:
113 | img[0, x1:x1+h, y1:y1+w] = self.mean[0]
114 | img[1, x1:x1+h, y1:y1+w] = self.mean[1]
115 | img[2, x1:x1+h, y1:y1+w] = self.mean[2]
116 | else:
117 | img[0, x1:x1+h, y1:y1+w] = self.mean[0]
118 | return img
119 |
120 | return img
--------------------------------------------------------------------------------
/data/samplers.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | import random
4 | import numpy as np
5 | from torch import distributed as dist
6 | from collections import defaultdict
7 | from torch.utils.data.sampler import Sampler
8 |
9 |
10 | class RandomIdentitySampler(Sampler):
11 | """
12 | Randomly sample N identities, then for each identity,
13 | randomly sample K instances, therefore batch size is N*K.
14 |
15 | Args:
16 | data_source (Dataset): dataset to sample from.
17 | num_instances (int): number of instances per identity.
18 | """
19 | def __init__(self, data_source, num_instances=4):
20 | self.data_source = data_source
21 | self.num_instances = num_instances
22 | self.index_dic = defaultdict(list)
23 | for index, (_, pid, _, _) in enumerate(data_source):
24 | self.index_dic[pid].append(index)
25 | self.pids = list(self.index_dic.keys())
26 | self.num_identities = len(self.pids)
27 |
28 | # compute number of examples in an epoch
29 | self.length = 0
30 | for pid in self.pids:
31 | idxs = self.index_dic[pid]
32 | num = len(idxs)
33 | if num < self.num_instances:
34 | num = self.num_instances
35 | self.length += num - num % self.num_instances
36 |
37 | def __iter__(self):
38 | list_container = []
39 |
40 | for pid in self.pids:
41 | idxs = copy.deepcopy(self.index_dic[pid])
42 | if len(idxs) < self.num_instances:
43 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True)
44 | random.shuffle(idxs)
45 | batch_idxs = []
46 | for idx in idxs:
47 | batch_idxs.append(idx)
48 | if len(batch_idxs) == self.num_instances:
49 | list_container.append(batch_idxs)
50 | batch_idxs = []
51 |
52 | random.shuffle(list_container)
53 |
54 | ret = []
55 | for batch_idxs in list_container:
56 | ret.extend(batch_idxs)
57 |
58 | return iter(ret)
59 |
60 | def __len__(self):
61 | return self.length
62 |
63 |
64 | class DistributedRandomIdentitySampler(Sampler):
65 | """
66 | Randomly sample N identities, then for each identity,
67 | randomly sample K instances, therefore batch size is N*K.
68 |
69 | Args:
70 | - data_source (Dataset): dataset to sample from.
71 | - num_instances (int): number of instances per identity.
72 | - num_replicas (int, optional): Number of processes participating in
73 | distributed training. By default, :attr:`world_size` is retrieved from the
74 | current distributed group.
75 | - rank (int, optional): Rank of the current process within :attr:`num_replicas`.
76 | By default, :attr:`rank` is retrieved from the current distributed group.
77 | - seed (int, optional): random seed used to shuffle the sampler.
78 | This number should be identical across all
79 | processes in the distributed group. Default: ``0``.
80 | """
81 | def __init__(self, data_source, num_instances=4,
82 | num_replicas=None, rank=None, seed=0):
83 | if num_replicas is None:
84 | if not dist.is_available():
85 | raise RuntimeError("Requires distributed package to be available")
86 | num_replicas = dist.get_world_size()
87 | if rank is None:
88 | if not dist.is_available():
89 | raise RuntimeError("Requires distributed package to be available")
90 | rank = dist.get_rank()
91 | if rank >= num_replicas or rank < 0:
92 | raise ValueError(
93 | "Invalid rank {}, rank should be in the interval"
94 | " [0, {}]".format(rank, num_replicas - 1))
95 | self.num_replicas = num_replicas
96 | self.rank = rank
97 | self.seed = seed
98 | self.epoch = 0
99 |
100 | self.data_source = data_source
101 | self.num_instances = num_instances
102 | self.index_dic = defaultdict(list)
103 | for index, (_, pid, _, _) in enumerate(data_source):
104 | self.index_dic[pid].append(index)
105 | self.pids = list(self.index_dic.keys())
106 | self.num_identities = len(self.pids)
107 |
108 | # compute number of examples in an epoch
109 | self.length = 0
110 | for pid in self.pids:
111 | idxs = self.index_dic[pid]
112 | num = len(idxs)
113 | if num < self.num_instances:
114 | num = self.num_instances
115 | self.length += num - num % self.num_instances
116 | assert self.length % self.num_instances == 0
117 |
118 | if self.length // self.num_instances % self.num_replicas != 0:
119 | self.num_samples = math.ceil((self.length // self.num_instances - self.num_replicas) / self.num_replicas) * self.num_instances
120 | else:
121 | self.num_samples = math.ceil(self.length / self.num_replicas)
122 | self.total_size = self.num_samples * self.num_replicas
123 |
124 | def __iter__(self):
125 | # deterministically shuffle based on epoch and seed
126 | random.seed(self.seed + self.epoch)
127 | np.random.seed(self.seed + self.epoch)
128 |
129 | list_container = []
130 | for pid in self.pids:
131 | idxs = copy.deepcopy(self.index_dic[pid])
132 | if len(idxs) < self.num_instances:
133 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True)
134 | random.shuffle(idxs)
135 | batch_idxs = []
136 | for idx in idxs:
137 | batch_idxs.append(idx)
138 | if len(batch_idxs) == self.num_instances:
139 | list_container.append(batch_idxs)
140 | batch_idxs = []
141 | random.shuffle(list_container)
142 |
143 | # remove tail of data to make it evenly divisible.
144 | list_container = list_container[:self.total_size//self.num_instances]
145 | assert len(list_container) == self.total_size//self.num_instances
146 |
147 | # subsample
148 | list_container = list_container[self.rank:self.total_size//self.num_instances:self.num_replicas]
149 | assert len(list_container) == self.num_samples//self.num_instances
150 |
151 | ret = []
152 | for batch_idxs in list_container:
153 | ret.extend(batch_idxs)
154 |
155 | return iter(ret)
156 |
157 | def __len__(self):
158 | return self.num_samples
159 |
160 | def set_epoch(self, epoch):
161 | """
162 | Sets the epoch for this sampler. This ensures all replicas
163 | use a different random ordering for each epoch. Otherwise, the next iteration of this
164 | sampler will yield the same ordering.
165 |
166 | Args:
167 | epoch (int): Epoch number.
168 | """
169 | self.epoch = epoch
170 |
171 |
172 | class DistributedInferenceSampler(Sampler):
173 | """
174 | refer to: https://github.com/huggingface/transformers/blob/447808c85f0e6d6b0aeeb07214942bf1e578f9d2/src/transformers/trainer_pt_utils.py
175 |
176 | Distributed Sampler that subsamples indicies sequentially,
177 | making it easier to collate all results at the end.
178 | Even though we only use this sampler for eval and predict (no training),
179 | which means that the model params won't have to be synced (i.e. will not hang
180 | for synchronization even if varied number of forward passes), we still add extra
181 | samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
182 | to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
183 | """
184 | def __init__(self, dataset, rank=None, num_replicas=None):
185 | if num_replicas is None:
186 | if not dist.is_available():
187 | raise RuntimeError("Requires distributed package to be available")
188 | num_replicas = dist.get_world_size()
189 | if rank is None:
190 | if not dist.is_available():
191 | raise RuntimeError("Requires distributed package to be available")
192 | rank = dist.get_rank()
193 | self.dataset = dataset
194 | self.num_replicas = num_replicas
195 | self.rank = rank
196 |
197 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
198 | self.total_size = self.num_samples * self.num_replicas
199 |
200 | def __iter__(self):
201 | indices = list(range(len(self.dataset)))
202 | # add extra samples to make it evenly divisible
203 | indices += [indices[-1]] * (self.total_size - len(indices))
204 | # subsample
205 | indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
206 | return iter(indices)
207 |
208 | def __len__(self):
209 | return self.num_samples
--------------------------------------------------------------------------------
/demo.sh:
--------------------------------------------------------------------------------
1 | # Example
2 | # python demo_single_image.py --cfg configs/res50_cels_cal.yaml --img_path /root/github/datasets/LTCC_ReID/query/001_1_c4_015861.png --weights /root/github/ReID_demo/CC-ReID/weights/ltcc.pth.tar --gpu 7
3 |
4 |
5 | #
6 | python demo_single_image.py --cfg configs/res50_cels_cal.yaml --img_path YOUR_IMAGE_PATH --weights YOUR_WEIGHT_PATH --gpu YOUR_GPU_ID
7 |
8 |
--------------------------------------------------------------------------------
/demo_single_image.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import datetime
4 | import logging
5 | import argparse
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | from torch import distributed as dist
10 | import torchvision
11 | from torchvision import datasets, models, transforms
12 | from configs.default_img_single import get_img_config
13 | from models.img_resnet import ResNet50
14 | from PIL import Image
15 |
16 | def parse_option():
17 | parser = argparse.ArgumentParser(
18 | description='Train clothes-changing re-id model with clothes-based adversarial loss')
19 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file')
20 | # Datasets
21 | parser.add_argument('--root', type=str, help="your root path to data directory")
22 | # Miscs
23 | parser.add_argument('--img_path', type=str, help='path to the image')
24 | parser.add_argument('--weights', type=str, help='path to the weights')
25 | parser.add_argument('--gpu', type=str, default='0', help='gpu id')
26 |
27 | args, unparsed = parser.parse_known_args()
28 | config = get_img_config(args)
29 | return config, args
30 |
31 | @torch.no_grad()
32 | def extract_img_feature(model, img):
33 | flip_img = torch.flip(img, [3])
34 | img, flip_img = img.cuda(), flip_img.cuda()
35 | _, batch_features = model(img)
36 | _, batch_features_flip = model(flip_img)
37 | batch_features += batch_features_flip
38 | batch_features = F.normalize(batch_features, p=2, dim=1)
39 | features = batch_features.cpu()
40 |
41 | return features
42 |
43 | config, args = parse_option()
44 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
45 |
46 | dict = torch.load(args.weights, weights_only=True)
47 | model = ResNet50(config)
48 | model.load_state_dict(dict['model_state_dict'])
49 | model = model.cuda()
50 | model.eval()
51 |
52 | # IMAGENET_MEAN = [0.485, 0.456, 0.406]
53 | # IMAGENET_STD = [0.229, 0.224, 0.225]
54 | # GRID_SPACING = 10
55 |
56 | data_transforms = transforms.Compose([
57 | transforms.Resize((config.DATA.HEIGHT, config.DATA.WIDTH)),
58 | transforms.ToTensor(),
59 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
60 | ])
61 |
62 | image = Image.open(args.img_path)
63 | image_tensor = data_transforms(image)
64 | input_batch = image_tensor.unsqueeze(0) # Add a batch dimension
65 |
66 | feature = extract_img_feature(model, input_batch)
67 |
68 | print("Input Image:", args.img_path, " Output Feautre:", feature)
--------------------------------------------------------------------------------
/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from losses.cross_entropy_loss_with_label_smooth import CrossEntropyWithLabelSmooth
3 | from losses.triplet_loss import TripletLoss
4 | from losses.contrastive_loss import ContrastiveLoss
5 | from losses.arcface_loss import ArcFaceLoss
6 | from losses.cosface_loss import CosFaceLoss, PairwiseCosFaceLoss
7 | from losses.circle_loss import CircleLoss, PairwiseCircleLoss
8 | from losses.clothes_based_adversarial_loss import ClothesBasedAdversarialLoss
9 |
10 | def build_losses(config, num_train_clothes):
11 | # Build identity classification loss
12 | if config.LOSS.CLA_LOSS == 'crossentropy':
13 | criterion_cla = nn.CrossEntropyLoss()
14 | elif config.LOSS.CLA_LOSS == 'crossentropylabelsmooth':
15 | criterion_cla = CrossEntropyWithLabelSmooth()
16 | elif config.LOSS.CLA_LOSS == 'arcface':
17 | criterion_cla = ArcFaceLoss(scale=config.LOSS.CLA_S, margin=config.LOSS.CLA_M)
18 | elif config.LOSS.CLA_LOSS == 'cosface':
19 | criterion_cla = CosFaceLoss(scale=config.LOSS.CLA_S, margin=config.LOSS.CLA_M)
20 | elif config.LOSS.CLA_LOSS == 'circle':
21 | criterion_cla = CircleLoss(scale=config.LOSS.CLA_S, margin=config.LOSS.CLA_M)
22 | else:
23 | raise KeyError("Invalid classification loss: '{}'".format(config.LOSS.CLA_LOSS))
24 |
25 | # Build pairwise loss
26 | if config.LOSS.PAIR_LOSS == 'triplet':
27 | criterion_pair = TripletLoss(margin=config.LOSS.PAIR_M)
28 | elif config.LOSS.PAIR_LOSS == 'contrastive':
29 | criterion_pair = ContrastiveLoss(scale=config.LOSS.PAIR_S)
30 | elif config.LOSS.PAIR_LOSS == 'cosface':
31 | criterion_pair = PairwiseCosFaceLoss(scale=config.LOSS.PAIR_S, margin=config.LOSS.PAIR_M)
32 | elif config.LOSS.PAIR_LOSS == 'circle':
33 | criterion_pair = PairwiseCircleLoss(scale=config.LOSS.PAIR_S, margin=config.LOSS.PAIR_M)
34 | else:
35 | raise KeyError("Invalid pairwise loss: '{}'".format(config.LOSS.PAIR_LOSS))
36 |
37 |
38 | # Build clothes classification loss
39 | if config.LOSS.CLOTHES_CLA_LOSS == 'crossentropy':
40 | criterion_clothes = nn.CrossEntropyLoss()
41 | elif config.LOSS.CLOTHES_CLA_LOSS == 'cosface':
42 | criterion_clothes = CosFaceLoss(scale=config.LOSS.CLA_S, margin=0)
43 | else:
44 | raise KeyError("Invalid clothes classification loss: '{}'".format(config.LOSS.CLOTHES_CLA_LOSS))
45 |
46 | # Build clothes-based adversarial loss
47 | if config.LOSS.CAL == 'cal':
48 | criterion_cal = ClothesBasedAdversarialLoss(scale=config.LOSS.CLA_S, epsilon=config.LOSS.EPSILON)
49 | else:
50 | raise KeyError("Invalid clothing classification loss: '{}'".format(config.LOSS.CAL))
51 |
52 | kl = nn.functional.kl_div
53 |
54 | return criterion_cla, criterion_pair, criterion_clothes, criterion_cal, kl
55 |
--------------------------------------------------------------------------------
/losses/arcface_loss.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn.functional as F
4 | from torch import nn
5 |
6 |
7 | class ArcFaceLoss(nn.Module):
8 | """ ArcFace loss.
9 |
10 | Reference:
11 | Deng et al. ArcFace: Additive Angular Margin Loss for Deep Face Recognition. In CVPR, 2019.
12 |
13 | Args:
14 | scale (float): scaling factor.
15 | margin (float): pre-defined margin.
16 | """
17 | def __init__(self, scale=16, margin=0.1):
18 | super().__init__()
19 | self.s = scale
20 | self.m = margin
21 |
22 | def forward(self, inputs, targets):
23 | """
24 | Args:
25 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
26 | targets: ground truth labels with shape (batch_size)
27 | """
28 | # get a one-hot index
29 | index = inputs.data * 0.0
30 | index.scatter_(1, targets.data.view(-1, 1), 1)
31 | index = index.bool()
32 |
33 | cos_m = math.cos(self.m)
34 | sin_m = math.sin(self.m)
35 | cos_t = inputs[index]
36 | sin_t = torch.sqrt(1.0 - cos_t * cos_t)
37 | cos_t_add_m = cos_t * cos_m - sin_t * sin_m
38 |
39 | cond_v = cos_t - math.cos(math.pi - self.m)
40 | cond = F.relu(cond_v)
41 | keep = cos_t - math.sin(math.pi - self.m) * self.m
42 |
43 | cos_t_add_m = torch.where(cond.bool(), cos_t_add_m, keep)
44 |
45 | output = inputs * 1.0
46 | output[index] = cos_t_add_m
47 | output = self.s * output
48 |
49 | return F.cross_entropy(output, targets)
50 |
--------------------------------------------------------------------------------
/losses/circle_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 | from torch import distributed as dist
5 | from losses.gather import GatherLayer
6 |
7 |
8 | class CircleLoss(nn.Module):
9 | """ Circle Loss based on the predictions of classifier.
10 |
11 | Reference:
12 | Sun et al. Circle Loss: A Unified Perspective of Pair Similarity Optimization. In CVPR, 2020.
13 |
14 | Args:
15 | scale (float): scaling factor.
16 | margin (float): pre-defined margin.
17 | """
18 | def __init__(self, scale=96, margin=0.3, **kwargs):
19 | super().__init__()
20 | self.s = scale
21 | self.m = margin
22 |
23 | def forward(self, inputs, targets):
24 | """
25 | Args:
26 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
27 | targets: ground truth labels with shape (batch_size)
28 | """
29 | mask = torch.zeros_like(inputs).cuda()
30 | mask.scatter_(1, targets.view(-1, 1), 1.0)
31 |
32 | pos_scale = self.s * F.relu(1 + self.m - inputs.detach())
33 | neg_scale = self.s * F.relu(inputs.detach() + self.m)
34 | scale_matrix = pos_scale * mask + neg_scale * (1 - mask)
35 |
36 | scores = (inputs - (1 - self.m) * mask - self.m * (1 - mask)) * scale_matrix
37 |
38 | loss = F.cross_entropy(scores, targets)
39 |
40 | return loss
41 |
42 |
43 | class PairwiseCircleLoss(nn.Module):
44 | """ Circle Loss among sample pairs.
45 |
46 | Reference:
47 | Sun et al. Circle Loss: A Unified Perspective of Pair Similarity Optimization. In CVPR, 2020.
48 |
49 | Args:
50 | scale (float): scaling factor.
51 | margin (float): pre-defined margin.
52 | """
53 | def __init__(self, scale=48, margin=0.35, **kwargs):
54 | super().__init__()
55 | self.s = scale
56 | self.m = margin
57 |
58 | def forward(self, inputs, targets):
59 | """
60 | Args:
61 | inputs: sample features (before classifier) with shape (batch_size, feat_dim)
62 | targets: ground truth labels with shape (batch_size)
63 | """
64 | # l2-normalize
65 | inputs = F.normalize(inputs, p=2, dim=1)
66 |
67 | # gather all samples from different GPUs as gallery to compute pairwise loss.
68 | gallery_inputs = torch.cat(GatherLayer.apply(inputs), dim=0)
69 | gallery_targets = torch.cat(GatherLayer.apply(targets), dim=0)
70 | m, n = targets.size(0), gallery_targets.size(0)
71 |
72 | # compute cosine similarity
73 | similarities = torch.matmul(inputs, gallery_inputs.t())
74 |
75 | # get mask for pos/neg pairs
76 | targets, gallery_targets = targets.view(-1, 1), gallery_targets.view(-1, 1)
77 | mask = torch.eq(targets, gallery_targets.T).float().cuda()
78 | mask_self = torch.zeros_like(mask)
79 | rank = dist.get_rank()
80 | mask_self[:, rank * m:(rank + 1) * m] += torch.eye(m).float().cuda()
81 | mask_pos = mask - mask_self
82 | mask_neg = 1 - mask
83 |
84 | pos_scale = self.s * F.relu(1 + self.m - similarities.detach())
85 | neg_scale = self.s * F.relu(similarities.detach() + self.m)
86 | scale_matrix = pos_scale * mask_pos + neg_scale * mask_neg
87 |
88 | scores = (similarities - self.m) * mask_neg + (1 - self.m - similarities) * mask_pos
89 | scores = scores * scale_matrix
90 |
91 | neg_scores_LSE = torch.logsumexp(scores * mask_neg - 99999999 * (1 - mask_neg), dim=1)
92 | pos_scores_LSE = torch.logsumexp(scores * mask_pos - 99999999 * (1 - mask_pos), dim=1)
93 |
94 | loss = F.softplus(neg_scores_LSE + pos_scores_LSE).mean()
95 |
96 | return loss
97 |
--------------------------------------------------------------------------------
/losses/clothes_based_adversarial_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 | from losses.gather import GatherLayer
5 |
6 |
7 | class ClothesBasedAdversarialLoss(nn.Module):
8 | """ Clothes-based Adversarial Loss.
9 |
10 | Reference:
11 | Gu et al. Clothes-Changing Person Re-identification with RGB Modality Only. In CVPR, 2022.
12 |
13 | Args:
14 | scale (float): scaling factor.
15 | epsilon (float): a trade-off hyper-parameter.
16 | """
17 | def __init__(self, scale=16, epsilon=0.1):
18 | super().__init__()
19 | self.scale = scale
20 | self.epsilon = epsilon
21 |
22 | def forward(self, inputs, targets, positive_mask):
23 | """
24 | Args:
25 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
26 | targets: ground truth labels with shape (batch_size)
27 | positive_mask: positive mask matrix with shape (batch_size, num_classes). The clothes classes with
28 | the same identity as the anchor sample are defined as positive clothes classes and their mask
29 | values are 1. The clothes classes with different identities from the anchor sample are defined
30 | as negative clothes classes and their mask values in positive_mask are 0.
31 | """
32 | inputs = self.scale * inputs
33 | negtive_mask = 1 - positive_mask
34 | identity_mask = torch.zeros(inputs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1).cuda()
35 |
36 | exp_logits = torch.exp(inputs)
37 | log_sum_exp_pos_and_all_neg = torch.log((exp_logits * negtive_mask).sum(1, keepdim=True) + exp_logits)
38 | log_prob = inputs - log_sum_exp_pos_and_all_neg
39 |
40 | mask = (1 - self.epsilon) * identity_mask + self.epsilon / positive_mask.sum(1, keepdim=True) * positive_mask
41 | loss = (- mask * log_prob).sum(1).mean()
42 |
43 | return loss
44 |
--------------------------------------------------------------------------------
/losses/contrastive_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 | from torch import distributed as dist
5 | from losses.gather import GatherLayer
6 |
7 |
8 | class ContrastiveLoss(nn.Module):
9 | """ Supervised Contrastive Learning Loss among sample pairs.
10 |
11 | Args:
12 | scale (float): scaling factor.
13 | """
14 | def __init__(self, scale=16, **kwargs):
15 | super().__init__()
16 | self.s = scale
17 |
18 | def forward(self, inputs, targets):
19 | """
20 | Args:
21 | inputs: sample features (before classifier) with shape (batch_size, feat_dim)
22 | targets: ground truth labels with shape (batch_size)
23 | """
24 | # l2-normalize
25 | inputs = F.normalize(inputs, p=2, dim=1)
26 |
27 | # gather all samples from different GPUs as gallery to compute pairwise loss.
28 | gallery_inputs = torch.cat(GatherLayer.apply(inputs), dim=0)
29 | gallery_targets = torch.cat(GatherLayer.apply(targets), dim=0)
30 | m, n = targets.size(0), gallery_targets.size(0)
31 |
32 | # compute cosine similarity
33 | similarities = torch.matmul(inputs, gallery_inputs.t()) * self.s
34 |
35 | # get mask for pos/neg pairs
36 | targets, gallery_targets = targets.view(-1, 1), gallery_targets.view(-1, 1)
37 | mask = torch.eq(targets, gallery_targets.T).float().cuda()
38 | mask_self = torch.zeros_like(mask)
39 | rank = dist.get_rank()
40 | mask_self[:, rank * m:(rank + 1) * m] += torch.eye(m).float().cuda()
41 | mask_pos = mask - mask_self
42 | mask_neg = 1 - mask
43 |
44 | # compute log_prob
45 | exp_logits = torch.exp(similarities) * (1 - mask_self)
46 | # log_prob = similarities - torch.log(exp_logits.sum(1, keepdim=True))
47 | log_sum_exp_pos_and_all_neg = torch.log((exp_logits * mask_neg).sum(1, keepdim=True) + exp_logits)
48 | log_prob = similarities - log_sum_exp_pos_and_all_neg
49 |
50 | # compute mean of log-likelihood over positive
51 | loss = (mask_pos * log_prob).sum(1) / mask_pos.sum(1)
52 |
53 | loss = - loss.mean()
54 |
55 | return loss
--------------------------------------------------------------------------------
/losses/cosface_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 | from torch import distributed as dist
5 | from losses.gather import GatherLayer
6 |
7 |
8 | class CosFaceLoss(nn.Module):
9 | """ CosFace Loss based on the predictions of classifier.
10 |
11 | Reference:
12 | Wang et al. CosFace: Large Margin Cosine Loss for Deep Face Recognition. In CVPR, 2018.
13 |
14 | Args:
15 | scale (float): scaling factor.
16 | margin (float): pre-defined margin.
17 | """
18 | def __init__(self, scale=16, margin=0.1, **kwargs):
19 | super().__init__()
20 | self.s = scale
21 | self.m = margin
22 |
23 | def forward(self, inputs, targets):
24 | """
25 | Args:
26 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
27 | targets: ground truth labels with shape (batch_size)
28 | """
29 | one_hot = torch.zeros_like(inputs)
30 | one_hot.scatter_(1, targets.view(-1, 1), 1.0)
31 |
32 | output = self.s * (inputs - one_hot * self.m)
33 |
34 | return F.cross_entropy(output, targets)
35 |
36 |
37 | class PairwiseCosFaceLoss(nn.Module):
38 | """ CosFace Loss among sample pairs.
39 |
40 | Reference:
41 | Sun et al. Circle Loss: A Unified Perspective of Pair Similarity Optimization. In CVPR, 2020.
42 |
43 | Args:
44 | scale (float): scaling factor.
45 | margin (float): pre-defined margin.
46 | """
47 | def __init__(self, scale=16, margin=0):
48 | super().__init__()
49 | self.s = scale
50 | self.m = margin
51 |
52 | def forward(self, inputs, targets):
53 | """
54 | Args:
55 | inputs: sample features (before classifier) with shape (batch_size, feat_dim)
56 | targets: ground truth labels with shape (batch_size)
57 | """
58 | # l2-normalize
59 | inputs = F.normalize(inputs, p=2, dim=1)
60 |
61 | # gather all samples from different GPUs as gallery to compute pairwise loss.
62 | gallery_inputs = torch.cat(GatherLayer.apply(inputs), dim=0)
63 | gallery_targets = torch.cat(GatherLayer.apply(targets), dim=0)
64 | m, n = targets.size(0), gallery_targets.size(0)
65 |
66 | # compute cosine similarity
67 | similarities = torch.matmul(inputs, gallery_inputs.t())
68 |
69 | # get mask for pos/neg pairs
70 | targets, gallery_targets = targets.view(-1, 1), gallery_targets.view(-1, 1)
71 | mask = torch.eq(targets, gallery_targets.T).float().cuda()
72 | mask_self = torch.zeros_like(mask)
73 | rank = dist.get_rank()
74 | mask_self[:, rank * m:(rank + 1) * m] += torch.eye(m).float().cuda()
75 | mask_pos = mask - mask_self
76 | mask_neg = 1 - mask
77 |
78 | scores = (similarities + self.m) * mask_neg - similarities * mask_pos
79 | scores = scores * self.s
80 |
81 | neg_scores_LSE = torch.logsumexp(scores * mask_neg - 99999999 * (1 - mask_neg), dim=1)
82 | pos_scores_LSE = torch.logsumexp(scores * mask_pos - 99999999 * (1 - mask_pos), dim=1)
83 |
84 | loss = F.softplus(neg_scores_LSE + pos_scores_LSE).mean()
85 |
86 | return loss
--------------------------------------------------------------------------------
/losses/cross_entropy_loss_with_label_smooth.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class CrossEntropyWithLabelSmooth(nn.Module):
6 | """ Cross entropy loss with label smoothing regularization.
7 |
8 | Reference:
9 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. In CVPR, 2016.
10 | Equation:
11 | y = (1 - epsilon) * y + epsilon / K.
12 |
13 | Args:
14 | epsilon (float): a hyper-parameter in the above equation.
15 | """
16 | def __init__(self, epsilon=0.1):
17 | super().__init__()
18 | self.epsilon = epsilon
19 | self.logsoftmax = nn.LogSoftmax(dim=1)
20 |
21 | def forward(self, inputs, targets):
22 | """
23 | Args:
24 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
25 | targets: ground truth labels with shape (batch_size)
26 | """
27 | _, num_classes = inputs.size()
28 | log_probs = self.logsoftmax(inputs)
29 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1).cuda()
30 | targets = (1 - self.epsilon) * targets + self.epsilon / num_classes
31 | loss = (- targets * log_probs).mean(0).sum()
32 |
33 | return loss
34 |
--------------------------------------------------------------------------------
/losses/gather.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 |
4 |
5 | class GatherLayer(torch.autograd.Function):
6 | """Gather tensors from all process, supporting backward propagation."""
7 |
8 | @staticmethod
9 | def forward(ctx, input):
10 | ctx.save_for_backward(input)
11 | output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
12 | dist.all_gather(output, input)
13 |
14 | return tuple(output)
15 |
16 | @staticmethod
17 | def backward(ctx, *grads):
18 | (input,) = ctx.saved_tensors
19 | grad_out = torch.zeros_like(input)
20 |
21 | # dist.reduce_scatter(grad_out, list(grads))
22 | # grad_out.div_(dist.get_world_size())
23 |
24 | grad_out[:] = grads[dist.get_rank()]
25 |
26 | return grad_out
--------------------------------------------------------------------------------
/losses/triplet_loss.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn.functional as F
4 | from torch import nn
5 | from losses.gather import GatherLayer
6 |
7 |
8 | class TripletLoss(nn.Module):
9 | """ Triplet loss with hard example mining.
10 |
11 | Reference:
12 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.
13 |
14 | Args:
15 | margin (float): pre-defined margin.
16 |
17 | Note that we use cosine similarity, rather than Euclidean distance in the original paper.
18 | """
19 | def __init__(self, margin=0.3):
20 | super().__init__()
21 | self.m = margin
22 | self.ranking_loss = nn.MarginRankingLoss(margin=margin)
23 |
24 | def forward(self, inputs, targets):
25 | """
26 | Args:
27 | inputs: sample features (before classifier) with shape (batch_size, feat_dim)
28 | targets: ground truth labels with shape (batch_size)
29 | """
30 | # l2-normlize
31 | inputs = F.normalize(inputs, p=2, dim=1)
32 |
33 | # gather all samples from different GPUs as gallery to compute pairwise loss.
34 | gallery_inputs = torch.cat(GatherLayer.apply(inputs), dim=0)
35 | gallery_targets = torch.cat(GatherLayer.apply(targets), dim=0)
36 |
37 | # compute distance
38 | dist = 1 - torch.matmul(inputs, gallery_inputs.t()) # values in [0, 2]
39 |
40 | # get positive and negative masks
41 | targets, gallery_targets = targets.view(-1,1), gallery_targets.view(-1,1)
42 | mask_pos = torch.eq(targets, gallery_targets.T).float().cuda()
43 | mask_neg = 1 - mask_pos
44 |
45 | # For each anchor, find the hardest positive and negative pairs
46 | dist_ap, _ = torch.max((dist - mask_neg * 99999999.), dim=1)
47 | dist_an, _ = torch.min((dist + mask_pos * 99999999.), dim=1)
48 |
49 | # Compute ranking hinge loss
50 | y = torch.ones_like(dist_an)
51 | loss = self.ranking_loss(dist_an, dist_ap, y)
52 |
53 | return loss
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from threading import local
4 | import time
5 | import datetime
6 | import argparse
7 | import logging
8 | import os.path as osp
9 | import numpy as np
10 | import gc
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.optim as optim
15 | from torch.optim import lr_scheduler
16 | from torch import distributed as dist
17 | from apex import amp
18 | # from models.lr_scheduler import WarmupMultiStepLR
19 |
20 | from configs.default_img import get_img_config
21 | from data import build_dataloader
22 | from models import build_model
23 | from losses import build_losses
24 | from tools.utils import save_checkpoint, set_seed, get_logger
25 | from train import train_aim
26 | from test import test, test_prcc
27 |
28 |
29 | def parse_option():
30 | parser = argparse.ArgumentParser(description='Train clothes-changing re-id model with clothes-based adversarial loss')
31 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file')
32 | # Datasets
33 | parser.add_argument('--root', type=str, help="your root path to data directory")
34 | parser.add_argument('--dataset', type=str, default='ltcc', help="ltcc, prcc, vcclothes, last, deepchange")
35 | # Miscs
36 | parser.add_argument('--output', type=str, help="your output path to save model and logs")
37 | parser.add_argument('--resume', type=str, metavar='PATH')
38 | parser.add_argument('--amp', action='store_true', help="automatic mixed precision")
39 | parser.add_argument('--eval', action='store_true', help="evaluation only")
40 | parser.add_argument('--tag', type=str, help='tag for log file')
41 | parser.add_argument('--name', type=str, help='your model name for record')
42 | parser.add_argument('--gpu', default='0', type=str, help='gpu device ids for CUDA_VISIBLE_DEVICES')
43 | # Options and Hyper-parameters
44 | parser.add_argument('--seed', type=str, help='seed for single-shot')
45 | parser.add_argument('--single_shot', action='store_true', help='single-shot option')
46 | parser.add_argument('--k_cal', type=str)
47 | parser.add_argument('--k_kl', type=str)
48 |
49 | args, unparsed = parser.parse_known_args()
50 | config = get_img_config(args)
51 |
52 | return config
53 |
54 | def main(config):
55 | # Build dataloader
56 | if config.DATA.DATASET == 'prcc':
57 | trainloader, queryloader_same, queryloader_diff, galleryloader, dataset, train_sampler = build_dataloader(config)
58 | else:
59 | trainloader, queryloader, galleryloader, dataset, train_sampler = build_dataloader(config)
60 |
61 | # Define a matrix pid2clothes with shape (num_pids, num_clothes).
62 | # pid2clothes[i, j] = 1 when j-th clothes belongs to i-th identity. Otherwise, pid2clothes[i, j] = 0.
63 | pid2clothes = torch.from_numpy(dataset.pid2clothes)
64 |
65 | # Build model
66 | model, model2, fuse, classifier, clothes_classifier, clothes_classifier2 = build_model(config, dataset.num_train_pids, dataset.num_train_clothes)
67 | print("model loaded")
68 | # Build identity classification loss, pairwise loss, clothes classificaiton loss, and adversarial loss.
69 | criterion_cla, criterion_pair, criterion_clothes, criterion_adv, kl = build_losses(config, dataset.num_train_clothes)
70 | print("loss built")
71 | # Build optimizer
72 | parameters = list(model.parameters()) + list(fuse.parameters()) + list(classifier.parameters())
73 | parameters2 = list(model2.parameters()) + list(clothes_classifier2.parameters())
74 |
75 |
76 | if config.TRAIN.OPTIMIZER.NAME == 'adam':
77 | optimizer = optim.Adam(parameters, lr=config.TRAIN.OPTIMIZER.LR,
78 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY)
79 | optimizer2 = optim.Adam(parameters2, lr=config.TRAIN.OPTIMIZER.LR,
80 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY)
81 | optimizer_cc = optim.Adam(clothes_classifier.parameters(), lr=config.TRAIN.OPTIMIZER.LR,
82 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY)
83 | elif config.TRAIN.OPTIMIZER.NAME == 'adamw':
84 | optimizer = optim.AdamW(parameters, lr=config.TRAIN.OPTIMIZER.LR,
85 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY)
86 | optimizer2 = optim.AdamW(parameters2, lr=config.TRAIN.OPTIMIZER.LR,
87 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY)
88 | optimizer_cc = optim.AdamW(clothes_classifier.parameters(), lr=config.TRAIN.OPTIMIZER.LR,
89 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY)
90 | elif config.TRAIN.OPTIMIZER.NAME == 'sgd':
91 | optimizer = optim.SGD(parameters, lr=config.TRAIN.OPTIMIZER.LR, momentum=0.9,
92 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY, nesterov=True)
93 | optimizer2 = optim.SGD(parameters2, lr=config.TRAIN.OPTIMIZER.LR, momentum=0.9,
94 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY, nesterov=True)
95 | optimizer_cc = optim.SGD(clothes_classifier.parameters(), lr=config.TRAIN.OPTIMIZER.LR, momentum=0.9,
96 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY, nesterov=True)
97 | else:
98 | raise KeyError("Unknown optimizer: {}".format(config.TRAIN.OPTIMIZER.NAME))
99 |
100 | # Build lr_scheduler
101 | scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=config.TRAIN.LR_SCHEDULER.STEPSIZE,
102 | gamma=config.TRAIN.LR_SCHEDULER.DECAY_RATE)
103 | scheduler2 = lr_scheduler.MultiStepLR(optimizer2, milestones=config.TRAIN.LR_SCHEDULER.STEPSIZE,
104 | gamma=config.TRAIN.LR_SCHEDULER.DECAY_RATE)
105 |
106 | start_epoch = config.TRAIN.START_EPOCH
107 |
108 | if config.MODEL.RESUME:
109 | logger.info("Loading checkpoint from '{}'".format(config.MODEL.RESUME))
110 | checkpoint = torch.load(config.MODEL.RESUME)
111 | model.load_state_dict(checkpoint['model_state_dict'])
112 | classifier.load_state_dict(checkpoint['classifier_state_dict'])
113 | fuse.load_state_dict(checkpoint['fuse_state_dict'])
114 | model2.load_state_dict(checkpoint['model2_state_dict'])
115 | clothes_classifier.load_state_dict(checkpoint['clothes_classifier_state_dict'])
116 | clothes_classifier2.load_state_dict(checkpoint['clothes_classifier2_state_dict'])
117 | start_epoch = checkpoint['epoch']
118 |
119 | local_rank = dist.get_rank()
120 | model = model.cuda(local_rank)
121 | model2 = model2.cuda(local_rank)
122 | classifier = classifier.cuda(local_rank)
123 | clothes_classifier2 = clothes_classifier2.cuda(local_rank)
124 | fuse = fuse.cuda(local_rank)
125 | clothes_classifier = clothes_classifier.cuda(local_rank)
126 | torch.cuda.set_device(local_rank)
127 |
128 | if config.TRAIN.AMP:
129 | [model, fuse, classifier], optimizer = amp.initialize([model, fuse, classifier], optimizer, opt_level="O1")
130 | [model2, clothes_classifier2], optimizer2 = amp.initialize([model2, clothes_classifier2], optimizer2, opt_level="O1")
131 |
132 | model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
133 | fuse = nn.parallel.DistributedDataParallel(fuse, device_ids=[local_rank], output_device=local_rank)
134 | model2 = nn.parallel.DistributedDataParallel(model2, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
135 | classifier = nn.parallel.DistributedDataParallel(classifier, device_ids=[local_rank], output_device=local_rank)
136 | clothes_classifier2 = nn.parallel.DistributedDataParallel(clothes_classifier2, device_ids=[local_rank], output_device=local_rank)
137 |
138 | if config.EVAL_MODE:
139 | logger.info("Evaluate only")
140 | with torch.no_grad():
141 | if config.DATA.DATASET == 'prcc':
142 | test_prcc(model, queryloader_same, queryloader_diff, galleryloader, dataset)
143 | else:
144 | test(config, model, queryloader, galleryloader, dataset)
145 | return
146 |
147 | start_time = time.time()
148 | train_time = 0
149 | best_rank1 = -np.inf
150 | best_epoch = 0
151 | logger.info("==> Start training")
152 | for epoch in range(start_epoch, config.TRAIN.MAX_EPOCH):
153 | train_sampler.set_epoch(epoch)
154 | start_train_time = time.time()
155 |
156 | train_aim(config, epoch, model, model2, classifier, clothes_classifier, clothes_classifier2, fuse, criterion_cla, criterion_pair,
157 | criterion_clothes, criterion_adv, optimizer, optimizer2, optimizer_cc, trainloader, pid2clothes, kl)
158 |
159 | train_time += round(time.time() - start_train_time)
160 |
161 |
162 | if (epoch+1) > config.TEST.START_EVAL and config.TEST.EVAL_STEP > 0 and \
163 | (epoch+1) % config.TEST.EVAL_STEP == 0 or (epoch+1) == config.TRAIN.MAX_EPOCH:
164 | logger.info("==> Test")
165 | torch.cuda.empty_cache()
166 | if config.DATA.DATASET == 'prcc':
167 | rank1 = test_prcc(model, queryloader_same, queryloader_diff, galleryloader, dataset)
168 | else:
169 | rank1 = test(config, model, queryloader, galleryloader, dataset)
170 | torch.cuda.empty_cache()
171 | is_best = rank1 > best_rank1
172 | if is_best:
173 | best_rank1 = rank1
174 | best_epoch = epoch + 1
175 |
176 | model_state_dict = model.module.state_dict()
177 | model2_state_dict = model2.module.state_dict()
178 | fuse_state_dict = fuse.module.state_dict()
179 | classifier_state_dict = classifier.module.state_dict()
180 | clothes_classifier_state_dict = clothes_classifier.module.state_dict()
181 | clothes_classifier2_state_dict = clothes_classifier2.module.state_dict()
182 |
183 | if local_rank == 0:
184 | save_checkpoint({
185 | 'model_state_dict': model_state_dict,
186 | 'model2_state_dict': model2_state_dict,
187 | 'fuse_state_dict': fuse_state_dict,
188 | 'classifier_state_dict': classifier_state_dict,
189 | 'clothes_classifier_state_dict': clothes_classifier_state_dict,
190 | 'clothes_classifier2_state_dict': clothes_classifier2_state_dict,
191 | 'rank1': rank1,
192 | 'epoch': epoch,
193 | }, is_best, osp.join(config.OUTPUT, 'checkpoint_ep' + str(epoch+1) + '.pth.tar'))
194 | scheduler.step()
195 | scheduler2.step()
196 |
197 | logger.info("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(best_rank1, best_epoch))
198 |
199 | elapsed = round(time.time() - start_time)
200 | elapsed = str(datetime.timedelta(seconds=elapsed))
201 | train_time = str(datetime.timedelta(seconds=train_time))
202 | logger.info("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))
203 |
204 |
205 | if __name__ == '__main__':
206 | gc.collect()
207 | torch.cuda.empty_cache()
208 |
209 | config = parse_option()
210 | # Set GPU
211 | os.environ['CUDA_VISIBLE_DEVICES'] = config.GPU
212 | # Init dist
213 | dist.init_process_group(backend="nccl", init_method='env://')
214 | local_rank = dist.get_rank()
215 | # Set random seed
216 | set_seed(config.SEED + local_rank)
217 | # get logger
218 | if not config.EVAL_MODE:
219 | output_file = osp.join(config.OUTPUT, 'log_train_.log')
220 | else:
221 | output_file = osp.join(config.OUTPUT, 'log_test.log')
222 | logger = get_logger(output_file, local_rank, 'reid')
223 | logger.info("Config:\n-----------------------------------------")
224 | logger.info(config)
225 | logger.info("-----------------------------------------")
226 |
227 | main(config)
228 |
--------------------------------------------------------------------------------
/models/Fuse.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn import init
5 |
6 | EPSILON = 1e-12
7 |
8 | class BasicConv2d(nn.Module):
9 |
10 | def __init__(self, in_channels, out_channels, **kwargs):
11 | super(BasicConv2d, self).__init__()
12 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
13 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
14 |
15 | def forward(self, x):
16 | x = self.conv(x)
17 | x = self.bn(x)
18 | return F.relu(x, inplace=True)
19 |
20 | class Fuse(nn.Module):
21 | def __init__(self, feature_dim):
22 | super(Fuse, self).__init__()
23 |
24 | self.trans = nn.Linear(8*2048, feature_dim, bias=False)
25 | self.bn = nn.BatchNorm1d(feature_dim)
26 | self.pool = nn.AdaptiveMaxPool2d(1)
27 | self.M = 8
28 | self.attentions = BasicConv2d(2048, self.M, kernel_size=1)
29 | self.trans.weight.data.normal_(0, 0.001)
30 | init.normal_(self.bn.weight.data, 1.0, 0.02)
31 | init.constant_(self.bn.bias.data, 0.0)
32 |
33 | def forward(self, feat, counter_feat_in):
34 |
35 | counter_feat = self.attentions(counter_feat_in)
36 |
37 | B, C, H, W = feat.size()
38 | _, M, AH, AW = counter_feat.size()
39 |
40 | x = (torch.einsum('imjk,injk->imn', (counter_feat, feat)) / float(H * W)).view(B, -1)
41 | x = torch.sign(x) * torch.sqrt(torch.abs(x) + EPSILON)
42 | x = F.normalize(x, dim=-1)
43 | x = self.trans(x)
44 | x = self.bn(x)
45 |
46 | return x
--------------------------------------------------------------------------------
/models/Fusion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn import init
5 |
6 | EPSILON = 1e-12
7 |
8 | class BasicConv2d(nn.Module):
9 |
10 | def __init__(self, in_channels, out_channels, **kwargs):
11 | super(BasicConv2d, self).__init__()
12 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
13 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
14 |
15 | def forward(self, x):
16 | x = self.conv(x)
17 | x = self.bn(x)
18 | return F.relu(x, inplace=True)
19 |
20 | class Fusion(nn.Module):
21 | def __init__(self, feature_dim):
22 | super(Fusion, self).__init__()
23 |
24 | self.linear = nn.Linear(8*2048, feature_dim, bias=False)
25 | self.bn = nn.BatchNorm1d(feature_dim)
26 | self.pool = nn.AdaptiveMaxPool2d(1)
27 | self.M = 8
28 | self.attentions = BasicConv2d(2048, self.M, kernel_size=1)
29 | self.linear.weight.data.normal_(0, 0.001)
30 | init.normal_(self.bn.weight.data, 1.0, 0.02)
31 | init.constant_(self.bn.bias.data, 0.0)
32 |
33 | def forward(self, feat, feat2):
34 | feat2_att = self.attentions(feat2)
35 |
36 | B, C, H, W = feat.size()
37 | _, M, AH, AW = feat2_att.size()
38 |
39 | x = (torch.einsum('imjk,injk->imn', (feat2_att, feat)) / float(H * W)).view(B, -1)
40 | x = torch.sign(x) * torch.sqrt(torch.abs(x) + EPSILON)
41 | x = F.normalize(x, dim=-1)
42 | x = self.linear(x)
43 | x = self.bn(x)
44 |
45 | return x
--------------------------------------------------------------------------------
/models/PM.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn import init
5 | from torchvision import models
6 | from torch.autograd import Variable
7 | import torch.nn.functional as F
8 | from models.ResNet import resnet50
9 |
10 |
11 |
12 | ######################################################################
13 | def weight_init(m):
14 | if isinstance(m, nn.Conv2d):
15 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
16 | m.weight.data.normal_(0, math.sqrt(2. / n))
17 | elif isinstance(m, nn.BatchNorm2d):
18 | m.weight.data.fill_(1)
19 | m.bias.data.zero_()
20 | elif isinstance(m, nn.Linear):
21 | m.weight.data.normal_(0, 0.001)
22 |
23 |
24 | def pcb_block(num_ftrs, num_stripes, local_conv_out_channels, feature_dim, avg=False):
25 | if avg:
26 | pooling_list = nn.ModuleList([nn.AdaptiveAvgPool2d(1) for _ in range(num_stripes)])
27 | else:
28 | pooling_list = nn.ModuleList([nn.AdaptiveMaxPool2d(1) for _ in range(num_stripes)])
29 | conv_list = nn.ModuleList([nn.Conv2d(num_ftrs, local_conv_out_channels, 1, bias=False) for _ in range(num_stripes)])
30 | batchnorm_list = nn.ModuleList([nn.BatchNorm2d(local_conv_out_channels) for _ in range(num_stripes)])
31 | relu_list = nn.ModuleList([nn.ReLU(inplace=True) for _ in range(num_stripes)])
32 | for m in conv_list:
33 | weight_init(m)
34 | for m in batchnorm_list:
35 | weight_init(m)
36 | return pooling_list, conv_list, batchnorm_list, relu_list
37 |
38 |
39 | def spp_vertical(feats, pool_list, conv_list, bn_list, relu_list, num_strides, feat_list=[]):
40 | for i in range(num_strides):
41 | pcb_feat = pool_list[i](feats[:, :, i * int(feats.size(2) / num_strides): (i + 1) * int(feats.size(2) / num_strides), :])
42 | pcb_feat = conv_list[i](pcb_feat)
43 | pcb_feat = bn_list[i](pcb_feat)
44 | pcb_feat = relu_list[i](pcb_feat)
45 | pcb_feat = pcb_feat.view(pcb_feat.size(0), -1)
46 | feat_list.append(pcb_feat)
47 | return feat_list
48 |
49 | def global_pcb(feats, pool, conv, bn, relu, feat_list=[]):
50 | global_feat = pool(feats)
51 | global_feat = conv(global_feat)
52 | global_feat = bn(global_feat)
53 | global_feat = relu(global_feat)
54 | global_feat = global_feat.view(feats.size(0), -1)
55 | feat_list.append(global_feat)
56 | return feat_list
57 |
58 | class PM(nn.Module):
59 | def __init__(self, feature_dim, blocks=15, num_stripes=6, local_conv_out_channels=256, erase=0, loss={'htri'}, avg=False, **kwargs):
60 | super(PM, self).__init__()
61 | self.num_stripes = num_stripes
62 |
63 | model_ft = resnet50(pretrained=True, last_conv_stride=1)
64 | self.num_ftrs = list(model_ft.layer4)[-1].conv1.in_channels
65 | self.features = model_ft
66 |
67 | self.global_pooling = nn.AdaptiveMaxPool2d(1)
68 | self.global_conv = nn.Conv2d(self.num_ftrs, local_conv_out_channels, 1, bias=False)
69 | self.global_bn = nn.BatchNorm2d(local_conv_out_channels)
70 | self.global_relu = nn.ReLU(inplace=True)
71 |
72 | self.trans = nn.Linear(256*blocks, feature_dim, bias=False)
73 | self.bn = nn.BatchNorm1d(feature_dim)
74 |
75 | weight_init(self.global_conv)
76 | weight_init(self.global_bn)
77 | weight_init(self.trans)
78 | init.normal_(self.bn.weight.data, 1.0, 0.02)
79 | init.constant_(self.bn.bias.data, 0.0)
80 |
81 | self.pcb2_pool_list, self.pcb2_conv_list, self.pcb2_batchnorm_list, self.pcb2_relu_list = pcb_block(self.num_ftrs, 2, local_conv_out_channels, feature_dim, avg)
82 | self.pcb4_pool_list, self.pcb4_conv_list, self.pcb4_batchnorm_list, self.pcb4_relu_list = pcb_block(self.num_ftrs, 4, local_conv_out_channels, feature_dim, avg)
83 | self.pcb8_pool_list, self.pcb8_conv_list, self.pcb8_batchnorm_list, self.pcb8_relu_list = pcb_block(self.num_ftrs, 8, local_conv_out_channels, feature_dim, avg)
84 |
85 | def forward(self, x):
86 | feats = self.features(x)
87 |
88 | feat_list = global_pcb(feats, self.global_pooling, self.global_conv, self.global_bn, self.global_relu, [])
89 | feat_list = spp_vertical(feats, self.pcb2_pool_list, self.pcb2_conv_list, self.pcb2_batchnorm_list, self.pcb2_relu_list, 2, feat_list)
90 | feat_list = spp_vertical(feats, self.pcb4_pool_list, self.pcb4_conv_list, self.pcb4_batchnorm_list, self.pcb4_relu_list, 4, feat_list)
91 | feat_list = spp_vertical(feats, self.pcb8_pool_list, self.pcb8_conv_list, self.pcb8_batchnorm_list, self.pcb8_relu_list, 8, feat_list)
92 |
93 | ret = torch.cat(feat_list, dim=1)
94 | ret = self.trans(ret)
95 | ret = self.bn(ret)
96 | return feats, ret
97 |
--------------------------------------------------------------------------------
/models/ResNet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch.utils.model_zoo as model_zoo
4 | from models.utils import pooling
5 |
6 |
7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
8 | 'resnet152']
9 |
10 |
11 | model_urls = {
12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
17 | }
18 |
19 |
20 | def conv3x3(in_planes, out_planes, stride=1):
21 | """3x3 convolution with padding"""
22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
23 | padding=1, bias=False)
24 |
25 |
26 | class BasicBlock(nn.Module):
27 | expansion = 1
28 |
29 | def __init__(self, inplanes, planes, stride=1, downsample=None):
30 | super(BasicBlock, self).__init__()
31 | self.conv1 = conv3x3(inplanes, planes, stride)
32 | self.bn1 = nn.BatchNorm2d(planes)
33 | self.relu = nn.ReLU(inplace=True)
34 | self.conv2 = conv3x3(planes, planes)
35 | self.bn2 = nn.BatchNorm2d(planes)
36 | self.downsample = downsample
37 | self.stride = stride
38 |
39 | def forward(self, x):
40 | residual = x
41 |
42 | out = self.conv1(x)
43 | out = self.bn1(out)
44 | out = self.relu(out)
45 |
46 | out = self.conv2(out)
47 | out = self.bn2(out)
48 |
49 | if self.downsample is not None:
50 | residual = self.downsample(x)
51 |
52 | out += residual
53 | out = self.relu(out)
54 |
55 | return out
56 |
57 |
58 | class Bottleneck(nn.Module):
59 | expansion = 4
60 |
61 | def __init__(self, inplanes, planes, stride=1, downsample=None):
62 | super(Bottleneck, self).__init__()
63 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
64 | self.bn1 = nn.BatchNorm2d(planes)
65 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
66 | padding=1, bias=False)
67 | self.bn2 = nn.BatchNorm2d(planes)
68 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
69 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
70 | self.relu = nn.ReLU(inplace=True)
71 | self.downsample = downsample
72 | self.stride = stride
73 |
74 | def forward(self, x):
75 | residual = x
76 |
77 | out = self.conv1(x)
78 | out = self.bn1(out)
79 | out = self.relu(out)
80 |
81 | out = self.conv2(out)
82 | out = self.bn2(out)
83 | out = self.relu(out)
84 |
85 | out = self.conv3(out)
86 | out = self.bn3(out)
87 |
88 | if self.downsample is not None:
89 | residual = self.downsample(x)
90 |
91 | out += residual
92 | out = self.relu(out)
93 |
94 | return out
95 |
96 |
97 | class ResNet(nn.Module):
98 |
99 | def __init__(self, block, layers, last_conv_stride=2, num_classes=1000):
100 | self.inplanes = 64
101 | super(ResNet, self).__init__()
102 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
103 | self.bn1 = nn.BatchNorm2d(self.inplanes)
104 | self.relu = nn.ReLU(inplace=True)
105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
106 | self.layer1 = self._make_layer(block, 64, layers[0])
107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_conv_stride)
110 | self.avgpool = nn.AvgPool2d(7, stride=1)
111 | self.fc = nn.Linear(512 * block.expansion, num_classes)
112 | self.globalpool = pooling.MaxAvgPooling()
113 |
114 | for m in self.modules():
115 | if isinstance(m, nn.Conv2d):
116 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
117 | elif isinstance(m, nn.BatchNorm2d):
118 | nn.init.constant_(m.weight, 1)
119 | nn.init.constant_(m.bias, 0)
120 |
121 | def _make_layer(self, block, planes, blocks, stride=1):
122 | downsample = None
123 | if stride != 1 or self.inplanes != planes * block.expansion:
124 | downsample = nn.Sequential(
125 | nn.Conv2d(self.inplanes, planes * block.expansion,
126 | kernel_size=1, stride=stride, bias=False),
127 | nn.BatchNorm2d(planes * block.expansion),
128 | )
129 |
130 | layers = []
131 | layers.append(block(self.inplanes, planes, stride, downsample))
132 | self.inplanes = planes * block.expansion
133 | for i in range(1, blocks):
134 | layers.append(block(self.inplanes, planes))
135 |
136 | return nn.Sequential(*layers)
137 |
138 | def forward(self, x):
139 | x = self.conv1(x)
140 | x = self.bn1(x)
141 | x = self.relu(x)
142 | x = self.maxpool(x)
143 |
144 | x = self.layer1(x)
145 | x = self.layer2(x)
146 | x = self.layer3(x)
147 | x = self.layer4(x)
148 |
149 | # x = self.avgpool(x)
150 | # x = self.globalpool(x)
151 | # x = x.view(x.size(0), -1)
152 | # x = self.fc(x)
153 |
154 | return x
155 |
156 |
157 | def resnet18(pretrained=False, **kwargs):
158 | """Constructs a ResNet-18 model.
159 |
160 | Args:
161 | pretrained (bool): If True, returns a model pre-trained on ImageNet
162 | """
163 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
164 | if pretrained:
165 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
166 | return model
167 |
168 |
169 | def resnet34(pretrained=False, **kwargs):
170 | """Constructs a ResNet-34 model.
171 |
172 | Args:
173 | pretrained (bool): If True, returns a model pre-trained on ImageNet
174 | """
175 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
176 | if pretrained:
177 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
178 | return model
179 |
180 |
181 | def resnet50(pretrained=False, **kwargs):
182 | """Constructs a ResNet-50 model.
183 |
184 | Args:
185 | pretrained (bool): If True, returns a model pre-trained on ImageNet
186 | """
187 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
188 | if pretrained:
189 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
190 | return model
191 |
192 |
193 | def resnet101(pretrained=False, **kwargs):
194 | """Constructs a ResNet-101 model.
195 |
196 | Args:
197 | pretrained (bool): If True, returns a model pre-trained on ImageNet
198 | """
199 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
200 | if pretrained:
201 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
202 | return model
203 |
204 |
205 | def resnet152(pretrained=False, **kwargs):
206 | """Constructs a ResNet-152 model.
207 |
208 | Args:
209 | pretrained (bool): If True, returns a model pre-trained on ImageNet
210 | """
211 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
212 | if pretrained:
213 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
214 | return model
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from models.classifier import Classifier, NormalizedClassifier
3 | from models.img_resnet import ResNet50
4 | from models.PM import PM
5 | from models.Fusion import Fusion
6 |
7 |
8 | def build_model(config, num_identities, num_clothes):
9 | logger = logging.getLogger('reid.model')
10 | # Build backbone
11 | logger.info("Initializing model: {}".format(config.MODEL.NAME))
12 |
13 |
14 | logger.info("Init model: '{}'".format(config.MODEL.NAME))
15 | model = ResNet50(config)
16 |
17 | model2 = PM(feature_dim=config.MODEL.FEATURE_DIM)
18 | fusion = Fusion(feature_dim=config.MODEL.FEATURE_DIM)
19 |
20 | logger.info("Model size: {:.5f}M".format(sum(p.numel() for p in model.parameters())/1000000.0))
21 | logger.info("Model2 size: {:.5f}M".format(sum(p.numel() for p in model2.parameters())/1000000.0))
22 |
23 | # Build classifier
24 | if config.LOSS.CLA_LOSS in ['crossentropy', 'crossentropylabelsmooth']:
25 | identity_classifier = Classifier(feature_dim=config.MODEL.FEATURE_DIM, num_classes=num_identities)
26 | else:
27 | identity_classifier = NormalizedClassifier(feature_dim=config.MODEL.FEATURE_DIM, num_classes=num_identities)
28 |
29 | clothes_classifier = NormalizedClassifier(feature_dim=config.MODEL.FEATURE_DIM, num_classes=num_clothes)
30 |
31 | #classifier of new model
32 | clothes_classifier2 = NormalizedClassifier(feature_dim=config.MODEL.FEATURE_DIM, num_classes=num_clothes)
33 |
34 | return model, model2, fusion, identity_classifier, clothes_classifier, clothes_classifier2
--------------------------------------------------------------------------------
/models/classifier.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import init
4 | from torch.nn import functional as F
5 | from torch.nn import Parameter
6 |
7 |
8 | __all__ = ['Classifier', 'NormalizedClassifier']
9 |
10 |
11 | class Classifier(nn.Module):
12 | def __init__(self, feature_dim, num_classes):
13 | super().__init__()
14 | self.classifier = nn.Linear(feature_dim, num_classes)
15 | init.normal_(self.classifier.weight.data, std=0.001)
16 | init.constant_(self.classifier.bias.data, 0.0)
17 |
18 | def forward(self, x):
19 | y = self.classifier(x)
20 |
21 | return y
22 |
23 |
24 | class NormalizedClassifier(nn.Module):
25 | def __init__(self, feature_dim, num_classes):
26 | super().__init__()
27 | self.weight = Parameter(torch.Tensor(num_classes, feature_dim))
28 | self.weight.data.uniform_(-1, 1).renorm_(2,0,1e-5).mul_(1e5)
29 |
30 | def forward(self, x):
31 | w = self.weight
32 |
33 | x = F.normalize(x, p=2, dim=1)
34 | w = F.normalize(w, p=2, dim=1)
35 |
36 | return F.linear(x, w)
37 |
38 |
39 |
40 |
--------------------------------------------------------------------------------
/models/img_resnet.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | from torch import nn
3 | from torch.nn import init
4 | from models.utils import pooling
5 |
6 |
7 | class ResNet50(nn.Module):
8 | def __init__(self, config, **kwargs):
9 | super().__init__()
10 |
11 | resnet50 = torchvision.models.resnet50(pretrained=True)
12 | # or passing weights=ResNet50_Weights.IMAGENET1K_V1 instead of pretrained=True to haddle the warning
13 | if config.MODEL.RES4_STRIDE == 1:
14 | resnet50.layer4[0].conv2.stride=(1, 1)
15 | resnet50.layer4[0].downsample[0].stride=(1, 1)
16 | # self.base = nn.Sequential(*list(resnet50.children())[:-2])
17 | self.conv1 = resnet50.conv1
18 | self.bn1 = resnet50.bn1
19 | self.relu = resnet50.relu
20 | self.maxpool = resnet50.maxpool
21 |
22 | self.layer1 = resnet50.layer1
23 | self.layer2 = resnet50.layer2
24 | self.layer3 = resnet50.layer3
25 | self.layer4 = resnet50.layer4
26 |
27 | if config.MODEL.POOLING.NAME == 'avg':
28 | self.globalpooling = nn.AdaptiveAvgPool2d(1)
29 | elif config.MODEL.POOLING.NAME == 'max':
30 | self.globalpooling = nn.AdaptiveMaxPool2d(1)
31 | elif config.MODEL.POOLING.NAME == 'gem':
32 | self.globalpooling = pooling.GeMPooling(p=config.MODEL.POOLING.P)
33 | elif config.MODEL.POOLING.NAME == 'maxavg':
34 | self.globalpooling = pooling.MaxAvgPooling()
35 | else:
36 | raise KeyError("Invalid pooling: '{}'".format(config.MODEL.POOLING.NAME))
37 |
38 | self.bn = nn.BatchNorm1d(config.MODEL.FEATURE_DIM)
39 | init.normal_(self.bn.weight.data, 1.0, 0.02)
40 | init.constant_(self.bn.bias.data, 0.0)
41 |
42 | def forward(self, tmp):
43 | tmp = self.conv1(tmp)
44 | tmp = self.bn1(tmp)
45 | tmp = self.relu(tmp)
46 | tmp = self.maxpool(tmp)
47 |
48 | tmp = self.layer1(tmp)
49 | tmp = self.layer2(tmp)
50 | tmp = self.layer3(tmp)
51 | old_x = self.layer4(tmp) # torch.Size([32, 2048, 24, 12])
52 |
53 | # old_x = self.base(tmp)
54 |
55 | x = self.globalpooling(old_x) # torch.Size([32, 4096, 1, 1])
56 | x = x.view(x.size(0), -1)
57 | f = self.bn(x)
58 |
59 | return old_x, f
--------------------------------------------------------------------------------
/models/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 | from bisect import bisect_right
7 | import torch
8 |
9 |
10 | # FIXME ideally this would be achieved with a CombinedLRScheduler,
11 | # separating MultiStepLR with WarmupLR
12 | # but the current LRScheduler design doesn't allow it
13 |
14 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
15 | def __init__(
16 | self,
17 | optimizer,
18 | milestones,
19 | gamma=0.1,
20 | warmup_factor=1.0 / 3,
21 | warmup_iters=500,
22 | warmup_method="linear",
23 | last_epoch=-1,
24 | ):
25 | if not list(milestones) == sorted(milestones):
26 | raise ValueError(
27 | "Milestones should be a list of" " increasing integers. Got {}",
28 | milestones,
29 | )
30 |
31 | if warmup_method not in ("constant", "linear"):
32 | raise ValueError(
33 | "Only 'constant' or 'linear' warmup_method accepted"
34 | "got {}".format(warmup_method)
35 | )
36 | self.milestones = milestones # (40, 70)
37 | self.gamma = gamma # 0.1
38 | self.warmup_factor = warmup_factor # 0.01
39 | self.warmup_iters = warmup_iters # 0
40 | self.warmup_method = warmup_method # linear
41 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
42 |
43 | def get_lr(self):
44 | warmup_factor = 1
45 | if self.last_epoch < self.warmup_iters:
46 | if self.warmup_method == "constant":
47 | warmup_factor = self.warmup_factor
48 | elif self.warmup_method == "linear":
49 | alpha = self.last_epoch / self.warmup_iters
50 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
51 | return [
52 | base_lr
53 | * warmup_factor
54 | * self.gamma ** bisect_right(self.milestones, self.last_epoch)
55 | for base_lr in self.base_lrs
56 | ]
57 |
--------------------------------------------------------------------------------
/models/utils/c3d_blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class APM(nn.Module):
7 | def __init__(self, in_channels, out_channels, time_dim=3, temperature=4, contrastive_att=True):
8 | super(APM, self).__init__()
9 |
10 | self.time_dim = time_dim
11 | self.temperature = temperature
12 | self.contrastive_att = contrastive_att
13 |
14 | padding = (0, 0, 0, 0, (time_dim-1)//2, (time_dim-1)//2)
15 | self.padding = nn.ConstantPad3d(padding, value=0)
16 |
17 | self.semantic_mapping = nn.Conv3d(in_channels, out_channels, \
18 | kernel_size=1, bias=False)
19 | if self.contrastive_att:
20 | self.x_mapping = nn.Conv3d(in_channels, out_channels, \
21 | kernel_size=1, bias=False)
22 | self.n_mapping = nn.Conv3d(in_channels, out_channels, \
23 | kernel_size=1, bias=False)
24 | self.contrastive_att_net = nn.Sequential(nn.Conv3d(out_channels, 1, \
25 | kernel_size=1, bias=False), nn.Sigmoid())
26 |
27 | def forward(self, x):
28 | b, c, t, h, w = x.size()
29 | N = self.time_dim
30 |
31 | neighbor_time_index = torch.cat([(torch.arange(0,t)+i).unsqueeze(0) for i in range(N) if i!=N//2], dim=0).t().flatten().long()
32 |
33 | # feature map registration
34 | semantic = self.semantic_mapping(x) # (b, c/16, t, h, w)
35 | x_norm = F.normalize(semantic, p=2, dim=1) # (b, c/16, t, h, w)
36 | x_norm_padding = self.padding(x_norm) # (b, c/16, t+2, h, w)
37 | x_norm_expand = x_norm.unsqueeze(3).expand(-1, -1, -1, N-1, -1, -1).permute(0, 2, 3, 4, 5, 1).contiguous().view(-1, h*w, c//16) # (b*t*2, h*w, c/16)
38 | neighbor_norm = x_norm_padding[:, :, neighbor_time_index, :, :].permute(0, 2, 1, 3, 4).contiguous().view(-1, c//16, h*w) # (b*t*2, c/16, h*w)
39 |
40 | similarity = torch.matmul(x_norm_expand, neighbor_norm) * self.temperature # (b*t*2, h*w, h*w)
41 | similarity = F.softmax(similarity, dim=-1) # (b*t*2, h*w, h*w)
42 |
43 | x_padding = self.padding(x)
44 | neighbor = x_padding[:, :, neighbor_time_index, :, :].permute(0, 2, 3, 4, 1).contiguous().view(-1, h*w, c)
45 | neighbor_new = torch.matmul(similarity, neighbor).view(b, t*(N-1), h, w, c).permute(0, 4, 1, 2, 3) # (b, c, t*2, h, w)
46 |
47 | # contrastive attention
48 | if self.contrastive_att:
49 | x_att = self.x_mapping(x.unsqueeze(3).expand(-1, -1, -1, N-1, -1, -1).contiguous().view(b, c, (N-1)*t, h, w).detach())
50 | n_att = self.n_mapping(neighbor_new.detach())
51 | contrastive_att = self.contrastive_att_net(x_att * n_att)
52 | neighbor_new = neighbor_new * contrastive_att
53 |
54 | # integrating feature maps
55 | x_offset = torch.zeros([b, c, N*t, h, w], dtype=x.data.dtype, device=x.device.type)
56 | x_index = torch.tensor([i for i in range(t*N) if i%N==N//2])
57 | neighbor_index = torch.tensor([i for i in range(t*N) if i%N!=N//2])
58 | x_offset[:, :, x_index, :, :] += x
59 | x_offset[:, :, neighbor_index, :, :] += neighbor_new
60 |
61 | return x_offset
62 |
63 |
64 | class C2D(nn.Module):
65 | def __init__(self, conv2d, **kwargs):
66 | super(C2D, self).__init__()
67 |
68 | # conv3d kernel
69 | kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1])
70 | stride = (1, conv2d.stride[0], conv2d.stride[0])
71 | padding = (0, conv2d.padding[0], conv2d.padding[1])
72 | self.conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \
73 | kernel_size=kernel_dim, padding=padding, \
74 | stride=stride, bias=conv2d.bias)
75 |
76 | # init the parameters of conv3d
77 | weight_2d = conv2d.weight.data
78 | weight_3d = torch.zeros(*weight_2d.shape)
79 | weight_3d = weight_3d.unsqueeze(2)
80 | weight_3d[:, :, 0, :, :] = weight_2d
81 | self.conv3d.weight = nn.Parameter(weight_3d)
82 | self.conv3d.bias = conv2d.bias
83 |
84 | def forward(self, x):
85 | out = self.conv3d(x)
86 |
87 | return out
88 |
89 |
90 | class I3D(nn.Module):
91 | def __init__(self, conv2d, time_dim=3, time_stride=1, **kwargs):
92 | super(I3D, self).__init__()
93 |
94 | # conv3d kernel
95 | kernel_dim = (time_dim, conv2d.kernel_size[0], conv2d.kernel_size[1])
96 | stride = (time_stride, conv2d.stride[0], conv2d.stride[0])
97 | padding = (time_dim//2, conv2d.padding[0], conv2d.padding[1])
98 | self.conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \
99 | kernel_size=kernel_dim, padding=padding, \
100 | stride=stride, bias=conv2d.bias)
101 |
102 | # init the parameters of conv3d
103 | weight_2d = conv2d.weight.data
104 | weight_3d = torch.zeros(*weight_2d.shape)
105 | weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
106 | middle_idx = time_dim // 2
107 | weight_3d[:, :, middle_idx, :, :] = weight_2d
108 | self.conv3d.weight = nn.Parameter(weight_3d)
109 | self.conv3d.bias = conv2d.bias
110 |
111 | def forward(self, x):
112 | out = self.conv3d(x)
113 |
114 | return out
115 |
116 |
117 | class API3D(nn.Module):
118 | def __init__(self, conv2d, time_dim=3, time_stride=1, temperature=4, contrastive_att=True):
119 | super(API3D, self).__init__()
120 |
121 | self.APM = APM(conv2d.in_channels, conv2d.in_channels//16, \
122 | time_dim=time_dim, temperature=temperature, contrastive_att=contrastive_att)
123 |
124 | # conv3d kernel
125 | kernel_dim = (time_dim, conv2d.kernel_size[0], conv2d.kernel_size[1])
126 | stride = (time_stride*time_dim, conv2d.stride[0], conv2d.stride[0])
127 | padding = (0, conv2d.padding[0], conv2d.padding[1])
128 | self.conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \
129 | kernel_size=kernel_dim, padding=padding, \
130 | stride=stride, bias=conv2d.bias)
131 |
132 | # init the parameters of conv3d
133 | weight_2d = conv2d.weight.data
134 | weight_3d = torch.zeros(*weight_2d.shape)
135 | weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
136 | middle_idx = time_dim // 2
137 | weight_3d[:, :, middle_idx, :, :] = weight_2d
138 | self.conv3d.weight = nn.Parameter(weight_3d)
139 | self.conv3d.bias = conv2d.bias
140 |
141 | def forward(self, x):
142 | x_offset = self.APM(x)
143 | out = self.conv3d(x_offset)
144 |
145 | return out
146 |
147 |
148 | class P3DA(nn.Module):
149 | def __init__(self, conv2d, time_dim=3, time_stride=1, **kwargs):
150 | super(P3DA, self).__init__()
151 |
152 | # spatial conv3d kernel
153 | kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1])
154 | stride = (1, conv2d.stride[0], conv2d.stride[0])
155 | padding = (0, conv2d.padding[0], conv2d.padding[1])
156 | self.spatial_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \
157 | kernel_size=kernel_dim, padding=padding, \
158 | stride=stride, bias=conv2d.bias)
159 |
160 | # init the parameters of spatial_conv3d
161 | weight_2d = conv2d.weight.data
162 | weight_3d = torch.zeros(*weight_2d.shape)
163 | weight_3d = weight_3d.unsqueeze(2)
164 | weight_3d[:, :, 0, :, :] = weight_2d
165 | self.spatial_conv3d.weight = nn.Parameter(weight_3d)
166 | self.spatial_conv3d.bias = conv2d.bias
167 |
168 |
169 | # temporal conv3d kernel
170 | kernel_dim = (time_dim, 1, 1)
171 | stride = (time_stride, 1, 1)
172 | padding = (time_dim//2, 0, 0)
173 | self.temporal_conv3d = nn.Conv3d(conv2d.out_channels, conv2d.out_channels, \
174 | kernel_size=kernel_dim, padding=padding, \
175 | stride=stride, bias=False)
176 |
177 | # init the parameters of temporal_conv3d
178 | weight_2d = torch.eye(conv2d.out_channels).unsqueeze(2).unsqueeze(2)
179 | weight_3d = torch.zeros(*weight_2d.shape)
180 | weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
181 | middle_idx = time_dim // 2
182 | weight_3d[:, :, middle_idx, :, :] = weight_2d
183 | self.temporal_conv3d.weight = nn.Parameter(weight_3d)
184 |
185 |
186 | def forward(self, x):
187 | x = self.spatial_conv3d(x)
188 | out = self.temporal_conv3d(x)
189 |
190 | return out
191 |
192 |
193 | class P3DB(nn.Module):
194 | def __init__(self, conv2d, time_dim=3, time_stride=1, **kwargs):
195 | super(P3DB, self).__init__()
196 |
197 | # spatial conv3d kernel
198 | kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1])
199 | stride = (1, conv2d.stride[0], conv2d.stride[0])
200 | padding = (0, conv2d.padding[0], conv2d.padding[1])
201 | self.spatial_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \
202 | kernel_size=kernel_dim, padding=padding, \
203 | stride=stride, bias=conv2d.bias)
204 |
205 | # init the parameters of spatial_conv3d
206 | weight_2d = conv2d.weight.data
207 | weight_3d = torch.zeros(*weight_2d.shape)
208 | weight_3d = weight_3d.unsqueeze(2)
209 | weight_3d[:, :, 0, :, :] = weight_2d
210 | self.spatial_conv3d.weight = nn.Parameter(weight_3d)
211 | self.spatial_conv3d.bias = conv2d.bias
212 |
213 |
214 | # temporal conv3d kernel
215 | kernel_dim = (time_dim, 1, 1)
216 | stride = (time_stride, conv2d.stride[0], conv2d.stride[0])
217 | padding = (time_dim//2, 0, 0)
218 | self.temporal_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \
219 | kernel_size=kernel_dim, padding=padding, \
220 | stride=stride, bias=False)
221 |
222 | # init the parameters of temporal_conv3d
223 | nn.init.constant_(self.temporal_conv3d.weight, 0)
224 |
225 |
226 | def forward(self, x):
227 | # print(x.shape)
228 | out1 = self.spatial_conv3d(x)
229 | # print(out1.shape)
230 | out2 = self.temporal_conv3d(x)
231 | # print(out2.shape)
232 | out = out1 + out2
233 |
234 | return out
235 |
236 |
237 | class P3DC(nn.Module):
238 | def __init__(self, conv2d, time_dim=3, time_stride=1, **kwargs):
239 | super(P3DC, self).__init__()
240 |
241 | # spatial conv3d kernel
242 | kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1])
243 | stride = (1, conv2d.stride[0], conv2d.stride[0])
244 | padding = (0, conv2d.padding[0], conv2d.padding[1])
245 | self.spatial_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \
246 | kernel_size=kernel_dim, padding=padding, \
247 | stride=stride, bias=conv2d.bias)
248 |
249 | # init the parameters of spatial_conv3d
250 | weight_2d = conv2d.weight.data
251 | weight_3d = torch.zeros(*weight_2d.shape)
252 | weight_3d = weight_3d.unsqueeze(2)
253 | weight_3d[:, :, 0, :, :] = weight_2d
254 | self.spatial_conv3d.weight = nn.Parameter(weight_3d)
255 | self.spatial_conv3d.bias = conv2d.bias
256 |
257 |
258 | # temporal conv3d kernel
259 | kernel_dim = (time_dim, 1, 1)
260 | stride = (time_stride, 1, 1)
261 | padding = (time_dim//2, 0, 0)
262 | self.temporal_conv3d = nn.Conv3d(conv2d.out_channels, conv2d.out_channels, \
263 | kernel_size=kernel_dim, padding=padding, \
264 | stride=stride, bias=False)
265 |
266 | # init the parameters of temporal_conv3d
267 | nn.init.constant_(self.temporal_conv3d.weight, 0)
268 |
269 |
270 | def forward(self, x):
271 | out = self.spatial_conv3d(x)
272 | residual = self.temporal_conv3d(out)
273 | out = out + residual
274 |
275 | return out
276 |
277 |
278 | class APP3DA(nn.Module):
279 | def __init__(self, conv2d, time_dim=3, time_stride=1, temperature=4, contrastive_att=True):
280 | super(APP3DA, self).__init__()
281 |
282 | self.APM = APM(conv2d.out_channels, conv2d.out_channels//16, \
283 | time_dim=time_dim, temperature=temperature, contrastive_att=contrastive_att)
284 |
285 | # spatial conv3d kernel
286 | kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1])
287 | stride = (1, conv2d.stride[0], conv2d.stride[0])
288 | padding = (0, conv2d.padding[0], conv2d.padding[1])
289 | self.spatial_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \
290 | kernel_size=kernel_dim, padding=padding, \
291 | stride=stride, bias=conv2d.bias)
292 |
293 | # init the parameters of spatial_conv3d
294 | weight_2d = conv2d.weight.data
295 | weight_3d = torch.zeros(*weight_2d.shape)
296 | weight_3d = weight_3d.unsqueeze(2)
297 | weight_3d[:, :, 0, :, :] = weight_2d
298 | self.spatial_conv3d.weight = nn.Parameter(weight_3d)
299 | self.spatial_conv3d.bias = conv2d.bias
300 |
301 |
302 | # temporal conv3d kernel
303 | kernel_dim = (time_dim, 1, 1)
304 | stride = (time_stride*time_dim, 1, 1)
305 | padding = (0, 0, 0)
306 | self.temporal_conv3d = nn.Conv3d(conv2d.out_channels, conv2d.out_channels, \
307 | kernel_size=kernel_dim, padding=padding, \
308 | stride=stride, bias=False)
309 |
310 | # init the parameters of temporal_conv3d
311 | weight_2d = torch.eye(conv2d.out_channels).unsqueeze(2).unsqueeze(2)
312 | weight_3d = torch.zeros(*weight_2d.shape)
313 | weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
314 | middle_idx = time_dim // 2
315 | weight_3d[:, :, middle_idx, :, :] = weight_2d
316 | self.temporal_conv3d.weight = nn.Parameter(weight_3d)
317 |
318 |
319 | def forward(self, x):
320 | x = self.spatial_conv3d(x)
321 | out = self.temporal_conv3d(self.APM(x))
322 |
323 | return out
324 |
325 |
326 | class APP3DB(nn.Module):
327 | def __init__(self, conv2d, time_dim=3, time_stride=1, temperature=4, contrastive_att=True):
328 | super(APP3DB, self).__init__()
329 |
330 | self.APM = APM(conv2d.in_channels, conv2d.in_channels//16, \
331 | time_dim=time_dim, temperature=temperature, contrastive_att=contrastive_att)
332 |
333 | # spatial conv3d kernel
334 | kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1])
335 | stride = (1, conv2d.stride[0], conv2d.stride[0])
336 | padding = (0, conv2d.padding[0], conv2d.padding[1])
337 | self.spatial_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \
338 | kernel_size=kernel_dim, padding=padding, \
339 | stride=stride, bias=conv2d.bias)
340 |
341 | # init the parameters of spatial_conv3d
342 | weight_2d = conv2d.weight.data
343 | weight_3d = torch.zeros(*weight_2d.shape)
344 | weight_3d = weight_3d.unsqueeze(2)
345 | weight_3d[:, :, 0, :, :] = weight_2d
346 | self.spatial_conv3d.weight = nn.Parameter(weight_3d)
347 | self.spatial_conv3d.bias = conv2d.bias
348 |
349 |
350 | # temporal conv3d kernel
351 | kernel_dim = (time_dim, 1, 1)
352 | stride = (time_stride*time_dim, conv2d.stride[0], conv2d.stride[0])
353 | padding = (0, 0, 0)
354 | self.temporal_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \
355 | kernel_size=kernel_dim, padding=padding, \
356 | stride=stride, bias=False)
357 |
358 | # init the parameters of temporal_conv3d
359 | nn.init.constant_(self.temporal_conv3d.weight, 0)
360 |
361 |
362 | def forward(self, x):
363 | out1 = self.spatial_conv3d(x)
364 | out2 = self.temporal_conv3d(self.APM(x))
365 | out = out1 + out2
366 |
367 | return out
368 |
369 |
370 | class APP3DC(nn.Module):
371 | def __init__(self, conv2d, time_dim=3, time_stride=1, temperature=4, contrastive_att=True):
372 | super(APP3DC, self).__init__()
373 |
374 | self.APM = APM(conv2d.out_channels, conv2d.out_channels//16, \
375 | time_dim=time_dim, temperature=temperature, contrastive_att=contrastive_att)
376 |
377 | # spatial conv3d kernel
378 | kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1])
379 | stride = (1, conv2d.stride[0], conv2d.stride[0])
380 | padding = (0, conv2d.padding[0], conv2d.padding[1])
381 | self.spatial_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \
382 | kernel_size=kernel_dim, padding=padding, \
383 | stride=stride, bias=conv2d.bias)
384 |
385 | # init the parameters of spatial_conv3d
386 | weight_2d = conv2d.weight.data
387 | weight_3d = torch.zeros(*weight_2d.shape)
388 | weight_3d = weight_3d.unsqueeze(2)
389 | weight_3d[:, :, 0, :, :] = weight_2d
390 | self.spatial_conv3d.weight = nn.Parameter(weight_3d)
391 | self.spatial_conv3d.bias = conv2d.bias
392 |
393 |
394 | # temporal conv3d kernel
395 | kernel_dim = (time_dim, 1, 1)
396 | stride = (time_stride*time_dim, 1, 1)
397 | padding = (0, 0, 0)
398 | self.temporal_conv3d = nn.Conv3d(conv2d.out_channels, conv2d.out_channels, \
399 | kernel_size=kernel_dim, padding=padding, \
400 | stride=stride, bias=False)
401 |
402 | # init the parameters of temporal_conv3d
403 | nn.init.constant_(self.temporal_conv3d.weight, 0)
404 |
405 |
406 | def forward(self, x):
407 | out = self.spatial_conv3d(x)
408 | residual = self.temporal_conv3d(self.APM(out))
409 | out = out + residual
410 |
411 | return out
412 |
--------------------------------------------------------------------------------
/models/utils/inflate.py:
--------------------------------------------------------------------------------
1 | # inflate 2D modules to 3D modules
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn import functional as F
5 |
6 |
7 | def inflate_conv(conv2d,
8 | time_dim=1,
9 | time_padding=0,
10 | time_stride=1,
11 | time_dilation=1,
12 | center=False):
13 | # To preserve activations, padding should be by continuity and not zero
14 | # or no padding in time dimension
15 | kernel_dim = (time_dim, conv2d.kernel_size[0], conv2d.kernel_size[1])
16 | padding = (time_padding, conv2d.padding[0], conv2d.padding[1])
17 | stride = (time_stride, conv2d.stride[0], conv2d.stride[0])
18 | dilation = (time_dilation, conv2d.dilation[0], conv2d.dilation[1])
19 | conv3d = nn.Conv3d(
20 | conv2d.in_channels,
21 | conv2d.out_channels,
22 | kernel_dim,
23 | padding=padding,
24 | dilation=dilation,
25 | stride=stride)
26 | # Repeat filter time_dim times along time dimension
27 | weight_2d = conv2d.weight.data
28 | if center:
29 | weight_3d = torch.zeros(*weight_2d.shape)
30 | weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
31 | middle_idx = time_dim // 2
32 | weight_3d[:, :, middle_idx, :, :] = weight_2d
33 | else:
34 | weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
35 | weight_3d = weight_3d / time_dim
36 |
37 | # Assign new params
38 | conv3d.weight = nn.Parameter(weight_3d)
39 | conv3d.bias = conv2d.bias
40 | return conv3d
41 |
42 |
43 | def inflate_linear(linear2d, time_dim):
44 | """
45 | Args:
46 | time_dim: final time dimension of the features
47 | """
48 | linear3d = nn.Linear(linear2d.in_features * time_dim,
49 | linear2d.out_features)
50 | weight3d = linear2d.weight.data.repeat(1, time_dim)
51 | weight3d = weight3d / time_dim
52 |
53 | linear3d.weight = nn.Parameter(weight3d)
54 | linear3d.bias = linear2d.bias
55 | return linear3d
56 |
57 |
58 | def inflate_batch_norm(batch2d):
59 | # In pytorch 0.2.0 the 2d and 3d versions of batch norm
60 | # work identically except for the check that verifies the
61 | # input dimensions
62 |
63 | batch3d = nn.BatchNorm3d(batch2d.num_features)
64 | # retrieve 3d _check_input_dim function
65 | batch2d._check_input_dim = batch3d._check_input_dim
66 | return batch2d
67 |
68 |
69 | def inflate_pool(pool2d,
70 | time_dim=1,
71 | time_padding=0,
72 | time_stride=None,
73 | time_dilation=1):
74 | kernel_dim = (time_dim, pool2d.kernel_size, pool2d.kernel_size)
75 | padding = (time_padding, pool2d.padding, pool2d.padding)
76 | if time_stride is None:
77 | time_stride = time_dim
78 | stride = (time_stride, pool2d.stride, pool2d.stride)
79 | if isinstance(pool2d, nn.MaxPool2d):
80 | dilation = (time_dilation, pool2d.dilation, pool2d.dilation)
81 | pool3d = nn.MaxPool3d(
82 | kernel_dim,
83 | padding=padding,
84 | dilation=dilation,
85 | stride=stride,
86 | ceil_mode=pool2d.ceil_mode)
87 | elif isinstance(pool2d, nn.AvgPool2d):
88 | pool3d = nn.AvgPool3d(kernel_dim, stride=stride)
89 | else:
90 | raise ValueError(
91 | '{} is not among known pooling classes'.format(type(pool2d)))
92 | return pool3d
93 |
94 |
95 | class MaxPool2dFor3dInput(nn.Module):
96 | """
97 | Since nn.MaxPool3d is nondeterministic operation, using fixed random seeds can't get consistent results.
98 | So we attempt to use max_pool2d to implement MaxPool3d with kernelsize (1, kernel_size, kernel_size).
99 | """
100 | def __init__(self, kernel_size, stride=None, padding=0, dilation=1):
101 | super().__init__()
102 | self.maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation)
103 | def forward(self, x):
104 | b, c, t, h, w = x.size()
105 | x = x.permute(0, 2, 1, 3, 4).contiguous() # b, t, c, h, w
106 | x = x.view(b*t, c, h, w)
107 | # max pooling
108 | x = self.maxpool(x)
109 | _, _, h, w = x.size()
110 | x = x.view(b, t, c, h, w).permute(0, 2, 1, 3, 4).contiguous()
111 |
112 | return x
--------------------------------------------------------------------------------
/models/utils/nonlocal_blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from models.utils import inflate
6 |
7 |
8 | class NonLocalBlockND(nn.Module):
9 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
10 | super(NonLocalBlockND, self).__init__()
11 |
12 | assert dimension in [1, 2, 3]
13 |
14 | self.dimension = dimension
15 | self.sub_sample = sub_sample
16 | self.in_channels = in_channels
17 | self.inter_channels = inter_channels
18 |
19 | if self.inter_channels is None:
20 | self.inter_channels = in_channels // 2
21 | if self.inter_channels == 0:
22 | self.inter_channels = 1
23 |
24 | if dimension == 3:
25 | conv_nd = nn.Conv3d
26 | # max_pool = inflate.MaxPool2dFor3dInput
27 | max_pool = nn.MaxPool3d
28 | bn = nn.BatchNorm3d
29 | elif dimension == 2:
30 | conv_nd = nn.Conv2d
31 | max_pool = nn.MaxPool2d
32 | bn = nn.BatchNorm2d
33 | else:
34 | conv_nd = nn.Conv1d
35 | max_pool = nn.MaxPool1d
36 | bn = nn.BatchNorm1d
37 |
38 | self.g = conv_nd(self.in_channels, self.inter_channels,
39 | kernel_size=1, stride=1, padding=0, bias=True)
40 | self.theta = conv_nd(self.in_channels, self.inter_channels,
41 | kernel_size=1, stride=1, padding=0, bias=True)
42 | self.phi = conv_nd(self.in_channels, self.inter_channels,
43 | kernel_size=1, stride=1, padding=0, bias=True)
44 | # if sub_sample:
45 | # self.g = nn.Sequential(self.g, max_pool(kernel_size=2))
46 | # self.phi = nn.Sequential(self.phi, max_pool(kernel_size=2))
47 | if sub_sample:
48 | if dimension == 3:
49 | self.g = nn.Sequential(self.g, max_pool((1, 2, 2)))
50 | self.phi = nn.Sequential(self.phi, max_pool((1, 2, 2)))
51 | else:
52 | self.g = nn.Sequential(self.g, max_pool(kernel_size=2))
53 | self.phi = nn.Sequential(self.phi, max_pool(kernel_size=2))
54 |
55 | if bn_layer:
56 | self.W = nn.Sequential(
57 | conv_nd(self.inter_channels, self.in_channels,
58 | kernel_size=1, stride=1, padding=0, bias=True),
59 | bn(self.in_channels)
60 | )
61 | else:
62 | self.W = conv_nd(self.inter_channels, self.in_channels,
63 | kernel_size=1, stride=1, padding=0, bias=True)
64 |
65 | # init
66 | for m in self.modules():
67 | if isinstance(m, conv_nd):
68 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
69 | m.weight.data.normal_(0, math.sqrt(2. / n))
70 | elif isinstance(m, bn):
71 | m.weight.data.fill_(1)
72 | m.bias.data.zero_()
73 |
74 | if bn_layer:
75 | nn.init.constant_(self.W[1].weight.data, 0.0)
76 | nn.init.constant_(self.W[1].bias.data, 0.0)
77 | else:
78 | nn.init.constant_(self.W.weight.data, 0.0)
79 | nn.init.constant_(self.W.bias.data, 0.0)
80 |
81 |
82 | def forward(self, x):
83 | '''
84 | :param x: (b, c, t, h, w)
85 | :return:
86 | '''
87 | batch_size = x.size(0)
88 |
89 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
90 | g_x = g_x.permute(0, 2, 1)
91 |
92 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
93 | theta_x = theta_x.permute(0, 2, 1)
94 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
95 | f = torch.matmul(theta_x, phi_x)
96 | f = F.softmax(f, dim=-1)
97 |
98 | y = torch.matmul(f, g_x)
99 | y = y.permute(0, 2, 1).contiguous()
100 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
101 | y = self.W(y)
102 | z = y + x
103 |
104 | return z
105 |
106 |
107 | class NonLocalBlock1D(NonLocalBlockND):
108 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
109 | super(NonLocalBlock1D, self).__init__(in_channels,
110 | inter_channels=inter_channels,
111 | dimension=1, sub_sample=sub_sample,
112 | bn_layer=bn_layer)
113 |
114 |
115 | class NonLocalBlock2D(NonLocalBlockND):
116 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
117 | super(NonLocalBlock2D, self).__init__(in_channels,
118 | inter_channels=inter_channels,
119 | dimension=2, sub_sample=sub_sample,
120 | bn_layer=bn_layer)
121 |
122 |
123 | class NonLocalBlock3D(NonLocalBlockND):
124 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
125 | super(NonLocalBlock3D, self).__init__(in_channels,
126 | inter_channels=inter_channels,
127 | dimension=3, sub_sample=sub_sample,
128 | bn_layer=bn_layer)
129 |
--------------------------------------------------------------------------------
/models/utils/pooling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 |
6 | class GeMPooling(nn.Module):
7 | def __init__(self, p=3, eps=1e-6):
8 | super().__init__()
9 | self.p = nn.Parameter(torch.ones(1) * p)
10 | self.eps = eps
11 |
12 | def forward(self, x):
13 | return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), x.size()[2:]).pow(1./self.p)
14 |
15 |
16 | class MaxAvgPooling(nn.Module):
17 | def __init__(self):
18 | super().__init__()
19 | self.maxpooling = nn.AdaptiveMaxPool2d(1)
20 | self.avgpooling = nn.AdaptiveAvgPool2d(1)
21 |
22 | def forward(self, x):
23 | max_f = self.maxpooling(x)
24 | avg_f = self.avgpooling(x)
25 |
26 | return torch.cat((max_f, avg_f), 1)
27 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import time
2 | import logging
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import distributed as dist
6 | from tools.eval_metrics import evaluate, evaluate_with_clothes
7 |
8 |
9 | def concat_all_gather(tensors, num_total_examples):
10 | '''
11 | Performs all_gather operation on the provided tensor list.
12 | '''
13 | outputs = []
14 | for tensor in tensors:
15 | tensor = tensor.cuda()
16 | tensors_gather = [tensor.clone() for _ in range(dist.get_world_size())]
17 | dist.all_gather(tensors_gather, tensor)
18 | output = torch.cat(tensors_gather, dim=0).cpu()
19 | # truncate the dummy elements added by DistributedInferenceSampler
20 | outputs.append(output[:num_total_examples])
21 | return outputs
22 |
23 |
24 | @torch.no_grad()
25 | def extract_img_feature(model, dataloader):
26 | features, pids, camids, clothes_ids = [], torch.tensor([]), torch.tensor([]), torch.tensor([])
27 | for batch_idx, (imgs, batch_pids, batch_camids, batch_clothes_ids, batch_img_path) in enumerate(dataloader):
28 | flip_imgs = torch.flip(imgs, [3])
29 | imgs, flip_imgs = imgs.cuda(), flip_imgs.cuda()
30 | _, batch_features = model(imgs)
31 | _, batch_features_flip = model(flip_imgs)
32 | batch_features += batch_features_flip
33 | batch_features = F.normalize(batch_features, p=2, dim=1)
34 |
35 | features.append(batch_features.cpu())
36 | pids = torch.cat((pids, batch_pids.cpu()), dim=0)
37 | camids = torch.cat((camids, batch_camids.cpu()), dim=0)
38 | clothes_ids = torch.cat((clothes_ids, batch_clothes_ids.cpu()), dim=0)
39 | features = torch.cat(features, 0)
40 |
41 | return features, pids, camids, clothes_ids
42 |
43 |
44 | def test(config, model, queryloader, galleryloader, dataset):
45 | logger = logging.getLogger('reid.test')
46 | since = time.time()
47 | model.eval()
48 | local_rank = dist.get_rank()
49 | # Extract features
50 | qf, q_pids, q_camids, q_clothes_ids = extract_img_feature(model, queryloader)
51 | gf, g_pids, g_camids, g_clothes_ids = extract_img_feature(model, galleryloader)
52 | # Gather samples from different GPUs
53 | torch.cuda.empty_cache()
54 | qf, q_pids, q_camids, q_clothes_ids = concat_all_gather([qf, q_pids, q_camids, q_clothes_ids], len(dataset.query))
55 | gf, g_pids, g_camids, g_clothes_ids = concat_all_gather([gf, g_pids, g_camids, g_clothes_ids], len(dataset.gallery))
56 | torch.cuda.empty_cache()
57 | time_elapsed = time.time() - since
58 |
59 | logger.info("Extracted features for query set, obtained {} matrix".format(qf.shape))
60 | logger.info("Extracted features for gallery set, obtained {} matrix".format(gf.shape))
61 | logger.info('Extracting features complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
62 | # Compute distance matrix between query and gallery
63 | since = time.time()
64 | m, n = qf.size(0), gf.size(0)
65 | distmat = torch.zeros((m,n))
66 | qf, gf = qf.cuda(), gf.cuda()
67 | # Cosine similarity
68 | for i in range(m):
69 | distmat[i] = (- torch.mm(qf[i:i+1], gf.t())).cpu()
70 | distmat = distmat.numpy()
71 | q_pids, q_camids, q_clothes_ids = q_pids.numpy(), q_camids.numpy(), q_clothes_ids.numpy()
72 | g_pids, g_camids, g_clothes_ids = g_pids.numpy(), g_camids.numpy(), g_clothes_ids.numpy()
73 | time_elapsed = time.time() - since
74 | logger.info('Distance computing in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
75 |
76 | since = time.time()
77 | logger.info("Computing CMC and mAP")
78 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
79 | logger.info("Results ---------------------------------------------------")
80 | logger.info('top1:{:.1%} top5:{:.1%} top10:{:.1%} top20:{:.1%} mAP:{:.1%}'.format(cmc[0], cmc[4], cmc[9], cmc[19], mAP))
81 | logger.info("-----------------------------------------------------------")
82 | time_elapsed = time.time() - since
83 | logger.info('Using {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
84 |
85 | logger.info("Computing CMC and mAP only for the same clothes setting")
86 | cmc, mAP = evaluate_with_clothes(distmat, q_pids, g_pids, q_camids, g_camids, q_clothes_ids, g_clothes_ids, mode='SC')
87 | logger.info("Results ---------------------------------------------------")
88 | logger.info('top1:{:.1%} top5:{:.1%} top10:{:.1%} top20:{:.1%} mAP:{:.1%}'.format(cmc[0], cmc[4], cmc[9], cmc[19], mAP))
89 | logger.info("-----------------------------------------------------------")
90 |
91 | logger.info("Computing CMC and mAP only for clothes-changing")
92 | cmc, mAP = evaluate_with_clothes(distmat, q_pids, g_pids, q_camids, g_camids, q_clothes_ids, g_clothes_ids, mode='CC')
93 | logger.info("Results ---------------------------------------------------")
94 | logger.info('top1:{:.1%} top5:{:.1%} top10:{:.1%} top20:{:.1%} mAP:{:.1%}'.format(cmc[0], cmc[4], cmc[9], cmc[19], mAP))
95 | logger.info("-----------------------------------------------------------")
96 |
97 | return cmc[0]
98 |
99 |
100 | def test_prcc(model, queryloader_same, queryloader_diff, galleryloader, dataset):
101 | logger = logging.getLogger('reid.test')
102 | since = time.time()
103 | model.eval()
104 | local_rank = dist.get_rank()
105 | # Extract features for query set
106 | qsf, qs_pids, qs_camids, qs_clothes_ids = extract_img_feature(model, queryloader_same)
107 | qdf, qd_pids, qd_camids, qd_clothes_ids = extract_img_feature(model, queryloader_diff)
108 | # Extract features for gallery set
109 | gf, g_pids, g_camids, g_clothes_ids = extract_img_feature(model, galleryloader)
110 | # Gather samples from different GPUs
111 | torch.cuda.empty_cache()
112 | qsf, qs_pids, qs_camids, qs_clothes_ids = concat_all_gather([qsf, qs_pids, qs_camids, qs_clothes_ids], len(dataset.query_same))
113 | qdf, qd_pids, qd_camids, qd_clothes_ids = concat_all_gather([qdf, qd_pids, qd_camids, qd_clothes_ids], len(dataset.query_diff))
114 | gf, g_pids, g_camids, g_clothes_ids = concat_all_gather([gf, g_pids, g_camids, g_clothes_ids], len(dataset.gallery))
115 | time_elapsed = time.time() - since
116 |
117 | logger.info("Extracted features for query set (with same clothes), obtained {} matrix".format(qsf.shape))
118 | logger.info("Extracted features for query set (with different clothes), obtained {} matrix".format(qdf.shape))
119 | logger.info("Extracted features for gallery set, obtained {} matrix".format(gf.shape))
120 | logger.info('Extracting features complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
121 | # Compute distance matrix between query and gallery
122 | m, n, k = qsf.size(0), qdf.size(0), gf.size(0)
123 | distmat_same = torch.zeros((m, k))
124 | distmat_diff = torch.zeros((n, k))
125 | qsf, qdf, gf = qsf.cuda(), qdf.cuda(), gf.cuda()
126 | # Cosine similarity
127 | for i in range(m):
128 | distmat_same[i] = (- torch.mm(qsf[i:i+1], gf.t())).cpu()
129 | for i in range(n):
130 | distmat_diff[i] = (- torch.mm(qdf[i:i+1], gf.t())).cpu()
131 | distmat_same = distmat_same.numpy()
132 | distmat_diff = distmat_diff.numpy()
133 | qs_pids, qs_camids, qs_clothes_ids = qs_pids.numpy(), qs_camids.numpy(), qs_clothes_ids.numpy()
134 | qd_pids, qd_camids, qd_clothes_ids = qd_pids.numpy(), qd_camids.numpy(), qd_clothes_ids.numpy()
135 | g_pids, g_camids, g_clothes_ids = g_pids.numpy(), g_camids.numpy(), g_clothes_ids.numpy()
136 |
137 | logger.info("Computing CMC and mAP for the same clothes setting")
138 | cmc, mAP = evaluate(distmat_same, qs_pids, g_pids, qs_camids, g_camids)
139 | logger.info("Results ---------------------------------------------------")
140 | logger.info('top1:{:.1%} top5:{:.1%} top10:{:.1%} top20:{:.1%} mAP:{:.1%}'.format(cmc[0], cmc[4], cmc[9], cmc[19], mAP))
141 | logger.info("-----------------------------------------------------------")
142 |
143 | logger.info("Computing CMC and mAP only for clothes changing")
144 | cmc, mAP = evaluate(distmat_diff, qd_pids, g_pids, qd_camids, g_camids)
145 | logger.info("Results ---------------------------------------------------")
146 | logger.info('top1:{:.1%} top5:{:.1%} top10:{:.1%} top20:{:.1%} mAP:{:.1%}'.format(cmc[0], cmc[4], cmc[9], cmc[19], mAP))
147 | logger.info("-----------------------------------------------------------")
148 |
149 | return cmc[0]
--------------------------------------------------------------------------------
/test_AIM.sh:
--------------------------------------------------------------------------------
1 | # For LTCC dataset
2 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset ltcc --cfg configs/res50_cels_cal.yaml --gpu 0,1 --eval --resume ./ltcc.pth.tar #
3 | # For PRCC dataset
4 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset prcc --cfg configs/res50_cels_cal.yaml --gpu 0,1 --eval --resume ./prcc.pth.tar #
--------------------------------------------------------------------------------
/tools/eval_metrics.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import numpy as np
3 |
4 |
5 | def compute_ap_cmc(index, good_index, junk_index):
6 | """ Compute AP and CMC for each sample
7 | """
8 | ap = 0
9 | cmc = np.zeros(len(index))
10 |
11 | # remove junk_index
12 | mask = np.in1d(index, junk_index, invert=True)
13 | index = index[mask]
14 |
15 | # find good_index index
16 | ngood = len(good_index)
17 | mask = np.in1d(index, good_index)
18 | rows_good = np.argwhere(mask==True)
19 | rows_good = rows_good.flatten()
20 |
21 | cmc[rows_good[0]:] = 1.0
22 | for i in range(ngood):
23 | d_recall = 1.0/ngood
24 | precision = (i+1)*1.0/(rows_good[i]+1)
25 | ap = ap + d_recall*precision
26 |
27 | return ap, cmc
28 |
29 |
30 | def evaluate(distmat, q_pids, g_pids, q_camids, g_camids):
31 | """ Compute CMC and mAP
32 |
33 | Args:
34 | distmat (numpy ndarray): distance matrix with shape (num_query, num_gallery).
35 | q_pids (numpy array): person IDs for query samples.
36 | g_pids (numpy array): person IDs for gallery samples.
37 | q_camids (numpy array): camera IDs for query samples.
38 | g_camids (numpy array): camera IDs for gallery samples.
39 | """
40 | num_q, num_g = distmat.shape
41 | index = np.argsort(distmat, axis=1) # from small to large
42 |
43 | num_no_gt = 0 # num of query imgs without groundtruth
44 | num_r1 = 0
45 | CMC = np.zeros(len(g_pids))
46 | AP = 0
47 |
48 | for i in range(num_q):
49 | # groundtruth index
50 | query_index = np.argwhere(g_pids==q_pids[i])
51 | camera_index = np.argwhere(g_camids==q_camids[i])
52 | good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
53 | if good_index.size == 0:
54 | num_no_gt += 1
55 | continue
56 | # remove gallery samples that have the same pid and camid with query
57 | junk_index = np.intersect1d(query_index, camera_index)
58 |
59 | ap_tmp, CMC_tmp = compute_ap_cmc(index[i], good_index, junk_index)
60 | if CMC_tmp[0]==1:
61 | num_r1 += 1
62 | CMC = CMC + CMC_tmp
63 | AP += ap_tmp
64 |
65 | if num_no_gt > 0:
66 | logger = logging.getLogger('reid.evaluate')
67 | logger.info("{} query samples do not have groundtruth.".format(num_no_gt))
68 |
69 | CMC = CMC / (num_q - num_no_gt)
70 | mAP = AP / (num_q - num_no_gt)
71 |
72 | return CMC, mAP
73 |
74 |
75 | def evaluate_with_clothes(distmat, q_pids, g_pids, q_camids, g_camids, q_clothids, g_clothids, mode='CC'):
76 | """ Compute CMC and mAP with clothes
77 |
78 | Args:
79 | distmat (numpy ndarray): distance matrix with shape (num_query, num_gallery).
80 | q_pids (numpy array): person IDs for query samples.
81 | g_pids (numpy array): person IDs for gallery samples.
82 | q_camids (numpy array): camera IDs for query samples.
83 | g_camids (numpy array): camera IDs for gallery samples.
84 | q_clothids (numpy array): clothes IDs for query samples.
85 | g_clothids (numpy array): clothes IDs for gallery samples.
86 | mode: 'CC' for clothes-changing; 'SC' for the same clothes.
87 | """
88 | assert mode in ['CC', 'SC']
89 |
90 | num_q, num_g = distmat.shape
91 | index = np.argsort(distmat, axis=1) # from small to large
92 |
93 | num_no_gt = 0 # num of query imgs without groundtruth
94 | num_r1 = 0
95 | CMC = np.zeros(len(g_pids))
96 | AP = 0
97 |
98 | for i in range(num_q):
99 | # groundtruth index
100 | query_index = np.argwhere(g_pids==q_pids[i])
101 | camera_index = np.argwhere(g_camids==q_camids[i])
102 | cloth_index = np.argwhere(g_clothids==q_clothids[i])
103 | good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
104 | if mode == 'CC':
105 | good_index = np.setdiff1d(good_index, cloth_index, assume_unique=True)
106 | # remove gallery samples that have the same (pid, camid) or (pid, clothid) with query
107 | junk_index1 = np.intersect1d(query_index, camera_index)
108 | junk_index2 = np.intersect1d(query_index, cloth_index)
109 | junk_index = np.union1d(junk_index1, junk_index2)
110 | else:
111 | good_index = np.intersect1d(good_index, cloth_index)
112 | # remove gallery samples that have the same (pid, camid) or
113 | # (the same pid and different clothid) with query
114 | junk_index1 = np.intersect1d(query_index, camera_index)
115 | junk_index2 = np.setdiff1d(query_index, cloth_index)
116 | junk_index = np.union1d(junk_index1, junk_index2)
117 |
118 | if good_index.size == 0:
119 | num_no_gt += 1
120 | continue
121 |
122 | ap_tmp, CMC_tmp = compute_ap_cmc(index[i], good_index, junk_index)
123 | if CMC_tmp[0]==1:
124 | num_r1 += 1
125 | CMC = CMC + CMC_tmp
126 | AP += ap_tmp
127 |
128 | if num_no_gt > 0:
129 | logger = logging.getLogger('reid.evaluate')
130 | logger.info("{} query samples do not have groundtruth.".format(num_no_gt))
131 |
132 | if (num_q - num_no_gt) != 0:
133 | CMC = CMC / (num_q - num_no_gt)
134 | mAP = AP / (num_q - num_no_gt)
135 | else:
136 | mAP = 0
137 |
138 | return CMC, mAP
--------------------------------------------------------------------------------
/tools/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import shutil
4 | import errno
5 | import json
6 | import os.path as osp
7 | import torch
8 | import random
9 | import logging
10 | import numpy as np
11 |
12 |
13 | def set_seed(seed=None):
14 | if seed is None:
15 | return
16 | random.seed(seed)
17 | os.environ['PYTHONHASHSEED'] = ("%s" % seed)
18 | np.random.seed(seed)
19 | torch.manual_seed(seed)
20 | torch.cuda.manual_seed(seed)
21 | torch.cuda.manual_seed_all(seed)
22 | torch.backends.cudnn.benchmark = False
23 | torch.backends.cudnn.deterministic = True
24 |
25 |
26 | def mkdir_if_missing(directory):
27 | if not osp.exists(directory):
28 | try:
29 | os.makedirs(directory)
30 | except OSError as e:
31 | if e.errno != errno.EEXIST:
32 | raise
33 |
34 |
35 | def read_json(fpath):
36 | with open(fpath, 'r') as f:
37 | obj = json.load(f)
38 | return obj
39 |
40 |
41 | def write_json(obj, fpath):
42 | mkdir_if_missing(osp.dirname(fpath))
43 | with open(fpath, 'w') as f:
44 | json.dump(obj, f, indent=4, separators=(',', ': '))
45 |
46 |
47 | class AverageMeter(object):
48 | """Computes and stores the average and current value.
49 |
50 | Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
51 | """
52 | def __init__(self):
53 | self.reset()
54 |
55 | def reset(self):
56 | self.val = 0
57 | self.avg = 0
58 | self.sum = 0
59 | self.count = 0
60 |
61 | def update(self, val, n=1):
62 | self.val = val
63 | self.sum += val * n
64 | self.count += n
65 | self.avg = self.sum / self.count
66 |
67 |
68 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'):
69 | mkdir_if_missing(osp.dirname(fpath))
70 | torch.save(state, fpath)
71 | if is_best:
72 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar'))
73 |
74 | '''
75 | class Logger(object):
76 | """
77 | Write console output to external text file.
78 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py.
79 | """
80 | def __init__(self, fpath=None):
81 | self.console = sys.stdout
82 | self.file = None
83 | if fpath is not None:
84 | mkdir_if_missing(os.path.dirname(fpath))
85 | self.file = open(fpath, 'w')
86 |
87 | def __del__(self):
88 | self.close()
89 |
90 | def __enter__(self):
91 | pass
92 |
93 | def __exit__(self, *args):
94 | self.close()
95 |
96 | def write(self, msg):
97 | self.console.write(msg)
98 | if self.file is not None:
99 | self.file.write(msg)
100 |
101 | def flush(self):
102 | self.console.flush()
103 | if self.file is not None:
104 | self.file.flush()
105 | os.fsync(self.file.fileno())
106 |
107 | def close(self):
108 | self.console.close()
109 | if self.file is not None:
110 | self.file.close()
111 | '''
112 |
113 |
114 | def get_logger(fpath, local_rank=0, name=''):
115 | # Creat logger
116 | logger = logging.getLogger(name)
117 | level = logging.INFO if local_rank in [-1, 0] else logging.WARN
118 | logger.setLevel(level=level)
119 |
120 | # Output to console
121 | console_handler = logging.StreamHandler(sys.stdout)
122 | console_handler.setLevel(level=level)
123 | console_handler.setFormatter(logging.Formatter('%(message)s'))
124 | logger.addHandler(console_handler)
125 |
126 | # Output to file
127 | if fpath is not None:
128 | mkdir_if_missing(os.path.dirname(fpath))
129 | file_handler = logging.FileHandler(fpath, mode='w')
130 | file_handler.setLevel(level=level)
131 | file_handler.setFormatter(logging.Formatter('%(message)s'))
132 | logger.addHandler(file_handler)
133 |
134 | return logger
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from multiprocessing import reduction
2 | import time
3 | import datetime
4 | import logging
5 | import torch
6 | from apex import amp
7 | from tools.utils import AverageMeter
8 |
9 |
10 | def train_aim(config, epoch, model, model2, classifier, clothes_classifier, clothes_classifier2, fuse, criterion_cla, criterion_pair,
11 | criterion_clothes, criterion_adv, optimizer, optimizer2, optimizer_cc, trainloader, pid2clothes, kl):
12 | logger = logging.getLogger('reid.train')
13 | batch_cla_loss = AverageMeter()
14 | batch_pair_loss = AverageMeter()
15 | batch_clo_loss = AverageMeter()
16 | batch_adv_loss = AverageMeter()
17 | batch_clothes_loss2 = AverageMeter()
18 | batch_loss2 = AverageMeter()
19 | batch_kl_loss = AverageMeter()
20 | corrects = AverageMeter()
21 | corrects2 = AverageMeter()
22 | corrects3 = AverageMeter()
23 | clothes_corrects = AverageMeter()
24 | clothes_corrects2 = AverageMeter()
25 | batch_time = AverageMeter()
26 | data_time = AverageMeter()
27 |
28 | model.train()
29 | model2.train()
30 | fuse.train()
31 | classifier.train()
32 | clothes_classifier.train()
33 | clothes_classifier2.train()
34 |
35 | end = time.time()
36 | for batch_idx, (imgs, pids, camids, clothes_ids, img_path) in enumerate(trainloader):
37 | # Get all positive clothes classes (belonging to the same identity) for each sample
38 | pos_mask = pid2clothes[pids]
39 | imgs, pids, clothes_ids, pos_mask = imgs.cuda(), pids.cuda(), clothes_ids.cuda(), pos_mask.float().cuda()
40 | # Measure data loading timeq
41 | data_time.update(time.time() - end)
42 | # Forward
43 | pri_feat, features = model(imgs) # torch.size([32,4096])
44 | pri_feat2, features2 = model2(imgs)
45 |
46 | pri_feat2 = pri_feat2.clone().detach()
47 | features_fuse = fuse(pri_feat, pri_feat2)
48 |
49 | outputs = classifier(features)
50 | outputs2 = clothes_classifier2(features2)
51 | outputs3 = classifier(features_fuse) # clothes score on id classifier
52 |
53 | # new_pred_clothes2 = clothes_classifier2(features2)
54 | # loss2 = criterion_adv(new_pred_clothes2, clothes_ids, pos_mask)
55 |
56 | pred_clothes = clothes_classifier(features.detach()) # no grad
57 |
58 | _, preds = torch.max(outputs.data, 1) # return (max_value, index), 1 indicates dim=1
59 | _, preds3 = torch.max(outputs3.data, 1)
60 |
61 | # Update the clothes discriminator
62 | clothes_loss = criterion_clothes(pred_clothes, clothes_ids)
63 | if epoch >= config.TRAIN.START_EPOCH_CC:
64 | optimizer_cc.zero_grad()
65 | if config.TRAIN.AMP:
66 | with amp.scale_loss(clothes_loss, optimizer_cc) as scaled_loss:
67 | scaled_loss.backward()
68 | else:
69 | clothes_loss.backward()
70 | optimizer_cc.step()
71 |
72 | # Update the backbone
73 | new_pred_clothes = clothes_classifier(features)
74 | _, clothes_preds = torch.max(new_pred_clothes.data, 1)
75 |
76 | _, pred_clothes2 = torch.max(outputs2.data, 1)
77 | # outputs2_no_grad = clothes_classifier2(features2.detach())
78 |
79 | Q = new_pred_clothes.clone().detach()
80 | P = outputs2.clone()
81 | Q = torch.nn.functional.softmax(Q, dim=-1)
82 | P = torch.nn.functional.softmax(P, dim=-1)
83 |
84 | # Update the clothes discriminator 2
85 |
86 | clothes_loss2 = criterion_clothes(outputs2, clothes_ids)
87 |
88 | kl_loss = kl(torch.log(Q), P, reduction='sum') + kl(torch.log(P), Q, reduction='sum')
89 |
90 | if epoch >= config.TRAIN.START_EPOCH_CC:
91 | loss2 = clothes_loss2 + config.k_kl * kl_loss
92 | else:
93 | loss2 = clothes_loss2
94 |
95 | optimizer2.zero_grad()
96 | if config.TRAIN.AMP:
97 | with amp.scale_loss(loss2, optimizer2) as scaled_loss2:
98 | scaled_loss2.backward()
99 | else:
100 | loss2.backward()
101 | optimizer2.step()
102 |
103 | GENERAL_EPOCH = config.TRAIN.START_EPOCH_ADV
104 |
105 | # Compute loss
106 | if epoch >= GENERAL_EPOCH:
107 | cla_loss = criterion_cla(outputs, pids) + config.k_cal * criterion_cla(outputs - outputs3, pids)
108 | else:
109 | cla_loss = criterion_cla(outputs, pids)
110 | pair_loss = criterion_pair(features, pids)
111 | adv_loss = criterion_adv(new_pred_clothes, clothes_ids, pos_mask)
112 |
113 | if epoch >= config.TRAIN.START_EPOCH_ADV:
114 | loss = cla_loss + adv_loss + config.LOSS.PAIR_LOSS_WEIGHT * pair_loss
115 | else:
116 | loss = cla_loss + config.LOSS.PAIR_LOSS_WEIGHT * pair_loss
117 |
118 | optimizer.zero_grad()
119 | if config.TRAIN.AMP:
120 | with amp.scale_loss(loss, optimizer) as scaled_loss:
121 | scaled_loss.backward()
122 | else:
123 | loss.backward()
124 | optimizer.step()
125 |
126 | # statistics
127 | corrects.update(torch.sum(preds == pids.data).float()/pids.size(0), pids.size(0))
128 | corrects2.update(torch.sum(pred_clothes2 == clothes_ids.data).float()/clothes_ids.size(0), clothes_ids.size(0))
129 | corrects3.update(torch.sum(preds3 == pids.data).float()/pids.size(0), pids.size(0))
130 | clothes_corrects.update(torch.sum(clothes_preds == clothes_ids.data).float()/clothes_ids.size(0), clothes_ids.size(0))
131 | clothes_corrects2.update(torch.sum(pred_clothes2 == clothes_ids.data).float()/clothes_ids.size(0), clothes_ids.size(0))
132 | batch_cla_loss.update(cla_loss.item(), pids.size(0))
133 | batch_pair_loss.update(pair_loss.item(), pids.size(0))
134 | batch_clo_loss.update(clothes_loss.item(), clothes_ids.size(0))
135 | batch_adv_loss.update(adv_loss.item(), clothes_ids.size(0))
136 | batch_loss2.update(loss2.item(), clothes_ids.size(0))
137 | batch_clothes_loss2.update(clothes_loss2.item(), clothes_ids.size(0))
138 | batch_kl_loss.update(kl_loss.item(), clothes_ids.size(0))
139 |
140 | # measure elapsed time
141 | batch_time.update(time.time() - end)
142 | end = time.time()
143 |
144 | logger.info('Epoch{0} '
145 | 'Time:{batch_time.sum:.1f}s '
146 | 'Data:{data_time.sum:.1f}s '
147 | 'ClaLoss:{cla_loss.avg:.4f} '
148 | 'PairLoss:{pair_loss.avg:.4f} '
149 | 'CloLoss:{clo_loss.avg:.4f} '
150 | 'AdvLoss:{adv_loss.avg:.4f} '
151 | 'clothes_loss2:{clothes_loss2.avg:.4f} '
152 | 'loss2:{loss2.avg:.4f} '
153 | 'kl_loss:{kl_loss.avg:.4f} '
154 | 'Acc:{acc.avg:.2%} '
155 | 'Acc2:{acc2.avg:.2%} '
156 | 'Acc3:{acc3.avg:.2%} '
157 | 'CloAcc:{clo_acc.avg:.2%} '
158 | 'Clo2Acc:{clo2_acc.avg:.2%} '.format(
159 | epoch+1, batch_time=batch_time, data_time=data_time,
160 | cla_loss=batch_cla_loss, pair_loss=batch_pair_loss,
161 | clo_loss=batch_clo_loss, adv_loss=batch_adv_loss,
162 | clothes_loss2=batch_clothes_loss2,
163 | loss2=batch_loss2, kl_loss=batch_kl_loss,
164 | acc=corrects, acc2=corrects2, acc3=corrects3,
165 | clo_acc=clothes_corrects, clo2_acc=clothes_corrects2))
--------------------------------------------------------------------------------