├── IDKL ├── .idea │ ├── IDKL.iml │ ├── deployment.xml │ ├── inspectionProfiles │ │ └── profiles_settings.xml │ ├── modules.xml │ ├── vcs.xml │ └── workspace.xml ├── README.md ├── configs │ ├── LLCM.yml │ ├── RegDB.yml │ ├── SYSU.yml │ └── default │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── dataset.cpython-37.pyc │ │ └── strategy.cpython-37.pyc │ │ ├── dataset.py │ │ └── strategy.py ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── clone.cpython-37.pyc │ │ ├── dataset.cpython-36.pyc │ │ ├── dataset.cpython-37.pyc │ │ ├── sampler.cpython-36.pyc │ │ └── sampler.cpython-37.pyc │ ├── dataset.py │ └── sampler.py ├── engine │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── engine.cpython-37.pyc │ │ └── metric.cpython-37.pyc │ ├── engine.py │ └── metric.py ├── grad_cam │ ├── README.md │ ├── both.png │ ├── imagenet1k_classes.txt │ ├── imagenet21k_classes.txt │ ├── main_cnn.py │ ├── main_swin.py │ ├── main_vit.py │ ├── swin_model.py │ ├── utils.py │ └── vit_model.py ├── layers │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-37.pyc │ ├── loss │ │ ├── JSD.py │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── JSD.cpython-37.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── am_softmax.cpython-37.pyc │ │ │ ├── center_loss.cpython-37.pyc │ │ │ ├── crossquad_loss.cpython-37.pyc │ │ │ ├── crosstriplet_loss.cpython-37.pyc │ │ │ ├── local_center_loss.cpython-37.pyc │ │ │ ├── mixtriplet_loss.cpython-37.pyc │ │ │ ├── trapezoid_loss.cpython-37.pyc │ │ │ └── triplet_loss.cpython-37.pyc │ │ ├── am_softmax.py │ │ ├── center_loss.py │ │ ├── local_center_loss.py │ │ ├── rerank_loss.py │ │ └── triplet_loss.py │ └── module │ │ ├── CBAM.py │ │ ├── NonLocal.py │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── CBAM.cpython-37.pyc │ │ ├── NonLocal.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── norm_linear.cpython-37.pyc │ │ └── reverse_grad.cpython-37.pyc │ │ ├── norm_linear.py │ │ └── reverse_grad.py ├── models │ ├── __pycache__ │ │ ├── baseline.cpython-36.pyc │ │ ├── baseline.cpython-37.pyc │ │ ├── resnet.cpython-36.pyc │ │ └── resnet.cpython-37.pyc │ ├── baseline.py │ └── resnet.py ├── train.py └── utils │ ├── __pycache__ │ ├── calc_acc.cpython-37.pyc │ ├── eval_regdb.cpython-37.pyc │ ├── eval_sysu.cpython-37.pyc │ ├── neighbor.cpython-37.pyc │ └── rerank.cpython-37.pyc │ ├── calc_acc.py │ ├── eval_llcm.py │ ├── eval_regdb.py │ ├── eval_sysu.py │ ├── rerank.py │ └── tsne.py └── README.md /IDKL/.idea/IDKL.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /IDKL/.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /IDKL/.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /IDKL/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /IDKL/.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /IDKL/.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 10 | 11 | 13 | 14 | 16 | 17 | 18 | 19 | 22 | 30 | 31 | 32 | 33 | 34 | 1723144514458 35 | 40 | 41 | 42 | 43 | 52 | 53 | -------------------------------------------------------------------------------- /IDKL/README.md: -------------------------------------------------------------------------------- 1 | [CVPR2024]IDKL: Implicit Discriminative Knowledge Learning for Visible-Infrared Person Re-Identification. (https://arxiv.org/abs/2403.11708) 2 | ## Environmental requirements: 3 | 4 | python == 3.7 5 | PyTorch == 1.10.1 6 | ignite == 0.2.1 7 | torchvision == 0.11.2 8 | apex == 0.1 9 | 10 | ## Training: 11 | 12 | To train the model, you can use following command: 13 | 14 | SYSU-MM01: 15 | ```Shell 16 | python train.py --cfg ./configs/SYSU.yml 17 | ``` 18 | 19 | RegDB: 20 | ```Shell 21 | python train.py --cfg ./configs/RegDB.yml 22 | ``` 23 | 24 | RegDB: 25 | ```Shell 26 | python train.py --cfg ./configs/RegDB.yml 27 | ``` 28 | 29 | -------------------------------------------------------------------------------- /IDKL/configs/LLCM.yml: -------------------------------------------------------------------------------- 1 | prefix: LLCM 2 | fp16: true 3 | 4 | # dataset 5 | sample_method: identity_random 6 | image_size: (384, 144) 7 | p_size: 12 8 | k_size: 10 9 | 10 | dataset: llcm 11 | 12 | # loss 13 | bg_kl: true 14 | sm_kl: true 15 | IP: true 16 | decompose: true 17 | distalign: true 18 | classification: true 19 | center_cluster: false 20 | triplet: true 21 | center: false 22 | fb_dt: false 23 | 24 | # parameters 25 | margin: 1.3 #0.7 26 | # pattern attention 27 | num_parts: 6 28 | weight_sep: 0.5 29 | # mutual learning 30 | update_rate: 0.2 31 | weight_sid: 0.5 32 | weight_KL: 2.5 33 | 34 | # architecture 35 | drop_last_stride: true 36 | 37 | # optimizer 38 | lr: 0.00035 39 | optimizer: adam 40 | num_epoch: 160 #160 41 | lr_step: [55, 95] 42 | 43 | # augmentation 44 | random_flip: true 45 | random_crop: true 46 | random_erase: true 47 | color_jitter: false 48 | padding: 10 49 | 50 | # log 51 | log_period: 150 52 | start_eval: 200 53 | eval_interval: 5 54 | -------------------------------------------------------------------------------- /IDKL/configs/RegDB.yml: -------------------------------------------------------------------------------- 1 | prefix: RegDB 2 | 3 | fp16: true 4 | 5 | # dataset 6 | sample_method: identity_random 7 | image_size: (256, 128) 8 | p_size: 6 9 | k_size: 10 10 | 11 | dataset: regdb 12 | 13 | # loss 14 | bg_kl: true 15 | sm_kl: true 16 | decompose: true 17 | IP: true 18 | distalign: false 19 | classification: true 20 | center_cluster: false 21 | triplet: true 22 | fb_dt: false #true 23 | center: false 24 | 25 | # parameters 26 | margin: 1.3 27 | 28 | num_parts: 6 29 | weight_sep: 0.5 30 | 31 | update_rate: 0.2 32 | weight_sid: 0.5 33 | weight_KL: 2.5 34 | 35 | # architecture 36 | #mutual learning 37 | #rerank: false 38 | #pattern attention 39 | 40 | drop_last_stride: true 41 | pattern_attention: false 42 | mutual_learning: false 43 | modality_attention: 0 44 | 45 | # optimizer 46 | lr: 0.00035 47 | optimizer: adam 48 | num_epoch: 160 49 | lr_step: [55, 95] 50 | 51 | # augmentation 52 | random_flip: true 53 | random_crop: true 54 | random_erase: true 55 | color_jitter: false 56 | padding: 10 57 | 58 | # log 59 | log_period: 20 60 | start_eval: 0 61 | eval_interval: 5 62 | -------------------------------------------------------------------------------- /IDKL/configs/SYSU.yml: -------------------------------------------------------------------------------- 1 | prefix: SYSU 2 | fp16: true 3 | 4 | # dataset 5 | sample_method: identity_random #identity_uniform #identity_random 6 | image_size: (384, 144) #(384, 144) 7 | p_size: 12 8 | k_size: 10 9 | 10 | dataset: sysu 11 | 12 | # loss 13 | bg_kl: true 14 | sm_kl: true 15 | decompose: true 16 | distalign: true 17 | IP: true 18 | classification: true 19 | center_cluster: false 20 | triplet: true 21 | center: false 22 | fb_dt: false 23 | 24 | # parameters 25 | margin: 1.3 26 | # pattern attention 27 | num_parts: 6 28 | weight_sep: 0.5 29 | # mutual learning 30 | update_rate: 0.2 31 | weight_sid: 0.5 32 | weight_KL: 2.5 33 | 34 | # architecture 35 | drop_last_stride: true 36 | pattern_attention: false 37 | mutual_learning: false 38 | modality_attention: 0 39 | 40 | # optimizer 41 | lr: 0.00035 42 | optimizer: adam 43 | num_epoch: 160 #160 44 | lr_step: [55, 95] 45 | 46 | # augmentation 47 | random_flip: true 48 | random_crop: true 49 | random_erase: true 50 | color_jitter: false 51 | padding: 10 52 | 53 | # log 54 | log_period: 150 55 | start_eval: 200 56 | eval_interval: 5 57 | -------------------------------------------------------------------------------- /IDKL/configs/default/__init__.py: -------------------------------------------------------------------------------- 1 | from configs.default.dataset import dataset_cfg 2 | from configs.default.strategy import strategy_cfg 3 | 4 | __all__ = ["dataset_cfg", "strategy_cfg"] 5 | -------------------------------------------------------------------------------- /IDKL/configs/default/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/configs/default/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/configs/default/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/configs/default/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/configs/default/__pycache__/strategy.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/configs/default/__pycache__/strategy.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/configs/default/dataset.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode 2 | 3 | dataset_cfg = CfgNode() 4 | 5 | # config for dataset 6 | dataset_cfg.sysu = CfgNode() 7 | dataset_cfg.sysu.num_id = 395 8 | dataset_cfg.sysu.num_cam = 6 9 | dataset_cfg.sysu.data_root = "../dataset/SYSU-MM01" 10 | 11 | dataset_cfg.regdb = CfgNode() 12 | dataset_cfg.regdb.num_id = 206 13 | dataset_cfg.regdb.num_cam = 2 14 | dataset_cfg.regdb.data_root = "../dataset/RegDB" 15 | 16 | dataset_cfg.llcm = CfgNode() 17 | dataset_cfg.llcm.num_id = 713 18 | dataset_cfg.llcm.num_cam = 2 19 | dataset_cfg.llcm.data_root = "../dataset/LLCM" 20 | 21 | -------------------------------------------------------------------------------- /IDKL/configs/default/strategy.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode 2 | 3 | strategy_cfg = CfgNode() 4 | 5 | strategy_cfg.prefix = "baseline" 6 | 7 | # setting for loader 8 | strategy_cfg.sample_method = "random" 9 | strategy_cfg.batch_size = 128 10 | strategy_cfg.p_size = 16 11 | strategy_cfg.k_size = 8 12 | 13 | # setting for loss 14 | strategy_cfg.classification = True 15 | strategy_cfg.triplet = False 16 | strategy_cfg.center_cluster = False 17 | strategy_cfg.center = False 18 | strategy_cfg.sm_kl = False 19 | strategy_cfg.bg_kl = False 20 | strategy_cfg.IP = False 21 | strategy_cfg.decompose = False 22 | strategy_cfg.fb_dt = False 23 | strategy_cfg.distalign = False 24 | 25 | # setting for metric learning 26 | strategy_cfg.margin = 0.3 27 | strategy_cfg.weight_KL = 3.0 28 | strategy_cfg.weight_sid = 1.0 29 | strategy_cfg.weight_sep = 1.0 30 | strategy_cfg.update_rate = 1.0 31 | 32 | # settings for optimizer 33 | strategy_cfg.optimizer = "sgd" 34 | strategy_cfg.lr = 0.1 35 | strategy_cfg.wd = 5e-3 #5e-3 36 | ##5e-4 37 | strategy_cfg.lr_step = [40] 38 | 39 | strategy_cfg.fp16 = False 40 | 41 | strategy_cfg.num_epoch = 60 42 | 43 | # settings for dataset 44 | strategy_cfg.dataset = "sysu" 45 | strategy_cfg.image_size = (384, 128) 46 | 47 | # settings for augmentation 48 | strategy_cfg.random_flip = True 49 | strategy_cfg.random_crop = True 50 | strategy_cfg.random_erase = True 51 | strategy_cfg.color_jitter = False 52 | strategy_cfg.padding = 10 53 | 54 | # settings for base architecture 55 | strategy_cfg.drop_last_stride = False 56 | strategy_cfg.pattern_attention = False 57 | strategy_cfg.modality_attention = 0 58 | strategy_cfg.mutual_learning = False 59 | strategy_cfg.rerank = False 60 | strategy_cfg.num_parts = 6 61 | 62 | # logging 63 | strategy_cfg.eval_interval = -1 64 | strategy_cfg.start_eval = 60 65 | strategy_cfg.log_period = 10 66 | 67 | # testing 68 | strategy_cfg.resume = '' 69 | #/home/zhang/E/RKJ/MAPnet/MPA-LL2-cvpr/checkpoints/regdb/RegDB/model_best.pth 70 | #/root/MPANet/MPA-LL2-cvpr/checkpoints/sysu/SYSU/model_best.pth 71 | #/root/MPANet/MPA-LL2-cvpr/checkpoints/llcm/LLCM/model_best.pth 72 | #/home/zhang/E/RKJ/MAPnet/MPA-cvpr/checkpoints/llcm/LLCM/model_best.pth -------------------------------------------------------------------------------- /IDKL/data/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torchvision.transforms as T 5 | 6 | from torch.utils.data import DataLoader 7 | from data.dataset import SYSUDataset 8 | from data.dataset import RegDBDataset 9 | from data.dataset import LLCMData 10 | from data.dataset import MarketDataset 11 | 12 | from data.sampler import CrossModalityIdentitySampler 13 | from data.sampler import CrossModalityRandomSampler 14 | from data.sampler import RandomIdentitySampler 15 | from data.sampler import NormTripletSampler 16 | import random 17 | 18 | 19 | def collate_fn(batch): # img, label, cam_id, img_path, img_id 20 | samples = list(zip(*batch)) 21 | 22 | data = [torch.stack(x, 0) for i, x in enumerate(samples) if i != 3] 23 | data.insert(3, samples[3]) 24 | return data 25 | 26 | 27 | class ChannelAdapGray(object): 28 | """ Adaptive selects a channel or two channels. 29 | Args: 30 | probability: The probability that the Random Erasing operation will be performed. 31 | sl: Minimum proportion of erased area against input image. 32 | sh: Maximum proportion of erased area against input image. 33 | r1: Minimum aspect ratio of erased area. 34 | mean: Erasing value. 35 | """ 36 | 37 | def __init__(self, probability=0.5): 38 | self.probability = probability 39 | 40 | def __call__(self, img): 41 | 42 | # if random.uniform(0, 1) > self.probability: 43 | # return img 44 | 45 | idx = random.randint(0, 3) 46 | 47 | if idx == 0: 48 | # random select R Channel 49 | img[1, :, :] = img[0, :, :] 50 | img[2, :, :] = img[0, :, :] 51 | elif idx == 1: 52 | # random select B Channel 53 | img[0, :, :] = img[1, :, :] 54 | img[2, :, :] = img[1, :, :] 55 | elif idx == 2: 56 | # random select G Channel 57 | img[0, :, :] = img[2, :, :] 58 | img[1, :, :] = img[2, :, :] 59 | else: 60 | if random.uniform(0, 1) > self.probability: 61 | # return img 62 | img = img 63 | else: 64 | tmp_img = 0.2989 * img[0, :, :] + 0.5870 * img[1, :, :] + 0.1140 * img[2, :, :] 65 | img[0, :, :] = tmp_img 66 | img[1, :, :] = tmp_img 67 | img[2, :, :] = tmp_img 68 | return img 69 | 70 | def get_train_loader(dataset, root, sample_method, batch_size, p_size, k_size, image_size, random_flip=False, random_crop=False, 71 | random_erase=False, color_jitter=False, padding=0, num_workers=4): 72 | if True==False: #tsne 73 | transform = T.Compose([ 74 | T.Resize(image_size), 75 | T.ToTensor(), 76 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 77 | ]) 78 | 79 | else: 80 | # data pre-processing 81 | t = [T.Resize(image_size)] 82 | 83 | if random_flip: 84 | t.append(T.RandomHorizontalFlip()) 85 | 86 | if color_jitter: 87 | t.append(T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0)) 88 | 89 | if random_crop: 90 | t.extend([T.Pad(padding, fill=127), T.RandomCrop(image_size)]) 91 | 92 | t.extend([T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 93 | 94 | if random_erase: 95 | t.append(T.RandomErasing()) 96 | #t.append(ChannelAdapGray(probability=0.5)) ###58 97 | # t.append(Jigsaw()) 98 | 99 | transform = T.Compose(t) 100 | # # data pre-processing 101 | # t = [T.Resize(image_size)] 102 | # 103 | # if random_flip: 104 | # t.append(T.RandomHorizontalFlip()) 105 | # 106 | # if color_jitter: 107 | # t.append(T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0)) 108 | # 109 | # if random_crop: 110 | # t.extend([T.Pad(padding, fill=127), T.RandomCrop(image_size)]) 111 | # 112 | # t.extend([T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 113 | # 114 | # if random_erase: 115 | # t.append(T.RandomErasing()) 116 | # # t.append(Jigsaw()) 117 | # 118 | # transform = T.Compose(t) 119 | 120 | # dataset 121 | if dataset == 'sysu': 122 | train_dataset = SYSUDataset(root, mode='train', transform=transform) 123 | elif dataset == 'regdb': 124 | train_dataset = RegDBDataset(root, mode='train', transform=transform) 125 | elif dataset == 'llcm': 126 | train_dataset = LLCMData(root, mode='train', transform=transform) 127 | elif dataset == 'market': 128 | train_dataset = MarketDataset(root, mode='train', transform=transform) 129 | 130 | # sampler 131 | assert sample_method in ['random', 'identity_uniform', 'identity_random', 'norm_triplet'] 132 | if sample_method == 'identity_uniform': 133 | batch_size = p_size * k_size 134 | sampler = CrossModalityIdentitySampler(train_dataset, p_size, k_size) 135 | elif sample_method == 'identity_random': 136 | batch_size = p_size * k_size 137 | sampler = RandomIdentitySampler(train_dataset, p_size * k_size, k_size) 138 | elif sample_method == 'norm_triplet': 139 | batch_size = p_size * k_size 140 | sampler = NormTripletSampler(train_dataset, p_size * k_size, k_size) 141 | else: 142 | sampler = CrossModalityRandomSampler(train_dataset, batch_size) 143 | 144 | # loader 145 | train_loader = DataLoader(train_dataset, batch_size, sampler=sampler, drop_last=True, pin_memory=True, 146 | collate_fn=collate_fn, num_workers=num_workers) 147 | 148 | return train_loader 149 | 150 | 151 | def get_test_loader(dataset, root, batch_size, image_size, num_workers=4): 152 | # transform 153 | transform = T.Compose([ 154 | T.Resize(image_size), 155 | T.ToTensor(), 156 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 157 | ]) 158 | 159 | # dataset 160 | if dataset == 'sysu': 161 | gallery_dataset = SYSUDataset(root, mode='gallery', transform=transform) 162 | query_dataset = SYSUDataset(root, mode='query', transform=transform) 163 | elif dataset == 'regdb': 164 | gallery_dataset = RegDBDataset(root, mode='gallery', transform=transform) 165 | query_dataset = RegDBDataset(root, mode='query', transform=transform) 166 | elif dataset == 'llcm': 167 | gallery_dataset = LLCMData(root, mode='gallery', transform=transform) 168 | query_dataset = LLCMData(root, mode='query', transform=transform) 169 | elif dataset == 'market': 170 | gallery_dataset = MarketDataset(root, mode='gallery', transform=transform) 171 | query_dataset = MarketDataset(root, mode='query', transform=transform) 172 | 173 | # dataloader 174 | query_loader = DataLoader(dataset=query_dataset, 175 | batch_size=batch_size, 176 | shuffle=False, 177 | pin_memory=True, 178 | drop_last=False, 179 | collate_fn=collate_fn, 180 | num_workers=num_workers) 181 | 182 | gallery_loader = DataLoader(dataset=gallery_dataset, 183 | batch_size=batch_size, 184 | shuffle=False, 185 | pin_memory=True, 186 | drop_last=False, 187 | collate_fn=collate_fn, 188 | num_workers=num_workers) 189 | 190 | return gallery_loader, query_loader 191 | -------------------------------------------------------------------------------- /IDKL/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /IDKL/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/data/__pycache__/clone.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/data/__pycache__/clone.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/data/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/data/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /IDKL/data/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/data/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/data/__pycache__/sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/data/__pycache__/sampler.cpython-36.pyc -------------------------------------------------------------------------------- /IDKL/data/__pycache__/sampler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/data/__pycache__/sampler.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import os.path as osp 4 | from glob import glob 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | 10 | ''' 11 | Specific dataset classes for person re-identification dataset. 12 | ''' 13 | 14 | 15 | class SYSUDataset(Dataset): 16 | def __init__(self, root, mode='train', transform=None): 17 | assert os.path.isdir(root) 18 | assert mode in ['train', 'gallery', 'query'] 19 | 20 | if mode == 'train': 21 | train_ids = open(os.path.join(root, 'exp', 'train_id.txt')).readline() 22 | val_ids = open(os.path.join(root, 'exp', 'val_id.txt')).readline() 23 | 24 | train_ids = train_ids.strip('\n').split(',') 25 | val_ids = val_ids.strip('\n').split(',') 26 | selected_ids = train_ids + val_ids 27 | else: 28 | test_ids = open(os.path.join(root, 'exp', 'test_id.txt')).readline() 29 | selected_ids = test_ids.strip('\n').split(',') 30 | 31 | selected_ids = [int(i) for i in selected_ids] 32 | num_ids = len(selected_ids) 33 | 34 | img_paths = glob(os.path.join(root, '**/*.jpg'), recursive=True) 35 | img_paths = [path for path in img_paths if int(path.split('/')[-2]) in selected_ids] 36 | 37 | if mode == 'gallery': 38 | img_paths = [path for path in img_paths if int(path.split('/')[-3][-1]) in (1, 2, 4, 5)] 39 | elif mode == 'query': 40 | img_paths = [path for path in img_paths if int(path.split('/')[-3][-1]) in (3, 6)] 41 | 42 | img_paths = sorted(img_paths) 43 | self.img_paths = img_paths 44 | self.cam_ids = [int(path.split('/')[-3][-1]) for path in img_paths] 45 | self.num_ids = num_ids 46 | self.transform = transform 47 | 48 | if mode == 'train': 49 | id_map = dict(zip(selected_ids, range(num_ids))) 50 | self.ids = [id_map[int(path.split('/')[-2])] for path in img_paths] 51 | else: 52 | self.ids = [int(path.split('/')[-2]) for path in img_paths] 53 | 54 | def __len__(self): 55 | return len(self.img_paths) 56 | 57 | def __getitem__(self, item): 58 | path = self.img_paths[item] 59 | img = Image.open(path) 60 | if self.transform is not None: 61 | img = self.transform(img) 62 | 63 | label = torch.tensor(self.ids[item], dtype=torch.long) 64 | cam = torch.tensor(self.cam_ids[item], dtype=torch.long) 65 | item = torch.tensor(item, dtype=torch.long) 66 | 67 | return img, label, cam, path, item 68 | 69 | class RegDBDataset(Dataset): 70 | def __init__(self, root, mode='train', transform=None): 71 | assert os.path.isdir(root) 72 | assert mode in ['train', 'gallery', 'query'] 73 | 74 | def loadIdx(index): 75 | Lines = index.readlines() 76 | idx = [] 77 | for line in Lines: 78 | tmp = line.strip('\n') 79 | tmp = tmp.split(' ') 80 | idx.append(tmp) 81 | return idx 82 | 83 | num = '1' 84 | if mode == 'train': 85 | index_RGB = loadIdx(open(root + '/idx/train_visible_'+num+'.txt','r')) 86 | index_IR = loadIdx(open(root + '/idx/train_thermal_'+num+'.txt','r')) 87 | else: 88 | index_RGB = loadIdx(open(root + '/idx/test_visible_'+num+'.txt','r')) 89 | index_IR = loadIdx(open(root + '/idx/test_thermal_'+num+'.txt','r')) 90 | 91 | if mode == 'gallery': 92 | img_paths = [root + '/' + path for path, _ in index_RGB] 93 | elif mode == 'query': 94 | img_paths = [root + '/' + path for path, _ in index_IR] 95 | else: 96 | img_paths = [root + '/' + path for path, _ in index_RGB] + [root + '/' + path for path, _ in index_IR] 97 | 98 | selected_ids = [int(path.split('/')[-2]) for path in img_paths] 99 | selected_ids = list(set(selected_ids)) 100 | num_ids = len(selected_ids) 101 | 102 | img_paths = sorted(img_paths) 103 | self.img_paths = img_paths 104 | self.cam_ids = [int(path.split('/')[-3] == 'Thermal') + 2 for path in img_paths] 105 | # the visible cams are 1 2 4 5 and thermal cams are 3 6 in sysu 106 | # to simplify the code, visible cam is 2 and thermal cam is 3 in regdb 107 | self.num_ids = num_ids 108 | self.transform = transform 109 | 110 | if mode == 'train': 111 | id_map = dict(zip(selected_ids, range(num_ids))) 112 | self.ids = [id_map[int(path.split('/')[-2])] for path in img_paths] 113 | else: 114 | self.ids = [int(path.split('/')[-2]) for path in img_paths] 115 | 116 | def __len__(self): 117 | return len(self.img_paths) 118 | 119 | def __getitem__(self, item): 120 | path = self.img_paths[item] 121 | img = Image.open(path) 122 | if self.transform is not None: 123 | img = self.transform(img) 124 | 125 | label = torch.tensor(self.ids[item], dtype=torch.long) 126 | cam = torch.tensor(self.cam_ids[item], dtype=torch.long) 127 | item = torch.tensor(item, dtype=torch.long) 128 | 129 | return img, label, cam, path, item 130 | 131 | class LLCMData(Dataset): 132 | def __init__(self, root, mode='train', transform=None, colorIndex=None, thermalIndex=None): 133 | # Load training images (path) and labels 134 | assert os.path.isdir(root) 135 | assert mode in ['train', 'gallery', 'query'] 136 | 137 | def loadIdx(index): 138 | Lines = index.readlines() 139 | idx = [] 140 | for line in Lines: 141 | tmp = line.strip('\n') 142 | tmp = tmp.split(' ') 143 | idx.append(tmp) 144 | return idx 145 | 146 | if mode == 'train': 147 | index_RGB = loadIdx(open(root + '/idx/train_vis.txt','r')) 148 | index_IR = loadIdx(open(root + '/idx/train_nir.txt','r')) 149 | else: 150 | index_RGB = loadIdx(open(root + '/idx/test_vis.txt','r')) 151 | index_IR = loadIdx(open(root + '/idx/test_nir.txt','r')) 152 | 153 | 154 | if mode == 'gallery': 155 | img_paths = [root + '/' + path for path, _ in index_RGB] 156 | elif mode == 'query': 157 | img_paths = [root + '/' + path for path, _ in index_IR] 158 | else: 159 | img_paths = [root + '/' + path for path, _ in index_RGB] + [root + '/' + path for path, _ in index_IR] 160 | 161 | selected_ids = [int(path.split('/')[-2]) for path in img_paths] 162 | selected_ids = list(set(selected_ids)) 163 | num_ids = len(selected_ids) 164 | # path = '/home/zhang/E/RKJ/MAPnet/dataset/LLCM/nir/0351/0351_c06_s200656_f4830_nir.jpg' 165 | # img = Image.open(path).convert('RGB') 166 | # img = np.array(img, dtype=np.uint8) 167 | # import pdb 168 | # pdb.set_trace() 169 | 170 | img_paths = sorted(img_paths) 171 | self.img_paths = img_paths 172 | self.cam_ids = [int(path.split('/')[-3] == 'nir') + 2 for path in img_paths] 173 | self.num_ids = num_ids 174 | self.transform = transform 175 | 176 | if mode == 'train': 177 | id_map = dict(zip(selected_ids, range(num_ids))) 178 | 179 | self.ids = [id_map[int(path.split('/')[-2])] for path in img_paths] 180 | else: 181 | self.ids = [int(path.split('/')[-2]) for path in img_paths] 182 | 183 | def __len__(self): 184 | return len(self.img_paths) 185 | 186 | def __getitem__(self, item): 187 | path = self.img_paths[item] 188 | img = Image.open(path) 189 | if self.transform is not None: 190 | img = self.transform(img) 191 | 192 | label = torch.tensor(self.ids[item], dtype=torch.long) 193 | cam = torch.tensor(self.cam_ids[item], dtype=torch.long) 194 | item = torch.tensor(item, dtype=torch.long) 195 | 196 | return img, label, cam, path, item 197 | 198 | class MarketDataset(Dataset): 199 | def __init__(self, root, mode='train', transform=None): 200 | assert os.path.isdir(root) 201 | assert mode in ['train', 'gallery', 'query'] 202 | 203 | self.transform = transform 204 | 205 | if mode == 'train': 206 | img_paths = glob(os.path.join(root, 'bounding_box_train/*.jpg'), recursive=True) 207 | elif mode == 'gallery': 208 | img_paths = glob(os.path.join(root, 'bounding_box_test/*.jpg'), recursive=True) 209 | elif mode == 'query': 210 | img_paths = glob(os.path.join(root, 'query/*.jpg'), recursive=True) 211 | 212 | pattern = re.compile(r'([-\d]+)_c(\d)') 213 | all_pids = {} 214 | relabel = mode == 'train' 215 | self.img_paths = [] 216 | self.cam_ids = [] 217 | self.ids = [] 218 | for fpath in img_paths: 219 | fname = osp.basename(fpath) 220 | pid, cam = map(int, pattern.search(fname).groups()) 221 | if pid == -1: continue 222 | if relabel: 223 | if pid not in all_pids: 224 | all_pids[pid] = len(all_pids) 225 | else: 226 | if pid not in all_pids: 227 | all_pids[pid] = pid 228 | self.img_paths.append(fpath) 229 | self.ids.append(all_pids[pid]) 230 | self.cam_ids.append(cam - 1) 231 | 232 | def __len__(self): 233 | return len(self.img_paths) 234 | 235 | def __getitem__(self, item): 236 | path = self.img_paths[item] 237 | img = Image.open(path) 238 | if self.transform is not None: 239 | img = self.transform(img) 240 | 241 | label = torch.tensor(self.ids[item], dtype=torch.long) 242 | cam = torch.tensor(self.cam_ids[item], dtype=torch.long) 243 | item = torch.tensor(item, dtype=torch.long) 244 | 245 | return img, label, cam, path, item 246 | -------------------------------------------------------------------------------- /IDKL/data/sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import copy 4 | from torch.utils.data import Sampler 5 | from collections import defaultdict 6 | 7 | 8 | class CrossModalityRandomSampler(Sampler): 9 | def __init__(self, dataset, batch_size): 10 | self.dataset = dataset 11 | self.batch_size = batch_size 12 | 13 | self.rgb_list = [] 14 | self.ir_list = [] 15 | for i, cam in enumerate(dataset.cam_ids): 16 | if cam in [3, 6]: 17 | self.ir_list.append(i) 18 | else: 19 | self.rgb_list.append(i) 20 | 21 | def __len__(self): 22 | return max(len(self.rgb_list), len(self.ir_list)) * 2 23 | 24 | def __iter__(self): 25 | sample_list = [] 26 | rgb_list = np.random.permutation(self.rgb_list).tolist() 27 | ir_list = np.random.permutation(self.ir_list).tolist() 28 | 29 | rgb_size = len(self.rgb_list) 30 | ir_size = len(self.ir_list) 31 | if rgb_size >= ir_size: 32 | diff = rgb_size - ir_size 33 | reps = diff // ir_size 34 | pad_size = diff % ir_size 35 | for _ in range(reps): 36 | ir_list.extend(np.random.permutation(self.ir_list).tolist()) 37 | ir_list.extend(np.random.choice(self.ir_list, pad_size, replace=False).tolist()) 38 | else: 39 | diff = ir_size - rgb_size 40 | reps = diff // ir_size 41 | pad_size = diff % ir_size 42 | for _ in range(reps): 43 | rgb_list.extend(np.random.permutation(self.rgb_list).tolist()) 44 | rgb_list.extend(np.random.choice(self.rgb_list, pad_size, replace=False).tolist()) 45 | 46 | assert len(rgb_list) == len(ir_list) 47 | 48 | half_bs = self.batch_size // 2 49 | for start in range(0, len(rgb_list), half_bs): 50 | sample_list.extend(rgb_list[start:start + half_bs]) 51 | sample_list.extend(ir_list[start:start + half_bs]) 52 | 53 | return iter(sample_list) 54 | 55 | 56 | class CrossModalityIdentitySampler(Sampler): 57 | def __init__(self, dataset, p_size, k_size): 58 | self.dataset = dataset 59 | self.p_size = p_size 60 | self.k_size = k_size // 2 61 | self.batch_size = p_size * k_size * 2 62 | 63 | self.id2idx_rgb = defaultdict(list) 64 | self.id2idx_ir = defaultdict(list) 65 | for i, identity in enumerate(dataset.ids): 66 | if dataset.cam_ids[i] in [3, 6]: 67 | self.id2idx_ir[identity].append(i) 68 | else: 69 | self.id2idx_rgb[identity].append(i) 70 | 71 | def __len__(self): 72 | return self.dataset.num_ids * self.k_size * 2 73 | 74 | def __iter__(self): 75 | sample_list = [] 76 | 77 | id_perm = np.random.permutation(self.dataset.num_ids) 78 | for start in range(0, self.dataset.num_ids, self.p_size): 79 | selected_ids = id_perm[start:start + self.p_size] 80 | 81 | sample = [] 82 | for identity in selected_ids: 83 | replace = len(self.id2idx_rgb[identity]) < self.k_size 84 | s = np.random.choice(self.id2idx_rgb[identity], size=self.k_size, replace=replace) 85 | sample.extend(s) 86 | 87 | sample_list.extend(sample) 88 | 89 | sample.clear() 90 | for identity in selected_ids: 91 | replace = len(self.id2idx_ir[identity]) < self.k_size 92 | s = np.random.choice(self.id2idx_ir[identity], size=self.k_size, replace=replace) 93 | sample.extend(s) 94 | 95 | sample_list.extend(sample) 96 | 97 | return iter(sample_list) 98 | 99 | 100 | class RandomIdentitySampler(Sampler): 101 | def __init__(self, data_source, batch_size, num_instances): 102 | self.data_source = data_source 103 | self.batch_size = batch_size 104 | self.num_instances = num_instances 105 | self.num_pids_per_batch = self.batch_size // self.num_instances 106 | self.index_dic_R = defaultdict(list) 107 | self.index_dic_I = defaultdict(list) 108 | for i, identity in enumerate(data_source.ids): 109 | if data_source.cam_ids[i] in [3, 6]: 110 | self.index_dic_I[identity].append(i) 111 | else: 112 | self.index_dic_R[identity].append(i) 113 | self.pids = list(self.index_dic_I.keys()) 114 | 115 | # estimate number of examples in an epoch 116 | self.length = 0 117 | for pid in self.pids: 118 | idxs = self.index_dic_I[pid] 119 | num = len(idxs) 120 | if num < self.num_instances: 121 | num = self.num_instances 122 | self.length += num - num % self.num_instances 123 | 124 | def __iter__(self): 125 | batch_idxs_dict = defaultdict(list) 126 | 127 | for pid in self.pids: 128 | idxs_I = copy.deepcopy(self.index_dic_I[pid]) 129 | idxs_R = copy.deepcopy(self.index_dic_R[pid]) 130 | if len(idxs_I) < self.num_instances // 2 and len(idxs_R) < self.num_instances // 2: 131 | idxs_I = np.random.choice(idxs_I, size=self.num_instances // 2, replace=True) 132 | idxs_R = np.random.choice(idxs_R, size=self.num_instances // 2, replace=True) 133 | if len(idxs_I) > len(idxs_R): 134 | idxs_I = np.random.choice(idxs_I, size=len(idxs_R), replace=False) 135 | if len(idxs_R) > len(idxs_I): 136 | idxs_R = np.random.choice(idxs_R, size=len(idxs_I), replace=False) 137 | np.random.shuffle(idxs_I) 138 | np.random.shuffle(idxs_R) 139 | batch_idxs = [] 140 | for idx_I, idx_R in zip(idxs_I, idxs_R): 141 | batch_idxs.append(idx_I) 142 | batch_idxs.append(idx_R) 143 | if len(batch_idxs) == self.num_instances: 144 | batch_idxs_dict[pid].append(batch_idxs) 145 | batch_idxs = [] 146 | 147 | avai_pids = copy.deepcopy(self.pids) 148 | final_idxs = [] 149 | 150 | while len(avai_pids) >= self.num_pids_per_batch: 151 | selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False) 152 | for pid in selected_pids: 153 | batch_idxs = batch_idxs_dict[pid].pop(0) 154 | final_idxs.extend(batch_idxs) 155 | if len(batch_idxs_dict[pid]) == 0: 156 | avai_pids.remove(pid) 157 | 158 | self.length = len(final_idxs) 159 | return iter(final_idxs) 160 | 161 | def __len__(self): 162 | return self.length 163 | 164 | 165 | class NormTripletSampler(Sampler): 166 | """ 167 | Randomly sample N identities, then for each identity, 168 | randomly sample K instances, therefore batch size is N*K. 169 | Args: 170 | - data_source (list): list of (img_path, pid, camid). 171 | - num_instances (int): number of instances per identity in a batch. 172 | - batch_size (int): number of examples in a batch. 173 | """ 174 | 175 | def __init__(self, data_source, batch_size, num_instances): 176 | self.data_source = data_source 177 | self.batch_size = batch_size 178 | self.num_instances = num_instances 179 | self.num_pids_per_batch = self.batch_size // self.num_instances 180 | self.index_dic = defaultdict(list) 181 | for index, pid in enumerate(self.data_source.ids): 182 | self.index_dic[pid].append(index) 183 | self.pids = list(self.index_dic.keys()) 184 | 185 | # estimate number of examples in an epoch 186 | self.length = 0 187 | for pid in self.pids: 188 | idxs = self.index_dic[pid] 189 | num = len(idxs) 190 | if num < self.num_instances: 191 | num = self.num_instances 192 | self.length += num - num % self.num_instances 193 | 194 | def __iter__(self): 195 | batch_idxs_dict = defaultdict(list) 196 | 197 | for pid in self.pids: 198 | idxs = copy.deepcopy(self.index_dic[pid]) 199 | if len(idxs) < self.num_instances: 200 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 201 | np.random.shuffle(idxs) 202 | batch_idxs = [] 203 | for idx in idxs: 204 | batch_idxs.append(idx) 205 | if len(batch_idxs) == self.num_instances: 206 | batch_idxs_dict[pid].append(batch_idxs) 207 | batch_idxs = [] 208 | 209 | avai_pids = copy.deepcopy(self.pids) 210 | final_idxs = [] 211 | 212 | while len(avai_pids) >= self.num_pids_per_batch: 213 | selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False) 214 | for pid in selected_pids: 215 | batch_idxs = batch_idxs_dict[pid].pop(0) 216 | final_idxs.extend(batch_idxs) 217 | if len(batch_idxs_dict[pid]) == 0: 218 | avai_pids.remove(pid) 219 | 220 | self.length = len(final_idxs) 221 | return iter(final_idxs) 222 | 223 | def __len__(self): 224 | return self.length -------------------------------------------------------------------------------- /IDKL/engine/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import numpy as np 4 | import torch 5 | import scipy.io as sio 6 | 7 | from ignite.engine import Events 8 | from ignite.handlers import ModelCheckpoint 9 | from ignite.handlers import Timer 10 | 11 | from engine.engine import create_eval_engine 12 | from engine.engine import create_train_engine 13 | from engine.metric import AutoKVMetric 14 | from utils.eval_sysu import eval_sysu 15 | from utils.eval_regdb import eval_regdb 16 | from utils.eval_llcm import eval_llcm 17 | from configs.default.dataset import dataset_cfg 18 | from configs.default.strategy import strategy_cfg 19 | 20 | def get_trainer(dataset, model, optimizer, lr_scheduler=None, logger=None, writer=None, non_blocking=False, log_period=10, 21 | save_dir="checkpoints", prefix="model", gallery_loader=None, query_loader=None, 22 | eval_interval=None, start_eval=None, rerank=False): 23 | if logger is None: 24 | logger = logging.getLogger() 25 | logger.setLevel(logging.WARN) 26 | 27 | # trainer 28 | trainer = create_train_engine(model, optimizer, non_blocking) 29 | 30 | setattr(trainer, "rerank", rerank) 31 | 32 | # checkpoint handler 33 | handler = ModelCheckpoint(save_dir, prefix, save_interval=eval_interval, n_saved=3, create_dir=True, 34 | save_as_state_dict=True, require_empty=False) 35 | trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, {"model": model}) 36 | 37 | # metric 38 | timer = Timer(average=True) 39 | rank = True 40 | 41 | kv_metric = AutoKVMetric() 42 | 43 | # evaluator 44 | evaluator = None 45 | if not type(eval_interval) == int: 46 | raise TypeError("The parameter 'validate_interval' must be type INT.") 47 | if not type(start_eval) == int: 48 | raise TypeError("The parameter 'start_eval' must be type INT.") 49 | if eval_interval > 0 and gallery_loader is not None and query_loader is not None: 50 | evaluator = create_eval_engine(model, non_blocking) 51 | 52 | @trainer.on(Events.STARTED) 53 | def train_start(engine): 54 | setattr(engine.state, "best_rank1", 0.0) 55 | 56 | @trainer.on(Events.COMPLETED) 57 | def train_completed(engine): 58 | torch.cuda.empty_cache() 59 | 60 | # extract query feature 61 | evaluator.run(query_loader) 62 | 63 | q_feats = torch.cat(evaluator.state.feat_list, dim=0) 64 | q_ids = torch.cat(evaluator.state.id_list, dim=0).numpy() 65 | q_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy() 66 | q_img_paths = np.concatenate(evaluator.state.img_path_list, axis=0) 67 | 68 | # extract gallery feature 69 | evaluator.run(gallery_loader) 70 | 71 | g_feats = torch.cat(evaluator.state.feat_list, dim=0) 72 | g_ids = torch.cat(evaluator.state.id_list, dim=0).numpy() 73 | g_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy() 74 | g_img_paths = np.concatenate(evaluator.state.img_path_list, axis=0) 75 | 76 | # print("best rank1={:.2f}%".format(engine.state.best_rank1)) 77 | 78 | if dataset == 'sysu': 79 | perm = sio.loadmat(os.path.join(dataset_cfg.sysu.data_root, 'exp', 'rand_perm_cam.mat'))[ 80 | 'rand_perm_cam'] 81 | eval_sysu(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, perm, mode='all', num_shots=1, rerank=rank) 82 | eval_sysu(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, perm, mode='all', num_shots=10, rerank=rank) 83 | eval_sysu(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, perm, mode='indoor', num_shots=1, rerank=rank) 84 | eval_sysu(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, perm, mode='indoor', num_shots=10, rerank=rank) 85 | elif dataset == 'regdb': 86 | print('infrared to visible') 87 | eval_regdb(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, rerank=engine.rerank) 88 | print('visible to infrared') 89 | eval_regdb(g_feats, g_ids, g_cams, q_feats, q_ids, q_cams, q_img_paths, rerank=engine.rerank) 90 | elif dataset == 'llcm': 91 | print('infrared to visible') 92 | eval_llcm(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, rerank=rank) 93 | print('visible to infrared') 94 | eval_llcm(g_feats, g_ids, g_cams, q_feats, q_ids, q_cams, q_img_paths, rerank=rank) 95 | 96 | 97 | evaluator.state.feat_list.clear() 98 | evaluator.state.id_list.clear() 99 | evaluator.state.cam_list.clear() 100 | evaluator.state.img_path_list.clear() 101 | del q_feats, q_ids, q_cams, g_feats, g_ids, g_cams 102 | 103 | torch.cuda.empty_cache() 104 | 105 | @trainer.on(Events.EPOCH_STARTED) 106 | def epoch_started_callback(engine): 107 | 108 | epoch = engine.state.epoch 109 | if model.mutual_learning: 110 | model.update_rate = min(100 / (epoch + 1), 1.0) * model.update_rate_ 111 | 112 | kv_metric.reset() 113 | timer.reset() 114 | 115 | @trainer.on(Events.EPOCH_COMPLETED) 116 | def epoch_completed_callback(engine): 117 | epoch = engine.state.epoch 118 | 119 | if lr_scheduler is not None: 120 | lr_scheduler.step() 121 | 122 | if epoch % eval_interval == 0: 123 | logger.info("Model saved at {}/{}_model_{}.pth".format(save_dir, prefix, epoch)) 124 | 125 | if evaluator and epoch % eval_interval == 0 and epoch > start_eval: 126 | torch.cuda.empty_cache() 127 | 128 | # extract query feature 129 | evaluator.run(query_loader) 130 | 131 | q_feats = torch.cat(evaluator.state.feat_list, dim=0) 132 | q_ids = torch.cat(evaluator.state.id_list, dim=0).numpy() 133 | q_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy() 134 | q_img_paths = np.concatenate(evaluator.state.img_path_list, axis=0) 135 | 136 | # extract gallery feature 137 | evaluator.run(gallery_loader) 138 | 139 | g_feats = torch.cat(evaluator.state.feat_list, dim=0) 140 | g_ids = torch.cat(evaluator.state.id_list, dim=0).numpy() 141 | g_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy() 142 | g_img_paths = np.concatenate(evaluator.state.img_path_list, axis=0) 143 | 144 | if dataset == 'sysu': 145 | perm = sio.loadmat(os.path.join(dataset_cfg.sysu.data_root, 'exp', 'rand_perm_cam.mat'))[ 146 | 'rand_perm_cam'] 147 | mAP, r1, r5, _, _ = eval_sysu(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, perm, mode='all', num_shots=1, rerank=rank) 148 | elif dataset == 'regdb': 149 | print('infrared to visible') 150 | mAP, r1, r5, _, _ = eval_regdb(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, rerank=engine.rerank) 151 | print('visible to infrared') 152 | mAP, r1_, r5, _, _ = eval_regdb(g_feats, g_ids, g_cams, q_feats, q_ids, q_cams, q_img_paths, rerank=engine.rerank) 153 | r1 = (r1 + r1_) / 2 154 | elif dataset == 'llcm': 155 | print('infrared to visible') 156 | mAP, r1, r5, _, _ = eval_llcm(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, rerank=rank) 157 | #new_all_cmc,mAP, _ = eval_llcm(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, rerank=engine.rerank) 158 | print('visible to infrared') 159 | mAP, r1_, r5, _, _ = eval_llcm(g_feats, g_ids, g_cams, q_feats, q_ids, q_cams, q_img_paths, rerank=rank) 160 | r1 = (r1 + r1_) / 2 161 | 162 | # new_all_cmc,mAP_, _= eval_llcm(g_feats, g_ids, g_cams, q_feats, q_ids, q_cams, q_img_paths, 163 | # rerank=engine.rerank) 164 | # r1 = (mAP + mAP_) / 2 165 | # import pdb 166 | # pdb.set_trace() 167 | 168 | if r1 > engine.state.best_rank1: 169 | engine.state.best_rank1 = r1 170 | torch.save(model.state_dict(), "{}/model_best.pth".format(save_dir)) 171 | 172 | if writer is not None: 173 | writer.add_scalar('eval/mAP', mAP, epoch) 174 | writer.add_scalar('eval/r1', r1, epoch) 175 | writer.add_scalar('eval/r5', r5, epoch) 176 | 177 | evaluator.state.feat_list.clear() 178 | evaluator.state.id_list.clear() 179 | evaluator.state.cam_list.clear() 180 | evaluator.state.img_path_list.clear() 181 | del q_feats, q_ids, q_cams, g_feats, g_ids, g_cams 182 | 183 | torch.cuda.empty_cache() 184 | 185 | @trainer.on(Events.ITERATION_COMPLETED) 186 | def iteration_complete_callback(engine): 187 | timer.step() 188 | 189 | # print(engine.state.output) 190 | kv_metric.update(engine.state.output) 191 | 192 | epoch = engine.state.epoch 193 | iteration = engine.state.iteration 194 | iter_in_epoch = iteration - (epoch - 1) * len(engine.state.dataloader) 195 | 196 | if iter_in_epoch % log_period == 0 and iter_in_epoch > 0: 197 | batch_size = engine.state.batch[0].size(0) 198 | speed = batch_size / timer.value() 199 | 200 | msg = "Epoch[%d] Batch [%d]\tSpeed: %.2f samples/sec" % (epoch, iter_in_epoch, speed) 201 | 202 | metric_dict = kv_metric.compute() 203 | 204 | # log output information 205 | if logger is not None: 206 | for k in sorted(metric_dict.keys()): 207 | msg += "\t%s: %.4f" % (k, metric_dict[k]) 208 | if writer is not None: 209 | writer.add_scalar('metric/{}'.format(k), metric_dict[k], iteration) 210 | 211 | logger.info(msg) 212 | 213 | kv_metric.reset() 214 | timer.reset() 215 | 216 | return trainer 217 | -------------------------------------------------------------------------------- /IDKL/engine/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/engine/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /IDKL/engine/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/engine/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/engine/__pycache__/engine.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/engine/__pycache__/engine.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/engine/__pycache__/metric.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/engine/__pycache__/metric.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/engine/engine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | import numpy as np 5 | import os 6 | from apex import amp 7 | from ignite.engine import Engine 8 | from ignite.engine import Events 9 | from torch.autograd import no_grad 10 | from torch.nn import functional as F 11 | import torchvision.transforms as T 12 | import cv2 13 | from torchvision.io.image import read_image 14 | from PIL import Image 15 | from torchvision.transforms.functional import normalize, resize, to_pil_image 16 | 17 | from torchvision import transforms 18 | from grad_cam.utils import GradCAM, show_cam_on_image, center_crop_img 19 | import copy 20 | from torch.optim.lr_scheduler import LambdaLR 21 | 22 | 23 | # import torch 24 | # import numpy as np 25 | # import os 26 | # from apex import amp 27 | # from ignite.engine import Engine 28 | # from ignite.engine import Events 29 | # from torch.autograd import no_grad 30 | # from torchvision import transforms 31 | # from PIL import Image 32 | # import cv2 33 | # from grad_cam.utils import GradCAM, show_cam_on_image, center_crop_img 34 | # from thop import profile 35 | # from thop import clever_format 36 | # 37 | # from utils.calc_acc import calc_acc 38 | # from torch.nn import functional as F 39 | # #import objgraph 40 | 41 | 42 | def some_function(epoch, initial_weight_decay): 43 | if epoch > 15: 44 | new_weight_decay = initial_weight_decay/100 45 | elif epoch > 5 and epoch <= 15: 46 | new_weight_decay = initial_weight_decay*1/10 47 | else: 48 | new_weight_decay = initial_weight_decay 49 | return new_weight_decay 50 | 51 | def create_train_engine(model, optimizer, non_blocking=False): 52 | device = torch.device("cuda") #"cuda", torch.cuda.current_device() 53 | 54 | def _process_func(engine, batch): 55 | model.train() 56 | #model.eval() 57 | 58 | data, labels, cam_ids, img_paths, img_ids = batch 59 | epoch = engine.state.epoch 60 | iteration = engine.state.iteration 61 | 62 | data = data.to(device, non_blocking=non_blocking) 63 | labels = labels.to(device, non_blocking=non_blocking) 64 | cam_ids = cam_ids.to(device, non_blocking=non_blocking) 65 | 66 | warmup = False 67 | if warmup == True: #学习率warmup 68 | if epoch < 21: 69 | # 进行warmup,逐渐增加学习率 70 | warm_iteration = 30 * 213 71 | lr = 0.00035 * iteration / warm_iteration 72 | for param_group in optimizer.param_groups: 73 | param_group['lr'] = lr 74 | if True: #正则化参数warmup 75 | new_weight_decay = some_function(epoch, 0.5) 76 | for param_group in optimizer.param_groups: 77 | param_group['weight_decay'] = new_weight_decay 78 | 79 | optimizer.zero_grad() 80 | 81 | loss, metric = model(data, labels, 82 | cam_ids=cam_ids, 83 | epoch=epoch) 84 | 85 | 86 | with amp.scale_loss(loss, optimizer) as scaled_loss: 87 | scaled_loss.backward() 88 | optimizer.step() 89 | 90 | return metric 91 | 92 | return Engine(_process_func) 93 | 94 | 95 | def create_eval_engine(model, non_blocking=False): 96 | device = torch.device("cuda", torch.cuda.current_device()) 97 | 98 | def _process_func(engine, batch): 99 | model.eval() 100 | 101 | data, labels, cam_ids, img_paths = batch[:4] 102 | 103 | data = data.to(device, non_blocking=non_blocking) 104 | 105 | with no_grad(): 106 | feat = model(data, cam_ids=cam_ids.to(device, non_blocking=non_blocking)) 107 | 108 | return feat.data.float().cpu(), labels, cam_ids, np.array(img_paths) 109 | 110 | engine = Engine(_process_func) 111 | 112 | @engine.on(Events.EPOCH_STARTED) 113 | def clear_data(engine): 114 | # feat list 115 | if not hasattr(engine.state, "feat_list"): 116 | setattr(engine.state, "feat_list", []) 117 | else: 118 | engine.state.feat_list.clear() 119 | 120 | # id_list 121 | if not hasattr(engine.state, "id_list"): 122 | setattr(engine.state, "id_list", []) 123 | else: 124 | engine.state.id_list.clear() 125 | 126 | # cam list 127 | if not hasattr(engine.state, "cam_list"): 128 | setattr(engine.state, "cam_list", []) 129 | else: 130 | engine.state.cam_list.clear() 131 | 132 | # img path list 133 | if not hasattr(engine.state, "img_path_list"): 134 | setattr(engine.state, "img_path_list", []) 135 | else: 136 | engine.state.img_path_list.clear() 137 | 138 | @engine.on(Events.ITERATION_COMPLETED) 139 | def store_data(engine): 140 | engine.state.feat_list.append(engine.state.output[0]) 141 | engine.state.id_list.append(engine.state.output[1]) 142 | engine.state.cam_list.append(engine.state.output[2]) 143 | engine.state.img_path_list.append(engine.state.output[3]) 144 | 145 | return engine 146 | -------------------------------------------------------------------------------- /IDKL/engine/metric.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import torch 4 | from ignite.exceptions import NotComputableError 5 | from ignite.metrics import Metric, Accuracy 6 | 7 | 8 | class ScalarMetric(Metric): 9 | 10 | def update(self, value): 11 | self.sum_metric += value 12 | self.sum_inst += 1 13 | 14 | def reset(self): 15 | self.sum_inst = 0 16 | self.sum_metric = 0 17 | 18 | def compute(self): 19 | if self.sum_inst == 0: 20 | raise NotComputableError('Accuracy must have at least one example before it can be computed') 21 | return self.sum_metric / self.sum_inst 22 | 23 | 24 | class IgnoreAccuracy(Accuracy): 25 | def __init__(self, ignore_index=-1): 26 | super(IgnoreAccuracy, self).__init__() 27 | 28 | self.ignore_index = ignore_index 29 | 30 | def reset(self): 31 | self._num_correct = 0 32 | self._num_examples = 0 33 | 34 | def update(self, output): 35 | 36 | y_pred, y = self._check_shape(output) 37 | self._check_type((y_pred, y)) 38 | 39 | if self._type == "binary": 40 | indices = torch.round(y_pred).type(y.type()) 41 | elif self._type == "multiclass": 42 | indices = torch.max(y_pred, dim=1)[1] 43 | 44 | correct = torch.eq(indices, y).view(-1) 45 | ignore = torch.eq(y, self.ignore_index).view(-1) 46 | self._num_correct += torch.sum(correct).item() 47 | self._num_examples += correct.shape[0] - ignore.sum().item() 48 | 49 | def compute(self): 50 | if self._num_examples == 0: 51 | raise NotComputableError('Accuracy must have at least one example before it can be computed') 52 | return self._num_correct / self._num_examples 53 | 54 | 55 | class AutoKVMetric(Metric): 56 | def __init__(self): 57 | self.kv_sum_metric = defaultdict(lambda: torch.tensor(0., device="cuda")) 58 | self.kv_sum_inst = defaultdict(lambda: torch.tensor(0., device="cuda")) 59 | 60 | self.kv_metric = defaultdict(lambda: 0) 61 | 62 | super(AutoKVMetric, self).__init__() 63 | 64 | def update(self, output): 65 | if not isinstance(output, dict): 66 | raise TypeError('The output must be a key-value dict.') 67 | 68 | for k in output.keys(): 69 | self.kv_sum_metric[k].add_(output[k]) 70 | self.kv_sum_inst[k].add_(1) 71 | 72 | def reset(self): 73 | for k in self.kv_sum_metric.keys(): 74 | self.kv_sum_metric[k].zero_() 75 | self.kv_sum_inst[k].zero_() 76 | self.kv_metric[k] = 0 77 | 78 | def compute(self): 79 | for k in self.kv_sum_metric.keys(): 80 | if self.kv_sum_inst[k] == 0: 81 | continue 82 | # raise NotComputableError('Accuracy must have at least one example before it can be computed') 83 | 84 | metric_value = self.kv_sum_metric[k] / self.kv_sum_inst[k] 85 | self.kv_metric[k] = metric_value.item() 86 | 87 | return self.kv_metric 88 | -------------------------------------------------------------------------------- /IDKL/grad_cam/README.md: -------------------------------------------------------------------------------- 1 | ## Grad-CAM 2 | - Original Impl: [https://github.com/jacobgil/pytorch-grad-cam](https://github.com/jacobgil/pytorch-grad-cam) 3 | - Grad-CAM简介: [https://b23.tv/1kccjmb](https://b23.tv/1kccjmb) 4 | - 使用Pytorch实现Grad-CAM并绘制热力图: [https://b23.tv/n1e60vN](https://b23.tv/n1e60vN) 5 | 6 | ## 使用流程(替换成自己的网络) 7 | 1. 将创建模型部分代码替换成自己创建模型的代码,并载入自己训练好的权重 8 | 2. 根据自己网络设置合适的`target_layers` 9 | 3. 根据自己的网络设置合适的预处理方法 10 | 4. 将要预测的图片路径赋值给`img_path` 11 | 5. 将感兴趣的类别id赋值给`target_category` 12 | 13 | -------------------------------------------------------------------------------- /IDKL/grad_cam/both.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/grad_cam/both.png -------------------------------------------------------------------------------- /IDKL/grad_cam/imagenet1k_classes.txt: -------------------------------------------------------------------------------- 1 | tench, Tinca tinca 2 | goldfish, Carassius auratus 3 | great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias 4 | tiger shark, Galeocerdo cuvieri 5 | hammerhead, hammerhead shark 6 | electric ray, crampfish, numbfish, torpedo 7 | stingray 8 | cock 9 | hen 10 | ostrich, Struthio camelus 11 | brambling, Fringilla montifringilla 12 | goldfinch, Carduelis carduelis 13 | house finch, linnet, Carpodacus mexicanus 14 | junco, snowbird 15 | indigo bunting, indigo finch, indigo bird, Passerina cyanea 16 | robin, American robin, Turdus migratorius 17 | bulbul 18 | jay 19 | magpie 20 | chickadee 21 | water ouzel, dipper 22 | kite 23 | bald eagle, American eagle, Haliaeetus leucocephalus 24 | vulture 25 | great grey owl, great gray owl, Strix nebulosa 26 | European fire salamander, Salamandra salamandra 27 | common newt, Triturus vulgaris 28 | eft 29 | spotted salamander, Ambystoma maculatum 30 | axolotl, mud puppy, Ambystoma mexicanum 31 | bullfrog, Rana catesbeiana 32 | tree frog, tree-frog 33 | tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui 34 | loggerhead, loggerhead turtle, Caretta caretta 35 | leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea 36 | mud turtle 37 | terrapin 38 | box turtle, box tortoise 39 | banded gecko 40 | common iguana, iguana, Iguana iguana 41 | American chameleon, anole, Anolis carolinensis 42 | whiptail, whiptail lizard 43 | agama 44 | frilled lizard, Chlamydosaurus kingi 45 | alligator lizard 46 | Gila monster, Heloderma suspectum 47 | green lizard, Lacerta viridis 48 | African chameleon, Chamaeleo chamaeleon 49 | Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis 50 | African crocodile, Nile crocodile, Crocodylus niloticus 51 | American alligator, Alligator mississipiensis 52 | triceratops 53 | thunder snake, worm snake, Carphophis amoenus 54 | ringneck snake, ring-necked snake, ring snake 55 | hognose snake, puff adder, sand viper 56 | green snake, grass snake 57 | king snake, kingsnake 58 | garter snake, grass snake 59 | water snake 60 | vine snake 61 | night snake, Hypsiglena torquata 62 | boa constrictor, Constrictor constrictor 63 | rock python, rock snake, Python sebae 64 | Indian cobra, Naja naja 65 | green mamba 66 | sea snake 67 | horned viper, cerastes, sand viper, horned asp, Cerastes cornutus 68 | diamondback, diamondback rattlesnake, Crotalus adamanteus 69 | sidewinder, horned rattlesnake, Crotalus cerastes 70 | trilobite 71 | harvestman, daddy longlegs, Phalangium opilio 72 | scorpion 73 | black and gold garden spider, Argiope aurantia 74 | barn spider, Araneus cavaticus 75 | garden spider, Aranea diademata 76 | black widow, Latrodectus mactans 77 | tarantula 78 | wolf spider, hunting spider 79 | tick 80 | centipede 81 | black grouse 82 | ptarmigan 83 | ruffed grouse, partridge, Bonasa umbellus 84 | prairie chicken, prairie grouse, prairie fowl 85 | peacock 86 | quail 87 | partridge 88 | African grey, African gray, Psittacus erithacus 89 | macaw 90 | sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita 91 | lorikeet 92 | coucal 93 | bee eater 94 | hornbill 95 | hummingbird 96 | jacamar 97 | toucan 98 | drake 99 | red-breasted merganser, Mergus serrator 100 | goose 101 | black swan, Cygnus atratus 102 | tusker 103 | echidna, spiny anteater, anteater 104 | platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus 105 | wallaby, brush kangaroo 106 | koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus 107 | wombat 108 | jellyfish 109 | sea anemone, anemone 110 | brain coral 111 | flatworm, platyhelminth 112 | nematode, nematode worm, roundworm 113 | conch 114 | snail 115 | slug 116 | sea slug, nudibranch 117 | chiton, coat-of-mail shell, sea cradle, polyplacophore 118 | chambered nautilus, pearly nautilus, nautilus 119 | Dungeness crab, Cancer magister 120 | rock crab, Cancer irroratus 121 | fiddler crab 122 | king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica 123 | American lobster, Northern lobster, Maine lobster, Homarus americanus 124 | spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish 125 | crayfish, crawfish, crawdad, crawdaddy 126 | hermit crab 127 | isopod 128 | white stork, Ciconia ciconia 129 | black stork, Ciconia nigra 130 | spoonbill 131 | flamingo 132 | little blue heron, Egretta caerulea 133 | American egret, great white heron, Egretta albus 134 | bittern 135 | crane 136 | limpkin, Aramus pictus 137 | European gallinule, Porphyrio porphyrio 138 | American coot, marsh hen, mud hen, water hen, Fulica americana 139 | bustard 140 | ruddy turnstone, Arenaria interpres 141 | red-backed sandpiper, dunlin, Erolia alpina 142 | redshank, Tringa totanus 143 | dowitcher 144 | oystercatcher, oyster catcher 145 | pelican 146 | king penguin, Aptenodytes patagonica 147 | albatross, mollymawk 148 | grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus 149 | killer whale, killer, orca, grampus, sea wolf, Orcinus orca 150 | dugong, Dugong dugon 151 | sea lion 152 | Chihuahua 153 | Japanese spaniel 154 | Maltese dog, Maltese terrier, Maltese 155 | Pekinese, Pekingese, Peke 156 | Shih-Tzu 157 | Blenheim spaniel 158 | papillon 159 | toy terrier 160 | Rhodesian ridgeback 161 | Afghan hound, Afghan 162 | basset, basset hound 163 | beagle 164 | bloodhound, sleuthhound 165 | bluetick 166 | black-and-tan coonhound 167 | Walker hound, Walker foxhound 168 | English foxhound 169 | redbone 170 | borzoi, Russian wolfhound 171 | Irish wolfhound 172 | Italian greyhound 173 | whippet 174 | Ibizan hound, Ibizan Podenco 175 | Norwegian elkhound, elkhound 176 | otterhound, otter hound 177 | Saluki, gazelle hound 178 | Scottish deerhound, deerhound 179 | Weimaraner 180 | Staffordshire bullterrier, Staffordshire bull terrier 181 | American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier 182 | Bedlington terrier 183 | Border terrier 184 | Kerry blue terrier 185 | Irish terrier 186 | Norfolk terrier 187 | Norwich terrier 188 | Yorkshire terrier 189 | wire-haired fox terrier 190 | Lakeland terrier 191 | Sealyham terrier, Sealyham 192 | Airedale, Airedale terrier 193 | cairn, cairn terrier 194 | Australian terrier 195 | Dandie Dinmont, Dandie Dinmont terrier 196 | Boston bull, Boston terrier 197 | miniature schnauzer 198 | giant schnauzer 199 | standard schnauzer 200 | Scotch terrier, Scottish terrier, Scottie 201 | Tibetan terrier, chrysanthemum dog 202 | silky terrier, Sydney silky 203 | soft-coated wheaten terrier 204 | West Highland white terrier 205 | Lhasa, Lhasa apso 206 | flat-coated retriever 207 | curly-coated retriever 208 | golden retriever 209 | Labrador retriever 210 | Chesapeake Bay retriever 211 | German short-haired pointer 212 | vizsla, Hungarian pointer 213 | English setter 214 | Irish setter, red setter 215 | Gordon setter 216 | Brittany spaniel 217 | clumber, clumber spaniel 218 | English springer, English springer spaniel 219 | Welsh springer spaniel 220 | cocker spaniel, English cocker spaniel, cocker 221 | Sussex spaniel 222 | Irish water spaniel 223 | kuvasz 224 | schipperke 225 | groenendael 226 | malinois 227 | briard 228 | kelpie 229 | komondor 230 | Old English sheepdog, bobtail 231 | Shetland sheepdog, Shetland sheep dog, Shetland 232 | collie 233 | Border collie 234 | Bouvier des Flandres, Bouviers des Flandres 235 | Rottweiler 236 | German shepherd, German shepherd dog, German police dog, alsatian 237 | Doberman, Doberman pinscher 238 | miniature pinscher 239 | Greater Swiss Mountain dog 240 | Bernese mountain dog 241 | Appenzeller 242 | EntleBucher 243 | boxer 244 | bull mastiff 245 | Tibetan mastiff 246 | French bulldog 247 | Great Dane 248 | Saint Bernard, St Bernard 249 | Eskimo dog, husky 250 | malamute, malemute, Alaskan malamute 251 | Siberian husky 252 | dalmatian, coach dog, carriage dog 253 | affenpinscher, monkey pinscher, monkey dog 254 | basenji 255 | pug, pug-dog 256 | Leonberg 257 | Newfoundland, Newfoundland dog 258 | Great Pyrenees 259 | Samoyed, Samoyede 260 | Pomeranian 261 | chow, chow chow 262 | keeshond 263 | Brabancon griffon 264 | Pembroke, Pembroke Welsh corgi 265 | Cardigan, Cardigan Welsh corgi 266 | toy poodle 267 | miniature poodle 268 | standard poodle 269 | Mexican hairless 270 | timber wolf, grey wolf, gray wolf, Canis lupus 271 | white wolf, Arctic wolf, Canis lupus tundrarum 272 | red wolf, maned wolf, Canis rufus, Canis niger 273 | coyote, prairie wolf, brush wolf, Canis latrans 274 | dingo, warrigal, warragal, Canis dingo 275 | dhole, Cuon alpinus 276 | African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus 277 | hyena, hyaena 278 | red fox, Vulpes vulpes 279 | kit fox, Vulpes macrotis 280 | Arctic fox, white fox, Alopex lagopus 281 | grey fox, gray fox, Urocyon cinereoargenteus 282 | tabby, tabby cat 283 | tiger cat 284 | Persian cat 285 | Siamese cat, Siamese 286 | Egyptian cat 287 | cougar, puma, catamount, mountain lion, painter, panther, Felis concolor 288 | lynx, catamount 289 | leopard, Panthera pardus 290 | snow leopard, ounce, Panthera uncia 291 | jaguar, panther, Panthera onca, Felis onca 292 | lion, king of beasts, Panthera leo 293 | tiger, Panthera tigris 294 | cheetah, chetah, Acinonyx jubatus 295 | brown bear, bruin, Ursus arctos 296 | American black bear, black bear, Ursus americanus, Euarctos americanus 297 | ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus 298 | sloth bear, Melursus ursinus, Ursus ursinus 299 | mongoose 300 | meerkat, mierkat 301 | tiger beetle 302 | ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle 303 | ground beetle, carabid beetle 304 | long-horned beetle, longicorn, longicorn beetle 305 | leaf beetle, chrysomelid 306 | dung beetle 307 | rhinoceros beetle 308 | weevil 309 | fly 310 | bee 311 | ant, emmet, pismire 312 | grasshopper, hopper 313 | cricket 314 | walking stick, walkingstick, stick insect 315 | cockroach, roach 316 | mantis, mantid 317 | cicada, cicala 318 | leafhopper 319 | lacewing, lacewing fly 320 | dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk 321 | damselfly 322 | admiral 323 | ringlet, ringlet butterfly 324 | monarch, monarch butterfly, milkweed butterfly, Danaus plexippus 325 | cabbage butterfly 326 | sulphur butterfly, sulfur butterfly 327 | lycaenid, lycaenid butterfly 328 | starfish, sea star 329 | sea urchin 330 | sea cucumber, holothurian 331 | wood rabbit, cottontail, cottontail rabbit 332 | hare 333 | Angora, Angora rabbit 334 | hamster 335 | porcupine, hedgehog 336 | fox squirrel, eastern fox squirrel, Sciurus niger 337 | marmot 338 | beaver 339 | guinea pig, Cavia cobaya 340 | sorrel 341 | zebra 342 | hog, pig, grunter, squealer, Sus scrofa 343 | wild boar, boar, Sus scrofa 344 | warthog 345 | hippopotamus, hippo, river horse, Hippopotamus amphibius 346 | ox 347 | water buffalo, water ox, Asiatic buffalo, Bubalus bubalis 348 | bison 349 | ram, tup 350 | bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis 351 | ibex, Capra ibex 352 | hartebeest 353 | impala, Aepyceros melampus 354 | gazelle 355 | Arabian camel, dromedary, Camelus dromedarius 356 | llama 357 | weasel 358 | mink 359 | polecat, fitch, foulmart, foumart, Mustela putorius 360 | black-footed ferret, ferret, Mustela nigripes 361 | otter 362 | skunk, polecat, wood pussy 363 | badger 364 | armadillo 365 | three-toed sloth, ai, Bradypus tridactylus 366 | orangutan, orang, orangutang, Pongo pygmaeus 367 | gorilla, Gorilla gorilla 368 | chimpanzee, chimp, Pan troglodytes 369 | gibbon, Hylobates lar 370 | siamang, Hylobates syndactylus, Symphalangus syndactylus 371 | guenon, guenon monkey 372 | patas, hussar monkey, Erythrocebus patas 373 | baboon 374 | macaque 375 | langur 376 | colobus, colobus monkey 377 | proboscis monkey, Nasalis larvatus 378 | marmoset 379 | capuchin, ringtail, Cebus capucinus 380 | howler monkey, howler 381 | titi, titi monkey 382 | spider monkey, Ateles geoffroyi 383 | squirrel monkey, Saimiri sciureus 384 | Madagascar cat, ring-tailed lemur, Lemur catta 385 | indri, indris, Indri indri, Indri brevicaudatus 386 | Indian elephant, Elephas maximus 387 | African elephant, Loxodonta africana 388 | lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens 389 | giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca 390 | barracouta, snoek 391 | eel 392 | coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch 393 | rock beauty, Holocanthus tricolor 394 | anemone fish 395 | sturgeon 396 | gar, garfish, garpike, billfish, Lepisosteus osseus 397 | lionfish 398 | puffer, pufferfish, blowfish, globefish 399 | abacus 400 | abaya 401 | academic gown, academic robe, judge's robe 402 | accordion, piano accordion, squeeze box 403 | acoustic guitar 404 | aircraft carrier, carrier, flattop, attack aircraft carrier 405 | airliner 406 | airship, dirigible 407 | altar 408 | ambulance 409 | amphibian, amphibious vehicle 410 | analog clock 411 | apiary, bee house 412 | apron 413 | ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin 414 | assault rifle, assault gun 415 | backpack, back pack, knapsack, packsack, rucksack, haversack 416 | bakery, bakeshop, bakehouse 417 | balance beam, beam 418 | balloon 419 | ballpoint, ballpoint pen, ballpen, Biro 420 | Band Aid 421 | banjo 422 | bannister, banister, balustrade, balusters, handrail 423 | barbell 424 | barber chair 425 | barbershop 426 | barn 427 | barometer 428 | barrel, cask 429 | barrow, garden cart, lawn cart, wheelbarrow 430 | baseball 431 | basketball 432 | bassinet 433 | bassoon 434 | bathing cap, swimming cap 435 | bath towel 436 | bathtub, bathing tub, bath, tub 437 | beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon 438 | beacon, lighthouse, beacon light, pharos 439 | beaker 440 | bearskin, busby, shako 441 | beer bottle 442 | beer glass 443 | bell cote, bell cot 444 | bib 445 | bicycle-built-for-two, tandem bicycle, tandem 446 | bikini, two-piece 447 | binder, ring-binder 448 | binoculars, field glasses, opera glasses 449 | birdhouse 450 | boathouse 451 | bobsled, bobsleigh, bob 452 | bolo tie, bolo, bola tie, bola 453 | bonnet, poke bonnet 454 | bookcase 455 | bookshop, bookstore, bookstall 456 | bottlecap 457 | bow 458 | bow tie, bow-tie, bowtie 459 | brass, memorial tablet, plaque 460 | brassiere, bra, bandeau 461 | breakwater, groin, groyne, mole, bulwark, seawall, jetty 462 | breastplate, aegis, egis 463 | broom 464 | bucket, pail 465 | buckle 466 | bulletproof vest 467 | bullet train, bullet 468 | butcher shop, meat market 469 | cab, hack, taxi, taxicab 470 | caldron, cauldron 471 | candle, taper, wax light 472 | cannon 473 | canoe 474 | can opener, tin opener 475 | cardigan 476 | car mirror 477 | carousel, carrousel, merry-go-round, roundabout, whirligig 478 | carpenter's kit, tool kit 479 | carton 480 | car wheel 481 | cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM 482 | cassette 483 | cassette player 484 | castle 485 | catamaran 486 | CD player 487 | cello, violoncello 488 | cellular telephone, cellular phone, cellphone, cell, mobile phone 489 | chain 490 | chainlink fence 491 | chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour 492 | chain saw, chainsaw 493 | chest 494 | chiffonier, commode 495 | chime, bell, gong 496 | china cabinet, china closet 497 | Christmas stocking 498 | church, church building 499 | cinema, movie theater, movie theatre, movie house, picture palace 500 | cleaver, meat cleaver, chopper 501 | cliff dwelling 502 | cloak 503 | clog, geta, patten, sabot 504 | cocktail shaker 505 | coffee mug 506 | coffeepot 507 | coil, spiral, volute, whorl, helix 508 | combination lock 509 | computer keyboard, keypad 510 | confectionery, confectionary, candy store 511 | container ship, containership, container vessel 512 | convertible 513 | corkscrew, bottle screw 514 | cornet, horn, trumpet, trump 515 | cowboy boot 516 | cowboy hat, ten-gallon hat 517 | cradle 518 | crane 519 | crash helmet 520 | crate 521 | crib, cot 522 | Crock Pot 523 | croquet ball 524 | crutch 525 | cuirass 526 | dam, dike, dyke 527 | desk 528 | desktop computer 529 | dial telephone, dial phone 530 | diaper, nappy, napkin 531 | digital clock 532 | digital watch 533 | dining table, board 534 | dishrag, dishcloth 535 | dishwasher, dish washer, dishwashing machine 536 | disk brake, disc brake 537 | dock, dockage, docking facility 538 | dogsled, dog sled, dog sleigh 539 | dome 540 | doormat, welcome mat 541 | drilling platform, offshore rig 542 | drum, membranophone, tympan 543 | drumstick 544 | dumbbell 545 | Dutch oven 546 | electric fan, blower 547 | electric guitar 548 | electric locomotive 549 | entertainment center 550 | envelope 551 | espresso maker 552 | face powder 553 | feather boa, boa 554 | file, file cabinet, filing cabinet 555 | fireboat 556 | fire engine, fire truck 557 | fire screen, fireguard 558 | flagpole, flagstaff 559 | flute, transverse flute 560 | folding chair 561 | football helmet 562 | forklift 563 | fountain 564 | fountain pen 565 | four-poster 566 | freight car 567 | French horn, horn 568 | frying pan, frypan, skillet 569 | fur coat 570 | garbage truck, dustcart 571 | gasmask, respirator, gas helmet 572 | gas pump, gasoline pump, petrol pump, island dispenser 573 | goblet 574 | go-kart 575 | golf ball 576 | golfcart, golf cart 577 | gondola 578 | gong, tam-tam 579 | gown 580 | grand piano, grand 581 | greenhouse, nursery, glasshouse 582 | grille, radiator grille 583 | grocery store, grocery, food market, market 584 | guillotine 585 | hair slide 586 | hair spray 587 | half track 588 | hammer 589 | hamper 590 | hand blower, blow dryer, blow drier, hair dryer, hair drier 591 | hand-held computer, hand-held microcomputer 592 | handkerchief, hankie, hanky, hankey 593 | hard disc, hard disk, fixed disk 594 | harmonica, mouth organ, harp, mouth harp 595 | harp 596 | harvester, reaper 597 | hatchet 598 | holster 599 | home theater, home theatre 600 | honeycomb 601 | hook, claw 602 | hoopskirt, crinoline 603 | horizontal bar, high bar 604 | horse cart, horse-cart 605 | hourglass 606 | iPod 607 | iron, smoothing iron 608 | jack-o'-lantern 609 | jean, blue jean, denim 610 | jeep, landrover 611 | jersey, T-shirt, tee shirt 612 | jigsaw puzzle 613 | jinrikisha, ricksha, rickshaw 614 | joystick 615 | kimono 616 | knee pad 617 | knot 618 | lab coat, laboratory coat 619 | ladle 620 | lampshade, lamp shade 621 | laptop, laptop computer 622 | lawn mower, mower 623 | lens cap, lens cover 624 | letter opener, paper knife, paperknife 625 | library 626 | lifeboat 627 | lighter, light, igniter, ignitor 628 | limousine, limo 629 | liner, ocean liner 630 | lipstick, lip rouge 631 | Loafer 632 | lotion 633 | loudspeaker, speaker, speaker unit, loudspeaker system, speaker system 634 | loupe, jeweler's loupe 635 | lumbermill, sawmill 636 | magnetic compass 637 | mailbag, postbag 638 | mailbox, letter box 639 | maillot 640 | maillot, tank suit 641 | manhole cover 642 | maraca 643 | marimba, xylophone 644 | mask 645 | matchstick 646 | maypole 647 | maze, labyrinth 648 | measuring cup 649 | medicine chest, medicine cabinet 650 | megalith, megalithic structure 651 | microphone, mike 652 | microwave, microwave oven 653 | military uniform 654 | milk can 655 | minibus 656 | miniskirt, mini 657 | minivan 658 | missile 659 | mitten 660 | mixing bowl 661 | mobile home, manufactured home 662 | Model T 663 | modem 664 | monastery 665 | monitor 666 | moped 667 | mortar 668 | mortarboard 669 | mosque 670 | mosquito net 671 | motor scooter, scooter 672 | mountain bike, all-terrain bike, off-roader 673 | mountain tent 674 | mouse, computer mouse 675 | mousetrap 676 | moving van 677 | muzzle 678 | nail 679 | neck brace 680 | necklace 681 | nipple 682 | notebook, notebook computer 683 | obelisk 684 | oboe, hautboy, hautbois 685 | ocarina, sweet potato 686 | odometer, hodometer, mileometer, milometer 687 | oil filter 688 | organ, pipe organ 689 | oscilloscope, scope, cathode-ray oscilloscope, CRO 690 | overskirt 691 | oxcart 692 | oxygen mask 693 | packet 694 | paddle, boat paddle 695 | paddlewheel, paddle wheel 696 | padlock 697 | paintbrush 698 | pajama, pyjama, pj's, jammies 699 | palace 700 | panpipe, pandean pipe, syrinx 701 | paper towel 702 | parachute, chute 703 | parallel bars, bars 704 | park bench 705 | parking meter 706 | passenger car, coach, carriage 707 | patio, terrace 708 | pay-phone, pay-station 709 | pedestal, plinth, footstall 710 | pencil box, pencil case 711 | pencil sharpener 712 | perfume, essence 713 | Petri dish 714 | photocopier 715 | pick, plectrum, plectron 716 | pickelhaube 717 | picket fence, paling 718 | pickup, pickup truck 719 | pier 720 | piggy bank, penny bank 721 | pill bottle 722 | pillow 723 | ping-pong ball 724 | pinwheel 725 | pirate, pirate ship 726 | pitcher, ewer 727 | plane, carpenter's plane, woodworking plane 728 | planetarium 729 | plastic bag 730 | plate rack 731 | plow, plough 732 | plunger, plumber's helper 733 | Polaroid camera, Polaroid Land camera 734 | pole 735 | police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria 736 | poncho 737 | pool table, billiard table, snooker table 738 | pop bottle, soda bottle 739 | pot, flowerpot 740 | potter's wheel 741 | power drill 742 | prayer rug, prayer mat 743 | printer 744 | prison, prison house 745 | projectile, missile 746 | projector 747 | puck, hockey puck 748 | punching bag, punch bag, punching ball, punchball 749 | purse 750 | quill, quill pen 751 | quilt, comforter, comfort, puff 752 | racer, race car, racing car 753 | racket, racquet 754 | radiator 755 | radio, wireless 756 | radio telescope, radio reflector 757 | rain barrel 758 | recreational vehicle, RV, R.V. 759 | reel 760 | reflex camera 761 | refrigerator, icebox 762 | remote control, remote 763 | restaurant, eating house, eating place, eatery 764 | revolver, six-gun, six-shooter 765 | rifle 766 | rocking chair, rocker 767 | rotisserie 768 | rubber eraser, rubber, pencil eraser 769 | rugby ball 770 | rule, ruler 771 | running shoe 772 | safe 773 | safety pin 774 | saltshaker, salt shaker 775 | sandal 776 | sarong 777 | sax, saxophone 778 | scabbard 779 | scale, weighing machine 780 | school bus 781 | schooner 782 | scoreboard 783 | screen, CRT screen 784 | screw 785 | screwdriver 786 | seat belt, seatbelt 787 | sewing machine 788 | shield, buckler 789 | shoe shop, shoe-shop, shoe store 790 | shoji 791 | shopping basket 792 | shopping cart 793 | shovel 794 | shower cap 795 | shower curtain 796 | ski 797 | ski mask 798 | sleeping bag 799 | slide rule, slipstick 800 | sliding door 801 | slot, one-armed bandit 802 | snorkel 803 | snowmobile 804 | snowplow, snowplough 805 | soap dispenser 806 | soccer ball 807 | sock 808 | solar dish, solar collector, solar furnace 809 | sombrero 810 | soup bowl 811 | space bar 812 | space heater 813 | space shuttle 814 | spatula 815 | speedboat 816 | spider web, spider's web 817 | spindle 818 | sports car, sport car 819 | spotlight, spot 820 | stage 821 | steam locomotive 822 | steel arch bridge 823 | steel drum 824 | stethoscope 825 | stole 826 | stone wall 827 | stopwatch, stop watch 828 | stove 829 | strainer 830 | streetcar, tram, tramcar, trolley, trolley car 831 | stretcher 832 | studio couch, day bed 833 | stupa, tope 834 | submarine, pigboat, sub, U-boat 835 | suit, suit of clothes 836 | sundial 837 | sunglass 838 | sunglasses, dark glasses, shades 839 | sunscreen, sunblock, sun blocker 840 | suspension bridge 841 | swab, swob, mop 842 | sweatshirt 843 | swimming trunks, bathing trunks 844 | swing 845 | switch, electric switch, electrical switch 846 | syringe 847 | table lamp 848 | tank, army tank, armored combat vehicle, armoured combat vehicle 849 | tape player 850 | teapot 851 | teddy, teddy bear 852 | television, television system 853 | tennis ball 854 | thatch, thatched roof 855 | theater curtain, theatre curtain 856 | thimble 857 | thresher, thrasher, threshing machine 858 | throne 859 | tile roof 860 | toaster 861 | tobacco shop, tobacconist shop, tobacconist 862 | toilet seat 863 | torch 864 | totem pole 865 | tow truck, tow car, wrecker 866 | toyshop 867 | tractor 868 | trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi 869 | tray 870 | trench coat 871 | tricycle, trike, velocipede 872 | trimaran 873 | tripod 874 | triumphal arch 875 | trolleybus, trolley coach, trackless trolley 876 | trombone 877 | tub, vat 878 | turnstile 879 | typewriter keyboard 880 | umbrella 881 | unicycle, monocycle 882 | upright, upright piano 883 | vacuum, vacuum cleaner 884 | vase 885 | vault 886 | velvet 887 | vending machine 888 | vestment 889 | viaduct 890 | violin, fiddle 891 | volleyball 892 | waffle iron 893 | wall clock 894 | wallet, billfold, notecase, pocketbook 895 | wardrobe, closet, press 896 | warplane, military plane 897 | washbasin, handbasin, washbowl, lavabo, wash-hand basin 898 | washer, automatic washer, washing machine 899 | water bottle 900 | water jug 901 | water tower 902 | whiskey jug 903 | whistle 904 | wig 905 | window screen 906 | window shade 907 | Windsor tie 908 | wine bottle 909 | wing 910 | wok 911 | wooden spoon 912 | wool, woolen, woollen 913 | worm fence, snake fence, snake-rail fence, Virginia fence 914 | wreck 915 | yawl 916 | yurt 917 | web site, website, internet site, site 918 | comic book 919 | crossword puzzle, crossword 920 | street sign 921 | traffic light, traffic signal, stoplight 922 | book jacket, dust cover, dust jacket, dust wrapper 923 | menu 924 | plate 925 | guacamole 926 | consomme 927 | hot pot, hotpot 928 | trifle 929 | ice cream, icecream 930 | ice lolly, lolly, lollipop, popsicle 931 | French loaf 932 | bagel, beigel 933 | pretzel 934 | cheeseburger 935 | hotdog, hot dog, red hot 936 | mashed potato 937 | head cabbage 938 | broccoli 939 | cauliflower 940 | zucchini, courgette 941 | spaghetti squash 942 | acorn squash 943 | butternut squash 944 | cucumber, cuke 945 | artichoke, globe artichoke 946 | bell pepper 947 | cardoon 948 | mushroom 949 | Granny Smith 950 | strawberry 951 | orange 952 | lemon 953 | fig 954 | pineapple, ananas 955 | banana 956 | jackfruit, jak, jack 957 | custard apple 958 | pomegranate 959 | hay 960 | carbonara 961 | chocolate sauce, chocolate syrup 962 | dough 963 | meat loaf, meatloaf 964 | pizza, pizza pie 965 | potpie 966 | burrito 967 | red wine 968 | espresso 969 | cup 970 | eggnog 971 | alp 972 | bubble 973 | cliff, drop, drop-off 974 | coral reef 975 | geyser 976 | lakeside, lakeshore 977 | promontory, headland, head, foreland 978 | sandbar, sand bar 979 | seashore, coast, seacoast, sea-coast 980 | valley, vale 981 | volcano 982 | ballplayer, baseball player 983 | groom, bridegroom 984 | scuba diver 985 | rapeseed 986 | daisy 987 | yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum 988 | corn 989 | acorn 990 | hip, rose hip, rosehip 991 | buckeye, horse chestnut, conker 992 | coral fungus 993 | agaric 994 | gyromitra 995 | stinkhorn, carrion fungus 996 | earthstar 997 | hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa 998 | bolete 999 | ear, spike, capitulum 1000 | toilet tissue, toilet paper, bathroom tissue -------------------------------------------------------------------------------- /IDKL/grad_cam/main_cnn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | import matplotlib.pyplot as plt 6 | from torchvision import models 7 | from torchvision import transforms 8 | from utils import GradCAM, show_cam_on_image, center_crop_img 9 | 10 | 11 | def main(): 12 | model = models.mobilenet_v3_large(pretrained=True) 13 | target_layers = [model.features[-1]] 14 | 15 | # model = models.vgg16(pretrained=True) 16 | # target_layers = [model.features] 17 | 18 | # model = models.resnet34(pretrained=True) 19 | # target_layers = [model.layer4] 20 | 21 | # model = models.regnet_y_800mf(pretrained=True) 22 | # target_layers = [model.trunk_output] 23 | 24 | # model = models.efficientnet_b0(pretrained=True) 25 | # target_layers = [model.features] 26 | 27 | data_transform = transforms.Compose([transforms.ToTensor(), 28 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 29 | # load image 30 | img_path = "both.png" 31 | assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) 32 | img = Image.open(img_path).convert('RGB') 33 | img = np.array(img, dtype=np.uint8) 34 | # img = center_crop_img(img, 224) 35 | 36 | # [C, H, W] 37 | img_tensor = data_transform(img) 38 | # expand batch dimension 39 | # [C, H, W] -> [N, C, H, W] 40 | input_tensor = torch.unsqueeze(img_tensor, dim=0) 41 | 42 | cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False) 43 | target_category = 281 # tabby, tabby cat 44 | # target_category = 254 # pug, pug-dog 45 | 46 | grayscale_cam = cam(input_tensor=input_tensor, target_category=target_category) 47 | 48 | grayscale_cam = grayscale_cam[0, :] 49 | visualization = show_cam_on_image(img.astype(dtype=np.float32) / 255., 50 | grayscale_cam, 51 | use_rgb=True) 52 | plt.imshow(visualization) 53 | plt.show() 54 | 55 | 56 | if __name__ == '__main__': 57 | main() 58 | -------------------------------------------------------------------------------- /IDKL/grad_cam/main_swin.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import numpy as np 4 | import torch 5 | from PIL import Image 6 | import matplotlib.pyplot as plt 7 | from torchvision import transforms 8 | from utils import GradCAM, show_cam_on_image, center_crop_img 9 | from swin_model import swin_base_patch4_window7_224 10 | 11 | 12 | class ResizeTransform: 13 | def __init__(self, im_h: int, im_w: int): 14 | self.height = self.feature_size(im_h) 15 | self.width = self.feature_size(im_w) 16 | 17 | @staticmethod 18 | def feature_size(s): 19 | s = math.ceil(s / 4) # PatchEmbed 20 | s = math.ceil(s / 2) # PatchMerging1 21 | s = math.ceil(s / 2) # PatchMerging2 22 | s = math.ceil(s / 2) # PatchMerging3 23 | return s 24 | 25 | def __call__(self, x): 26 | result = x.reshape(x.size(0), 27 | self.height, 28 | self.width, 29 | x.size(2)) 30 | 31 | # Bring the channels to the first dimension, 32 | # like in CNNs. 33 | # [batch_size, H, W, C] -> [batch, C, H, W] 34 | result = result.permute(0, 3, 1, 2) 35 | 36 | return result 37 | 38 | 39 | def main(): 40 | # 注意输入的图片必须是32的整数倍 41 | # 否则由于padding的原因会出现注意力飘逸的问题 42 | img_size = 224 43 | assert img_size % 32 == 0 44 | 45 | model = swin_base_patch4_window7_224() 46 | # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth 47 | weights_path = "./swin_base_patch4_window7_224.pth" 48 | model.load_state_dict(torch.load(weights_path, map_location="cpu")["model"], strict=False) 49 | 50 | target_layers = [model.norm] 51 | 52 | data_transform = transforms.Compose([transforms.ToTensor(), 53 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 54 | # load image 55 | img_path = "both.png" 56 | assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) 57 | img = Image.open(img_path).convert('RGB') 58 | img = np.array(img, dtype=np.uint8) 59 | img = center_crop_img(img, img_size) 60 | 61 | # [C, H, W] 62 | img_tensor = data_transform(img) 63 | # expand batch dimension 64 | # [C, H, W] -> [N, C, H, W] 65 | input_tensor = torch.unsqueeze(img_tensor, dim=0) 66 | 67 | cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False, 68 | reshape_transform=ResizeTransform(im_h=img_size, im_w=img_size)) 69 | target_category = 281 # tabby, tabby cat 70 | # target_category = 254 # pug, pug-dog 71 | 72 | grayscale_cam = cam(input_tensor=input_tensor, target_category=target_category) 73 | 74 | grayscale_cam = grayscale_cam[0, :] 75 | visualization = show_cam_on_image(img / 255., grayscale_cam, use_rgb=True) 76 | plt.imshow(visualization) 77 | plt.show() 78 | 79 | 80 | if __name__ == '__main__': 81 | main() 82 | -------------------------------------------------------------------------------- /IDKL/grad_cam/main_vit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | import matplotlib.pyplot as plt 6 | from torchvision import transforms 7 | from utils import GradCAM, show_cam_on_image, center_crop_img 8 | from vit_model import vit_base_patch16_224 9 | 10 | 11 | class ReshapeTransform: 12 | def __init__(self, model): 13 | input_size = model.patch_embed.img_size 14 | patch_size = model.patch_embed.patch_size 15 | self.h = input_size[0] // patch_size[0] 16 | self.w = input_size[1] // patch_size[1] 17 | 18 | def __call__(self, x): 19 | # remove cls token and reshape 20 | # [batch_size, num_tokens, token_dim] 21 | result = x[:, 1:, :].reshape(x.size(0), 22 | self.h, 23 | self.w, 24 | x.size(2)) 25 | 26 | # Bring the channels to the first dimension, 27 | # like in CNNs. 28 | # [batch_size, H, W, C] -> [batch, C, H, W] 29 | result = result.permute(0, 3, 1, 2) 30 | return result 31 | 32 | 33 | def main(): 34 | model = vit_base_patch16_224() 35 | # 链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA 密码: eu9f 36 | weights_path = "./vit_base_patch16_224.pth" 37 | model.load_state_dict(torch.load(weights_path, map_location="cpu")) 38 | # Since the final classification is done on the class token computed in the last attention block, 39 | # the output will not be affected by the 14x14 channels in the last layer. 40 | # The gradient of the output with respect to them, will be 0! 41 | # We should chose any layer before the final attention block. 42 | target_layers = [model.blocks[-1].norm1] 43 | 44 | data_transform = transforms.Compose([transforms.ToTensor(), 45 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 46 | # load image 47 | img_path = "both.png" 48 | assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) 49 | img = Image.open(img_path).convert('RGB') 50 | img = np.array(img, dtype=np.uint8) 51 | img = center_crop_img(img, 224) 52 | # [C, H, W] 53 | img_tensor = data_transform(img) 54 | # expand batch dimension 55 | # [C, H, W] -> [N, C, H, W] 56 | input_tensor = torch.unsqueeze(img_tensor, dim=0) 57 | 58 | cam = GradCAM(model=model, 59 | target_layers=target_layers, 60 | use_cuda=False, 61 | reshape_transform=ReshapeTransform(model)) 62 | target_category = 281 # tabby, tabby cat 63 | # target_category = 254 # pug, pug-dog 64 | 65 | grayscale_cam = cam(input_tensor=input_tensor, target_category=target_category) 66 | 67 | grayscale_cam = grayscale_cam[0, :] 68 | visualization = show_cam_on_image(img / 255., grayscale_cam, use_rgb=True) 69 | plt.imshow(visualization) 70 | plt.show() 71 | 72 | 73 | if __name__ == '__main__': 74 | main() 75 | -------------------------------------------------------------------------------- /IDKL/grad_cam/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | class ActivationsAndGradients: 6 | """ Class for extracting activations and 7 | registering gradients from targeted intermediate layers """ 8 | 9 | def __init__(self, model, target_layers, reshape_transform): 10 | self.model = model 11 | self.gradients = [] 12 | self.activations = [] 13 | self.reshape_transform = reshape_transform 14 | self.handles = [] 15 | for target_layer in target_layers: 16 | self.handles.append( 17 | target_layer.register_forward_hook( 18 | self.save_activation)) 19 | # Backward compatibility with older pytorch versions: 20 | if hasattr(target_layer, 'register_full_backward_hook'): 21 | self.handles.append( 22 | target_layer.register_full_backward_hook( 23 | self.save_gradient)) 24 | else: 25 | self.handles.append( 26 | target_layer.register_backward_hook( 27 | self.save_gradient)) 28 | 29 | def save_activation(self, module, input, output): 30 | activation = output 31 | if self.reshape_transform is not None: 32 | activation = self.reshape_transform(activation) 33 | self.activations.append(activation.cpu().detach()) 34 | 35 | def save_gradient(self, module, grad_input, grad_output): 36 | # Gradients are computed in reverse order 37 | grad = grad_output[0] 38 | if self.reshape_transform is not None: 39 | grad = self.reshape_transform(grad) 40 | self.gradients = [grad.cpu().detach()] + self.gradients 41 | 42 | def __call__(self, x): 43 | self.gradients = [] 44 | self.activations = [] 45 | return self.model(x) 46 | 47 | def release(self): 48 | for handle in self.handles: 49 | handle.remove() 50 | 51 | 52 | class GradCAM: 53 | def __init__(self, 54 | model, 55 | target_layers, 56 | reshape_transform=None, 57 | use_cuda=False): 58 | self.model = model.eval() 59 | self.target_layers = target_layers 60 | self.reshape_transform = reshape_transform 61 | self.cuda = use_cuda 62 | if self.cuda: 63 | self.model = model.cuda() 64 | self.activations_and_grads = ActivationsAndGradients( 65 | self.model, target_layers, reshape_transform) 66 | 67 | """ Get a vector of weights for every channel in the target layer. 68 | Methods that return weights channels, 69 | will typically need to only implement this function. """ 70 | 71 | @staticmethod 72 | def get_cam_weights(grads): 73 | return np.mean(grads, axis=(2, 3), keepdims=True) 74 | 75 | @staticmethod 76 | def get_loss(output, target_category): 77 | loss = 0 78 | for i in range(len(target_category)): 79 | loss = loss + output[i, target_category[i]] 80 | return loss 81 | 82 | def get_cam_image(self, activations, grads): 83 | weights = self.get_cam_weights(grads) 84 | weighted_activations = weights * activations 85 | cam = weighted_activations.sum(axis=1) 86 | 87 | return cam 88 | 89 | @staticmethod 90 | def get_target_width_height(input_tensor): 91 | width, height = input_tensor.size(-1), input_tensor.size(-2) 92 | return width, height 93 | 94 | def compute_cam_per_layer(self, input_tensor): 95 | activations_list = [a.cpu().data.numpy() 96 | for a in self.activations_and_grads.activations] 97 | grads_list = [g.cpu().data.numpy() 98 | for g in self.activations_and_grads.gradients] 99 | target_size = self.get_target_width_height(input_tensor) 100 | 101 | cam_per_target_layer = [] 102 | # Loop over the saliency image from every layer 103 | 104 | for layer_activations, layer_grads in zip(activations_list, grads_list): 105 | cam = self.get_cam_image(layer_activations, layer_grads) 106 | #cam = cam*2-cam.mean() 107 | cam[cam < 0] = 0 # works like mute the min-max scale in the function of scale_cam_image 108 | #cam = cam**0.6 109 | 110 | #cam = np.exp(cam)-np.exp() 111 | 112 | scaled = self.scale_cam_image(cam, target_size) 113 | cam_per_target_layer.append(scaled[:, None, :]) 114 | 115 | return cam_per_target_layer 116 | 117 | def aggregate_multi_layers(self, cam_per_target_layer): 118 | cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1) 119 | cam_per_target_layer = np.maximum(cam_per_target_layer, 0) 120 | result = np.mean(cam_per_target_layer, axis=1) 121 | return self.scale_cam_image(result) 122 | 123 | @staticmethod 124 | def scale_cam_image(cam, target_size=None): 125 | result = [] 126 | for img in cam: 127 | img = img - np.min(img) 128 | img = img / (1e-7 + np.max(img)) 129 | if target_size is not None: 130 | img = cv2.resize(img, target_size) 131 | result.append(img) 132 | result = np.float32(result) 133 | 134 | return result 135 | 136 | def __call__(self, input_tensor, target_category=None): 137 | 138 | if self.cuda: 139 | input_tensor = input_tensor.cuda() 140 | 141 | # 正向传播得到网络输出logits(未经过softmax) 142 | output = self.activations_and_grads(input_tensor) 143 | if isinstance(target_category, int): 144 | target_category = [target_category] * input_tensor.size(0) 145 | 146 | if target_category is None: 147 | target_category = np.argmax(output.cpu().data.numpy(), axis=-1) 148 | print(f"category id: {target_category}") 149 | else: 150 | assert (len(target_category) == input_tensor.size(0)) 151 | 152 | self.model.zero_grad() 153 | loss = self.get_loss(output, target_category) 154 | loss.backward(retain_graph=True) 155 | 156 | # In most of the saliency attribution papers, the saliency is 157 | # computed with a single target layer. 158 | # Commonly it is the last convolutional layer. 159 | # Here we support passing a list with multiple target layers. 160 | # It will compute the saliency image for every image, 161 | # and then aggregate them (with a default mean aggregation). 162 | # This gives you more flexibility in case you just want to 163 | # use all conv layers for example, all Batchnorm layers, 164 | # or something else. 165 | cam_per_layer = self.compute_cam_per_layer(input_tensor) 166 | return self.aggregate_multi_layers(cam_per_layer) 167 | 168 | def __del__(self): 169 | self.activations_and_grads.release() 170 | 171 | def __enter__(self): 172 | return self 173 | 174 | def __exit__(self, exc_type, exc_value, exc_tb): 175 | self.activations_and_grads.release() 176 | if isinstance(exc_value, IndexError): 177 | # Handle IndexError here... 178 | print( 179 | f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}") 180 | return True 181 | 182 | 183 | def show_cam_on_image(img: np.ndarray, 184 | mask: np.ndarray, 185 | use_rgb: bool = False, 186 | colormap: int = cv2.COLORMAP_JET) -> np.ndarray: 187 | """ This function overlays the cam mask on the image as an heatmap. 188 | By default the heatmap is in BGR format. 189 | 190 | :param img: The base image in RGB or BGR format. 191 | :param mask: The cam mask. 192 | :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. 193 | :param colormap: The OpenCV colormap to be used. 194 | :returns: The default image with the cam overlay. 195 | """ 196 | 197 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) 198 | if use_rgb: 199 | heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) 200 | heatmap = np.float32(heatmap) / 255 201 | 202 | if np.max(img) > 1: 203 | raise Exception( 204 | "The input image should np.float32 in the range [0, 1]") 205 | 206 | cam = heatmap * 0.4 + img * 0.6 #heatmap + img 207 | cam = cam / np.max(cam) 208 | return np.uint8(255 * cam) 209 | 210 | 211 | def center_crop_img(img: np.ndarray, size: int): 212 | h, w, c = img.shape 213 | 214 | if w == h == size: 215 | return img 216 | 217 | if w < h: 218 | ratio = size / w 219 | new_w = size 220 | new_h = int(h * ratio) 221 | else: 222 | ratio = size / h 223 | new_h = size 224 | new_w = int(w * ratio) 225 | 226 | img = cv2.resize(img, dsize=(new_w, new_h)) 227 | 228 | if new_w == size: 229 | h = (new_h - size) // 2 230 | img = img[h: h+size] 231 | else: 232 | w = (new_w - size) // 2 233 | img = img[:, w: w+size] 234 | 235 | return img 236 | -------------------------------------------------------------------------------- /IDKL/grad_cam/vit_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | original code from rwightman: 3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 4 | """ 5 | from functools import partial 6 | from collections import OrderedDict 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | def drop_path(x, drop_prob: float = 0., training: bool = False): 13 | """ 14 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 15 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 16 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 17 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 18 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 19 | 'survival rate' as the argument. 20 | """ 21 | if drop_prob == 0. or not training: 22 | return x 23 | keep_prob = 1 - drop_prob 24 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 25 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 26 | random_tensor.floor_() # binarize 27 | output = x.div(keep_prob) * random_tensor 28 | return output 29 | 30 | 31 | class DropPath(nn.Module): 32 | """ 33 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 34 | """ 35 | def __init__(self, drop_prob=None): 36 | super(DropPath, self).__init__() 37 | self.drop_prob = drop_prob 38 | 39 | def forward(self, x): 40 | return drop_path(x, self.drop_prob, self.training) 41 | 42 | 43 | class PatchEmbed(nn.Module): 44 | """ 45 | 2D Image to Patch Embedding 46 | """ 47 | def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None): 48 | super().__init__() 49 | img_size = (img_size, img_size) 50 | patch_size = (patch_size, patch_size) 51 | self.img_size = img_size 52 | self.patch_size = patch_size 53 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 54 | self.num_patches = self.grid_size[0] * self.grid_size[1] 55 | 56 | self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size) 57 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 58 | 59 | def forward(self, x): 60 | B, C, H, W = x.shape 61 | assert H == self.img_size[0] and W == self.img_size[1], \ 62 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 63 | 64 | # flatten: [B, C, H, W] -> [B, C, HW] 65 | # transpose: [B, C, HW] -> [B, HW, C] 66 | x = self.proj(x).flatten(2).transpose(1, 2) 67 | x = self.norm(x) 68 | return x 69 | 70 | 71 | class Attention(nn.Module): 72 | def __init__(self, 73 | dim, # 输入token的dim 74 | num_heads=8, 75 | qkv_bias=False, 76 | qk_scale=None, 77 | attn_drop_ratio=0., 78 | proj_drop_ratio=0.): 79 | super(Attention, self).__init__() 80 | self.num_heads = num_heads 81 | head_dim = dim // num_heads 82 | self.scale = qk_scale or head_dim ** -0.5 83 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 84 | self.attn_drop = nn.Dropout(attn_drop_ratio) 85 | self.proj = nn.Linear(dim, dim) 86 | self.proj_drop = nn.Dropout(proj_drop_ratio) 87 | 88 | def forward(self, x): 89 | # [batch_size, num_patches + 1, total_embed_dim] 90 | B, N, C = x.shape 91 | 92 | # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim] 93 | # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head] 94 | # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head] 95 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 96 | # [batch_size, num_heads, num_patches + 1, embed_dim_per_head] 97 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 98 | 99 | # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1] 100 | # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1] 101 | attn = (q @ k.transpose(-2, -1)) * self.scale 102 | attn = attn.softmax(dim=-1) 103 | attn = self.attn_drop(attn) 104 | 105 | # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head] 106 | # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head] 107 | # reshape: -> [batch_size, num_patches + 1, total_embed_dim] 108 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 109 | x = self.proj(x) 110 | x = self.proj_drop(x) 111 | return x 112 | 113 | 114 | class Mlp(nn.Module): 115 | """ 116 | MLP as used in Vision Transformer, MLP-Mixer and related networks 117 | """ 118 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 119 | super().__init__() 120 | out_features = out_features or in_features 121 | hidden_features = hidden_features or in_features 122 | self.fc1 = nn.Linear(in_features, hidden_features) 123 | self.act = act_layer() 124 | self.fc2 = nn.Linear(hidden_features, out_features) 125 | self.drop = nn.Dropout(drop) 126 | 127 | def forward(self, x): 128 | x = self.fc1(x) 129 | x = self.act(x) 130 | x = self.drop(x) 131 | x = self.fc2(x) 132 | x = self.drop(x) 133 | return x 134 | 135 | 136 | class Block(nn.Module): 137 | def __init__(self, 138 | dim, 139 | num_heads, 140 | mlp_ratio=4., 141 | qkv_bias=False, 142 | qk_scale=None, 143 | drop_ratio=0., 144 | attn_drop_ratio=0., 145 | drop_path_ratio=0., 146 | act_layer=nn.GELU, 147 | norm_layer=nn.LayerNorm): 148 | super(Block, self).__init__() 149 | self.norm1 = norm_layer(dim) 150 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 151 | attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio) 152 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 153 | self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity() 154 | self.norm2 = norm_layer(dim) 155 | mlp_hidden_dim = int(dim * mlp_ratio) 156 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio) 157 | 158 | def forward(self, x): 159 | x = x + self.drop_path(self.attn(self.norm1(x))) 160 | x = x + self.drop_path(self.mlp(self.norm2(x))) 161 | return x 162 | 163 | 164 | class VisionTransformer(nn.Module): 165 | def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000, 166 | embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, 167 | qk_scale=None, representation_size=None, distilled=False, drop_ratio=0., 168 | attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None, 169 | act_layer=None): 170 | """ 171 | Args: 172 | img_size (int, tuple): input image size 173 | patch_size (int, tuple): patch size 174 | in_c (int): number of input channels 175 | num_classes (int): number of classes for classification head 176 | embed_dim (int): embedding dimension 177 | depth (int): depth of transformer 178 | num_heads (int): number of attention heads 179 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 180 | qkv_bias (bool): enable bias for qkv if True 181 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 182 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 183 | distilled (bool): model includes a distillation token and head as in DeiT models 184 | drop_ratio (float): dropout rate 185 | attn_drop_ratio (float): attention dropout rate 186 | drop_path_ratio (float): stochastic depth rate 187 | embed_layer (nn.Module): patch embedding layer 188 | norm_layer: (nn.Module): normalization layer 189 | """ 190 | super(VisionTransformer, self).__init__() 191 | self.num_classes = num_classes 192 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 193 | self.num_tokens = 2 if distilled else 1 194 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 195 | act_layer = act_layer or nn.GELU 196 | 197 | self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim) 198 | num_patches = self.patch_embed.num_patches 199 | 200 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 201 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None 202 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 203 | self.pos_drop = nn.Dropout(p=drop_ratio) 204 | 205 | dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule 206 | self.blocks = nn.Sequential(*[ 207 | Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 208 | drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i], 209 | norm_layer=norm_layer, act_layer=act_layer) 210 | for i in range(depth) 211 | ]) 212 | self.norm = norm_layer(embed_dim) 213 | 214 | # Representation layer 215 | if representation_size and not distilled: 216 | self.has_logits = True 217 | self.num_features = representation_size 218 | self.pre_logits = nn.Sequential(OrderedDict([ 219 | ("fc", nn.Linear(embed_dim, representation_size)), 220 | ("act", nn.Tanh()) 221 | ])) 222 | else: 223 | self.has_logits = False 224 | self.pre_logits = nn.Identity() 225 | 226 | # Classifier head(s) 227 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 228 | self.head_dist = None 229 | if distilled: 230 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 231 | 232 | # Weight init 233 | nn.init.trunc_normal_(self.pos_embed, std=0.02) 234 | if self.dist_token is not None: 235 | nn.init.trunc_normal_(self.dist_token, std=0.02) 236 | 237 | nn.init.trunc_normal_(self.cls_token, std=0.02) 238 | self.apply(_init_vit_weights) 239 | 240 | def forward_features(self, x): 241 | # [B, C, H, W] -> [B, num_patches, embed_dim] 242 | x = self.patch_embed(x) # [B, 196, 768] 243 | # [1, 1, 768] -> [B, 1, 768] 244 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) 245 | if self.dist_token is None: 246 | x = torch.cat((cls_token, x), dim=1) # [B, 197, 768] 247 | else: 248 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) 249 | 250 | x = self.pos_drop(x + self.pos_embed) 251 | x = self.blocks(x) 252 | x = self.norm(x) 253 | if self.dist_token is None: 254 | return self.pre_logits(x[:, 0]) 255 | else: 256 | return x[:, 0], x[:, 1] 257 | 258 | def forward(self, x): 259 | x = self.forward_features(x) 260 | if self.head_dist is not None: 261 | x, x_dist = self.head(x[0]), self.head_dist(x[1]) 262 | if self.training and not torch.jit.is_scripting(): 263 | # during inference, return the average of both classifier predictions 264 | return x, x_dist 265 | else: 266 | return (x + x_dist) / 2 267 | else: 268 | x = self.head(x) 269 | return x 270 | 271 | 272 | def _init_vit_weights(m): 273 | """ 274 | ViT weight initialization 275 | :param m: module 276 | """ 277 | if isinstance(m, nn.Linear): 278 | nn.init.trunc_normal_(m.weight, std=.01) 279 | if m.bias is not None: 280 | nn.init.zeros_(m.bias) 281 | elif isinstance(m, nn.Conv2d): 282 | nn.init.kaiming_normal_(m.weight, mode="fan_out") 283 | if m.bias is not None: 284 | nn.init.zeros_(m.bias) 285 | elif isinstance(m, nn.LayerNorm): 286 | nn.init.zeros_(m.bias) 287 | nn.init.ones_(m.weight) 288 | 289 | 290 | def vit_base_patch16_224(num_classes: int = 1000): 291 | """ 292 | ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 293 | ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer. 294 | weights ported from official Google JAX impl: 295 | 链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA 密码: eu9f 296 | """ 297 | model = VisionTransformer(img_size=224, 298 | patch_size=16, 299 | embed_dim=768, 300 | depth=12, 301 | num_heads=12, 302 | representation_size=None, 303 | num_classes=num_classes) 304 | return model 305 | 306 | 307 | def vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True): 308 | """ 309 | ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 310 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 311 | weights ported from official Google JAX impl: 312 | https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth 313 | """ 314 | model = VisionTransformer(img_size=224, 315 | patch_size=16, 316 | embed_dim=768, 317 | depth=12, 318 | num_heads=12, 319 | representation_size=768 if has_logits else None, 320 | num_classes=num_classes) 321 | return model 322 | 323 | 324 | def vit_base_patch32_224(num_classes: int = 1000): 325 | """ 326 | ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). 327 | ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer. 328 | weights ported from official Google JAX impl: 329 | 链接: https://pan.baidu.com/s/1hCv0U8pQomwAtHBYc4hmZg 密码: s5hl 330 | """ 331 | model = VisionTransformer(img_size=224, 332 | patch_size=32, 333 | embed_dim=768, 334 | depth=12, 335 | num_heads=12, 336 | representation_size=None, 337 | num_classes=num_classes) 338 | return model 339 | 340 | 341 | def vit_base_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True): 342 | """ 343 | ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). 344 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 345 | weights ported from official Google JAX impl: 346 | https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth 347 | """ 348 | model = VisionTransformer(img_size=224, 349 | patch_size=32, 350 | embed_dim=768, 351 | depth=12, 352 | num_heads=12, 353 | representation_size=768 if has_logits else None, 354 | num_classes=num_classes) 355 | return model 356 | 357 | 358 | def vit_large_patch16_224(num_classes: int = 1000): 359 | """ 360 | ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). 361 | ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer. 362 | weights ported from official Google JAX impl: 363 | 链接: https://pan.baidu.com/s/1cxBgZJJ6qUWPSBNcE4TdRQ 密码: qqt8 364 | """ 365 | model = VisionTransformer(img_size=224, 366 | patch_size=16, 367 | embed_dim=1024, 368 | depth=24, 369 | num_heads=16, 370 | representation_size=None, 371 | num_classes=num_classes) 372 | return model 373 | 374 | 375 | def vit_large_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True): 376 | """ 377 | ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). 378 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 379 | weights ported from official Google JAX impl: 380 | https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth 381 | """ 382 | model = VisionTransformer(img_size=224, 383 | patch_size=16, 384 | embed_dim=1024, 385 | depth=24, 386 | num_heads=16, 387 | representation_size=1024 if has_logits else None, 388 | num_classes=num_classes) 389 | return model 390 | 391 | 392 | def vit_large_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True): 393 | """ 394 | ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). 395 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 396 | weights ported from official Google JAX impl: 397 | https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth 398 | """ 399 | model = VisionTransformer(img_size=224, 400 | patch_size=32, 401 | embed_dim=1024, 402 | depth=24, 403 | num_heads=16, 404 | representation_size=1024 if has_logits else None, 405 | num_classes=num_classes) 406 | return model 407 | 408 | 409 | def vit_huge_patch14_224_in21k(num_classes: int = 21843, has_logits: bool = True): 410 | """ 411 | ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). 412 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 413 | NOTE: converted weights not currently available, too large for github release hosting. 414 | """ 415 | model = VisionTransformer(img_size=224, 416 | patch_size=14, 417 | embed_dim=1280, 418 | depth=32, 419 | num_heads=16, 420 | representation_size=1280 if has_logits else None, 421 | num_classes=num_classes) 422 | return model 423 | -------------------------------------------------------------------------------- /IDKL/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from layers.loss.am_softmax import AMSoftmaxLoss 2 | from layers.loss.center_loss import CenterLoss 3 | from layers.loss.triplet_loss import TripletLoss 4 | from layers.loss.rerank_loss import RerankLoss 5 | from layers.loss.local_center_loss import CenterTripletLoss 6 | from layers.module.norm_linear import NormalizeLinear 7 | from layers.module.reverse_grad import ReverseGrad 8 | from layers.loss.JSD import js_div 9 | from layers.module.CBAM import cbam 10 | from layers.module.NonLocal import NonLocalBlockND 11 | 12 | 13 | __all__ = ['RerankLoss','CenterLoss', 'CenterTripletLoss', 'AMSoftmaxLoss', 'TripletLoss', 'NormalizeLinear', 'js_div', 'cbam', 'NonLocalBlockND'] -------------------------------------------------------------------------------- /IDKL/layers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/layers/loss/JSD.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | # import torch.softmax as softmax 3 | from torch.nn import functional as F 4 | 5 | class js_div: 6 | def __init__(self): 7 | self.KLDivLoss = nn.KLDivLoss(reduction='batchmean') 8 | 9 | def __call__(self, p_output, q_output, get_softmax=True): 10 | """ 11 | Function that measures JS divergence between target and output logits: 12 | """ 13 | if get_softmax: 14 | p_output = F.softmax(p_output, 1) 15 | q_output = F.softmax(q_output, 1) 16 | log_mean_output = ((p_output + q_output) / 2).log() 17 | return (self.KLDivLoss(log_mean_output, p_output) + self.KLDivLoss(log_mean_output, q_output))/2 -------------------------------------------------------------------------------- /IDKL/layers/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/loss/__init__.py -------------------------------------------------------------------------------- /IDKL/layers/loss/__pycache__/JSD.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/loss/__pycache__/JSD.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/layers/loss/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/loss/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/layers/loss/__pycache__/am_softmax.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/loss/__pycache__/am_softmax.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/layers/loss/__pycache__/center_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/loss/__pycache__/center_loss.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/layers/loss/__pycache__/crossquad_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/loss/__pycache__/crossquad_loss.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/layers/loss/__pycache__/crosstriplet_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/loss/__pycache__/crosstriplet_loss.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/layers/loss/__pycache__/local_center_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/loss/__pycache__/local_center_loss.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/layers/loss/__pycache__/mixtriplet_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/loss/__pycache__/mixtriplet_loss.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/layers/loss/__pycache__/trapezoid_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/loss/__pycache__/trapezoid_loss.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/layers/loss/__pycache__/triplet_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/loss/__pycache__/triplet_loss.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/layers/loss/am_softmax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class AMSoftmaxLoss(nn.Module): 8 | def __init__(self, scale, margin, weight=None, ignore_index=-100, reduction='mean'): 9 | super(AMSoftmaxLoss, self).__init__() 10 | self.weight = weight 11 | self.ignore_index = ignore_index 12 | self.reduction = reduction 13 | self.scale = scale 14 | self.margin = margin 15 | 16 | def forward(self, x, y): 17 | y_onehot = torch.zeros_like(x, device=x.device) 18 | y_onehot.scatter_(1, y.data.view(-1, 1), self.margin) 19 | 20 | out = self.scale * (x - y_onehot) 21 | loss = F.cross_entropy(out, y, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction) 22 | 23 | return loss 24 | -------------------------------------------------------------------------------- /IDKL/layers/loss/center_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class CenterLoss(nn.Module): 6 | """Center loss. 7 | 8 | Reference: 9 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 10 | 11 | Args: 12 | num_classes (int): number of classes. 13 | feat_dim (int): feature dimension. 14 | """ 15 | 16 | def __init__(self, num_classes, feat_dim, reduction='mean'): 17 | super(CenterLoss, self).__init__() 18 | self.num_classes = num_classes 19 | self.feat_dim = feat_dim 20 | self.reduction = reduction 21 | 22 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 23 | 24 | def forward(self, x, labels): 25 | """ 26 | Args: 27 | x: feature matrix with shape (batch_size, feat_dim). 28 | labels: ground truth labels with shape (batch_size). 29 | """ 30 | batch_size = x.size(0) 31 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 32 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 33 | distmat.addmm_(1, -2, x, self.centers.t()) 34 | 35 | classes = torch.arange(self.num_classes).to(device=x.device, dtype=torch.long) 36 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 37 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 38 | 39 | loss = distmat * mask.float() 40 | 41 | if self.reduction == 'mean': 42 | loss = loss.mean() 43 | elif self.reduction == 'sum': 44 | loss = loss.sum() 45 | 46 | return loss 47 | -------------------------------------------------------------------------------- /IDKL/layers/loss/local_center_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class CenterTripletLoss(nn.Module): 5 | def __init__(self, k_size, margin=0): 6 | super(CenterTripletLoss, self).__init__() 7 | self.margin = margin 8 | self.k_size = k_size 9 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 10 | 11 | def forward(self, inputs, targets): 12 | n = inputs.size(0) 13 | 14 | # Come to centers 15 | centers = [] 16 | for i in range(n): 17 | centers.append(inputs[targets == targets[i]].mean(0)) 18 | centers = torch.stack(centers) 19 | 20 | dist_pc = (inputs - centers)**2 21 | dist_pc = dist_pc.sum(1) 22 | dist_pc = dist_pc.sqrt() 23 | 24 | # Compute pairwise distance, replace by the official when merged 25 | dist = torch.pow(centers, 2).sum(dim=1, keepdim=True).expand(n, n) 26 | dist = dist + dist.t() 27 | dist.addmm_(1, -2, centers, centers.t()) 28 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 29 | 30 | # For each anchor, find the hardest positive and negative 31 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 32 | dist_an, dist_ap = [], [] 33 | for i in range(0, n, self.k_size): 34 | dist_an.append( (self.margin - dist[i][mask[i] == 0]).clamp(min=0.0).mean() ) 35 | dist_an = torch.stack(dist_an) 36 | 37 | # Compute ranking hinge loss 38 | y = dist_an.data.new() 39 | y.resize_as_(dist_an.data) 40 | y.fill_(1) 41 | loss = dist_pc.mean() + dist_an.mean() 42 | return loss, dist_pc.mean(), dist_an.mean() 43 | -------------------------------------------------------------------------------- /IDKL/layers/loss/rerank_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from utils.rerank import pairwise_distance 4 | 5 | def intersect1d(tensor1, tensor2): 6 | return torch.unique(torch.cat([tensor1[tensor1 == val] for val in tensor2])) 7 | 8 | # def rerank_vc(feat1, feat2, k1=20, k2=6, lambda_value=0.3, eval_type=True): #q_feat, g_feat ############代码结果不知正确与否,但没有增加显存了,aini 9 | # feats = torch.cat([feat1, feat2], 0) 10 | # 11 | # dist = torch.cdist(feats, feats) 12 | # original_dist = dist.clone() 13 | # all_num = original_dist.shape[0] 14 | # original_dist = (original_dist / original_dist.max(dim=0, keepdim=True).values).transpose(0, 1) 15 | # 16 | # V = torch.zeros_like(original_dist) 17 | # 18 | # query_num = feat1.size(0) 19 | # if eval_type: 20 | # max_val = dist.max() 21 | # dist = torch.cat((dist[:, :query_num], max_val.expand_as(dist[:, query_num:])), dim=1) 22 | # initial_rank = torch.argsort(dist, dim=1) 23 | # 24 | # for i in range(all_num): 25 | # forward_k_neigh_index = initial_rank[i, :k1 + 1] 26 | # backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 27 | # fi = (backward_k_neigh_index == i).nonzero(as_tuple=True)[0] 28 | # k_reciprocal_index = forward_k_neigh_index[fi] 29 | # k_reciprocal_expansion_index = k_reciprocal_index 30 | # 31 | # for j in k_reciprocal_index: 32 | # candidate = j 33 | # candidate_forward_k_neigh_index = initial_rank[candidate, :int(round(k1 / 2)) + 1] 34 | # candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, :int(round(k1 / 2)) + 1] 35 | # fi_candidate = (candidate_backward_k_neigh_index == candidate).nonzero(as_tuple=True)[0] 36 | # candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 37 | # if len(intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len( 38 | # candidate_k_reciprocal_index): 39 | # k_reciprocal_expansion_index = torch.unique( 40 | # torch.cat([k_reciprocal_expansion_index, candidate_k_reciprocal_index], dim=0)) 41 | # 42 | # weight = torch.exp(-original_dist[i, k_reciprocal_expansion_index]) 43 | # V[i, k_reciprocal_expansion_index] = weight / torch.sum(weight) 44 | # 45 | # original_dist = original_dist[:query_num, ] 46 | # if k2 != 1: 47 | # V_qe = torch.zeros_like(V) 48 | # for i in range(all_num): 49 | # V_qe[i, :] = torch.mean(V[initial_rank[i, :k2], :], dim=0) 50 | # V = V_qe 51 | # 52 | # invIndex = [] 53 | # for i in range(all_num): 54 | # invIndex.append((V[:, i] != 0).nonzero(as_tuple=True)[0]) 55 | # 56 | # jaccard_dist = torch.zeros_like(original_dist) 57 | # 58 | # for i in range(query_num): 59 | # temp_min = torch.zeros([1, all_num]).cuda() 60 | # indNonZero = (V[i, :] != 0).nonzero(as_tuple=True)[0] 61 | # indImages = [invIndex[ind] for ind in indNonZero] 62 | # for j, val in enumerate(indNonZero): 63 | # temp_min[0, indImages[j]] += torch.minimum(V[i, val], V[indImages[j], val]) 64 | # 65 | # jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 66 | # 67 | # final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value 68 | # final_dist = final_dist[:query_num, query_num:] 69 | # 70 | # return final_dist 71 | 72 | def rerank_dist(feat1, feat2, k1=20, k2=6, lambda_value=0.3, eval_type=True): #q_feat, g_feat 73 | 74 | #with torch.no_grad(): 75 | feats = torch.cat([feat1, feat2], 0) ####### 76 | dist = pairwise_distance(feats, feats) 77 | original_dist = dist.clone() # .detach() # .clone() 78 | # import pdb 79 | # pdb.set_trace() 80 | all_num = original_dist.shape[0] 81 | 82 | #original_dist = original_dist / torch.max(original_dist, dim=0).values 83 | 84 | original_dist = torch.transpose(original_dist, 0,1) #.transpose(0, 1) 85 | V = torch.zeros_like(original_dist) # .half() 86 | 87 | 88 | query_num = feat1.size(0) 89 | 90 | #with torch.no_grad(): 91 | if eval_type: 92 | # dist[:, query_num:] = dist.max()罪魁祸首 93 | max_val = dist.max() 94 | dist = torch.cat((dist[:, :query_num], max_val.expand_as(dist[:, query_num:])), dim=1) 95 | initial_rank = torch.argsort(dist, dim=1) 96 | # import pdb 97 | # pdb.set_trace() 98 | 99 | 100 | 101 | for i in range(all_num): 102 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 103 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 104 | fi = torch.where(backward_k_neigh_index == i)[0] 105 | k_reciprocal_index = forward_k_neigh_index[fi] 106 | k_reciprocal_expansion_index = k_reciprocal_index 107 | 108 | for j in k_reciprocal_index: 109 | candidate = j.item() 110 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(round(k1 / 2)) + 1] 111 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, 112 | :int(round(k1 / 2)) + 1] 113 | # import pdb 114 | # pdb.set_trace() 115 | fi_candidate = torch.where(candidate_backward_k_neigh_index == candidate)[0] 116 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 117 | 118 | if len(intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len( 119 | candidate_k_reciprocal_index): 120 | k_reciprocal_expansion_index = torch.unique( 121 | torch.cat([k_reciprocal_expansion_index, candidate_k_reciprocal_index], 0)) 122 | 123 | weight = torch.exp(-original_dist[i, k_reciprocal_expansion_index]) 124 | V[i, k_reciprocal_expansion_index] = (weight / torch.sum(weight)) # .half() 125 | 126 | 127 | original_dist = original_dist[:query_num, ] 128 | # print('before') 129 | # objgraph.show_growth(limit=3) 130 | 131 | if k2 != 1: 132 | V_qe = torch.zeros_like(V) # .half() 133 | for i in range(all_num): 134 | V_qe[i, :] = torch.mean(V[initial_rank[i, :k2], :], dim=0) 135 | V = V_qe 136 | 137 | invIndex = [] 138 | for i in range(all_num): 139 | invIndex.append(torch.where(V[:, i] != 0)[0]) 140 | 141 | jaccard_dist = torch.zeros_like(original_dist) # .half() 142 | 143 | # print('after') 144 | # objgraph.show_growth(limit=3) 145 | 146 | # with torch.no_grad(): 147 | # for i in range(query_num): 148 | # temp_min = torch.zeros([1, all_num], device="cuda") 149 | # indNonZero = torch.where(V[i, :] != 0)[0] 150 | # indImages = [invIndex[ind] for ind in indNonZero] 151 | # for j in range(len(indNonZero)): 152 | # temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + torch.min(V[i, indNonZero[j]], 153 | # V[indImages[j], indNonZero[j]]) 154 | # jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 155 | 156 | for i in range(query_num): 157 | temp_min = torch.zeros([1, all_num], device="cuda") # .half() 158 | indNonZero = torch.where(V[i, :] != 0)[0] 159 | indImages = [invIndex[ind] for ind in indNonZero] 160 | for j in range(len(indNonZero)): 161 | temp_min[0, indImages[j]] += torch.min(V[i, indNonZero[j]], V[indImages[j], indNonZero[j]]) 162 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 163 | 164 | # print('before') 165 | # objgraph.show_growth(limit=3) 166 | # print("Before:", torch.cuda.memory_allocated()) 167 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value 168 | # print("After:", torch.cuda.memory_allocated()) 169 | # import pdb 170 | # pdb.set_trace() ####jaccard_dist有细微差异和原来的rerank对比 171 | # del temp_min, jaccard_dist, V, original_dist,forward_k_neigh_index,backward_k_neigh_index,\ 172 | # k_reciprocal_expansion_index,V_qe,candidate_k_reciprocal_index,candidate_backward_k_neigh_index,\ 173 | # candidate_forward_k_neigh_index,k_reciprocal_index,fi,fi_candidate 174 | # torch.cuda.empty_cache() 175 | final_dist = final_dist[:query_num, query_num:] 176 | # import pdb 177 | # pdb.set_trace() 178 | # del original_dist, dist 179 | # torch.cuda.empty_cache() 180 | return final_dist 181 | 182 | 183 | class RerankLoss(nn.Module): 184 | def __init__(self, margin=0.03): 185 | super(RerankLoss, self).__init__() 186 | self.margin = margin 187 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 188 | 189 | def forward(self, inputs, targets): 190 | #def forward(self, inputs1, inputs2, targets): 191 | 192 | #n = inputs1.size(0) 193 | n = inputs.size(0) 194 | dist = rerank_dist(inputs, inputs) 195 | #dist = rerank_dist(inputs1, inputs2) 196 | 197 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 198 | dist_ap, dist_an = [], [] 199 | for i in range(n): 200 | dist_ap.append(dist[i][mask[i]].max()) 201 | dist_an.append(dist[i][mask[i] == 0].min()) 202 | dist_ap = torch.stack(dist_ap) 203 | dist_an = torch.stack(dist_an) 204 | 205 | # Compute ranking hinge loss 206 | # y = dist_an.data.new() 207 | # y.resize_as_(dist_an.data) 208 | # y.fill_(1) 209 | y = torch.ones_like(dist_an) 210 | loss = self.ranking_loss(dist_an, dist_ap, y) 211 | #prec = dist_an.data > dist_ap.data 212 | #length = torch.sqrt((inputs * inputs).sum(1)).mean() 213 | return loss, dist,dist_ap, dist_an -------------------------------------------------------------------------------- /IDKL/layers/loss/triplet_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class TripletLoss(nn.Module): 5 | def __init__(self, margin=0): 6 | super(TripletLoss, self).__init__() 7 | self.margin = margin 8 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 9 | 10 | def forward(self, inputs, targets): 11 | n = inputs.size(0) 12 | # Compute pairwise distance, replace by the official when merged 13 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 14 | dist = dist + dist.t() 15 | dist.addmm_(1, -2, inputs, inputs.t()) 16 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 17 | 18 | # For each anchor, find the hardest positive and negative 19 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 20 | dist_ap, dist_an = [], [] 21 | for i in range(n): 22 | dist_ap.append(dist[i][mask[i]].max()) 23 | dist_an.append(dist[i][mask[i] == 0].min()) 24 | dist_ap = torch.stack(dist_ap) 25 | dist_an = torch.stack(dist_an) 26 | 27 | # Compute ranking hinge loss 28 | # y = dist_an.data.new() 29 | # y.resize_as_(dist_an.data) 30 | # y.fill_(1) 31 | y = torch.ones_like(dist_an) 32 | loss = self.ranking_loss(dist_an, dist_ap, y) 33 | prec = dist_an.data > dist_ap.data 34 | length = torch.sqrt((inputs * inputs).sum(1)).mean() 35 | return loss, dist,dist_ap, dist_an 36 | -------------------------------------------------------------------------------- /IDKL/layers/module/CBAM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ChannelAttention(nn.Module): 6 | def __init__(self, in_planes, ratio=16): 7 | super(ChannelAttention, self).__init__() 8 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 9 | self.max_pool = nn.AdaptiveMaxPool2d(1) 10 | 11 | self.sharedMLP = nn.Sequential( 12 | nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), 13 | nn.ReLU(), 14 | nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)) 15 | self.sigmoid = nn.Sigmoid() 16 | 17 | def forward(self, x): 18 | avgout = self.sharedMLP(self.avg_pool(x)) 19 | maxout = self.sharedMLP(self.max_pool(x)) 20 | return self.sigmoid(avgout + maxout) 21 | 22 | 23 | class SpatialAttention(nn.Module): 24 | def __init__(self, kernel_size=7): 25 | super(SpatialAttention, self).__init__() 26 | assert kernel_size in (3,7), "kernel size must be 3 or 7" 27 | padding = 3 if kernel_size == 7 else 1 28 | 29 | self.conv = nn.Conv2d(2,1,kernel_size, padding=padding, bias=False) 30 | self.sigmoid = nn.Sigmoid() 31 | 32 | def forward(self, x): 33 | avgout = torch.mean(x, dim=1, keepdim=True) 34 | maxout, _ = torch.max(x, dim=1, keepdim=True) 35 | x = torch.cat([avgout, maxout], dim=1) 36 | x = self.conv(x) 37 | return self.sigmoid(x) 38 | 39 | 40 | class cbam(nn.Module): 41 | def __init__(self, planes): 42 | super(cbam, self).__init__() 43 | self.ca = ChannelAttention(planes) 44 | self.sa = SpatialAttention() 45 | 46 | def forward(self, x): 47 | x = self.ca(x) * x 48 | x = self.sa(x) * x 49 | return x -------------------------------------------------------------------------------- /IDKL/layers/module/NonLocal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class NonLocalBlockND(nn.Module): 7 | """ 8 | 调用过程 9 | NONLocalBlock2D(in_channels=32), 10 | super(NONLocalBlock2D, self).__init__(in_channels, 11 | inter_channels=inter_channels, 12 | dimension=2, sub_sample=sub_sample, 13 | bn_layer=bn_layer) 14 | """ 15 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 16 | super(NonLocalBlockND, self).__init__() 17 | 18 | assert dimension in [1, 2, 3] 19 | 20 | self.dimension = dimension 21 | self.sub_sample = sub_sample 22 | 23 | self.in_channels = in_channels 24 | self.inter_channels = inter_channels 25 | 26 | if self.inter_channels is None: 27 | self.inter_channels = in_channels // 2 28 | if self.inter_channels == 0: 29 | self.inter_channels = 1 30 | 31 | if dimension == 3: 32 | conv_nd = nn.Conv3d 33 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 34 | bn = nn.BatchNorm3d 35 | elif dimension == 2: 36 | conv_nd = nn.Conv2d 37 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 38 | bn = nn.BatchNorm2d 39 | else: 40 | conv_nd = nn.Conv1d 41 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 42 | bn = nn.BatchNorm1d 43 | 44 | self.g = conv_nd(in_channels=self.in_channels, 45 | out_channels=self.inter_channels, 46 | kernel_size=1, 47 | stride=1, 48 | padding=0) 49 | 50 | if bn_layer: 51 | self.W = nn.Sequential( 52 | conv_nd(in_channels=self.inter_channels, 53 | out_channels=self.in_channels, 54 | kernel_size=1, 55 | stride=1, 56 | padding=0), bn(self.in_channels)) 57 | nn.init.constant_(self.W[1].weight, 0) 58 | nn.init.constant_(self.W[1].bias, 0) 59 | else: 60 | self.W = conv_nd(in_channels=self.inter_channels, 61 | out_channels=self.in_channels, 62 | kernel_size=1, 63 | stride=1, 64 | padding=0) 65 | nn.init.constant_(self.W.weight, 0) 66 | nn.init.constant_(self.W.bias, 0) 67 | 68 | self.theta = conv_nd(in_channels=self.in_channels, 69 | out_channels=self.inter_channels, 70 | kernel_size=1, 71 | stride=1, 72 | padding=0) 73 | self.phi = conv_nd(in_channels=self.in_channels, 74 | out_channels=self.inter_channels, 75 | kernel_size=1, 76 | stride=1, 77 | padding=0) 78 | 79 | if sub_sample: 80 | self.g = nn.Sequential(self.g, max_pool_layer) 81 | self.phi = nn.Sequential(self.phi, max_pool_layer) 82 | 83 | def forward(self, x): 84 | ''' 85 | :param x: (b, c, h, w) 86 | :return: 87 | ''' 88 | 89 | batch_size = x.size(0) 90 | 91 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)#[bs, c, w*h] 92 | g_x = g_x.permute(0, 2, 1) 93 | 94 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 95 | theta_x = theta_x.permute(0, 2, 1) 96 | 97 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 98 | 99 | f = torch.matmul(theta_x, phi_x) 100 | 101 | # print(f.shape) 102 | 103 | f_div_C = F.softmax(f, dim=-1) 104 | 105 | y = torch.matmul(f_div_C, g_x) 106 | y = y.permute(0, 2, 1).contiguous() 107 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 108 | W_y = self.W(y) 109 | z = W_y + x 110 | return z -------------------------------------------------------------------------------- /IDKL/layers/module/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/module/__init__.py -------------------------------------------------------------------------------- /IDKL/layers/module/__pycache__/CBAM.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/module/__pycache__/CBAM.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/layers/module/__pycache__/NonLocal.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/module/__pycache__/NonLocal.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/layers/module/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/module/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/layers/module/__pycache__/norm_linear.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/module/__pycache__/norm_linear.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/layers/module/__pycache__/reverse_grad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/module/__pycache__/reverse_grad.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/layers/module/norm_linear.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | import torch.nn.functional as F 6 | 7 | 8 | class NormalizeLinear(nn.Module): 9 | def __init__(self, in_features, num_class): 10 | super(NormalizeLinear, self).__init__() 11 | self.weight = nn.Parameter(torch.Tensor(num_class, in_features)) 12 | self.reset_parameters() 13 | 14 | def reset_parameters(self): 15 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 16 | 17 | def forward(self, x): 18 | w = F.normalize(self.weight.float(), p=2, dim=1) 19 | return F.linear(x.float(), w) 20 | -------------------------------------------------------------------------------- /IDKL/layers/module/reverse_grad.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.autograd import Function 3 | 4 | 5 | class ReverseGradFunction(Function): 6 | 7 | @staticmethod 8 | def forward(ctx, data, alpha=1.0): 9 | ctx.alpha = alpha 10 | return data 11 | 12 | @staticmethod 13 | def backward(ctx, grad_outputs): 14 | grad = None 15 | 16 | if ctx.needs_input_grad[0]: 17 | grad = -ctx.alpha * grad_outputs 18 | 19 | return grad, None 20 | 21 | 22 | class ReverseGrad(nn.Module): 23 | def __init__(self): 24 | super(ReverseGrad, self).__init__() 25 | 26 | def forward(self, x, alpha=1.0): 27 | return ReverseGradFunction.apply(x, alpha) 28 | -------------------------------------------------------------------------------- /IDKL/models/__pycache__/baseline.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/models/__pycache__/baseline.cpython-36.pyc -------------------------------------------------------------------------------- /IDKL/models/__pycache__/baseline.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/models/__pycache__/baseline.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/models/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/models/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /IDKL/models/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/models/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/models/baseline.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import init 5 | from torch.nn import functional as F 6 | from torch.nn import Parameter 7 | import numpy as np 8 | 9 | import cv2 10 | from layers.module.reverse_grad import ReverseGrad 11 | from models.resnet import resnet50, embed_net, convDiscrimination, Discrimination 12 | from utils.calc_acc import calc_acc 13 | 14 | from layers import TripletLoss, RerankLoss 15 | from layers import CenterTripletLoss 16 | from layers import CenterLoss 17 | from layers import cbam 18 | from layers import NonLocalBlockND 19 | from utils.rerank import re_ranking, pairwise_distance 20 | 21 | def intersect1d(tensor1, tensor2): 22 | return torch.unique(torch.cat([tensor1[tensor1 == val] for val in tensor2])) 23 | 24 | def spearman_loss(dist_matrix, rerank_matrix): 25 | 26 | sorted_idx_dist = torch.argsort(dist_matrix, dim=1) 27 | sorted_idx_rerank = torch.argsort(rerank_matrix, dim=1) 28 | 29 | rank_corr = 0 30 | n = dist_matrix.size(1) 31 | for i in range(dist_matrix.size(0)): 32 | diff = sorted_idx_dist[i] - sorted_idx_rerank[i] 33 | rank_corr += 1 - (6 * torch.sum(diff * diff) / (n * (n**2 - 1))) 34 | 35 | rank_corr /= dist_matrix.size(0) 36 | 37 | return 1 - rank_corr 38 | 39 | 40 | def Fb_dt(feat, labels): 41 | feat_dt = feat 42 | n_ft = feat_dt.size(0) 43 | dist_f = torch.pow(feat_dt, 2).sum(dim=1, keepdim=True).expand(n_ft, n_ft) 44 | dist_f = dist_f + dist_f.t() 45 | dist_f.addmm_(1, -2, feat_dt, feat_dt.t()) 46 | dist_f = dist_f.clamp(min=1e-12).sqrt() 47 | mask_ft = labels.expand(n_ft, n_ft).eq(labels.expand(n_ft, n_ft).t()) 48 | mask_ft_1 = torch.ones(n_ft, n_ft, dtype=bool) 49 | for i in range(n_ft): 50 | mask_ft_1[i, i] = 0 51 | mask_ft_2 = [] 52 | for i in range(n_ft): 53 | 54 | mask_ft_2.append(mask_ft[i][mask_ft_1[i]]) 55 | mask_ft_2 = torch.stack(mask_ft_2) 56 | dist_f_2 = [] 57 | for i in range(n_ft): 58 | 59 | dist_f_2.append(dist_f[i][mask_ft_1[i]]) 60 | dist_f_2 = torch.stack(dist_f_2) 61 | dist_f_2 = F.softmax(-(dist_f_2 - 1), 1) 62 | cN_ft = (mask_ft_2[0] == True).sum() 63 | f_d_ap = [] 64 | for i in range(n_ft): 65 | 66 | f_d_ap.append(dist_f_2[i][mask_ft_2[i]]) 67 | f_d_ap = torch.stack(f_d_ap).flatten() 68 | loss_f_d_ap = [] 69 | xs_ft = 1 70 | m_ft = f_d_ap.size(0) 71 | for i in range(m_ft): 72 | loss_f_d_ap.append( 73 | -xs_ft * (1 / cN_ft) * torch.log(xs_ft * cN_ft * f_d_ap[i])) 74 | loss_f_d_ap = torch.stack(loss_f_d_ap).clamp(max=1e+3).sum() / n_ft 75 | return loss_f_d_ap 76 | 77 | 78 | def gem(x, p=3, eps=1e-6): 79 | return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1. / p) 80 | 81 | def gem_p(x): 82 | ss = gem(x).squeeze() # Gem池化 83 | ss= ss.view(ss.size(0), -1) # Gem池化 84 | return ss 85 | def pairwise_dist(x, y): 86 | # Compute pairwise distance of vectors 87 | xx = (x**2).sum(dim=1, keepdim=True) 88 | yy = (y**2).sum(dim=1, keepdim=True).t() 89 | dist = xx + yy - 2.0 * torch.mm(x, y.t()) 90 | dist = dist.clamp(min=1e-6).sqrt() # for numerical stability 91 | return dist 92 | 93 | def kl_soft_dist(feat1,feat2): 94 | n_st = feat1.size(0) 95 | dist_st = pairwise_dist(feat1, feat2) 96 | mask_st_1 = torch.ones(n_st, n_st, dtype=bool) 97 | for i in range(n_st): # 将同一类样本中自己与自己的距离舍弃 98 | mask_st_1[i, i] = 0 99 | dist_st_2 = [] 100 | for i in range(n_st): 101 | dist_st_2.append(dist_st[i][mask_st_1[i]]) 102 | dist_st_2 = torch.stack(dist_st_2) 103 | return dist_st_2 104 | 105 | 106 | def Bg_kl(logits1, logits2):####输入:(60,206),(60,206) 107 | KL = nn.KLDivLoss(reduction='batchmean') 108 | kl_loss_12 = KL(F.log_softmax(logits1, 1), F.softmax(logits2, 1)) 109 | kl_loss_21 = KL(F.log_softmax(logits2, 1), F.softmax(logits1, 1)) 110 | bg_loss_kl = kl_loss_12 + kl_loss_21 111 | return kl_loss_12, bg_loss_kl 112 | def Sm_kl(logits1, logits2, labels): 113 | KL = nn.KLDivLoss(reduction='batchmean') 114 | m_kl = torch.div((labels == labels[0]).sum(), 2, rounding_mode='floor') 115 | v_logits_s = logits1.split(m_kl, 0) 116 | i_logits_s = logits2.split(m_kl, 0) 117 | sm_v_logits = torch.cat(v_logits_s, 1) # .t() # 5,206*12->206*12,5 118 | sm_i_logits = torch.cat(i_logits_s, 1) # .t() 119 | sm_kl_loss_vi = KL(F.log_softmax(sm_v_logits, 1), F.softmax(sm_i_logits, 1)) 120 | sm_kl_loss_iv = KL(F.log_softmax(sm_i_logits, 1), F.softmax(sm_v_logits, 1)) 121 | sm_kl_loss = sm_kl_loss_vi + sm_kl_loss_iv 122 | return sm_kl_loss_vi, sm_kl_loss 123 | 124 | 125 | def samplewise_entropy(logits): 126 | probabilities = F.softmax(logits, dim=1) 127 | log_probabilities = F.log_softmax(logits, dim=1) 128 | entropies = -torch.sum(probabilities * log_probabilities, dim=1) 129 | return entropies 130 | 131 | 132 | def entropy_margin_loss(logits1, logits2, margin): 133 | entropy1 = samplewise_entropy(logits1) 134 | entropy2 = samplewise_entropy(logits2) 135 | losses = torch.exp(F.relu(entropy2 - entropy1 + margin)) - 1 136 | return losses.mean() 137 | 138 | 139 | def compute_centroid_distance(features, labels, modalities): 140 | """ 141 | 计算每个类别不同模态的中心特征的距离。 142 | 143 | 参数: 144 | features -- 特征矩阵,形状为(B, C)。 145 | labels -- 类别标签,形状为(B,)。 146 | modalities -- 模态标签,形状为(B,)。 147 | 148 | 返回: 149 | distances -- 每个类别模态中心距离的列表。 150 | """ 151 | unique_labels = torch.unique(labels) 152 | distances = [] 153 | for label in unique_labels: 154 | # 分别获取当前类别下的两种模态的特征 155 | features_modality_0 = features[(labels == label) & (modalities == 0)] 156 | features_modality_1 = features[(labels == label) & (modalities == 1)] 157 | 158 | # 计算中心特征 159 | centroid_modality_0 = features_modality_0.mean(dim=0) 160 | centroid_modality_1 = features_modality_1.mean(dim=0) 161 | 162 | # 计算两个中心特征之间的距离,这里使用欧氏距离 163 | distance = F.pairwise_distance(centroid_modality_0.unsqueeze(0), centroid_modality_1.unsqueeze(0)) 164 | distances.append(distance) 165 | 166 | 167 | return torch.stack(distances) 168 | 169 | 170 | def modal_centroid_loss(F1, F2, labels, modalities, margin): 171 | """ 172 | 计算损失函数,要求F2中每个类别不同模态的中心距离比F1更小,并施加一个margin。 173 | 174 | 参数: 175 | F1 -- 第一组特征,形状为(B, C)。 176 | F2 -- 第二组特征,经过网络结构优化,形状为(B, C)。 177 | labels -- 类别标签,形状为(B,)。 178 | modalities -- 模态标签,形状为(B,)。 179 | margin -- 施加的margin值。 180 | 181 | 返回: 182 | loss -- 计算的损失值。 183 | """ 184 | # 计算F1和F2的中心距离 185 | distances_F1 = compute_centroid_distance(F1, labels, modalities) 186 | distances_F2 = compute_centroid_distance(F2, labels, modalities) 187 | 188 | # 计算带margin的损失 189 | losses = F.relu(distances_F2 - distances_F1 + margin) 190 | 191 | # 返回损失的平均值 192 | return losses.mean() 193 | class Baseline(nn.Module): 194 | def __init__(self, num_classes=None, drop_last_stride=False, decompose=False, **kwargs): 195 | super(Baseline, self).__init__() 196 | 197 | self.drop_last_stride = drop_last_stride 198 | self.decompose = decompose 199 | self.backbone = embed_net(drop_last_stride=drop_last_stride, decompose=decompose) 200 | 201 | self.base_dim = 2048 202 | self.dim = 0 203 | self.part_num = kwargs.get('num_parts', 0) 204 | 205 | 206 | print("output feat length:{}".format(self.base_dim + self.dim * self.part_num)) 207 | self.bn_neck = nn.BatchNorm1d(self.base_dim + self.dim * self.part_num) 208 | nn.init.constant_(self.bn_neck.bias, 0) 209 | self.bn_neck.bias.requires_grad_(False) 210 | self.bn_neck_sp = nn.BatchNorm1d(self.base_dim + self.dim * self.part_num) 211 | nn.init.constant_(self.bn_neck_sp.bias, 0) 212 | self.bn_neck_sp.bias.requires_grad_(False) 213 | 214 | if kwargs.get('eval', False): 215 | return 216 | 217 | self.classification = kwargs.get('classification', False) 218 | self.triplet = kwargs.get('triplet', False) 219 | self.center_cluster = kwargs.get('center_cluster', False) 220 | self.center_loss = kwargs.get('center', False) 221 | self.margin = kwargs.get('margin', 0.3) 222 | self.CSA1 = kwargs.get('bg_kl', False) 223 | self.CSA2 = kwargs.get('sm_kl', False) 224 | self.TGSA = kwargs.get('distalign', False) 225 | self.IP = kwargs.get('IP', False) 226 | self.fb_dt = kwargs.get('fb_dt', False) 227 | 228 | if self.decompose: 229 | self.classifier = nn.Linear(self.base_dim + self.dim * self.part_num, num_classes, bias=False) 230 | self.classifier_sp = nn.Linear(self.base_dim, num_classes, bias=False) 231 | self.D_special = Discrimination() 232 | self.C_sp_f = nn.Linear(self.base_dim, num_classes, bias=False) 233 | 234 | self.D_shared_pseu = Discrimination(2048) 235 | self.grl = ReverseGrad() 236 | 237 | else: 238 | self.classifier = nn.Linear(self.base_dim + self.dim * self.part_num, num_classes, bias=False) 239 | if self.classification: 240 | self.id_loss = nn.CrossEntropyLoss(ignore_index=-1) 241 | if self.triplet: 242 | self.triplet_loss = TripletLoss(margin=self.margin) 243 | self.rerank_loss = RerankLoss(margin=0.7) 244 | if self.center_cluster: 245 | k_size = kwargs.get('k_size', 8) 246 | self.center_cluster_loss = CenterTripletLoss(k_size=k_size, margin=self.margin) 247 | if self.center_loss: 248 | self.center_loss = CenterLoss(num_classes, self.base_dim + self.dim * self.part_num) 249 | 250 | def forward(self, inputs, labels=None, **kwargs): 251 | 252 | cam_ids = kwargs.get('cam_ids') 253 | sub = (cam_ids == 3) + (cam_ids == 6) 254 | #epoch = kwargs.get('epoch') 255 | # CNN 256 | sh_feat, sh_pl, sp_pl, sp_IN,sp_IN_p,x_sp_f,x_sp_f_p = self.backbone(inputs) 257 | 258 | 259 | feats = sh_pl 260 | 261 | if not self.training: 262 | if feats.size(0) == 2048: 263 | feats = self.bn_neck(feats.permute(1, 0)) 264 | logits = self.classifier(feats) 265 | return logits # feats # 266 | 267 | 268 | else: 269 | feats = self.bn_neck( 270 | feats) 271 | return feats 272 | 273 | else: 274 | return self.train_forward(feats, sp_pl, labels, 275 | sub, sp_IN,sp_IN_p,x_sp_f,x_sp_f_p, **kwargs) 276 | 277 | 278 | 279 | def train_forward(self, feat, sp_pl, labels, 280 | sub, sp_IN,sp_IN_p,x_sp_f,x_sp_f_p, **kwargs): 281 | epoch = kwargs.get('epoch') 282 | metric = {} 283 | loss = 0 284 | 285 | if self.triplet: 286 | 287 | triplet_loss, dist, sh_ap, sh_an = self.triplet_loss(feat.float(), labels) 288 | triplet_loss_im, _, sp_ap, sp_an = self.triplet_loss(sp_pl.float(), labels) 289 | trip_loss = triplet_loss + triplet_loss_im 290 | loss += trip_loss 291 | metric.update({'tri': trip_loss.data}) 292 | 293 | 294 | bb = 120 #90 295 | if self.TGSA: 296 | 297 | sf_sp_dist_v = kl_soft_dist(sp_pl[sub == 0], sp_pl[sub == 0]) 298 | sf_sp_dist_i = kl_soft_dist(sp_pl[sub == 1], sp_pl[sub == 1]) 299 | sf_sh_dist_v = kl_soft_dist(feat[sub == 0], feat[sub == 0]) 300 | sf_sh_dist_i = kl_soft_dist(feat[sub == 1], feat[sub == 1]) 301 | half_B0 = feat[sub == 0].shape[0] // 2 302 | feat_half0 = feat[sub == 0][:half_B0] 303 | half_B1 = feat[sub == 1].shape[0] // 2 304 | feat_half1 = feat[sub == 1][:half_B1] 305 | feat_cross = torch.cat((feat_half0, feat_half1), dim=0) 306 | sf_sh_dist_vi = kl_soft_dist(feat_cross, feat_cross) 307 | 308 | 309 | 310 | _, kl_inter_v = Bg_kl(sf_sh_dist_v, sf_sp_dist_v) 311 | _, kl_inter_i = Bg_kl(sf_sh_dist_i, sf_sp_dist_i) 312 | 313 | 314 | _, kl_intra1 = Bg_kl(sf_sh_dist_v, sf_sh_dist_i) 315 | _, kl_intra2 = Bg_kl(sf_sh_dist_v, sf_sh_dist_vi) 316 | _, kl_intra3 = Bg_kl(sf_sh_dist_vi, sf_sh_dist_i) 317 | 318 | kl_intra = kl_intra1 + kl_intra2 + kl_intra3 319 | 320 | 321 | 322 | if feat.size(0) == bb: 323 | soft_dt = kl_intra + (kl_inter_v + kl_inter_i) * 0.6 324 | 325 | 326 | else: 327 | soft_dt = (kl_intra1 + kl_inter_v + kl_inter_i) * 0.1 328 | 329 | loss += soft_dt 330 | metric.update({'soft_dt': soft_dt.data}) 331 | 332 | if self.center_loss: 333 | center_loss = self.center_loss(feat.float(), labels) 334 | loss += center_loss 335 | metric.update({'cen': center_loss.data}) 336 | 337 | if self.center_cluster: 338 | center_cluster_loss, _, _ = self.center_cluster_loss(feat.float(), labels) 339 | loss += center_cluster_loss 340 | metric.update({'cc': center_cluster_loss.data}) 341 | 342 | 343 | if self.fb_dt: 344 | loss_f_d_ap = Fb_dt(feat, labels) 345 | loss_Fb_im = Fb_dt(sp_pl, labels) 346 | fb_loss = loss_f_d_ap + loss_Fb_im 347 | loss += fb_loss 348 | 349 | metric.update({'f_dt': fb_loss.data}) 350 | 351 | feat = self.bn_neck(feat) 352 | sp_pl = self.bn_neck_sp(sp_pl) 353 | sub_nb = sub + 0 ##模态标签 354 | 355 | if self.IP: 356 | ################ 357 | ################ 358 | l_F = self.C_sp_f(gem_p(x_sp_f)) 359 | l_F_p = self.C_sp_f(gem_p(x_sp_f_p)) 360 | loss_F = entropy_margin_loss(l_F, l_F_p, 0) 361 | loss_m_IN = modal_centroid_loss(gem_p(sp_IN), gem_p(sp_IN_p), labels, sub, 0) 362 | 363 | loss += 0.1 * (loss_F + loss_m_IN) 364 | metric.update({'IN_p': loss_m_IN.data}) 365 | metric.update({'F_p': loss_F.data}) 366 | 367 | ################ 368 | ################ 369 | 370 | if self.decompose: 371 | logits_sp = self.classifier_sp(sp_pl) # self.bn_neck_un(sp_pl) 372 | loss_id_sp = self.id_loss(logits_sp.float(), labels) 373 | 374 | 375 | sp_logits = self.D_special(sp_pl) 376 | unad_loss_b = self.id_loss(sp_logits.float(), sub_nb) 377 | unad_loss = unad_loss_b 378 | 379 | 380 | pseu_sh_logits = self.D_shared_pseu(feat) 381 | p_sub = sub_nb.chunk(2)[0].repeat_interleave(2) 382 | pp_sub = torch.roll(p_sub, -1) 383 | pseu_loss = self.id_loss(pseu_sh_logits.float(), pp_sub) 384 | 385 | loss += loss_id_sp + unad_loss + pseu_loss 386 | 387 | metric.update({'unad': unad_loss.data}) 388 | metric.update({'id_pl': loss_id_sp.data}) 389 | 390 | metric.update({'pse': pseu_loss.data}) 391 | 392 | 393 | 394 | 395 | if self.classification: 396 | logits = self.classifier(feat) 397 | if self.CSA1: 398 | 399 | _, inter_bg_v = Bg_kl(logits[sub == 0], logits_sp[sub == 0]) 400 | _, inter_bg_i = Bg_kl(logits[sub == 1], logits_sp[sub == 1]) 401 | 402 | _, intra_bg = Bg_kl(logits[sub == 0], logits[sub == 1]) 403 | 404 | 405 | if feat.size(0) == bb: 406 | bg_loss = intra_bg + (inter_bg_v + inter_bg_i) * 0.8 # intra_bg + (inter_bg_v + inter_bg_i) * 0.7 407 | 408 | else: 409 | bg_loss = intra_bg + (inter_bg_v + inter_bg_i) * 0.3 410 | loss += bg_loss 411 | metric.update({'bg_kl': bg_loss.data}) 412 | 413 | if self.CSA2: 414 | _, inter_Sm_v = Sm_kl(logits[sub == 0], logits_sp[sub == 0], labels) 415 | _, inter_Sm_i = Sm_kl(logits[sub == 1], logits_sp[sub == 1], labels) 416 | inter_Sm = inter_Sm_v + inter_Sm_i 417 | _, intra_Sm = Sm_kl(logits[sub == 0], logits[sub == 1], labels) 418 | 419 | if feat.size(0) == bb: 420 | sm_kl_loss = intra_Sm + inter_Sm * 0.8 421 | 422 | else: 423 | sm_kl_loss = intra_Sm + inter_Sm * 0.3 424 | loss += sm_kl_loss 425 | metric.update({'sm_kl': sm_kl_loss.data}) 426 | cls_loss = self.id_loss(logits.float(), labels) 427 | loss += cls_loss 428 | metric.update({'acc': calc_acc(logits.data, labels), 'ce': cls_loss.data}) 429 | 430 | return loss, metric 431 | -------------------------------------------------------------------------------- /IDKL/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import functional as F 3 | # from torchvision.models.utils import load_state_dict_from_url #原来 4 | from torch.hub import load_state_dict_from_url 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 16 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 17 | } 18 | 19 | 20 | class convDiscrimination(nn.Module): 21 | def __init__(self, dim=512): 22 | super(convDiscrimination, self).__init__() 23 | self.conv1 = conv3x3(dim, 512, stride=2) 24 | self.bn1 = nn.BatchNorm2d(512) 25 | self.conv2 = conv3x3(512, 128, stride=2) 26 | self.bn2 = nn.BatchNorm2d(128) 27 | self.conv3 = conv3x3(128, 128, stride=2) 28 | self.bn3 = nn.BatchNorm2d(128) 29 | self.fc = nn.Linear(128, 2) 30 | 31 | def forward(self, x): 32 | x = F.dropout(F.relu(self.bn1(self.conv1(x))), training=self.training) 33 | x = F.dropout(F.relu(self.bn2(self.conv2(x))), training=self.training) 34 | x = F.dropout(F.relu(self.bn3(self.conv3(x))), training=self.training) 35 | x = F.avg_pool2d(x, (x.size(2), x.size(3))) 36 | x = x.view(-1, 128) 37 | x = self.fc(x) 38 | return x 39 | 40 | 41 | class Discrimination(nn.Module): 42 | def __init__(self, dim=2048): 43 | super(Discrimination, self).__init__() 44 | self.fc1 = nn.Linear(dim, 100) 45 | self.bn1 = nn.BatchNorm1d(100) 46 | self.fc2 = nn.Linear(100, 100) 47 | self.bn2 = nn.BatchNorm1d(100) 48 | self.fc3 = nn.Linear(100, 2) 49 | 50 | def forward(self, x): 51 | x = F.dropout(F.relu(self.bn1(self.fc1(x))), training=self.training) 52 | x = F.dropout(F.relu(self.bn2(self.fc2(x))), training=self.training) 53 | x = self.fc3(x) 54 | return x 55 | 56 | 57 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 58 | """3x3 convolution with padding""" 59 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 60 | padding=dilation, groups=groups, bias=False, dilation=dilation) 61 | 62 | 63 | def conv1x1(in_planes, out_planes, stride=1): 64 | """1x1 convolution""" 65 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 66 | 67 | 68 | class MAM(nn.Module): 69 | def __init__(self, dim, r=16): 70 | super(MAM, self).__init__() 71 | 72 | self.channel_attention = nn.Sequential( 73 | nn.Conv2d(dim, dim // r, kernel_size=1, bias=False), 74 | nn.ReLU(inplace=True), 75 | nn.Conv2d(dim // r, dim, kernel_size=1, bias=False), 76 | nn.Sigmoid() 77 | ) 78 | self.IN = nn.InstanceNorm2d(dim, track_running_stats=False) 79 | 80 | def forward(self, x): 81 | pooled = F.avg_pool2d(x, x.size()[2:]) 82 | mask = self.channel_attention(pooled) 83 | x = x * mask + self.IN(x) * (1 - mask) 84 | 85 | return x 86 | 87 | 88 | class BasicBlock(nn.Module): 89 | expansion = 1 90 | 91 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 92 | base_width=64, dilation=1, norm_layer=None): 93 | super(BasicBlock, self).__init__() 94 | if norm_layer is None: 95 | norm_layer = nn.BatchNorm2d 96 | if groups != 1 or base_width != 64: 97 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 98 | if dilation > 1: 99 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 100 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 101 | self.conv1 = conv3x3(inplanes, planes, stride) 102 | self.bn1 = norm_layer(planes) 103 | self.relu = nn.ReLU(inplace=True) 104 | self.conv2 = conv3x3(planes, planes) 105 | self.bn2 = norm_layer(planes) 106 | self.downsample = downsample 107 | self.stride = stride 108 | 109 | def forward(self, x): 110 | identity = x 111 | 112 | out = self.conv1(x) 113 | out = self.bn1(out) 114 | out = self.relu(out) 115 | 116 | out = self.conv2(out) 117 | out = self.bn2(out) 118 | 119 | if self.downsample is not None: 120 | identity = self.downsample(x) 121 | 122 | out += identity 123 | out = self.relu(out) 124 | 125 | return out 126 | 127 | 128 | class Bottleneck(nn.Module): 129 | expansion = 4 130 | 131 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 132 | base_width=64, dilation=1, norm_layer=None): 133 | super(Bottleneck, self).__init__() 134 | if norm_layer is None: 135 | norm_layer = nn.BatchNorm2d 136 | width = int(planes * (base_width / 64.)) * groups 137 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 138 | self.conv1 = conv1x1(inplanes, width) 139 | self.bn1 = norm_layer(width) 140 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 141 | self.bn2 = norm_layer(width) 142 | self.conv3 = conv1x1(width, planes * self.expansion) 143 | self.bn3 = norm_layer(planes * self.expansion) 144 | self.relu = nn.ReLU(inplace=True) 145 | self.downsample = downsample 146 | self.stride = stride 147 | 148 | def forward(self, x): 149 | identity = x 150 | 151 | out = self.conv1(x) 152 | out = self.bn1(out) 153 | out = self.relu(out) 154 | 155 | out = self.conv2(out) 156 | out = self.bn2(out) 157 | out = self.relu(out) 158 | 159 | out = self.conv3(out) 160 | out = self.bn3(out) 161 | 162 | if self.downsample is not None: 163 | identity = self.downsample(x) 164 | 165 | out += identity 166 | out = self.relu(out) 167 | 168 | return out 169 | 170 | 171 | class ResNet(nn.Module): 172 | 173 | def __init__(self, block, layers, zero_init_residual=False, modality_attention=0, 174 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 175 | norm_layer=None, drop_last_stride=False): 176 | super(ResNet, self).__init__() 177 | if norm_layer is None: 178 | norm_layer = nn.BatchNorm2d 179 | self._norm_layer = norm_layer 180 | 181 | self.inplanes = 64 182 | self.dilation = 1 183 | if replace_stride_with_dilation is None: 184 | # each element in the tuple indicates if we should replace 185 | # the 2x2 stride with a dilated convolution instead 186 | replace_stride_with_dilation = [False, False, False] 187 | if len(replace_stride_with_dilation) != 3: 188 | raise ValueError("replace_stride_with_dilation should be None " 189 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 190 | self.groups = groups 191 | self.base_width = width_per_group 192 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 193 | bias=False) 194 | self.bn1 = norm_layer(self.inplanes) 195 | self.relu = nn.ReLU(inplace=True) 196 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 197 | self.layer1 = self._make_layer(block, 64, layers[0]) 198 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 199 | dilate=replace_stride_with_dilation[0]) 200 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 201 | dilate=replace_stride_with_dilation[1]) 202 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1 if drop_last_stride else 2, 203 | dilate=replace_stride_with_dilation[2]) 204 | 205 | 206 | for m in self.modules(): 207 | if isinstance(m, nn.Conv2d): 208 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 209 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 210 | nn.init.constant_(m.weight, 1) 211 | nn.init.constant_(m.bias, 0) 212 | 213 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 214 | norm_layer = self._norm_layer 215 | downsample = None 216 | previous_dilation = self.dilation 217 | if dilate: 218 | self.dilation *= stride 219 | stride = 1 220 | if stride != 1 or self.inplanes != planes * block.expansion: 221 | downsample = nn.Sequential( 222 | conv1x1(self.inplanes, planes * block.expansion, stride), 223 | norm_layer(planes * block.expansion), 224 | ) 225 | 226 | layers = [] 227 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 228 | self.base_width, previous_dilation, norm_layer)) 229 | self.inplanes = planes * block.expansion 230 | for _ in range(1, blocks): 231 | layers.append(block(self.inplanes, planes, groups=self.groups, 232 | base_width=self.base_width, dilation=self.dilation, 233 | norm_layer=norm_layer)) 234 | 235 | return nn.Sequential(*layers) 236 | 237 | def forward(self, x): 238 | x = self.conv1(x) 239 | x = self.bn1(x) 240 | x = self.relu(x) 241 | x = self.maxpool(x) 242 | 243 | x = self.layer1(x) 244 | x = self.layer2(x) 245 | x = self.layer3(x) 246 | 247 | x = self.layer4(x) 248 | 249 | return x 250 | 251 | 252 | class Shared_module_fr(nn.Module): 253 | def __init__(self, drop_last_stride, modality_attention): 254 | super(Shared_module_fr, self).__init__() 255 | 256 | model_sh_fr = resnet50(pretrained=True, drop_last_stride=drop_last_stride, 257 | modality_attention=modality_attention) 258 | # avg pooling to global pooling 259 | self.model_sh_fr = model_sh_fr 260 | 261 | def forward(self, x): 262 | x = self.model_sh_fr.conv1(x) 263 | x = self.model_sh_fr.bn1(x) 264 | x = self.model_sh_fr.relu(x) 265 | x = self.model_sh_fr.maxpool(x) 266 | x = self.model_sh_fr.layer1(x) 267 | x = self.model_sh_fr.layer2(x) 268 | # x = self.model_sh_fr.layer3(x) 269 | return x 270 | 271 | 272 | class Special_module(nn.Module): 273 | def __init__(self, drop_last_stride, modality_attention): 274 | super(Special_module, self).__init__() 275 | 276 | special_module = resnet50(pretrained=True, drop_last_stride=drop_last_stride, 277 | ) 278 | 279 | self.special_module = special_module 280 | 281 | def forward(self, x): 282 | # x = self.special_module.layer2(x) 283 | x = self.special_module.layer3(x) 284 | x = self.special_module.layer4(x) 285 | return x 286 | 287 | 288 | class Shared_module_bh(nn.Module): 289 | def __init__(self, drop_last_stride, modality_attention): 290 | super(Shared_module_bh, self).__init__() 291 | 292 | model_sh_bh = resnet50(pretrained=True, drop_last_stride=drop_last_stride,) # model_sh_fr model_sh_bh 293 | 294 | self.model_sh_bh = model_sh_bh # self.model_sh_bh = model_sh_bh #self.model_sh_fr = model_sh_fr 295 | 296 | def forward(self, x): 297 | # x = self.model_sh_bh.layer2(x) 298 | x_sh3 = self.model_sh_bh.layer3(x) # self.model_sh_fr self.model_sh_bh 299 | x_sh4 = self.model_sh_bh.layer4(x_sh3) # self.model_sh_fr self.model_sh_bh 300 | return x_sh3, x_sh4 301 | 302 | 303 | class Mask(nn.Module): 304 | def __init__(self, dim, r=16): 305 | super(Mask, self).__init__() 306 | 307 | self.channel_attention = nn.Sequential( 308 | nn.Conv2d(dim, dim // r, kernel_size=1, bias=False), 309 | nn.ReLU(inplace=True), 310 | nn.Conv2d(dim // r, dim, kernel_size=1, bias=False), 311 | nn.Sigmoid() 312 | ) 313 | 314 | def forward(self, x): 315 | mask = self.channel_attention(x) 316 | return mask 317 | 318 | 319 | class special_att(nn.Module): 320 | def __init__(self, dim, r=16): 321 | super(special_att, self).__init__() 322 | 323 | self.channel_attention = nn.Sequential( 324 | nn.Conv2d(dim, dim // r, kernel_size=1, bias=False), 325 | nn.ReLU(inplace=True), 326 | nn.Conv2d(dim // r, dim, kernel_size=1, bias=False), 327 | nn.Sigmoid() 328 | ) 329 | self.IN = nn.InstanceNorm2d(dim, track_running_stats=False) #self.IN = nn.InstanceNorm2d(dim, track_running_stats=True, affine=True) 330 | 331 | def forward(self, x): 332 | x_IN = self.IN(x) 333 | x_R = x - x_IN 334 | pooled = gem(x_R) 335 | mask = self.channel_attention(pooled) 336 | x_sp = x_R * mask + x_IN # x 337 | 338 | return x_sp, x_IN 339 | 340 | 341 | def gem(x, p=3, eps=1e-6): 342 | return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1. / p) 343 | 344 | 345 | class embed_net(nn.Module): 346 | def __init__(self, drop_last_stride, decompose=False): 347 | super(embed_net, self).__init__() 348 | 349 | self.shared_module_fr = Shared_module_fr(drop_last_stride=drop_last_stride, 350 | ) 351 | self.shared_module_bh = Shared_module_bh(drop_last_stride=drop_last_stride, 352 | ) 353 | 354 | self.special = Special_module(drop_last_stride=drop_last_stride) 355 | 356 | self.decompose = decompose 357 | self.IN = nn.InstanceNorm2d(2048, track_running_stats=True, affine=True) 358 | if decompose: 359 | self.special_att = special_att(2048) 360 | self.mask1 = Mask(2048) 361 | self.mask2 = Mask(2048) 362 | 363 | def forward(self, x): 364 | x2 = self.shared_module_fr(x) 365 | x3, x_sh = self.shared_module_bh(x2) # bchw 366 | 367 | sh_pl = gem(x_sh).squeeze() # Gem池化 368 | sh_pl = sh_pl.view(sh_pl.size(0), -1) # Gem池化 369 | 370 | if self.decompose: 371 | ######special structure 372 | 373 | x_sp_f = self.special(x2) 374 | sp_IN = self.IN(x_sp_f) 375 | m_IN = self.mask1(sp_IN) 376 | m_F = self.mask2(x_sp_f) 377 | sp_IN_p = m_IN * sp_IN 378 | x_sp_f_p = m_F * x_sp_f 379 | x_sp = m_IN * x_sp_f_p + m_F * sp_IN_p 380 | 381 | sp_pl = gem(x_sp).squeeze() # Gem池化 382 | sp_pl = sp_pl.view(sp_pl.size(0), -1) # Gem池化 383 | 384 | 385 | return x_sh, sh_pl, sp_pl,sp_IN,sp_IN_p,x_sp_f,x_sp_f_p 386 | 387 | 388 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 389 | model = ResNet(block, layers, **kwargs) 390 | if pretrained: 391 | state_dict = load_state_dict_from_url(model_urls[arch], 392 | progress=progress) 393 | model.load_state_dict(state_dict, strict=False) 394 | return model 395 | 396 | 397 | def resnet18(pretrained=False, progress=True, **kwargs): 398 | """Constructs a ResNet-18 model. 399 | 400 | Args: 401 | pretrained (bool): If True, returns a model pre-trained on ImageNet 402 | progress (bool): If True, displays a progress bar of the download to stderr 403 | """ 404 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 405 | **kwargs) 406 | 407 | 408 | def resnet34(pretrained=False, progress=True, **kwargs): 409 | """Constructs a ResNet-34 model. 410 | 411 | Args: 412 | pretrained (bool): If True, returns a model pre-trained on ImageNet 413 | progress (bool): If True, displays a progress bar of the download to stderr 414 | """ 415 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 416 | **kwargs) 417 | 418 | 419 | def resnet50(pretrained=False, progress=True, **kwargs): 420 | """Constructs a ResNet-50 model. 421 | 422 | Args: 423 | pretrained (bool): If True, returns a model pre-trained on ImageNet 424 | progress (bool): If True, displays a progress bar of the download to stderr 425 | """ 426 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 427 | **kwargs) 428 | 429 | 430 | def resnet101(pretrained=False, progress=True, **kwargs): 431 | """Constructs a ResNet-101 model. 432 | 433 | Args: 434 | pretrained (bool): If True, returns a model pre-trained on ImageNet 435 | progress (bool): If True, displays a progress bar of the download to stderr 436 | """ 437 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 438 | **kwargs) 439 | 440 | 441 | def resnet152(pretrained=False, progress=True, **kwargs): 442 | """Constructs a ResNet-152 model. 443 | 444 | Args: 445 | pretrained (bool): If True, returns a model pre-trained on ImageNet 446 | progress (bool): If True, displays a progress bar of the download to stderr 447 | """ 448 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 449 | **kwargs) 450 | 451 | 452 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 453 | """Constructs a ResNeXt-50 32x4d model. 454 | 455 | Args: 456 | pretrained (bool): If True, returns a model pre-trained on ImageNet 457 | progress (bool): If True, displays a progress bar of the download to stderr 458 | """ 459 | kwargs['groups'] = 32 460 | kwargs['width_per_group'] = 4 461 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 462 | pretrained, progress, **kwargs) 463 | 464 | 465 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 466 | """Constructs a ResNeXt-101 32x8d model. 467 | 468 | Args: 469 | pretrained (bool): If True, returns a model pre-trained on ImageNet 470 | progress (bool): If True, displays a progress bar of the download to stderr 471 | """ 472 | kwargs['groups'] = 32 473 | kwargs['width_per_group'] = 8 474 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 475 | pretrained, progress, **kwargs) 476 | -------------------------------------------------------------------------------- /IDKL/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pprint 4 | 5 | import torch 6 | import yaml 7 | from apex import amp 8 | from torch import optim 9 | 10 | from data import get_test_loader 11 | from data import get_train_loader 12 | from engine import get_trainer 13 | from models.baseline import Baseline 14 | # from torchstat import stat 15 | 16 | # from WarmUpLR import WarmUpStepLR 17 | 18 | def train(cfg): 19 | # set logger 20 | log_dir = os.path.join("logs/", cfg.dataset, cfg.prefix) 21 | if not os.path.isdir(log_dir): 22 | os.makedirs(log_dir, exist_ok=True) 23 | 24 | logging.basicConfig(format="%(asctime)s %(message)s", 25 | filename=log_dir + "/" + "log.txt", 26 | filemode="w") 27 | 28 | logger = logging.getLogger() 29 | logger.setLevel(logging.INFO) 30 | stream_handler = logging.StreamHandler() 31 | stream_handler.setLevel(logging.INFO) 32 | logger.addHandler(stream_handler) 33 | 34 | logger.info(pprint.pformat(cfg)) 35 | 36 | # training data loader 37 | train_loader = get_train_loader(dataset=cfg.dataset, 38 | root=cfg.data_root, 39 | sample_method=cfg.sample_method, 40 | batch_size=cfg.batch_size, 41 | p_size=cfg.p_size, 42 | k_size=cfg.k_size, 43 | random_flip=cfg.random_flip, 44 | random_crop=cfg.random_crop, 45 | random_erase=cfg.random_erase, 46 | color_jitter=cfg.color_jitter, 47 | padding=cfg.padding, 48 | image_size=cfg.image_size, 49 | num_workers=8) 50 | 51 | # evaluation data loader 52 | gallery_loader, query_loader = None, None 53 | if cfg.eval_interval > 0: 54 | if True == False: # tsne 55 | query_loader = get_train_loader(dataset=cfg.dataset, 56 | root=cfg.data_root, 57 | sample_method=cfg.sample_method, 58 | batch_size=cfg.batch_size, 59 | p_size=cfg.p_size, 60 | k_size=cfg.k_size, 61 | # random_flip=cfg.random_flip, 62 | # random_crop=cfg.random_crop, 63 | # random_erase=cfg.random_erase, 64 | # color_jitter=cfg.color_jitter, 65 | # padding=cfg.padding, 66 | image_size=cfg.image_size, 67 | num_workers=8) 68 | gallery_loader = query_loader 69 | 70 | else: 71 | gallery_loader, query_loader = get_test_loader(dataset=cfg.dataset, 72 | root=cfg.data_root, 73 | batch_size=64, 74 | image_size=cfg.image_size, 75 | num_workers=4) 76 | 77 | 78 | 79 | # model 80 | model = Baseline(num_classes=cfg.num_id, 81 | pattern_attention=cfg.pattern_attention, 82 | modality_attention=cfg.modality_attention, 83 | mutual_learning=cfg.mutual_learning, 84 | decompose=cfg.decompose, 85 | drop_last_stride=cfg.drop_last_stride, 86 | triplet=cfg.triplet, 87 | k_size=cfg.k_size, 88 | center_cluster=cfg.center_cluster, 89 | center=cfg.center, 90 | margin=cfg.margin, 91 | num_parts=cfg.num_parts, 92 | weight_KL=cfg.weight_KL, 93 | weight_sid=cfg.weight_sid, 94 | weight_sep=cfg.weight_sep, 95 | update_rate=cfg.update_rate, 96 | classification=cfg.classification, 97 | bg_kl=cfg.bg_kl, 98 | sm_kl=cfg.sm_kl, 99 | fb_dt=cfg.fb_dt, 100 | IP=cfg.IP, 101 | distalign=cfg.distalign) 102 | 103 | def get_parameter_number(net): 104 | total_num = sum(p.numel() for p in net.parameters()) 105 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) 106 | return {'Total': total_num, 'Trainable': trainable_num} 107 | 108 | print(get_parameter_number(model)) 109 | 110 | model.to(device) 111 | 112 | # optimizer 113 | assert cfg.optimizer in ['adam', 'sgd'] 114 | if cfg.optimizer == 'adam': 115 | #optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd) 116 | optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd) 117 | else: 118 | optimizer = optim.SGD(model.parameters(), lr=cfg.lr, momentum=0.9, weight_decay=cfg.wd) 119 | 120 | # convert model for mixed precision training 121 | model, optimizer = amp.initialize(model, optimizer, enabled=cfg.fp16, opt_level="O1") 122 | if cfg.center: 123 | model.center_loss.centers = model.center_loss.centers.float() 124 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, 125 | milestones=cfg.lr_step, 126 | gamma=0.1) 127 | 128 | 129 | 130 | if cfg.resume: 131 | checkpoint = torch.load(cfg.resume) 132 | model.load_state_dict(checkpoint) 133 | 134 | # stat(model,(3,224,224)) 135 | # import pdb 136 | # pdb.set_trace() 137 | 138 | # engine 139 | checkpoint_dir = os.path.join("checkpoints", cfg.dataset, cfg.prefix) 140 | engine = get_trainer(dataset=cfg.dataset, 141 | model=model, 142 | optimizer=optimizer, 143 | lr_scheduler=lr_scheduler, 144 | logger=logger, 145 | non_blocking=True, 146 | log_period=cfg.log_period, 147 | save_dir=checkpoint_dir, 148 | prefix=cfg.prefix, 149 | eval_interval=cfg.eval_interval, 150 | start_eval=cfg.start_eval, 151 | gallery_loader=gallery_loader, 152 | query_loader=query_loader, 153 | rerank=cfg.rerank) 154 | 155 | # training 156 | engine.run(train_loader, max_epochs=cfg.num_epoch) 157 | 158 | 159 | if __name__ == '__main__': 160 | import argparse 161 | import random 162 | import numpy as np 163 | from configs.default import strategy_cfg 164 | from configs.default import dataset_cfg 165 | 166 | parser = argparse.ArgumentParser() 167 | parser.add_argument("--cfg", type=str, default="configs/softmax.yml") 168 | ################ 169 | parser.add_argument('--gpu', default='0', type=str, 170 | help='gpu device ids for CUDA_VISIBLE_DEVICES') 171 | #################### 172 | args = parser.parse_args() 173 | 174 | ###################### 175 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 176 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 177 | ################ 178 | 179 | # set random seed 180 | seed = 1 181 | random.seed(seed) 182 | np.random.RandomState(seed) 183 | np.random.seed(seed) 184 | torch.manual_seed(seed) 185 | torch.cuda.manual_seed(seed) 186 | 187 | # enable cudnn backend 188 | torch.backends.cudnn.benchmark = True 189 | # torch.backends.cudnn.benchmark = False 190 | # torch.backends.cudnn.deterministic = True 191 | 192 | # load configuration 193 | customized_cfg = yaml.load(open(args.cfg, "r"), Loader=yaml.SafeLoader) 194 | 195 | cfg = strategy_cfg 196 | cfg.merge_from_file(args.cfg) 197 | 198 | dataset_cfg = dataset_cfg.get(cfg.dataset) 199 | 200 | for k, v in dataset_cfg.items(): 201 | cfg[k] = v 202 | 203 | if cfg.sample_method == 'identity_uniform' or 'identity_random': #'identity_uniform' or 'identity_random' 204 | cfg.batch_size = cfg.p_size * cfg.k_size 205 | 206 | cfg.freeze() 207 | 208 | train(cfg) 209 | -------------------------------------------------------------------------------- /IDKL/utils/__pycache__/calc_acc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/utils/__pycache__/calc_acc.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/utils/__pycache__/eval_regdb.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/utils/__pycache__/eval_regdb.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/utils/__pycache__/eval_sysu.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/utils/__pycache__/eval_sysu.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/utils/__pycache__/neighbor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/utils/__pycache__/neighbor.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/utils/__pycache__/rerank.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/utils/__pycache__/rerank.cpython-37.pyc -------------------------------------------------------------------------------- /IDKL/utils/calc_acc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def calc_acc(logits, label, ignore_index=-100, mode="multiclass"): 5 | if mode == "binary": 6 | indices = torch.round(logits).type(label.type()) 7 | elif mode == "multiclass": 8 | indices = torch.max(logits, dim=1)[1] 9 | 10 | if label.size() == logits.size(): 11 | ignore = 1 - torch.round(label.sum(dim=1)) 12 | label = torch.max(label, dim=1)[1] 13 | else: 14 | ignore = torch.eq(label, ignore_index).view(-1) 15 | 16 | correct = torch.eq(indices, label).view(-1) 17 | num_correct = torch.sum(correct) 18 | num_examples = logits.shape[0] - ignore.sum() 19 | 20 | return num_correct.float() / num_examples.float() 21 | -------------------------------------------------------------------------------- /IDKL/utils/eval_llcm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | import torch 5 | from sklearn.preprocessing import normalize 6 | from torch.nn import functional as F 7 | from .rerank import re_ranking, pairwise_distance 8 | 9 | 10 | def get_gallery_names(perm, cams, ids, trial_id, num_shots=1): 11 | names = [] 12 | for cam in cams: 13 | cam_perm = perm[cam - 1][0].squeeze() 14 | for i in ids: 15 | instance_id = cam_perm[i - 1][trial_id][:num_shots] 16 | names.extend(['cam{}/{:0>4d}/{:0>4d}'.format(cam, i, ins) for ins in instance_id.tolist()]) 17 | 18 | return names 19 | 20 | 21 | def get_unique(array): 22 | _, idx = np.unique(array, return_index=True) 23 | return array[np.sort(idx)] 24 | 25 | 26 | def get_cmc(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids): 27 | gallery_unique_count = get_unique(gallery_ids).shape[0] 28 | match_counter = np.zeros((gallery_unique_count,)) 29 | 30 | result = gallery_ids[sorted_indices] 31 | cam_locations_result = gallery_cam_ids[sorted_indices] 32 | 33 | valid_probe_sample_count = 0 34 | 35 | for probe_index in range(sorted_indices.shape[0]): 36 | # remove gallery samples from the same camera of the probe 37 | result_i = result[probe_index, :] 38 | result_i[np.equal(cam_locations_result[probe_index], query_cam_ids[probe_index])] = -1 39 | 40 | # remove the -1 entries from the label result 41 | result_i = np.array([i for i in result_i if i != -1]) 42 | 43 | # remove duplicated id in "stable" manner 44 | result_i_unique = get_unique(result_i) 45 | 46 | # match for probe i 47 | match_i = np.equal(result_i_unique, query_ids[probe_index]) 48 | 49 | if np.sum(match_i) != 0: # if there is true matching in gallery 50 | valid_probe_sample_count += 1 51 | match_counter += match_i 52 | 53 | rank = match_counter / valid_probe_sample_count 54 | cmc = np.cumsum(rank) 55 | return cmc 56 | 57 | 58 | def get_mAP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids): 59 | result = gallery_ids[sorted_indices] 60 | cam_locations_result = gallery_cam_ids[sorted_indices] 61 | 62 | valid_probe_sample_count = 0 63 | avg_precision_sum = 0 64 | 65 | for probe_index in range(sorted_indices.shape[0]): 66 | # remove gallery samples from the same camera of the probe 67 | result_i = result[probe_index, :] 68 | result_i[cam_locations_result[probe_index, :] == query_cam_ids[probe_index]] = -1 69 | 70 | # remove the -1 entries from the label result 71 | result_i = np.array([i for i in result_i if i != -1]) 72 | 73 | # match for probe i 74 | match_i = result_i == query_ids[probe_index] 75 | true_match_count = np.sum(match_i) 76 | 77 | if true_match_count != 0: # if there is true matching in gallery 78 | valid_probe_sample_count += 1 79 | true_match_rank = np.where(match_i)[0] 80 | 81 | ap = np.mean(np.arange(1, true_match_count + 1) / (true_match_rank + 1)) 82 | avg_precision_sum += ap 83 | 84 | mAP = avg_precision_sum / valid_probe_sample_count 85 | return mAP 86 | 87 | 88 | # def eval_llcm(query_feats, q_pids, q_camids, gallery_feats, g_pids, g_camids, max_rank=20, rerank=False): 89 | # """Evaluation with sysu metric 90 | # Key: for each query identity, its gallery images from the same camera view are discarded. "Following the original setting in ite dataset" 91 | # """ 92 | # ptr = 0 93 | # query_feat = np.zeros((nquery, 2048)) 94 | # query_feat_att = np.zeros((nquery, 2048)) 95 | # with torch.no_grad(): 96 | # for batch_idx, (input, label) in enumerate(query_loader): 97 | # batch_num = input.size(0) 98 | # input = Variable(input.cuda()) 99 | # feat, feat_att = net(input, input, test_mode[1]) 100 | # query_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 101 | # query_feat_att[ptr:ptr + batch_num, :] = feat_att.detach().cpu().numpy() 102 | # ptr = ptr + batch_num 103 | # distmat = -np.matmul(query_feats.cpu().numpy(), np.transpose(gallery_feats.cpu().numpy())) 104 | # num_q, num_g = distmat.shape 105 | # if num_g < max_rank: 106 | # max_rank = num_g 107 | # print("Note: number of gallery samples is quite small, got {}".format(num_g)) 108 | # indices = np.argsort(distmat, axis=1) 109 | # pred_label = g_pids[indices] 110 | # matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 111 | # 112 | # # compute cmc curve for each query 113 | # new_all_cmc = [] 114 | # all_cmc = [] 115 | # all_AP = [] 116 | # all_INP = [] 117 | # num_valid_q = 0. # number of valid query 118 | # for q_idx in range(num_q): 119 | # # get query pid and camid 120 | # q_pid = q_pids[q_idx] 121 | # q_camid = q_camids[q_idx] 122 | # 123 | # # remove gallery samples that have the same pid and camid with query 124 | # 125 | # order = indices[q_idx] 126 | # remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 127 | # keep = np.invert(remove) 128 | # 129 | # # compute cmc curve 130 | # # the cmc calculation is different from standard protocol 131 | # # we follow the protocol of the author's released code 132 | # new_cmc = pred_label[q_idx][keep] 133 | # new_index = np.unique(new_cmc, return_index=True)[1] 134 | # 135 | # new_cmc = [new_cmc[index] for index in sorted(new_index)] 136 | # 137 | # new_match = (new_cmc == q_pid).astype(np.int32) 138 | # new_cmc = new_match.cumsum() 139 | # new_all_cmc.append(new_cmc[:max_rank]) 140 | # 141 | # orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 142 | # if not np.any(orig_cmc): 143 | # # this condition is true when query identity does not appear in gallery 144 | # continue 145 | # 146 | # cmc = orig_cmc.cumsum() 147 | # 148 | # # compute mINP 149 | # # refernece Deep Learning for Person Re-identification: A Survey and Outlook 150 | # pos_idx = np.where(orig_cmc == 1) 151 | # pos_max_idx = np.max(pos_idx) 152 | # inp = cmc[pos_max_idx] / (pos_max_idx + 1.0) 153 | # all_INP.append(inp) 154 | # 155 | # cmc[cmc > 1] = 1 156 | # 157 | # all_cmc.append(cmc[:max_rank]) 158 | # num_valid_q += 1. 159 | # 160 | # # compute average precision 161 | # # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 162 | # num_rel = orig_cmc.sum() 163 | # tmp_cmc = orig_cmc.cumsum() 164 | # tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 165 | # tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 166 | # AP = tmp_cmc.sum() / num_rel 167 | # all_AP.append(AP) 168 | # 169 | # assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 170 | # 171 | # all_cmc = np.asarray(all_cmc).astype(np.float32) 172 | # all_cmc = all_cmc.sum(0) / num_valid_q # standard CMC 173 | # 174 | # new_all_cmc = np.asarray(new_all_cmc).astype(np.float32) 175 | # new_all_cmc = new_all_cmc.sum(0) / num_valid_q 176 | # mAP = np.mean(all_AP) 177 | # mINP = np.mean(all_INP) 178 | # return new_all_cmc, mAP, mINP 179 | 180 | 181 | def eval_llcm(query_feats, query_ids, query_cam_ids, gallery_feats, gallery_ids, gallery_cam_ids, gallery_img_paths, rerank=False): 182 | # gallery_feats = F.normalize(gallery_feats, dim=1) 183 | # query_feats = F.normalize(query_feats, dim=1) 184 | 185 | if rerank: 186 | dist_mat = re_ranking(query_feats, gallery_feats, eval_type=False) 187 | else: 188 | dist_mat = pairwise_distance(query_feats, gallery_feats) 189 | # dist_mat = -torch.mm(query_feats, gallery_feats.t()) 190 | 191 | sorted_indices = np.argsort(dist_mat, axis=1) 192 | 193 | mAP = get_mAP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids) 194 | cmc = get_cmc(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids) 195 | 196 | r1 = cmc[0] 197 | r5 = cmc[4] 198 | r10 = cmc[9] 199 | r20 = cmc[19] 200 | 201 | r1 = r1 * 100 202 | r5 = r5 * 100 203 | r10 = r10 * 100 204 | r20 = r20 * 100 205 | mAP = mAP * 100 206 | 207 | perf = 'r1 precision = {:.2f} , r10 precision = {:.2f} , r20 precision = {:.2f}, mAP = {:.2f}' 208 | logging.info(perf.format(r1, r10, r20, mAP)) 209 | 210 | return mAP, r1, r5, r10, r20 211 | -------------------------------------------------------------------------------- /IDKL/utils/eval_regdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | import torch 5 | from sklearn.preprocessing import normalize 6 | from torch.nn import functional as F 7 | from .rerank import re_ranking, pairwise_distance 8 | 9 | 10 | def get_gallery_names(perm, cams, ids, trial_id, num_shots=1): 11 | names = [] 12 | for cam in cams: 13 | cam_perm = perm[cam - 1][0].squeeze() 14 | for i in ids: 15 | instance_id = cam_perm[i - 1][trial_id][:num_shots] 16 | names.extend(['cam{}/{:0>4d}/{:0>4d}'.format(cam, i, ins) for ins in instance_id.tolist()]) 17 | 18 | return names 19 | 20 | 21 | def get_unique(array): 22 | _, idx = np.unique(array, return_index=True) 23 | return array[np.sort(idx)] 24 | 25 | 26 | def get_cmc(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids): 27 | gallery_unique_count = get_unique(gallery_ids).shape[0] 28 | match_counter = np.zeros((gallery_unique_count,)) 29 | 30 | result = gallery_ids[sorted_indices] 31 | cam_locations_result = gallery_cam_ids[sorted_indices] 32 | 33 | valid_probe_sample_count = 0 34 | 35 | for probe_index in range(sorted_indices.shape[0]): 36 | # remove gallery samples from the same camera of the probe 37 | result_i = result[probe_index, :] 38 | result_i[np.equal(cam_locations_result[probe_index], query_cam_ids[probe_index])] = -1 39 | 40 | # remove the -1 entries from the label result 41 | result_i = np.array([i for i in result_i if i != -1]) 42 | 43 | # remove duplicated id in "stable" manner 44 | result_i_unique = get_unique(result_i) 45 | 46 | # match for probe i 47 | match_i = np.equal(result_i_unique, query_ids[probe_index]) 48 | 49 | if np.sum(match_i) != 0: # if there is true matching in gallery 50 | valid_probe_sample_count += 1 51 | match_counter += match_i 52 | 53 | rank = match_counter / valid_probe_sample_count 54 | cmc = np.cumsum(rank) 55 | return cmc 56 | 57 | 58 | def get_mAP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids): 59 | result = gallery_ids[sorted_indices] 60 | cam_locations_result = gallery_cam_ids[sorted_indices] 61 | 62 | valid_probe_sample_count = 0 63 | avg_precision_sum = 0 64 | 65 | for probe_index in range(sorted_indices.shape[0]): 66 | # remove gallery samples from the same camera of the probe 67 | result_i = result[probe_index, :] 68 | result_i[cam_locations_result[probe_index, :] == query_cam_ids[probe_index]] = -1 69 | 70 | # remove the -1 entries from the label result 71 | result_i = np.array([i for i in result_i if i != -1]) 72 | 73 | # match for probe i 74 | match_i = result_i == query_ids[probe_index] 75 | true_match_count = np.sum(match_i) 76 | 77 | if true_match_count != 0: # if there is true matching in gallery 78 | valid_probe_sample_count += 1 79 | true_match_rank = np.where(match_i)[0] 80 | 81 | ap = np.mean(np.arange(1, true_match_count + 1) / (true_match_rank + 1)) 82 | avg_precision_sum += ap 83 | 84 | mAP = avg_precision_sum / valid_probe_sample_count 85 | return mAP 86 | 87 | def eval_regdb(query_feats, query_ids, query_cam_ids, gallery_feats, gallery_ids, gallery_cam_ids, gallery_img_paths, rerank=False): 88 | # gallery_feats = F.normalize(gallery_feats, dim=1) 89 | # query_feats = F.normalize(query_feats, dim=1) 90 | 91 | if rerank: 92 | dist_mat = re_ranking(query_feats, gallery_feats, eval_type=False) 93 | else: 94 | dist_mat = pairwise_distance(query_feats, gallery_feats) 95 | # dist_mat = -torch.mm(query_feats, gallery_feats.t()) 96 | 97 | sorted_indices = np.argsort(dist_mat, axis=1) 98 | 99 | mAP = get_mAP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids) 100 | cmc = get_cmc(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids) 101 | 102 | r1 = cmc[0] 103 | r5 = cmc[4] 104 | r10 = cmc[9] 105 | r20 = cmc[19] 106 | 107 | r1 = r1 * 100 108 | r5 = r5 * 100 109 | r10 = r10 * 100 110 | r20 = r20 * 100 111 | mAP = mAP * 100 112 | 113 | perf = 'r1 precision = {:.2f} , r10 precision = {:.2f} , r20 precision = {:.2f}, mAP = {:.2f}' 114 | logging.info(perf.format(r1, r10, r20, mAP)) 115 | 116 | return mAP, r1, r5, r10, r20 117 | -------------------------------------------------------------------------------- /IDKL/utils/eval_sysu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import torch 4 | import numpy as np 5 | from sklearn.preprocessing import normalize 6 | from .rerank import re_ranking, pairwise_distance 7 | from torch.nn import functional as F 8 | 9 | 10 | def get_gallery_names(perm, cams, ids, trial_id, num_shots=1): 11 | names = [] 12 | for cam in cams: 13 | cam_perm = perm[cam - 1][0].squeeze() 14 | for i in ids: 15 | instance_id = cam_perm[i - 1][trial_id][:num_shots] 16 | names.extend(['cam{}/{:0>4d}/{:0>4d}'.format(cam, i, ins) for ins in instance_id.tolist()]) 17 | 18 | return names 19 | 20 | 21 | def get_unique(array): 22 | _, idx = np.unique(array, return_index=True) 23 | return array[np.sort(idx)] 24 | 25 | 26 | def get_cmc(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids): 27 | gallery_unique_count = get_unique(gallery_ids).shape[0] 28 | match_counter = np.zeros((gallery_unique_count,)) 29 | 30 | result = gallery_ids[sorted_indices] 31 | cam_locations_result = gallery_cam_ids[sorted_indices] 32 | 33 | valid_probe_sample_count = 0 34 | 35 | for probe_index in range(sorted_indices.shape[0]): 36 | # remove gallery samples from the same camera of the probe 37 | result_i = result[probe_index, :] 38 | result_i[np.equal(cam_locations_result[probe_index], query_cam_ids[probe_index])] = -1 39 | 40 | # remove the -1 entries from the label result 41 | result_i = np.array([i for i in result_i if i != -1]) 42 | 43 | # remove duplicated id in "stable" manner 44 | result_i_unique = get_unique(result_i) 45 | 46 | # match for probe i 47 | match_i = np.equal(result_i_unique, query_ids[probe_index]) 48 | 49 | if np.sum(match_i) != 0: # if there is true matching in gallery 50 | valid_probe_sample_count += 1 51 | match_counter += match_i 52 | 53 | rank = match_counter / valid_probe_sample_count 54 | cmc = np.cumsum(rank) 55 | return cmc 56 | 57 | 58 | def get_mAP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids): 59 | result = gallery_ids[sorted_indices] 60 | cam_locations_result = gallery_cam_ids[sorted_indices] 61 | 62 | valid_probe_sample_count = 0 63 | avg_precision_sum = 0 64 | 65 | for probe_index in range(sorted_indices.shape[0]): 66 | # remove gallery samples from the same camera of the probe 67 | result_i = result[probe_index, :] 68 | result_i[cam_locations_result[probe_index, :] == query_cam_ids[probe_index]] = -1 69 | 70 | # remove the -1 entries from the label result 71 | result_i = np.array([i for i in result_i if i != -1]) 72 | 73 | # match for probe i 74 | match_i = result_i == query_ids[probe_index] 75 | true_match_count = np.sum(match_i) 76 | 77 | if true_match_count != 0: # if there is true matching in gallery 78 | valid_probe_sample_count += 1 79 | true_match_rank = np.where(match_i)[0] 80 | 81 | ap = np.mean(np.arange(1, true_match_count + 1) / (true_match_rank + 1)) 82 | avg_precision_sum += ap 83 | 84 | mAP = avg_precision_sum / valid_probe_sample_count 85 | return mAP 86 | 87 | 88 | def eval_sysu(query_feats, query_ids, query_cam_ids, gallery_feats, gallery_ids, gallery_cam_ids, gallery_img_paths, 89 | perm, mode='all', num_shots=1, num_trials=10, rerank=False): 90 | assert mode in ['indoor', 'all'] 91 | 92 | gallery_cams = [1, 2] if mode == 'indoor' else [1, 2, 4, 5] 93 | 94 | # cam2 and cam3 are in the same location 95 | query_cam_ids[np.equal(query_cam_ids, 3)] = 2 96 | query_feats = F.normalize(query_feats, dim=1) 97 | 98 | gallery_indices = np.in1d(gallery_cam_ids, gallery_cams) 99 | 100 | gallery_feats = gallery_feats[gallery_indices] 101 | gallery_feats = F.normalize(gallery_feats, dim=1) 102 | gallery_cam_ids = gallery_cam_ids[gallery_indices] 103 | gallery_ids = gallery_ids[gallery_indices] 104 | gallery_img_paths = gallery_img_paths[gallery_indices] 105 | gallery_names = np.array(['/'.join(os.path.splitext(path)[0].split('/')[-3:]) for path in gallery_img_paths]) 106 | 107 | gallery_id_set = np.unique(gallery_ids) 108 | 109 | mAP, r1, r5, r10, r20 = 0, 0, 0, 0, 0 110 | for t in range(num_trials): 111 | names = get_gallery_names(perm, gallery_cams, gallery_id_set, t, num_shots) 112 | flag = np.in1d(gallery_names, names) 113 | 114 | g_feat = gallery_feats[flag] 115 | g_ids = gallery_ids[flag] 116 | g_cam_ids = gallery_cam_ids[flag] 117 | 118 | if rerank: 119 | dist_mat = re_ranking(query_feats, g_feat) 120 | else: 121 | dist_mat = pairwise_distance(query_feats, g_feat) 122 | # dist_mat = -torch.mm(query_feats, g_feat.permute(1,0)) 123 | 124 | sorted_indices = np.argsort(dist_mat, axis=1) 125 | 126 | mAP += get_mAP(sorted_indices, query_ids, query_cam_ids, g_ids, g_cam_ids) 127 | cmc = get_cmc(sorted_indices, query_ids, query_cam_ids, g_ids, g_cam_ids) 128 | 129 | r1 += cmc[0] 130 | r5 += cmc[4] 131 | r10 += cmc[9] 132 | r20 += cmc[19] 133 | 134 | r1 = r1 / num_trials * 100 135 | r5 = r5 / num_trials * 100 136 | r10 = r10 / num_trials * 100 137 | r20 = r20 / num_trials * 100 138 | mAP = mAP / num_trials * 100 139 | 140 | perf = '{} num-shot:{} r1 precision = {:.2f} , r10 precision = {:.2f} , r20 precision = {:.2f}, mAP = {:.2f}' 141 | logging.info(perf.format(mode, num_shots, r1, r10, r20, mAP)) 142 | 143 | return mAP, r1, r5, r10, r20 144 | -------------------------------------------------------------------------------- /IDKL/utils/rerank.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def k_reciprocal_neigh( initial_rank, i, k1): 5 | forward_k_neigh_index = initial_rank[i,:k1+1] 6 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 7 | fi = np.where(backward_k_neigh_index==i)[0] 8 | return forward_k_neigh_index[fi] 9 | 10 | def pairwise_distance(query_features, gallery_features): 11 | x = query_features 12 | y = gallery_features 13 | m, n = x.size(0), y.size(0) 14 | x = x.view(m, -1) 15 | y = y.view(n, -1) 16 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 17 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 18 | dist.addmm_(1, -2, x, y.t()) 19 | return dist 20 | 21 | def re_ranking(q_feat, g_feat, k1=20, k2=6, lambda_value=0.3, eval_type=True): 22 | # The following naming, e.g. gallery_num, is different from outer scope. 23 | # Don't care about it. 24 | feats = torch.cat([q_feat, g_feat], 0) 25 | dist = pairwise_distance(feats, feats) 26 | original_dist = dist.detach().cpu().numpy() 27 | all_num = original_dist.shape[0] 28 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0)) 29 | V = np.zeros_like(original_dist).astype(np.float16) 30 | 31 | query_num = q_feat.size(0) 32 | all_num = original_dist.shape[0] 33 | if eval_type: 34 | dist[:, query_num:] = dist.max() 35 | dist = dist.detach().cpu().numpy() 36 | initial_rank = np.argsort(dist).astype(np.int32) 37 | 38 | # print("start re-ranking") 39 | for i in range(all_num): 40 | # k-reciprocal neighbors 41 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 42 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 43 | fi = np.where(backward_k_neigh_index == i)[0] 44 | k_reciprocal_index = forward_k_neigh_index[fi] 45 | k_reciprocal_expansion_index = k_reciprocal_index 46 | for j in range(len(k_reciprocal_index)): 47 | candidate = k_reciprocal_index[j] 48 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1] 49 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, 50 | :int(np.around(k1 / 2)) + 1] 51 | # import pdb 52 | # pdb.set_trace() 53 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 54 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 55 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len( 56 | candidate_k_reciprocal_index): 57 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index) 58 | 59 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 60 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) 61 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight) 62 | original_dist = original_dist[:query_num, ] 63 | if k2 != 1: 64 | V_qe = np.zeros_like(V, dtype=np.float16) 65 | for i in range(all_num): 66 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 67 | V = V_qe 68 | del V_qe 69 | del initial_rank 70 | invIndex = [] 71 | for i in range(all_num): 72 | invIndex.append(np.where(V[:, i] != 0)[0]) 73 | 74 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16) 75 | 76 | 77 | for i in range(query_num): 78 | temp_min = np.zeros(shape=[1, all_num], dtype=np.float16) 79 | indNonZero = np.where(V[i, :] != 0)[0] 80 | indImages = [] 81 | indImages = [invIndex[ind] for ind in indNonZero] 82 | for j in range(len(indNonZero)): 83 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]], 84 | V[indImages[j], indNonZero[j]]) 85 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 86 | 87 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value 88 | del original_dist 89 | del V 90 | del jaccard_dist 91 | final_dist = final_dist[:query_num, query_num:] 92 | return final_dist -------------------------------------------------------------------------------- /IDKL/utils/tsne.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import scipy.io as sio 5 | import matplotlib as mpl 6 | 7 | mpl.use('AGG') 8 | import matplotlib.pyplot as plt 9 | from sklearn.manifold import TSNE 10 | 11 | if __name__ == '__main__': 12 | test_ids = [ 13 | 6, 10, 17, 21, 24, 25, 27, 28, 31, 34, 36, 37, 40, 41, 42, 43, 44, 45, 49, 50, 51, 54, 63, 69, 75, 80, 81, 82, 14 | 83, 84, 85, 86, 87, 88, 89, 90, 93, 102, 104, 105, 106, 108, 112, 116, 117, 122, 125, 129, 130, 134, 138, 139, 15 | 150, 152, 162, 166, 167, 170, 172, 176, 185, 190, 192, 202, 204, 207, 210, 215, 223, 229, 232, 237, 252, 253, 16 | 257, 259, 263, 266, 269, 272, 273, 274, 275, 282, 285, 291, 300, 301, 302, 303, 307, 312, 315, 318, 331, 333 17 | ] 18 | random.seed(0) 19 | tsne = TSNE(n_components=2, init='pca') 20 | selected_ids = random.sample(test_ids, 20) 21 | plt.figure(figsize=(5, 5)) 22 | 23 | # features without dual path 24 | q_mat_path = 'features/sysu/query-sysu-test-nodual-nore-adam-16x8-grey_model_150.mat' 25 | g_mat_path = 'features/sysu/gallery-sysu-test-nodual-nore-adam-16x8-grey_model_150.mat' 26 | 27 | mat = sio.loadmat(q_mat_path) 28 | q_feats = mat["feat"] 29 | q_ids = mat["ids"].squeeze() 30 | flag = np.in1d(q_ids, selected_ids) 31 | q_feats = q_feats[flag] 32 | 33 | mat = sio.loadmat(g_mat_path) 34 | g_feats = mat["feat"] 35 | g_ids = mat["ids"].squeeze() 36 | flag = np.in1d(g_ids, selected_ids) 37 | g_feats = g_feats[flag] 38 | 39 | embed = tsne.fit_transform(np.concatenate([q_feats, g_feats], axis=0)) 40 | c = ['r'] * q_feats.shape[0] + ['b'] * g_feats.shape[0] 41 | # plt.subplot(1, 2, 1) 42 | plt.scatter(embed[:, 0], embed[:, 1], c=c) 43 | 44 | # # features with dual path 45 | # q_mat_path = 'features/sysu/query-sysu-test-dual-nore-separatelayer12-0.05_model_30.mat' 46 | # g_mat_path = 'features/sysu/gallery-sysu-test-dual-nore-separatelayer12-0.05_model_30.mat' 47 | # 48 | # mat = sio.loadmat(q_mat_path) 49 | # q_feats = mat["feat"] 50 | # q_ids = mat["ids"].squeeze() 51 | # flag = np.in1d(q_ids, selected_ids) 52 | # q_feats = q_feats[flag] 53 | # 54 | # mat = sio.loadmat(g_mat_path) 55 | # g_feats = mat["feat"] 56 | # g_ids = mat["ids"].squeeze() 57 | # flag = np.in1d(g_ids, selected_ids) 58 | # g_feats = g_feats[flag] 59 | # 60 | # embed = tsne.fit_transform(np.concatenate([q_feats, g_feats], axis=0)) 61 | # c = ['r'] * q_feats.shape[0] + ['b'] * g_feats.shape[0] 62 | # plt.subplot(1, 2, 2) 63 | # plt.scatter(embed[:, 0], embed[:, 1], c=c) 64 | 65 | plt.tight_layout() 66 | plt.savefig('tsne-adv-layer2-separate-l2.jpg') 67 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [CVPR2024]IDKL: Implicit Discriminative Knowledge Learning for Visible-Infrared Person Re-Identification. (https://arxiv.org/abs/2403.11708) 2 | 3 | ## Environmental requirements: 4 | 5 | PyTorch == 1.10.0 6 | 7 | ignite == 0.2.1 8 | 9 | torchvision == 0.11.2 10 | 11 | apex == 0.1 12 | 13 | **Training:** 14 | 15 | To train the model, you can use following command: 16 | 17 | SYSU-MM01: 18 | ```Shell 19 | python train.py --cfg ./configs/SYSU.yml 20 | ``` 21 | 22 | RegDB: 23 | ```Shell 24 | python train.py --cfg ./configs/RegDB.yml 25 | ``` 26 | 27 | RegDB: 28 | ```Shell 29 | python train.py --cfg ./configs/RegDB.yml 30 | ``` 31 | 32 | 33 | --------------------------------------------------------------------------------