├── IDKL
├── .idea
│ ├── IDKL.iml
│ ├── deployment.xml
│ ├── inspectionProfiles
│ │ └── profiles_settings.xml
│ ├── modules.xml
│ ├── vcs.xml
│ └── workspace.xml
├── README.md
├── configs
│ ├── LLCM.yml
│ ├── RegDB.yml
│ ├── SYSU.yml
│ └── default
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── dataset.cpython-37.pyc
│ │ └── strategy.cpython-37.pyc
│ │ ├── dataset.py
│ │ └── strategy.py
├── data
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── clone.cpython-37.pyc
│ │ ├── dataset.cpython-36.pyc
│ │ ├── dataset.cpython-37.pyc
│ │ ├── sampler.cpython-36.pyc
│ │ └── sampler.cpython-37.pyc
│ ├── dataset.py
│ └── sampler.py
├── engine
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── engine.cpython-37.pyc
│ │ └── metric.cpython-37.pyc
│ ├── engine.py
│ └── metric.py
├── grad_cam
│ ├── README.md
│ ├── both.png
│ ├── imagenet1k_classes.txt
│ ├── imagenet21k_classes.txt
│ ├── main_cnn.py
│ ├── main_swin.py
│ ├── main_vit.py
│ ├── swin_model.py
│ ├── utils.py
│ └── vit_model.py
├── layers
│ ├── __init__.py
│ ├── __pycache__
│ │ └── __init__.cpython-37.pyc
│ ├── loss
│ │ ├── JSD.py
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── JSD.cpython-37.pyc
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── am_softmax.cpython-37.pyc
│ │ │ ├── center_loss.cpython-37.pyc
│ │ │ ├── crossquad_loss.cpython-37.pyc
│ │ │ ├── crosstriplet_loss.cpython-37.pyc
│ │ │ ├── local_center_loss.cpython-37.pyc
│ │ │ ├── mixtriplet_loss.cpython-37.pyc
│ │ │ ├── trapezoid_loss.cpython-37.pyc
│ │ │ └── triplet_loss.cpython-37.pyc
│ │ ├── am_softmax.py
│ │ ├── center_loss.py
│ │ ├── local_center_loss.py
│ │ ├── rerank_loss.py
│ │ └── triplet_loss.py
│ └── module
│ │ ├── CBAM.py
│ │ ├── NonLocal.py
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── CBAM.cpython-37.pyc
│ │ ├── NonLocal.cpython-37.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── norm_linear.cpython-37.pyc
│ │ └── reverse_grad.cpython-37.pyc
│ │ ├── norm_linear.py
│ │ └── reverse_grad.py
├── models
│ ├── __pycache__
│ │ ├── baseline.cpython-36.pyc
│ │ ├── baseline.cpython-37.pyc
│ │ ├── resnet.cpython-36.pyc
│ │ └── resnet.cpython-37.pyc
│ ├── baseline.py
│ └── resnet.py
├── train.py
└── utils
│ ├── __pycache__
│ ├── calc_acc.cpython-37.pyc
│ ├── eval_regdb.cpython-37.pyc
│ ├── eval_sysu.cpython-37.pyc
│ ├── neighbor.cpython-37.pyc
│ └── rerank.cpython-37.pyc
│ ├── calc_acc.py
│ ├── eval_llcm.py
│ ├── eval_regdb.py
│ ├── eval_sysu.py
│ ├── rerank.py
│ └── tsne.py
└── README.md
/IDKL/.idea/IDKL.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/IDKL/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
--------------------------------------------------------------------------------
/IDKL/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/IDKL/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/IDKL/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/IDKL/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
30 |
31 |
32 |
33 |
34 | 1723144514458
35 |
36 |
37 | 1723144514458
38 |
39 |
40 |
41 |
42 |
43 |
52 |
53 |
--------------------------------------------------------------------------------
/IDKL/README.md:
--------------------------------------------------------------------------------
1 | [CVPR2024]IDKL: Implicit Discriminative Knowledge Learning for Visible-Infrared Person Re-Identification. (https://arxiv.org/abs/2403.11708)
2 | ## Environmental requirements:
3 |
4 | python == 3.7
5 | PyTorch == 1.10.1
6 | ignite == 0.2.1
7 | torchvision == 0.11.2
8 | apex == 0.1
9 |
10 | ## Training:
11 |
12 | To train the model, you can use following command:
13 |
14 | SYSU-MM01:
15 | ```Shell
16 | python train.py --cfg ./configs/SYSU.yml
17 | ```
18 |
19 | RegDB:
20 | ```Shell
21 | python train.py --cfg ./configs/RegDB.yml
22 | ```
23 |
24 | RegDB:
25 | ```Shell
26 | python train.py --cfg ./configs/RegDB.yml
27 | ```
28 |
29 |
--------------------------------------------------------------------------------
/IDKL/configs/LLCM.yml:
--------------------------------------------------------------------------------
1 | prefix: LLCM
2 | fp16: true
3 |
4 | # dataset
5 | sample_method: identity_random
6 | image_size: (384, 144)
7 | p_size: 12
8 | k_size: 10
9 |
10 | dataset: llcm
11 |
12 | # loss
13 | bg_kl: true
14 | sm_kl: true
15 | IP: true
16 | decompose: true
17 | distalign: true
18 | classification: true
19 | center_cluster: false
20 | triplet: true
21 | center: false
22 | fb_dt: false
23 |
24 | # parameters
25 | margin: 1.3 #0.7
26 | # pattern attention
27 | num_parts: 6
28 | weight_sep: 0.5
29 | # mutual learning
30 | update_rate: 0.2
31 | weight_sid: 0.5
32 | weight_KL: 2.5
33 |
34 | # architecture
35 | drop_last_stride: true
36 |
37 | # optimizer
38 | lr: 0.00035
39 | optimizer: adam
40 | num_epoch: 160 #160
41 | lr_step: [55, 95]
42 |
43 | # augmentation
44 | random_flip: true
45 | random_crop: true
46 | random_erase: true
47 | color_jitter: false
48 | padding: 10
49 |
50 | # log
51 | log_period: 150
52 | start_eval: 200
53 | eval_interval: 5
54 |
--------------------------------------------------------------------------------
/IDKL/configs/RegDB.yml:
--------------------------------------------------------------------------------
1 | prefix: RegDB
2 |
3 | fp16: true
4 |
5 | # dataset
6 | sample_method: identity_random
7 | image_size: (256, 128)
8 | p_size: 6
9 | k_size: 10
10 |
11 | dataset: regdb
12 |
13 | # loss
14 | bg_kl: true
15 | sm_kl: true
16 | decompose: true
17 | IP: true
18 | distalign: false
19 | classification: true
20 | center_cluster: false
21 | triplet: true
22 | fb_dt: false #true
23 | center: false
24 |
25 | # parameters
26 | margin: 1.3
27 |
28 | num_parts: 6
29 | weight_sep: 0.5
30 |
31 | update_rate: 0.2
32 | weight_sid: 0.5
33 | weight_KL: 2.5
34 |
35 | # architecture
36 | #mutual learning
37 | #rerank: false
38 | #pattern attention
39 |
40 | drop_last_stride: true
41 | pattern_attention: false
42 | mutual_learning: false
43 | modality_attention: 0
44 |
45 | # optimizer
46 | lr: 0.00035
47 | optimizer: adam
48 | num_epoch: 160
49 | lr_step: [55, 95]
50 |
51 | # augmentation
52 | random_flip: true
53 | random_crop: true
54 | random_erase: true
55 | color_jitter: false
56 | padding: 10
57 |
58 | # log
59 | log_period: 20
60 | start_eval: 0
61 | eval_interval: 5
62 |
--------------------------------------------------------------------------------
/IDKL/configs/SYSU.yml:
--------------------------------------------------------------------------------
1 | prefix: SYSU
2 | fp16: true
3 |
4 | # dataset
5 | sample_method: identity_random #identity_uniform #identity_random
6 | image_size: (384, 144) #(384, 144)
7 | p_size: 12
8 | k_size: 10
9 |
10 | dataset: sysu
11 |
12 | # loss
13 | bg_kl: true
14 | sm_kl: true
15 | decompose: true
16 | distalign: true
17 | IP: true
18 | classification: true
19 | center_cluster: false
20 | triplet: true
21 | center: false
22 | fb_dt: false
23 |
24 | # parameters
25 | margin: 1.3
26 | # pattern attention
27 | num_parts: 6
28 | weight_sep: 0.5
29 | # mutual learning
30 | update_rate: 0.2
31 | weight_sid: 0.5
32 | weight_KL: 2.5
33 |
34 | # architecture
35 | drop_last_stride: true
36 | pattern_attention: false
37 | mutual_learning: false
38 | modality_attention: 0
39 |
40 | # optimizer
41 | lr: 0.00035
42 | optimizer: adam
43 | num_epoch: 160 #160
44 | lr_step: [55, 95]
45 |
46 | # augmentation
47 | random_flip: true
48 | random_crop: true
49 | random_erase: true
50 | color_jitter: false
51 | padding: 10
52 |
53 | # log
54 | log_period: 150
55 | start_eval: 200
56 | eval_interval: 5
57 |
--------------------------------------------------------------------------------
/IDKL/configs/default/__init__.py:
--------------------------------------------------------------------------------
1 | from configs.default.dataset import dataset_cfg
2 | from configs.default.strategy import strategy_cfg
3 |
4 | __all__ = ["dataset_cfg", "strategy_cfg"]
5 |
--------------------------------------------------------------------------------
/IDKL/configs/default/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/configs/default/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/configs/default/__pycache__/dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/configs/default/__pycache__/dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/configs/default/__pycache__/strategy.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/configs/default/__pycache__/strategy.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/configs/default/dataset.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode
2 |
3 | dataset_cfg = CfgNode()
4 |
5 | # config for dataset
6 | dataset_cfg.sysu = CfgNode()
7 | dataset_cfg.sysu.num_id = 395
8 | dataset_cfg.sysu.num_cam = 6
9 | dataset_cfg.sysu.data_root = "../dataset/SYSU-MM01"
10 |
11 | dataset_cfg.regdb = CfgNode()
12 | dataset_cfg.regdb.num_id = 206
13 | dataset_cfg.regdb.num_cam = 2
14 | dataset_cfg.regdb.data_root = "../dataset/RegDB"
15 |
16 | dataset_cfg.llcm = CfgNode()
17 | dataset_cfg.llcm.num_id = 713
18 | dataset_cfg.llcm.num_cam = 2
19 | dataset_cfg.llcm.data_root = "../dataset/LLCM"
20 |
21 |
--------------------------------------------------------------------------------
/IDKL/configs/default/strategy.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode
2 |
3 | strategy_cfg = CfgNode()
4 |
5 | strategy_cfg.prefix = "baseline"
6 |
7 | # setting for loader
8 | strategy_cfg.sample_method = "random"
9 | strategy_cfg.batch_size = 128
10 | strategy_cfg.p_size = 16
11 | strategy_cfg.k_size = 8
12 |
13 | # setting for loss
14 | strategy_cfg.classification = True
15 | strategy_cfg.triplet = False
16 | strategy_cfg.center_cluster = False
17 | strategy_cfg.center = False
18 | strategy_cfg.sm_kl = False
19 | strategy_cfg.bg_kl = False
20 | strategy_cfg.IP = False
21 | strategy_cfg.decompose = False
22 | strategy_cfg.fb_dt = False
23 | strategy_cfg.distalign = False
24 |
25 | # setting for metric learning
26 | strategy_cfg.margin = 0.3
27 | strategy_cfg.weight_KL = 3.0
28 | strategy_cfg.weight_sid = 1.0
29 | strategy_cfg.weight_sep = 1.0
30 | strategy_cfg.update_rate = 1.0
31 |
32 | # settings for optimizer
33 | strategy_cfg.optimizer = "sgd"
34 | strategy_cfg.lr = 0.1
35 | strategy_cfg.wd = 5e-3 #5e-3
36 | ##5e-4
37 | strategy_cfg.lr_step = [40]
38 |
39 | strategy_cfg.fp16 = False
40 |
41 | strategy_cfg.num_epoch = 60
42 |
43 | # settings for dataset
44 | strategy_cfg.dataset = "sysu"
45 | strategy_cfg.image_size = (384, 128)
46 |
47 | # settings for augmentation
48 | strategy_cfg.random_flip = True
49 | strategy_cfg.random_crop = True
50 | strategy_cfg.random_erase = True
51 | strategy_cfg.color_jitter = False
52 | strategy_cfg.padding = 10
53 |
54 | # settings for base architecture
55 | strategy_cfg.drop_last_stride = False
56 | strategy_cfg.pattern_attention = False
57 | strategy_cfg.modality_attention = 0
58 | strategy_cfg.mutual_learning = False
59 | strategy_cfg.rerank = False
60 | strategy_cfg.num_parts = 6
61 |
62 | # logging
63 | strategy_cfg.eval_interval = -1
64 | strategy_cfg.start_eval = 60
65 | strategy_cfg.log_period = 10
66 |
67 | # testing
68 | strategy_cfg.resume = ''
69 | #/home/zhang/E/RKJ/MAPnet/MPA-LL2-cvpr/checkpoints/regdb/RegDB/model_best.pth
70 | #/root/MPANet/MPA-LL2-cvpr/checkpoints/sysu/SYSU/model_best.pth
71 | #/root/MPANet/MPA-LL2-cvpr/checkpoints/llcm/LLCM/model_best.pth
72 | #/home/zhang/E/RKJ/MAPnet/MPA-cvpr/checkpoints/llcm/LLCM/model_best.pth
--------------------------------------------------------------------------------
/IDKL/data/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torchvision.transforms as T
5 |
6 | from torch.utils.data import DataLoader
7 | from data.dataset import SYSUDataset
8 | from data.dataset import RegDBDataset
9 | from data.dataset import LLCMData
10 | from data.dataset import MarketDataset
11 |
12 | from data.sampler import CrossModalityIdentitySampler
13 | from data.sampler import CrossModalityRandomSampler
14 | from data.sampler import RandomIdentitySampler
15 | from data.sampler import NormTripletSampler
16 | import random
17 |
18 |
19 | def collate_fn(batch): # img, label, cam_id, img_path, img_id
20 | samples = list(zip(*batch))
21 |
22 | data = [torch.stack(x, 0) for i, x in enumerate(samples) if i != 3]
23 | data.insert(3, samples[3])
24 | return data
25 |
26 |
27 | class ChannelAdapGray(object):
28 | """ Adaptive selects a channel or two channels.
29 | Args:
30 | probability: The probability that the Random Erasing operation will be performed.
31 | sl: Minimum proportion of erased area against input image.
32 | sh: Maximum proportion of erased area against input image.
33 | r1: Minimum aspect ratio of erased area.
34 | mean: Erasing value.
35 | """
36 |
37 | def __init__(self, probability=0.5):
38 | self.probability = probability
39 |
40 | def __call__(self, img):
41 |
42 | # if random.uniform(0, 1) > self.probability:
43 | # return img
44 |
45 | idx = random.randint(0, 3)
46 |
47 | if idx == 0:
48 | # random select R Channel
49 | img[1, :, :] = img[0, :, :]
50 | img[2, :, :] = img[0, :, :]
51 | elif idx == 1:
52 | # random select B Channel
53 | img[0, :, :] = img[1, :, :]
54 | img[2, :, :] = img[1, :, :]
55 | elif idx == 2:
56 | # random select G Channel
57 | img[0, :, :] = img[2, :, :]
58 | img[1, :, :] = img[2, :, :]
59 | else:
60 | if random.uniform(0, 1) > self.probability:
61 | # return img
62 | img = img
63 | else:
64 | tmp_img = 0.2989 * img[0, :, :] + 0.5870 * img[1, :, :] + 0.1140 * img[2, :, :]
65 | img[0, :, :] = tmp_img
66 | img[1, :, :] = tmp_img
67 | img[2, :, :] = tmp_img
68 | return img
69 |
70 | def get_train_loader(dataset, root, sample_method, batch_size, p_size, k_size, image_size, random_flip=False, random_crop=False,
71 | random_erase=False, color_jitter=False, padding=0, num_workers=4):
72 | if True==False: #tsne
73 | transform = T.Compose([
74 | T.Resize(image_size),
75 | T.ToTensor(),
76 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
77 | ])
78 |
79 | else:
80 | # data pre-processing
81 | t = [T.Resize(image_size)]
82 |
83 | if random_flip:
84 | t.append(T.RandomHorizontalFlip())
85 |
86 | if color_jitter:
87 | t.append(T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0))
88 |
89 | if random_crop:
90 | t.extend([T.Pad(padding, fill=127), T.RandomCrop(image_size)])
91 |
92 | t.extend([T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
93 |
94 | if random_erase:
95 | t.append(T.RandomErasing())
96 | #t.append(ChannelAdapGray(probability=0.5)) ###58
97 | # t.append(Jigsaw())
98 |
99 | transform = T.Compose(t)
100 | # # data pre-processing
101 | # t = [T.Resize(image_size)]
102 | #
103 | # if random_flip:
104 | # t.append(T.RandomHorizontalFlip())
105 | #
106 | # if color_jitter:
107 | # t.append(T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0))
108 | #
109 | # if random_crop:
110 | # t.extend([T.Pad(padding, fill=127), T.RandomCrop(image_size)])
111 | #
112 | # t.extend([T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
113 | #
114 | # if random_erase:
115 | # t.append(T.RandomErasing())
116 | # # t.append(Jigsaw())
117 | #
118 | # transform = T.Compose(t)
119 |
120 | # dataset
121 | if dataset == 'sysu':
122 | train_dataset = SYSUDataset(root, mode='train', transform=transform)
123 | elif dataset == 'regdb':
124 | train_dataset = RegDBDataset(root, mode='train', transform=transform)
125 | elif dataset == 'llcm':
126 | train_dataset = LLCMData(root, mode='train', transform=transform)
127 | elif dataset == 'market':
128 | train_dataset = MarketDataset(root, mode='train', transform=transform)
129 |
130 | # sampler
131 | assert sample_method in ['random', 'identity_uniform', 'identity_random', 'norm_triplet']
132 | if sample_method == 'identity_uniform':
133 | batch_size = p_size * k_size
134 | sampler = CrossModalityIdentitySampler(train_dataset, p_size, k_size)
135 | elif sample_method == 'identity_random':
136 | batch_size = p_size * k_size
137 | sampler = RandomIdentitySampler(train_dataset, p_size * k_size, k_size)
138 | elif sample_method == 'norm_triplet':
139 | batch_size = p_size * k_size
140 | sampler = NormTripletSampler(train_dataset, p_size * k_size, k_size)
141 | else:
142 | sampler = CrossModalityRandomSampler(train_dataset, batch_size)
143 |
144 | # loader
145 | train_loader = DataLoader(train_dataset, batch_size, sampler=sampler, drop_last=True, pin_memory=True,
146 | collate_fn=collate_fn, num_workers=num_workers)
147 |
148 | return train_loader
149 |
150 |
151 | def get_test_loader(dataset, root, batch_size, image_size, num_workers=4):
152 | # transform
153 | transform = T.Compose([
154 | T.Resize(image_size),
155 | T.ToTensor(),
156 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
157 | ])
158 |
159 | # dataset
160 | if dataset == 'sysu':
161 | gallery_dataset = SYSUDataset(root, mode='gallery', transform=transform)
162 | query_dataset = SYSUDataset(root, mode='query', transform=transform)
163 | elif dataset == 'regdb':
164 | gallery_dataset = RegDBDataset(root, mode='gallery', transform=transform)
165 | query_dataset = RegDBDataset(root, mode='query', transform=transform)
166 | elif dataset == 'llcm':
167 | gallery_dataset = LLCMData(root, mode='gallery', transform=transform)
168 | query_dataset = LLCMData(root, mode='query', transform=transform)
169 | elif dataset == 'market':
170 | gallery_dataset = MarketDataset(root, mode='gallery', transform=transform)
171 | query_dataset = MarketDataset(root, mode='query', transform=transform)
172 |
173 | # dataloader
174 | query_loader = DataLoader(dataset=query_dataset,
175 | batch_size=batch_size,
176 | shuffle=False,
177 | pin_memory=True,
178 | drop_last=False,
179 | collate_fn=collate_fn,
180 | num_workers=num_workers)
181 |
182 | gallery_loader = DataLoader(dataset=gallery_dataset,
183 | batch_size=batch_size,
184 | shuffle=False,
185 | pin_memory=True,
186 | drop_last=False,
187 | collate_fn=collate_fn,
188 | num_workers=num_workers)
189 |
190 | return gallery_loader, query_loader
191 |
--------------------------------------------------------------------------------
/IDKL/data/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/data/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/IDKL/data/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/data/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/data/__pycache__/clone.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/data/__pycache__/clone.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/data/__pycache__/dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/data/__pycache__/dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/IDKL/data/__pycache__/dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/data/__pycache__/dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/data/__pycache__/sampler.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/data/__pycache__/sampler.cpython-36.pyc
--------------------------------------------------------------------------------
/IDKL/data/__pycache__/sampler.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/data/__pycache__/sampler.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/data/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import os.path as osp
4 | from glob import glob
5 | import numpy as np
6 | import torch
7 | from PIL import Image
8 | from torch.utils.data import Dataset
9 |
10 | '''
11 | Specific dataset classes for person re-identification dataset.
12 | '''
13 |
14 |
15 | class SYSUDataset(Dataset):
16 | def __init__(self, root, mode='train', transform=None):
17 | assert os.path.isdir(root)
18 | assert mode in ['train', 'gallery', 'query']
19 |
20 | if mode == 'train':
21 | train_ids = open(os.path.join(root, 'exp', 'train_id.txt')).readline()
22 | val_ids = open(os.path.join(root, 'exp', 'val_id.txt')).readline()
23 |
24 | train_ids = train_ids.strip('\n').split(',')
25 | val_ids = val_ids.strip('\n').split(',')
26 | selected_ids = train_ids + val_ids
27 | else:
28 | test_ids = open(os.path.join(root, 'exp', 'test_id.txt')).readline()
29 | selected_ids = test_ids.strip('\n').split(',')
30 |
31 | selected_ids = [int(i) for i in selected_ids]
32 | num_ids = len(selected_ids)
33 |
34 | img_paths = glob(os.path.join(root, '**/*.jpg'), recursive=True)
35 | img_paths = [path for path in img_paths if int(path.split('/')[-2]) in selected_ids]
36 |
37 | if mode == 'gallery':
38 | img_paths = [path for path in img_paths if int(path.split('/')[-3][-1]) in (1, 2, 4, 5)]
39 | elif mode == 'query':
40 | img_paths = [path for path in img_paths if int(path.split('/')[-3][-1]) in (3, 6)]
41 |
42 | img_paths = sorted(img_paths)
43 | self.img_paths = img_paths
44 | self.cam_ids = [int(path.split('/')[-3][-1]) for path in img_paths]
45 | self.num_ids = num_ids
46 | self.transform = transform
47 |
48 | if mode == 'train':
49 | id_map = dict(zip(selected_ids, range(num_ids)))
50 | self.ids = [id_map[int(path.split('/')[-2])] for path in img_paths]
51 | else:
52 | self.ids = [int(path.split('/')[-2]) for path in img_paths]
53 |
54 | def __len__(self):
55 | return len(self.img_paths)
56 |
57 | def __getitem__(self, item):
58 | path = self.img_paths[item]
59 | img = Image.open(path)
60 | if self.transform is not None:
61 | img = self.transform(img)
62 |
63 | label = torch.tensor(self.ids[item], dtype=torch.long)
64 | cam = torch.tensor(self.cam_ids[item], dtype=torch.long)
65 | item = torch.tensor(item, dtype=torch.long)
66 |
67 | return img, label, cam, path, item
68 |
69 | class RegDBDataset(Dataset):
70 | def __init__(self, root, mode='train', transform=None):
71 | assert os.path.isdir(root)
72 | assert mode in ['train', 'gallery', 'query']
73 |
74 | def loadIdx(index):
75 | Lines = index.readlines()
76 | idx = []
77 | for line in Lines:
78 | tmp = line.strip('\n')
79 | tmp = tmp.split(' ')
80 | idx.append(tmp)
81 | return idx
82 |
83 | num = '1'
84 | if mode == 'train':
85 | index_RGB = loadIdx(open(root + '/idx/train_visible_'+num+'.txt','r'))
86 | index_IR = loadIdx(open(root + '/idx/train_thermal_'+num+'.txt','r'))
87 | else:
88 | index_RGB = loadIdx(open(root + '/idx/test_visible_'+num+'.txt','r'))
89 | index_IR = loadIdx(open(root + '/idx/test_thermal_'+num+'.txt','r'))
90 |
91 | if mode == 'gallery':
92 | img_paths = [root + '/' + path for path, _ in index_RGB]
93 | elif mode == 'query':
94 | img_paths = [root + '/' + path for path, _ in index_IR]
95 | else:
96 | img_paths = [root + '/' + path for path, _ in index_RGB] + [root + '/' + path for path, _ in index_IR]
97 |
98 | selected_ids = [int(path.split('/')[-2]) for path in img_paths]
99 | selected_ids = list(set(selected_ids))
100 | num_ids = len(selected_ids)
101 |
102 | img_paths = sorted(img_paths)
103 | self.img_paths = img_paths
104 | self.cam_ids = [int(path.split('/')[-3] == 'Thermal') + 2 for path in img_paths]
105 | # the visible cams are 1 2 4 5 and thermal cams are 3 6 in sysu
106 | # to simplify the code, visible cam is 2 and thermal cam is 3 in regdb
107 | self.num_ids = num_ids
108 | self.transform = transform
109 |
110 | if mode == 'train':
111 | id_map = dict(zip(selected_ids, range(num_ids)))
112 | self.ids = [id_map[int(path.split('/')[-2])] for path in img_paths]
113 | else:
114 | self.ids = [int(path.split('/')[-2]) for path in img_paths]
115 |
116 | def __len__(self):
117 | return len(self.img_paths)
118 |
119 | def __getitem__(self, item):
120 | path = self.img_paths[item]
121 | img = Image.open(path)
122 | if self.transform is not None:
123 | img = self.transform(img)
124 |
125 | label = torch.tensor(self.ids[item], dtype=torch.long)
126 | cam = torch.tensor(self.cam_ids[item], dtype=torch.long)
127 | item = torch.tensor(item, dtype=torch.long)
128 |
129 | return img, label, cam, path, item
130 |
131 | class LLCMData(Dataset):
132 | def __init__(self, root, mode='train', transform=None, colorIndex=None, thermalIndex=None):
133 | # Load training images (path) and labels
134 | assert os.path.isdir(root)
135 | assert mode in ['train', 'gallery', 'query']
136 |
137 | def loadIdx(index):
138 | Lines = index.readlines()
139 | idx = []
140 | for line in Lines:
141 | tmp = line.strip('\n')
142 | tmp = tmp.split(' ')
143 | idx.append(tmp)
144 | return idx
145 |
146 | if mode == 'train':
147 | index_RGB = loadIdx(open(root + '/idx/train_vis.txt','r'))
148 | index_IR = loadIdx(open(root + '/idx/train_nir.txt','r'))
149 | else:
150 | index_RGB = loadIdx(open(root + '/idx/test_vis.txt','r'))
151 | index_IR = loadIdx(open(root + '/idx/test_nir.txt','r'))
152 |
153 |
154 | if mode == 'gallery':
155 | img_paths = [root + '/' + path for path, _ in index_RGB]
156 | elif mode == 'query':
157 | img_paths = [root + '/' + path for path, _ in index_IR]
158 | else:
159 | img_paths = [root + '/' + path for path, _ in index_RGB] + [root + '/' + path for path, _ in index_IR]
160 |
161 | selected_ids = [int(path.split('/')[-2]) for path in img_paths]
162 | selected_ids = list(set(selected_ids))
163 | num_ids = len(selected_ids)
164 | # path = '/home/zhang/E/RKJ/MAPnet/dataset/LLCM/nir/0351/0351_c06_s200656_f4830_nir.jpg'
165 | # img = Image.open(path).convert('RGB')
166 | # img = np.array(img, dtype=np.uint8)
167 | # import pdb
168 | # pdb.set_trace()
169 |
170 | img_paths = sorted(img_paths)
171 | self.img_paths = img_paths
172 | self.cam_ids = [int(path.split('/')[-3] == 'nir') + 2 for path in img_paths]
173 | self.num_ids = num_ids
174 | self.transform = transform
175 |
176 | if mode == 'train':
177 | id_map = dict(zip(selected_ids, range(num_ids)))
178 |
179 | self.ids = [id_map[int(path.split('/')[-2])] for path in img_paths]
180 | else:
181 | self.ids = [int(path.split('/')[-2]) for path in img_paths]
182 |
183 | def __len__(self):
184 | return len(self.img_paths)
185 |
186 | def __getitem__(self, item):
187 | path = self.img_paths[item]
188 | img = Image.open(path)
189 | if self.transform is not None:
190 | img = self.transform(img)
191 |
192 | label = torch.tensor(self.ids[item], dtype=torch.long)
193 | cam = torch.tensor(self.cam_ids[item], dtype=torch.long)
194 | item = torch.tensor(item, dtype=torch.long)
195 |
196 | return img, label, cam, path, item
197 |
198 | class MarketDataset(Dataset):
199 | def __init__(self, root, mode='train', transform=None):
200 | assert os.path.isdir(root)
201 | assert mode in ['train', 'gallery', 'query']
202 |
203 | self.transform = transform
204 |
205 | if mode == 'train':
206 | img_paths = glob(os.path.join(root, 'bounding_box_train/*.jpg'), recursive=True)
207 | elif mode == 'gallery':
208 | img_paths = glob(os.path.join(root, 'bounding_box_test/*.jpg'), recursive=True)
209 | elif mode == 'query':
210 | img_paths = glob(os.path.join(root, 'query/*.jpg'), recursive=True)
211 |
212 | pattern = re.compile(r'([-\d]+)_c(\d)')
213 | all_pids = {}
214 | relabel = mode == 'train'
215 | self.img_paths = []
216 | self.cam_ids = []
217 | self.ids = []
218 | for fpath in img_paths:
219 | fname = osp.basename(fpath)
220 | pid, cam = map(int, pattern.search(fname).groups())
221 | if pid == -1: continue
222 | if relabel:
223 | if pid not in all_pids:
224 | all_pids[pid] = len(all_pids)
225 | else:
226 | if pid not in all_pids:
227 | all_pids[pid] = pid
228 | self.img_paths.append(fpath)
229 | self.ids.append(all_pids[pid])
230 | self.cam_ids.append(cam - 1)
231 |
232 | def __len__(self):
233 | return len(self.img_paths)
234 |
235 | def __getitem__(self, item):
236 | path = self.img_paths[item]
237 | img = Image.open(path)
238 | if self.transform is not None:
239 | img = self.transform(img)
240 |
241 | label = torch.tensor(self.ids[item], dtype=torch.long)
242 | cam = torch.tensor(self.cam_ids[item], dtype=torch.long)
243 | item = torch.tensor(item, dtype=torch.long)
244 |
245 | return img, label, cam, path, item
246 |
--------------------------------------------------------------------------------
/IDKL/data/sampler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import copy
4 | from torch.utils.data import Sampler
5 | from collections import defaultdict
6 |
7 |
8 | class CrossModalityRandomSampler(Sampler):
9 | def __init__(self, dataset, batch_size):
10 | self.dataset = dataset
11 | self.batch_size = batch_size
12 |
13 | self.rgb_list = []
14 | self.ir_list = []
15 | for i, cam in enumerate(dataset.cam_ids):
16 | if cam in [3, 6]:
17 | self.ir_list.append(i)
18 | else:
19 | self.rgb_list.append(i)
20 |
21 | def __len__(self):
22 | return max(len(self.rgb_list), len(self.ir_list)) * 2
23 |
24 | def __iter__(self):
25 | sample_list = []
26 | rgb_list = np.random.permutation(self.rgb_list).tolist()
27 | ir_list = np.random.permutation(self.ir_list).tolist()
28 |
29 | rgb_size = len(self.rgb_list)
30 | ir_size = len(self.ir_list)
31 | if rgb_size >= ir_size:
32 | diff = rgb_size - ir_size
33 | reps = diff // ir_size
34 | pad_size = diff % ir_size
35 | for _ in range(reps):
36 | ir_list.extend(np.random.permutation(self.ir_list).tolist())
37 | ir_list.extend(np.random.choice(self.ir_list, pad_size, replace=False).tolist())
38 | else:
39 | diff = ir_size - rgb_size
40 | reps = diff // ir_size
41 | pad_size = diff % ir_size
42 | for _ in range(reps):
43 | rgb_list.extend(np.random.permutation(self.rgb_list).tolist())
44 | rgb_list.extend(np.random.choice(self.rgb_list, pad_size, replace=False).tolist())
45 |
46 | assert len(rgb_list) == len(ir_list)
47 |
48 | half_bs = self.batch_size // 2
49 | for start in range(0, len(rgb_list), half_bs):
50 | sample_list.extend(rgb_list[start:start + half_bs])
51 | sample_list.extend(ir_list[start:start + half_bs])
52 |
53 | return iter(sample_list)
54 |
55 |
56 | class CrossModalityIdentitySampler(Sampler):
57 | def __init__(self, dataset, p_size, k_size):
58 | self.dataset = dataset
59 | self.p_size = p_size
60 | self.k_size = k_size // 2
61 | self.batch_size = p_size * k_size * 2
62 |
63 | self.id2idx_rgb = defaultdict(list)
64 | self.id2idx_ir = defaultdict(list)
65 | for i, identity in enumerate(dataset.ids):
66 | if dataset.cam_ids[i] in [3, 6]:
67 | self.id2idx_ir[identity].append(i)
68 | else:
69 | self.id2idx_rgb[identity].append(i)
70 |
71 | def __len__(self):
72 | return self.dataset.num_ids * self.k_size * 2
73 |
74 | def __iter__(self):
75 | sample_list = []
76 |
77 | id_perm = np.random.permutation(self.dataset.num_ids)
78 | for start in range(0, self.dataset.num_ids, self.p_size):
79 | selected_ids = id_perm[start:start + self.p_size]
80 |
81 | sample = []
82 | for identity in selected_ids:
83 | replace = len(self.id2idx_rgb[identity]) < self.k_size
84 | s = np.random.choice(self.id2idx_rgb[identity], size=self.k_size, replace=replace)
85 | sample.extend(s)
86 |
87 | sample_list.extend(sample)
88 |
89 | sample.clear()
90 | for identity in selected_ids:
91 | replace = len(self.id2idx_ir[identity]) < self.k_size
92 | s = np.random.choice(self.id2idx_ir[identity], size=self.k_size, replace=replace)
93 | sample.extend(s)
94 |
95 | sample_list.extend(sample)
96 |
97 | return iter(sample_list)
98 |
99 |
100 | class RandomIdentitySampler(Sampler):
101 | def __init__(self, data_source, batch_size, num_instances):
102 | self.data_source = data_source
103 | self.batch_size = batch_size
104 | self.num_instances = num_instances
105 | self.num_pids_per_batch = self.batch_size // self.num_instances
106 | self.index_dic_R = defaultdict(list)
107 | self.index_dic_I = defaultdict(list)
108 | for i, identity in enumerate(data_source.ids):
109 | if data_source.cam_ids[i] in [3, 6]:
110 | self.index_dic_I[identity].append(i)
111 | else:
112 | self.index_dic_R[identity].append(i)
113 | self.pids = list(self.index_dic_I.keys())
114 |
115 | # estimate number of examples in an epoch
116 | self.length = 0
117 | for pid in self.pids:
118 | idxs = self.index_dic_I[pid]
119 | num = len(idxs)
120 | if num < self.num_instances:
121 | num = self.num_instances
122 | self.length += num - num % self.num_instances
123 |
124 | def __iter__(self):
125 | batch_idxs_dict = defaultdict(list)
126 |
127 | for pid in self.pids:
128 | idxs_I = copy.deepcopy(self.index_dic_I[pid])
129 | idxs_R = copy.deepcopy(self.index_dic_R[pid])
130 | if len(idxs_I) < self.num_instances // 2 and len(idxs_R) < self.num_instances // 2:
131 | idxs_I = np.random.choice(idxs_I, size=self.num_instances // 2, replace=True)
132 | idxs_R = np.random.choice(idxs_R, size=self.num_instances // 2, replace=True)
133 | if len(idxs_I) > len(idxs_R):
134 | idxs_I = np.random.choice(idxs_I, size=len(idxs_R), replace=False)
135 | if len(idxs_R) > len(idxs_I):
136 | idxs_R = np.random.choice(idxs_R, size=len(idxs_I), replace=False)
137 | np.random.shuffle(idxs_I)
138 | np.random.shuffle(idxs_R)
139 | batch_idxs = []
140 | for idx_I, idx_R in zip(idxs_I, idxs_R):
141 | batch_idxs.append(idx_I)
142 | batch_idxs.append(idx_R)
143 | if len(batch_idxs) == self.num_instances:
144 | batch_idxs_dict[pid].append(batch_idxs)
145 | batch_idxs = []
146 |
147 | avai_pids = copy.deepcopy(self.pids)
148 | final_idxs = []
149 |
150 | while len(avai_pids) >= self.num_pids_per_batch:
151 | selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False)
152 | for pid in selected_pids:
153 | batch_idxs = batch_idxs_dict[pid].pop(0)
154 | final_idxs.extend(batch_idxs)
155 | if len(batch_idxs_dict[pid]) == 0:
156 | avai_pids.remove(pid)
157 |
158 | self.length = len(final_idxs)
159 | return iter(final_idxs)
160 |
161 | def __len__(self):
162 | return self.length
163 |
164 |
165 | class NormTripletSampler(Sampler):
166 | """
167 | Randomly sample N identities, then for each identity,
168 | randomly sample K instances, therefore batch size is N*K.
169 | Args:
170 | - data_source (list): list of (img_path, pid, camid).
171 | - num_instances (int): number of instances per identity in a batch.
172 | - batch_size (int): number of examples in a batch.
173 | """
174 |
175 | def __init__(self, data_source, batch_size, num_instances):
176 | self.data_source = data_source
177 | self.batch_size = batch_size
178 | self.num_instances = num_instances
179 | self.num_pids_per_batch = self.batch_size // self.num_instances
180 | self.index_dic = defaultdict(list)
181 | for index, pid in enumerate(self.data_source.ids):
182 | self.index_dic[pid].append(index)
183 | self.pids = list(self.index_dic.keys())
184 |
185 | # estimate number of examples in an epoch
186 | self.length = 0
187 | for pid in self.pids:
188 | idxs = self.index_dic[pid]
189 | num = len(idxs)
190 | if num < self.num_instances:
191 | num = self.num_instances
192 | self.length += num - num % self.num_instances
193 |
194 | def __iter__(self):
195 | batch_idxs_dict = defaultdict(list)
196 |
197 | for pid in self.pids:
198 | idxs = copy.deepcopy(self.index_dic[pid])
199 | if len(idxs) < self.num_instances:
200 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True)
201 | np.random.shuffle(idxs)
202 | batch_idxs = []
203 | for idx in idxs:
204 | batch_idxs.append(idx)
205 | if len(batch_idxs) == self.num_instances:
206 | batch_idxs_dict[pid].append(batch_idxs)
207 | batch_idxs = []
208 |
209 | avai_pids = copy.deepcopy(self.pids)
210 | final_idxs = []
211 |
212 | while len(avai_pids) >= self.num_pids_per_batch:
213 | selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False)
214 | for pid in selected_pids:
215 | batch_idxs = batch_idxs_dict[pid].pop(0)
216 | final_idxs.extend(batch_idxs)
217 | if len(batch_idxs_dict[pid]) == 0:
218 | avai_pids.remove(pid)
219 |
220 | self.length = len(final_idxs)
221 | return iter(final_idxs)
222 |
223 | def __len__(self):
224 | return self.length
--------------------------------------------------------------------------------
/IDKL/engine/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import numpy as np
4 | import torch
5 | import scipy.io as sio
6 |
7 | from ignite.engine import Events
8 | from ignite.handlers import ModelCheckpoint
9 | from ignite.handlers import Timer
10 |
11 | from engine.engine import create_eval_engine
12 | from engine.engine import create_train_engine
13 | from engine.metric import AutoKVMetric
14 | from utils.eval_sysu import eval_sysu
15 | from utils.eval_regdb import eval_regdb
16 | from utils.eval_llcm import eval_llcm
17 | from configs.default.dataset import dataset_cfg
18 | from configs.default.strategy import strategy_cfg
19 |
20 | def get_trainer(dataset, model, optimizer, lr_scheduler=None, logger=None, writer=None, non_blocking=False, log_period=10,
21 | save_dir="checkpoints", prefix="model", gallery_loader=None, query_loader=None,
22 | eval_interval=None, start_eval=None, rerank=False):
23 | if logger is None:
24 | logger = logging.getLogger()
25 | logger.setLevel(logging.WARN)
26 |
27 | # trainer
28 | trainer = create_train_engine(model, optimizer, non_blocking)
29 |
30 | setattr(trainer, "rerank", rerank)
31 |
32 | # checkpoint handler
33 | handler = ModelCheckpoint(save_dir, prefix, save_interval=eval_interval, n_saved=3, create_dir=True,
34 | save_as_state_dict=True, require_empty=False)
35 | trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, {"model": model})
36 |
37 | # metric
38 | timer = Timer(average=True)
39 | rank = True
40 |
41 | kv_metric = AutoKVMetric()
42 |
43 | # evaluator
44 | evaluator = None
45 | if not type(eval_interval) == int:
46 | raise TypeError("The parameter 'validate_interval' must be type INT.")
47 | if not type(start_eval) == int:
48 | raise TypeError("The parameter 'start_eval' must be type INT.")
49 | if eval_interval > 0 and gallery_loader is not None and query_loader is not None:
50 | evaluator = create_eval_engine(model, non_blocking)
51 |
52 | @trainer.on(Events.STARTED)
53 | def train_start(engine):
54 | setattr(engine.state, "best_rank1", 0.0)
55 |
56 | @trainer.on(Events.COMPLETED)
57 | def train_completed(engine):
58 | torch.cuda.empty_cache()
59 |
60 | # extract query feature
61 | evaluator.run(query_loader)
62 |
63 | q_feats = torch.cat(evaluator.state.feat_list, dim=0)
64 | q_ids = torch.cat(evaluator.state.id_list, dim=0).numpy()
65 | q_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy()
66 | q_img_paths = np.concatenate(evaluator.state.img_path_list, axis=0)
67 |
68 | # extract gallery feature
69 | evaluator.run(gallery_loader)
70 |
71 | g_feats = torch.cat(evaluator.state.feat_list, dim=0)
72 | g_ids = torch.cat(evaluator.state.id_list, dim=0).numpy()
73 | g_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy()
74 | g_img_paths = np.concatenate(evaluator.state.img_path_list, axis=0)
75 |
76 | # print("best rank1={:.2f}%".format(engine.state.best_rank1))
77 |
78 | if dataset == 'sysu':
79 | perm = sio.loadmat(os.path.join(dataset_cfg.sysu.data_root, 'exp', 'rand_perm_cam.mat'))[
80 | 'rand_perm_cam']
81 | eval_sysu(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, perm, mode='all', num_shots=1, rerank=rank)
82 | eval_sysu(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, perm, mode='all', num_shots=10, rerank=rank)
83 | eval_sysu(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, perm, mode='indoor', num_shots=1, rerank=rank)
84 | eval_sysu(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, perm, mode='indoor', num_shots=10, rerank=rank)
85 | elif dataset == 'regdb':
86 | print('infrared to visible')
87 | eval_regdb(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, rerank=engine.rerank)
88 | print('visible to infrared')
89 | eval_regdb(g_feats, g_ids, g_cams, q_feats, q_ids, q_cams, q_img_paths, rerank=engine.rerank)
90 | elif dataset == 'llcm':
91 | print('infrared to visible')
92 | eval_llcm(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, rerank=rank)
93 | print('visible to infrared')
94 | eval_llcm(g_feats, g_ids, g_cams, q_feats, q_ids, q_cams, q_img_paths, rerank=rank)
95 |
96 |
97 | evaluator.state.feat_list.clear()
98 | evaluator.state.id_list.clear()
99 | evaluator.state.cam_list.clear()
100 | evaluator.state.img_path_list.clear()
101 | del q_feats, q_ids, q_cams, g_feats, g_ids, g_cams
102 |
103 | torch.cuda.empty_cache()
104 |
105 | @trainer.on(Events.EPOCH_STARTED)
106 | def epoch_started_callback(engine):
107 |
108 | epoch = engine.state.epoch
109 | if model.mutual_learning:
110 | model.update_rate = min(100 / (epoch + 1), 1.0) * model.update_rate_
111 |
112 | kv_metric.reset()
113 | timer.reset()
114 |
115 | @trainer.on(Events.EPOCH_COMPLETED)
116 | def epoch_completed_callback(engine):
117 | epoch = engine.state.epoch
118 |
119 | if lr_scheduler is not None:
120 | lr_scheduler.step()
121 |
122 | if epoch % eval_interval == 0:
123 | logger.info("Model saved at {}/{}_model_{}.pth".format(save_dir, prefix, epoch))
124 |
125 | if evaluator and epoch % eval_interval == 0 and epoch > start_eval:
126 | torch.cuda.empty_cache()
127 |
128 | # extract query feature
129 | evaluator.run(query_loader)
130 |
131 | q_feats = torch.cat(evaluator.state.feat_list, dim=0)
132 | q_ids = torch.cat(evaluator.state.id_list, dim=0).numpy()
133 | q_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy()
134 | q_img_paths = np.concatenate(evaluator.state.img_path_list, axis=0)
135 |
136 | # extract gallery feature
137 | evaluator.run(gallery_loader)
138 |
139 | g_feats = torch.cat(evaluator.state.feat_list, dim=0)
140 | g_ids = torch.cat(evaluator.state.id_list, dim=0).numpy()
141 | g_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy()
142 | g_img_paths = np.concatenate(evaluator.state.img_path_list, axis=0)
143 |
144 | if dataset == 'sysu':
145 | perm = sio.loadmat(os.path.join(dataset_cfg.sysu.data_root, 'exp', 'rand_perm_cam.mat'))[
146 | 'rand_perm_cam']
147 | mAP, r1, r5, _, _ = eval_sysu(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, perm, mode='all', num_shots=1, rerank=rank)
148 | elif dataset == 'regdb':
149 | print('infrared to visible')
150 | mAP, r1, r5, _, _ = eval_regdb(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, rerank=engine.rerank)
151 | print('visible to infrared')
152 | mAP, r1_, r5, _, _ = eval_regdb(g_feats, g_ids, g_cams, q_feats, q_ids, q_cams, q_img_paths, rerank=engine.rerank)
153 | r1 = (r1 + r1_) / 2
154 | elif dataset == 'llcm':
155 | print('infrared to visible')
156 | mAP, r1, r5, _, _ = eval_llcm(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, rerank=rank)
157 | #new_all_cmc,mAP, _ = eval_llcm(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, rerank=engine.rerank)
158 | print('visible to infrared')
159 | mAP, r1_, r5, _, _ = eval_llcm(g_feats, g_ids, g_cams, q_feats, q_ids, q_cams, q_img_paths, rerank=rank)
160 | r1 = (r1 + r1_) / 2
161 |
162 | # new_all_cmc,mAP_, _= eval_llcm(g_feats, g_ids, g_cams, q_feats, q_ids, q_cams, q_img_paths,
163 | # rerank=engine.rerank)
164 | # r1 = (mAP + mAP_) / 2
165 | # import pdb
166 | # pdb.set_trace()
167 |
168 | if r1 > engine.state.best_rank1:
169 | engine.state.best_rank1 = r1
170 | torch.save(model.state_dict(), "{}/model_best.pth".format(save_dir))
171 |
172 | if writer is not None:
173 | writer.add_scalar('eval/mAP', mAP, epoch)
174 | writer.add_scalar('eval/r1', r1, epoch)
175 | writer.add_scalar('eval/r5', r5, epoch)
176 |
177 | evaluator.state.feat_list.clear()
178 | evaluator.state.id_list.clear()
179 | evaluator.state.cam_list.clear()
180 | evaluator.state.img_path_list.clear()
181 | del q_feats, q_ids, q_cams, g_feats, g_ids, g_cams
182 |
183 | torch.cuda.empty_cache()
184 |
185 | @trainer.on(Events.ITERATION_COMPLETED)
186 | def iteration_complete_callback(engine):
187 | timer.step()
188 |
189 | # print(engine.state.output)
190 | kv_metric.update(engine.state.output)
191 |
192 | epoch = engine.state.epoch
193 | iteration = engine.state.iteration
194 | iter_in_epoch = iteration - (epoch - 1) * len(engine.state.dataloader)
195 |
196 | if iter_in_epoch % log_period == 0 and iter_in_epoch > 0:
197 | batch_size = engine.state.batch[0].size(0)
198 | speed = batch_size / timer.value()
199 |
200 | msg = "Epoch[%d] Batch [%d]\tSpeed: %.2f samples/sec" % (epoch, iter_in_epoch, speed)
201 |
202 | metric_dict = kv_metric.compute()
203 |
204 | # log output information
205 | if logger is not None:
206 | for k in sorted(metric_dict.keys()):
207 | msg += "\t%s: %.4f" % (k, metric_dict[k])
208 | if writer is not None:
209 | writer.add_scalar('metric/{}'.format(k), metric_dict[k], iteration)
210 |
211 | logger.info(msg)
212 |
213 | kv_metric.reset()
214 | timer.reset()
215 |
216 | return trainer
217 |
--------------------------------------------------------------------------------
/IDKL/engine/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/engine/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/IDKL/engine/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/engine/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/engine/__pycache__/engine.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/engine/__pycache__/engine.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/engine/__pycache__/metric.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/engine/__pycache__/metric.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/engine/engine.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn as nn
4 | import numpy as np
5 | import os
6 | from apex import amp
7 | from ignite.engine import Engine
8 | from ignite.engine import Events
9 | from torch.autograd import no_grad
10 | from torch.nn import functional as F
11 | import torchvision.transforms as T
12 | import cv2
13 | from torchvision.io.image import read_image
14 | from PIL import Image
15 | from torchvision.transforms.functional import normalize, resize, to_pil_image
16 |
17 | from torchvision import transforms
18 | from grad_cam.utils import GradCAM, show_cam_on_image, center_crop_img
19 | import copy
20 | from torch.optim.lr_scheduler import LambdaLR
21 |
22 |
23 | # import torch
24 | # import numpy as np
25 | # import os
26 | # from apex import amp
27 | # from ignite.engine import Engine
28 | # from ignite.engine import Events
29 | # from torch.autograd import no_grad
30 | # from torchvision import transforms
31 | # from PIL import Image
32 | # import cv2
33 | # from grad_cam.utils import GradCAM, show_cam_on_image, center_crop_img
34 | # from thop import profile
35 | # from thop import clever_format
36 | #
37 | # from utils.calc_acc import calc_acc
38 | # from torch.nn import functional as F
39 | # #import objgraph
40 |
41 |
42 | def some_function(epoch, initial_weight_decay):
43 | if epoch > 15:
44 | new_weight_decay = initial_weight_decay/100
45 | elif epoch > 5 and epoch <= 15:
46 | new_weight_decay = initial_weight_decay*1/10
47 | else:
48 | new_weight_decay = initial_weight_decay
49 | return new_weight_decay
50 |
51 | def create_train_engine(model, optimizer, non_blocking=False):
52 | device = torch.device("cuda") #"cuda", torch.cuda.current_device()
53 |
54 | def _process_func(engine, batch):
55 | model.train()
56 | #model.eval()
57 |
58 | data, labels, cam_ids, img_paths, img_ids = batch
59 | epoch = engine.state.epoch
60 | iteration = engine.state.iteration
61 |
62 | data = data.to(device, non_blocking=non_blocking)
63 | labels = labels.to(device, non_blocking=non_blocking)
64 | cam_ids = cam_ids.to(device, non_blocking=non_blocking)
65 |
66 | warmup = False
67 | if warmup == True: #学习率warmup
68 | if epoch < 21:
69 | # 进行warmup,逐渐增加学习率
70 | warm_iteration = 30 * 213
71 | lr = 0.00035 * iteration / warm_iteration
72 | for param_group in optimizer.param_groups:
73 | param_group['lr'] = lr
74 | if True: #正则化参数warmup
75 | new_weight_decay = some_function(epoch, 0.5)
76 | for param_group in optimizer.param_groups:
77 | param_group['weight_decay'] = new_weight_decay
78 |
79 | optimizer.zero_grad()
80 |
81 | loss, metric = model(data, labels,
82 | cam_ids=cam_ids,
83 | epoch=epoch)
84 |
85 |
86 | with amp.scale_loss(loss, optimizer) as scaled_loss:
87 | scaled_loss.backward()
88 | optimizer.step()
89 |
90 | return metric
91 |
92 | return Engine(_process_func)
93 |
94 |
95 | def create_eval_engine(model, non_blocking=False):
96 | device = torch.device("cuda", torch.cuda.current_device())
97 |
98 | def _process_func(engine, batch):
99 | model.eval()
100 |
101 | data, labels, cam_ids, img_paths = batch[:4]
102 |
103 | data = data.to(device, non_blocking=non_blocking)
104 |
105 | with no_grad():
106 | feat = model(data, cam_ids=cam_ids.to(device, non_blocking=non_blocking))
107 |
108 | return feat.data.float().cpu(), labels, cam_ids, np.array(img_paths)
109 |
110 | engine = Engine(_process_func)
111 |
112 | @engine.on(Events.EPOCH_STARTED)
113 | def clear_data(engine):
114 | # feat list
115 | if not hasattr(engine.state, "feat_list"):
116 | setattr(engine.state, "feat_list", [])
117 | else:
118 | engine.state.feat_list.clear()
119 |
120 | # id_list
121 | if not hasattr(engine.state, "id_list"):
122 | setattr(engine.state, "id_list", [])
123 | else:
124 | engine.state.id_list.clear()
125 |
126 | # cam list
127 | if not hasattr(engine.state, "cam_list"):
128 | setattr(engine.state, "cam_list", [])
129 | else:
130 | engine.state.cam_list.clear()
131 |
132 | # img path list
133 | if not hasattr(engine.state, "img_path_list"):
134 | setattr(engine.state, "img_path_list", [])
135 | else:
136 | engine.state.img_path_list.clear()
137 |
138 | @engine.on(Events.ITERATION_COMPLETED)
139 | def store_data(engine):
140 | engine.state.feat_list.append(engine.state.output[0])
141 | engine.state.id_list.append(engine.state.output[1])
142 | engine.state.cam_list.append(engine.state.output[2])
143 | engine.state.img_path_list.append(engine.state.output[3])
144 |
145 | return engine
146 |
--------------------------------------------------------------------------------
/IDKL/engine/metric.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 |
3 | import torch
4 | from ignite.exceptions import NotComputableError
5 | from ignite.metrics import Metric, Accuracy
6 |
7 |
8 | class ScalarMetric(Metric):
9 |
10 | def update(self, value):
11 | self.sum_metric += value
12 | self.sum_inst += 1
13 |
14 | def reset(self):
15 | self.sum_inst = 0
16 | self.sum_metric = 0
17 |
18 | def compute(self):
19 | if self.sum_inst == 0:
20 | raise NotComputableError('Accuracy must have at least one example before it can be computed')
21 | return self.sum_metric / self.sum_inst
22 |
23 |
24 | class IgnoreAccuracy(Accuracy):
25 | def __init__(self, ignore_index=-1):
26 | super(IgnoreAccuracy, self).__init__()
27 |
28 | self.ignore_index = ignore_index
29 |
30 | def reset(self):
31 | self._num_correct = 0
32 | self._num_examples = 0
33 |
34 | def update(self, output):
35 |
36 | y_pred, y = self._check_shape(output)
37 | self._check_type((y_pred, y))
38 |
39 | if self._type == "binary":
40 | indices = torch.round(y_pred).type(y.type())
41 | elif self._type == "multiclass":
42 | indices = torch.max(y_pred, dim=1)[1]
43 |
44 | correct = torch.eq(indices, y).view(-1)
45 | ignore = torch.eq(y, self.ignore_index).view(-1)
46 | self._num_correct += torch.sum(correct).item()
47 | self._num_examples += correct.shape[0] - ignore.sum().item()
48 |
49 | def compute(self):
50 | if self._num_examples == 0:
51 | raise NotComputableError('Accuracy must have at least one example before it can be computed')
52 | return self._num_correct / self._num_examples
53 |
54 |
55 | class AutoKVMetric(Metric):
56 | def __init__(self):
57 | self.kv_sum_metric = defaultdict(lambda: torch.tensor(0., device="cuda"))
58 | self.kv_sum_inst = defaultdict(lambda: torch.tensor(0., device="cuda"))
59 |
60 | self.kv_metric = defaultdict(lambda: 0)
61 |
62 | super(AutoKVMetric, self).__init__()
63 |
64 | def update(self, output):
65 | if not isinstance(output, dict):
66 | raise TypeError('The output must be a key-value dict.')
67 |
68 | for k in output.keys():
69 | self.kv_sum_metric[k].add_(output[k])
70 | self.kv_sum_inst[k].add_(1)
71 |
72 | def reset(self):
73 | for k in self.kv_sum_metric.keys():
74 | self.kv_sum_metric[k].zero_()
75 | self.kv_sum_inst[k].zero_()
76 | self.kv_metric[k] = 0
77 |
78 | def compute(self):
79 | for k in self.kv_sum_metric.keys():
80 | if self.kv_sum_inst[k] == 0:
81 | continue
82 | # raise NotComputableError('Accuracy must have at least one example before it can be computed')
83 |
84 | metric_value = self.kv_sum_metric[k] / self.kv_sum_inst[k]
85 | self.kv_metric[k] = metric_value.item()
86 |
87 | return self.kv_metric
88 |
--------------------------------------------------------------------------------
/IDKL/grad_cam/README.md:
--------------------------------------------------------------------------------
1 | ## Grad-CAM
2 | - Original Impl: [https://github.com/jacobgil/pytorch-grad-cam](https://github.com/jacobgil/pytorch-grad-cam)
3 | - Grad-CAM简介: [https://b23.tv/1kccjmb](https://b23.tv/1kccjmb)
4 | - 使用Pytorch实现Grad-CAM并绘制热力图: [https://b23.tv/n1e60vN](https://b23.tv/n1e60vN)
5 |
6 | ## 使用流程(替换成自己的网络)
7 | 1. 将创建模型部分代码替换成自己创建模型的代码,并载入自己训练好的权重
8 | 2. 根据自己网络设置合适的`target_layers`
9 | 3. 根据自己的网络设置合适的预处理方法
10 | 4. 将要预测的图片路径赋值给`img_path`
11 | 5. 将感兴趣的类别id赋值给`target_category`
12 |
13 |
--------------------------------------------------------------------------------
/IDKL/grad_cam/both.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/grad_cam/both.png
--------------------------------------------------------------------------------
/IDKL/grad_cam/imagenet1k_classes.txt:
--------------------------------------------------------------------------------
1 | tench, Tinca tinca
2 | goldfish, Carassius auratus
3 | great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
4 | tiger shark, Galeocerdo cuvieri
5 | hammerhead, hammerhead shark
6 | electric ray, crampfish, numbfish, torpedo
7 | stingray
8 | cock
9 | hen
10 | ostrich, Struthio camelus
11 | brambling, Fringilla montifringilla
12 | goldfinch, Carduelis carduelis
13 | house finch, linnet, Carpodacus mexicanus
14 | junco, snowbird
15 | indigo bunting, indigo finch, indigo bird, Passerina cyanea
16 | robin, American robin, Turdus migratorius
17 | bulbul
18 | jay
19 | magpie
20 | chickadee
21 | water ouzel, dipper
22 | kite
23 | bald eagle, American eagle, Haliaeetus leucocephalus
24 | vulture
25 | great grey owl, great gray owl, Strix nebulosa
26 | European fire salamander, Salamandra salamandra
27 | common newt, Triturus vulgaris
28 | eft
29 | spotted salamander, Ambystoma maculatum
30 | axolotl, mud puppy, Ambystoma mexicanum
31 | bullfrog, Rana catesbeiana
32 | tree frog, tree-frog
33 | tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui
34 | loggerhead, loggerhead turtle, Caretta caretta
35 | leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea
36 | mud turtle
37 | terrapin
38 | box turtle, box tortoise
39 | banded gecko
40 | common iguana, iguana, Iguana iguana
41 | American chameleon, anole, Anolis carolinensis
42 | whiptail, whiptail lizard
43 | agama
44 | frilled lizard, Chlamydosaurus kingi
45 | alligator lizard
46 | Gila monster, Heloderma suspectum
47 | green lizard, Lacerta viridis
48 | African chameleon, Chamaeleo chamaeleon
49 | Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis
50 | African crocodile, Nile crocodile, Crocodylus niloticus
51 | American alligator, Alligator mississipiensis
52 | triceratops
53 | thunder snake, worm snake, Carphophis amoenus
54 | ringneck snake, ring-necked snake, ring snake
55 | hognose snake, puff adder, sand viper
56 | green snake, grass snake
57 | king snake, kingsnake
58 | garter snake, grass snake
59 | water snake
60 | vine snake
61 | night snake, Hypsiglena torquata
62 | boa constrictor, Constrictor constrictor
63 | rock python, rock snake, Python sebae
64 | Indian cobra, Naja naja
65 | green mamba
66 | sea snake
67 | horned viper, cerastes, sand viper, horned asp, Cerastes cornutus
68 | diamondback, diamondback rattlesnake, Crotalus adamanteus
69 | sidewinder, horned rattlesnake, Crotalus cerastes
70 | trilobite
71 | harvestman, daddy longlegs, Phalangium opilio
72 | scorpion
73 | black and gold garden spider, Argiope aurantia
74 | barn spider, Araneus cavaticus
75 | garden spider, Aranea diademata
76 | black widow, Latrodectus mactans
77 | tarantula
78 | wolf spider, hunting spider
79 | tick
80 | centipede
81 | black grouse
82 | ptarmigan
83 | ruffed grouse, partridge, Bonasa umbellus
84 | prairie chicken, prairie grouse, prairie fowl
85 | peacock
86 | quail
87 | partridge
88 | African grey, African gray, Psittacus erithacus
89 | macaw
90 | sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita
91 | lorikeet
92 | coucal
93 | bee eater
94 | hornbill
95 | hummingbird
96 | jacamar
97 | toucan
98 | drake
99 | red-breasted merganser, Mergus serrator
100 | goose
101 | black swan, Cygnus atratus
102 | tusker
103 | echidna, spiny anteater, anteater
104 | platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus
105 | wallaby, brush kangaroo
106 | koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus
107 | wombat
108 | jellyfish
109 | sea anemone, anemone
110 | brain coral
111 | flatworm, platyhelminth
112 | nematode, nematode worm, roundworm
113 | conch
114 | snail
115 | slug
116 | sea slug, nudibranch
117 | chiton, coat-of-mail shell, sea cradle, polyplacophore
118 | chambered nautilus, pearly nautilus, nautilus
119 | Dungeness crab, Cancer magister
120 | rock crab, Cancer irroratus
121 | fiddler crab
122 | king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica
123 | American lobster, Northern lobster, Maine lobster, Homarus americanus
124 | spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish
125 | crayfish, crawfish, crawdad, crawdaddy
126 | hermit crab
127 | isopod
128 | white stork, Ciconia ciconia
129 | black stork, Ciconia nigra
130 | spoonbill
131 | flamingo
132 | little blue heron, Egretta caerulea
133 | American egret, great white heron, Egretta albus
134 | bittern
135 | crane
136 | limpkin, Aramus pictus
137 | European gallinule, Porphyrio porphyrio
138 | American coot, marsh hen, mud hen, water hen, Fulica americana
139 | bustard
140 | ruddy turnstone, Arenaria interpres
141 | red-backed sandpiper, dunlin, Erolia alpina
142 | redshank, Tringa totanus
143 | dowitcher
144 | oystercatcher, oyster catcher
145 | pelican
146 | king penguin, Aptenodytes patagonica
147 | albatross, mollymawk
148 | grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus
149 | killer whale, killer, orca, grampus, sea wolf, Orcinus orca
150 | dugong, Dugong dugon
151 | sea lion
152 | Chihuahua
153 | Japanese spaniel
154 | Maltese dog, Maltese terrier, Maltese
155 | Pekinese, Pekingese, Peke
156 | Shih-Tzu
157 | Blenheim spaniel
158 | papillon
159 | toy terrier
160 | Rhodesian ridgeback
161 | Afghan hound, Afghan
162 | basset, basset hound
163 | beagle
164 | bloodhound, sleuthhound
165 | bluetick
166 | black-and-tan coonhound
167 | Walker hound, Walker foxhound
168 | English foxhound
169 | redbone
170 | borzoi, Russian wolfhound
171 | Irish wolfhound
172 | Italian greyhound
173 | whippet
174 | Ibizan hound, Ibizan Podenco
175 | Norwegian elkhound, elkhound
176 | otterhound, otter hound
177 | Saluki, gazelle hound
178 | Scottish deerhound, deerhound
179 | Weimaraner
180 | Staffordshire bullterrier, Staffordshire bull terrier
181 | American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier
182 | Bedlington terrier
183 | Border terrier
184 | Kerry blue terrier
185 | Irish terrier
186 | Norfolk terrier
187 | Norwich terrier
188 | Yorkshire terrier
189 | wire-haired fox terrier
190 | Lakeland terrier
191 | Sealyham terrier, Sealyham
192 | Airedale, Airedale terrier
193 | cairn, cairn terrier
194 | Australian terrier
195 | Dandie Dinmont, Dandie Dinmont terrier
196 | Boston bull, Boston terrier
197 | miniature schnauzer
198 | giant schnauzer
199 | standard schnauzer
200 | Scotch terrier, Scottish terrier, Scottie
201 | Tibetan terrier, chrysanthemum dog
202 | silky terrier, Sydney silky
203 | soft-coated wheaten terrier
204 | West Highland white terrier
205 | Lhasa, Lhasa apso
206 | flat-coated retriever
207 | curly-coated retriever
208 | golden retriever
209 | Labrador retriever
210 | Chesapeake Bay retriever
211 | German short-haired pointer
212 | vizsla, Hungarian pointer
213 | English setter
214 | Irish setter, red setter
215 | Gordon setter
216 | Brittany spaniel
217 | clumber, clumber spaniel
218 | English springer, English springer spaniel
219 | Welsh springer spaniel
220 | cocker spaniel, English cocker spaniel, cocker
221 | Sussex spaniel
222 | Irish water spaniel
223 | kuvasz
224 | schipperke
225 | groenendael
226 | malinois
227 | briard
228 | kelpie
229 | komondor
230 | Old English sheepdog, bobtail
231 | Shetland sheepdog, Shetland sheep dog, Shetland
232 | collie
233 | Border collie
234 | Bouvier des Flandres, Bouviers des Flandres
235 | Rottweiler
236 | German shepherd, German shepherd dog, German police dog, alsatian
237 | Doberman, Doberman pinscher
238 | miniature pinscher
239 | Greater Swiss Mountain dog
240 | Bernese mountain dog
241 | Appenzeller
242 | EntleBucher
243 | boxer
244 | bull mastiff
245 | Tibetan mastiff
246 | French bulldog
247 | Great Dane
248 | Saint Bernard, St Bernard
249 | Eskimo dog, husky
250 | malamute, malemute, Alaskan malamute
251 | Siberian husky
252 | dalmatian, coach dog, carriage dog
253 | affenpinscher, monkey pinscher, monkey dog
254 | basenji
255 | pug, pug-dog
256 | Leonberg
257 | Newfoundland, Newfoundland dog
258 | Great Pyrenees
259 | Samoyed, Samoyede
260 | Pomeranian
261 | chow, chow chow
262 | keeshond
263 | Brabancon griffon
264 | Pembroke, Pembroke Welsh corgi
265 | Cardigan, Cardigan Welsh corgi
266 | toy poodle
267 | miniature poodle
268 | standard poodle
269 | Mexican hairless
270 | timber wolf, grey wolf, gray wolf, Canis lupus
271 | white wolf, Arctic wolf, Canis lupus tundrarum
272 | red wolf, maned wolf, Canis rufus, Canis niger
273 | coyote, prairie wolf, brush wolf, Canis latrans
274 | dingo, warrigal, warragal, Canis dingo
275 | dhole, Cuon alpinus
276 | African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus
277 | hyena, hyaena
278 | red fox, Vulpes vulpes
279 | kit fox, Vulpes macrotis
280 | Arctic fox, white fox, Alopex lagopus
281 | grey fox, gray fox, Urocyon cinereoargenteus
282 | tabby, tabby cat
283 | tiger cat
284 | Persian cat
285 | Siamese cat, Siamese
286 | Egyptian cat
287 | cougar, puma, catamount, mountain lion, painter, panther, Felis concolor
288 | lynx, catamount
289 | leopard, Panthera pardus
290 | snow leopard, ounce, Panthera uncia
291 | jaguar, panther, Panthera onca, Felis onca
292 | lion, king of beasts, Panthera leo
293 | tiger, Panthera tigris
294 | cheetah, chetah, Acinonyx jubatus
295 | brown bear, bruin, Ursus arctos
296 | American black bear, black bear, Ursus americanus, Euarctos americanus
297 | ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus
298 | sloth bear, Melursus ursinus, Ursus ursinus
299 | mongoose
300 | meerkat, mierkat
301 | tiger beetle
302 | ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle
303 | ground beetle, carabid beetle
304 | long-horned beetle, longicorn, longicorn beetle
305 | leaf beetle, chrysomelid
306 | dung beetle
307 | rhinoceros beetle
308 | weevil
309 | fly
310 | bee
311 | ant, emmet, pismire
312 | grasshopper, hopper
313 | cricket
314 | walking stick, walkingstick, stick insect
315 | cockroach, roach
316 | mantis, mantid
317 | cicada, cicala
318 | leafhopper
319 | lacewing, lacewing fly
320 | dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk
321 | damselfly
322 | admiral
323 | ringlet, ringlet butterfly
324 | monarch, monarch butterfly, milkweed butterfly, Danaus plexippus
325 | cabbage butterfly
326 | sulphur butterfly, sulfur butterfly
327 | lycaenid, lycaenid butterfly
328 | starfish, sea star
329 | sea urchin
330 | sea cucumber, holothurian
331 | wood rabbit, cottontail, cottontail rabbit
332 | hare
333 | Angora, Angora rabbit
334 | hamster
335 | porcupine, hedgehog
336 | fox squirrel, eastern fox squirrel, Sciurus niger
337 | marmot
338 | beaver
339 | guinea pig, Cavia cobaya
340 | sorrel
341 | zebra
342 | hog, pig, grunter, squealer, Sus scrofa
343 | wild boar, boar, Sus scrofa
344 | warthog
345 | hippopotamus, hippo, river horse, Hippopotamus amphibius
346 | ox
347 | water buffalo, water ox, Asiatic buffalo, Bubalus bubalis
348 | bison
349 | ram, tup
350 | bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis
351 | ibex, Capra ibex
352 | hartebeest
353 | impala, Aepyceros melampus
354 | gazelle
355 | Arabian camel, dromedary, Camelus dromedarius
356 | llama
357 | weasel
358 | mink
359 | polecat, fitch, foulmart, foumart, Mustela putorius
360 | black-footed ferret, ferret, Mustela nigripes
361 | otter
362 | skunk, polecat, wood pussy
363 | badger
364 | armadillo
365 | three-toed sloth, ai, Bradypus tridactylus
366 | orangutan, orang, orangutang, Pongo pygmaeus
367 | gorilla, Gorilla gorilla
368 | chimpanzee, chimp, Pan troglodytes
369 | gibbon, Hylobates lar
370 | siamang, Hylobates syndactylus, Symphalangus syndactylus
371 | guenon, guenon monkey
372 | patas, hussar monkey, Erythrocebus patas
373 | baboon
374 | macaque
375 | langur
376 | colobus, colobus monkey
377 | proboscis monkey, Nasalis larvatus
378 | marmoset
379 | capuchin, ringtail, Cebus capucinus
380 | howler monkey, howler
381 | titi, titi monkey
382 | spider monkey, Ateles geoffroyi
383 | squirrel monkey, Saimiri sciureus
384 | Madagascar cat, ring-tailed lemur, Lemur catta
385 | indri, indris, Indri indri, Indri brevicaudatus
386 | Indian elephant, Elephas maximus
387 | African elephant, Loxodonta africana
388 | lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens
389 | giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca
390 | barracouta, snoek
391 | eel
392 | coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch
393 | rock beauty, Holocanthus tricolor
394 | anemone fish
395 | sturgeon
396 | gar, garfish, garpike, billfish, Lepisosteus osseus
397 | lionfish
398 | puffer, pufferfish, blowfish, globefish
399 | abacus
400 | abaya
401 | academic gown, academic robe, judge's robe
402 | accordion, piano accordion, squeeze box
403 | acoustic guitar
404 | aircraft carrier, carrier, flattop, attack aircraft carrier
405 | airliner
406 | airship, dirigible
407 | altar
408 | ambulance
409 | amphibian, amphibious vehicle
410 | analog clock
411 | apiary, bee house
412 | apron
413 | ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin
414 | assault rifle, assault gun
415 | backpack, back pack, knapsack, packsack, rucksack, haversack
416 | bakery, bakeshop, bakehouse
417 | balance beam, beam
418 | balloon
419 | ballpoint, ballpoint pen, ballpen, Biro
420 | Band Aid
421 | banjo
422 | bannister, banister, balustrade, balusters, handrail
423 | barbell
424 | barber chair
425 | barbershop
426 | barn
427 | barometer
428 | barrel, cask
429 | barrow, garden cart, lawn cart, wheelbarrow
430 | baseball
431 | basketball
432 | bassinet
433 | bassoon
434 | bathing cap, swimming cap
435 | bath towel
436 | bathtub, bathing tub, bath, tub
437 | beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon
438 | beacon, lighthouse, beacon light, pharos
439 | beaker
440 | bearskin, busby, shako
441 | beer bottle
442 | beer glass
443 | bell cote, bell cot
444 | bib
445 | bicycle-built-for-two, tandem bicycle, tandem
446 | bikini, two-piece
447 | binder, ring-binder
448 | binoculars, field glasses, opera glasses
449 | birdhouse
450 | boathouse
451 | bobsled, bobsleigh, bob
452 | bolo tie, bolo, bola tie, bola
453 | bonnet, poke bonnet
454 | bookcase
455 | bookshop, bookstore, bookstall
456 | bottlecap
457 | bow
458 | bow tie, bow-tie, bowtie
459 | brass, memorial tablet, plaque
460 | brassiere, bra, bandeau
461 | breakwater, groin, groyne, mole, bulwark, seawall, jetty
462 | breastplate, aegis, egis
463 | broom
464 | bucket, pail
465 | buckle
466 | bulletproof vest
467 | bullet train, bullet
468 | butcher shop, meat market
469 | cab, hack, taxi, taxicab
470 | caldron, cauldron
471 | candle, taper, wax light
472 | cannon
473 | canoe
474 | can opener, tin opener
475 | cardigan
476 | car mirror
477 | carousel, carrousel, merry-go-round, roundabout, whirligig
478 | carpenter's kit, tool kit
479 | carton
480 | car wheel
481 | cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM
482 | cassette
483 | cassette player
484 | castle
485 | catamaran
486 | CD player
487 | cello, violoncello
488 | cellular telephone, cellular phone, cellphone, cell, mobile phone
489 | chain
490 | chainlink fence
491 | chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour
492 | chain saw, chainsaw
493 | chest
494 | chiffonier, commode
495 | chime, bell, gong
496 | china cabinet, china closet
497 | Christmas stocking
498 | church, church building
499 | cinema, movie theater, movie theatre, movie house, picture palace
500 | cleaver, meat cleaver, chopper
501 | cliff dwelling
502 | cloak
503 | clog, geta, patten, sabot
504 | cocktail shaker
505 | coffee mug
506 | coffeepot
507 | coil, spiral, volute, whorl, helix
508 | combination lock
509 | computer keyboard, keypad
510 | confectionery, confectionary, candy store
511 | container ship, containership, container vessel
512 | convertible
513 | corkscrew, bottle screw
514 | cornet, horn, trumpet, trump
515 | cowboy boot
516 | cowboy hat, ten-gallon hat
517 | cradle
518 | crane
519 | crash helmet
520 | crate
521 | crib, cot
522 | Crock Pot
523 | croquet ball
524 | crutch
525 | cuirass
526 | dam, dike, dyke
527 | desk
528 | desktop computer
529 | dial telephone, dial phone
530 | diaper, nappy, napkin
531 | digital clock
532 | digital watch
533 | dining table, board
534 | dishrag, dishcloth
535 | dishwasher, dish washer, dishwashing machine
536 | disk brake, disc brake
537 | dock, dockage, docking facility
538 | dogsled, dog sled, dog sleigh
539 | dome
540 | doormat, welcome mat
541 | drilling platform, offshore rig
542 | drum, membranophone, tympan
543 | drumstick
544 | dumbbell
545 | Dutch oven
546 | electric fan, blower
547 | electric guitar
548 | electric locomotive
549 | entertainment center
550 | envelope
551 | espresso maker
552 | face powder
553 | feather boa, boa
554 | file, file cabinet, filing cabinet
555 | fireboat
556 | fire engine, fire truck
557 | fire screen, fireguard
558 | flagpole, flagstaff
559 | flute, transverse flute
560 | folding chair
561 | football helmet
562 | forklift
563 | fountain
564 | fountain pen
565 | four-poster
566 | freight car
567 | French horn, horn
568 | frying pan, frypan, skillet
569 | fur coat
570 | garbage truck, dustcart
571 | gasmask, respirator, gas helmet
572 | gas pump, gasoline pump, petrol pump, island dispenser
573 | goblet
574 | go-kart
575 | golf ball
576 | golfcart, golf cart
577 | gondola
578 | gong, tam-tam
579 | gown
580 | grand piano, grand
581 | greenhouse, nursery, glasshouse
582 | grille, radiator grille
583 | grocery store, grocery, food market, market
584 | guillotine
585 | hair slide
586 | hair spray
587 | half track
588 | hammer
589 | hamper
590 | hand blower, blow dryer, blow drier, hair dryer, hair drier
591 | hand-held computer, hand-held microcomputer
592 | handkerchief, hankie, hanky, hankey
593 | hard disc, hard disk, fixed disk
594 | harmonica, mouth organ, harp, mouth harp
595 | harp
596 | harvester, reaper
597 | hatchet
598 | holster
599 | home theater, home theatre
600 | honeycomb
601 | hook, claw
602 | hoopskirt, crinoline
603 | horizontal bar, high bar
604 | horse cart, horse-cart
605 | hourglass
606 | iPod
607 | iron, smoothing iron
608 | jack-o'-lantern
609 | jean, blue jean, denim
610 | jeep, landrover
611 | jersey, T-shirt, tee shirt
612 | jigsaw puzzle
613 | jinrikisha, ricksha, rickshaw
614 | joystick
615 | kimono
616 | knee pad
617 | knot
618 | lab coat, laboratory coat
619 | ladle
620 | lampshade, lamp shade
621 | laptop, laptop computer
622 | lawn mower, mower
623 | lens cap, lens cover
624 | letter opener, paper knife, paperknife
625 | library
626 | lifeboat
627 | lighter, light, igniter, ignitor
628 | limousine, limo
629 | liner, ocean liner
630 | lipstick, lip rouge
631 | Loafer
632 | lotion
633 | loudspeaker, speaker, speaker unit, loudspeaker system, speaker system
634 | loupe, jeweler's loupe
635 | lumbermill, sawmill
636 | magnetic compass
637 | mailbag, postbag
638 | mailbox, letter box
639 | maillot
640 | maillot, tank suit
641 | manhole cover
642 | maraca
643 | marimba, xylophone
644 | mask
645 | matchstick
646 | maypole
647 | maze, labyrinth
648 | measuring cup
649 | medicine chest, medicine cabinet
650 | megalith, megalithic structure
651 | microphone, mike
652 | microwave, microwave oven
653 | military uniform
654 | milk can
655 | minibus
656 | miniskirt, mini
657 | minivan
658 | missile
659 | mitten
660 | mixing bowl
661 | mobile home, manufactured home
662 | Model T
663 | modem
664 | monastery
665 | monitor
666 | moped
667 | mortar
668 | mortarboard
669 | mosque
670 | mosquito net
671 | motor scooter, scooter
672 | mountain bike, all-terrain bike, off-roader
673 | mountain tent
674 | mouse, computer mouse
675 | mousetrap
676 | moving van
677 | muzzle
678 | nail
679 | neck brace
680 | necklace
681 | nipple
682 | notebook, notebook computer
683 | obelisk
684 | oboe, hautboy, hautbois
685 | ocarina, sweet potato
686 | odometer, hodometer, mileometer, milometer
687 | oil filter
688 | organ, pipe organ
689 | oscilloscope, scope, cathode-ray oscilloscope, CRO
690 | overskirt
691 | oxcart
692 | oxygen mask
693 | packet
694 | paddle, boat paddle
695 | paddlewheel, paddle wheel
696 | padlock
697 | paintbrush
698 | pajama, pyjama, pj's, jammies
699 | palace
700 | panpipe, pandean pipe, syrinx
701 | paper towel
702 | parachute, chute
703 | parallel bars, bars
704 | park bench
705 | parking meter
706 | passenger car, coach, carriage
707 | patio, terrace
708 | pay-phone, pay-station
709 | pedestal, plinth, footstall
710 | pencil box, pencil case
711 | pencil sharpener
712 | perfume, essence
713 | Petri dish
714 | photocopier
715 | pick, plectrum, plectron
716 | pickelhaube
717 | picket fence, paling
718 | pickup, pickup truck
719 | pier
720 | piggy bank, penny bank
721 | pill bottle
722 | pillow
723 | ping-pong ball
724 | pinwheel
725 | pirate, pirate ship
726 | pitcher, ewer
727 | plane, carpenter's plane, woodworking plane
728 | planetarium
729 | plastic bag
730 | plate rack
731 | plow, plough
732 | plunger, plumber's helper
733 | Polaroid camera, Polaroid Land camera
734 | pole
735 | police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria
736 | poncho
737 | pool table, billiard table, snooker table
738 | pop bottle, soda bottle
739 | pot, flowerpot
740 | potter's wheel
741 | power drill
742 | prayer rug, prayer mat
743 | printer
744 | prison, prison house
745 | projectile, missile
746 | projector
747 | puck, hockey puck
748 | punching bag, punch bag, punching ball, punchball
749 | purse
750 | quill, quill pen
751 | quilt, comforter, comfort, puff
752 | racer, race car, racing car
753 | racket, racquet
754 | radiator
755 | radio, wireless
756 | radio telescope, radio reflector
757 | rain barrel
758 | recreational vehicle, RV, R.V.
759 | reel
760 | reflex camera
761 | refrigerator, icebox
762 | remote control, remote
763 | restaurant, eating house, eating place, eatery
764 | revolver, six-gun, six-shooter
765 | rifle
766 | rocking chair, rocker
767 | rotisserie
768 | rubber eraser, rubber, pencil eraser
769 | rugby ball
770 | rule, ruler
771 | running shoe
772 | safe
773 | safety pin
774 | saltshaker, salt shaker
775 | sandal
776 | sarong
777 | sax, saxophone
778 | scabbard
779 | scale, weighing machine
780 | school bus
781 | schooner
782 | scoreboard
783 | screen, CRT screen
784 | screw
785 | screwdriver
786 | seat belt, seatbelt
787 | sewing machine
788 | shield, buckler
789 | shoe shop, shoe-shop, shoe store
790 | shoji
791 | shopping basket
792 | shopping cart
793 | shovel
794 | shower cap
795 | shower curtain
796 | ski
797 | ski mask
798 | sleeping bag
799 | slide rule, slipstick
800 | sliding door
801 | slot, one-armed bandit
802 | snorkel
803 | snowmobile
804 | snowplow, snowplough
805 | soap dispenser
806 | soccer ball
807 | sock
808 | solar dish, solar collector, solar furnace
809 | sombrero
810 | soup bowl
811 | space bar
812 | space heater
813 | space shuttle
814 | spatula
815 | speedboat
816 | spider web, spider's web
817 | spindle
818 | sports car, sport car
819 | spotlight, spot
820 | stage
821 | steam locomotive
822 | steel arch bridge
823 | steel drum
824 | stethoscope
825 | stole
826 | stone wall
827 | stopwatch, stop watch
828 | stove
829 | strainer
830 | streetcar, tram, tramcar, trolley, trolley car
831 | stretcher
832 | studio couch, day bed
833 | stupa, tope
834 | submarine, pigboat, sub, U-boat
835 | suit, suit of clothes
836 | sundial
837 | sunglass
838 | sunglasses, dark glasses, shades
839 | sunscreen, sunblock, sun blocker
840 | suspension bridge
841 | swab, swob, mop
842 | sweatshirt
843 | swimming trunks, bathing trunks
844 | swing
845 | switch, electric switch, electrical switch
846 | syringe
847 | table lamp
848 | tank, army tank, armored combat vehicle, armoured combat vehicle
849 | tape player
850 | teapot
851 | teddy, teddy bear
852 | television, television system
853 | tennis ball
854 | thatch, thatched roof
855 | theater curtain, theatre curtain
856 | thimble
857 | thresher, thrasher, threshing machine
858 | throne
859 | tile roof
860 | toaster
861 | tobacco shop, tobacconist shop, tobacconist
862 | toilet seat
863 | torch
864 | totem pole
865 | tow truck, tow car, wrecker
866 | toyshop
867 | tractor
868 | trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi
869 | tray
870 | trench coat
871 | tricycle, trike, velocipede
872 | trimaran
873 | tripod
874 | triumphal arch
875 | trolleybus, trolley coach, trackless trolley
876 | trombone
877 | tub, vat
878 | turnstile
879 | typewriter keyboard
880 | umbrella
881 | unicycle, monocycle
882 | upright, upright piano
883 | vacuum, vacuum cleaner
884 | vase
885 | vault
886 | velvet
887 | vending machine
888 | vestment
889 | viaduct
890 | violin, fiddle
891 | volleyball
892 | waffle iron
893 | wall clock
894 | wallet, billfold, notecase, pocketbook
895 | wardrobe, closet, press
896 | warplane, military plane
897 | washbasin, handbasin, washbowl, lavabo, wash-hand basin
898 | washer, automatic washer, washing machine
899 | water bottle
900 | water jug
901 | water tower
902 | whiskey jug
903 | whistle
904 | wig
905 | window screen
906 | window shade
907 | Windsor tie
908 | wine bottle
909 | wing
910 | wok
911 | wooden spoon
912 | wool, woolen, woollen
913 | worm fence, snake fence, snake-rail fence, Virginia fence
914 | wreck
915 | yawl
916 | yurt
917 | web site, website, internet site, site
918 | comic book
919 | crossword puzzle, crossword
920 | street sign
921 | traffic light, traffic signal, stoplight
922 | book jacket, dust cover, dust jacket, dust wrapper
923 | menu
924 | plate
925 | guacamole
926 | consomme
927 | hot pot, hotpot
928 | trifle
929 | ice cream, icecream
930 | ice lolly, lolly, lollipop, popsicle
931 | French loaf
932 | bagel, beigel
933 | pretzel
934 | cheeseburger
935 | hotdog, hot dog, red hot
936 | mashed potato
937 | head cabbage
938 | broccoli
939 | cauliflower
940 | zucchini, courgette
941 | spaghetti squash
942 | acorn squash
943 | butternut squash
944 | cucumber, cuke
945 | artichoke, globe artichoke
946 | bell pepper
947 | cardoon
948 | mushroom
949 | Granny Smith
950 | strawberry
951 | orange
952 | lemon
953 | fig
954 | pineapple, ananas
955 | banana
956 | jackfruit, jak, jack
957 | custard apple
958 | pomegranate
959 | hay
960 | carbonara
961 | chocolate sauce, chocolate syrup
962 | dough
963 | meat loaf, meatloaf
964 | pizza, pizza pie
965 | potpie
966 | burrito
967 | red wine
968 | espresso
969 | cup
970 | eggnog
971 | alp
972 | bubble
973 | cliff, drop, drop-off
974 | coral reef
975 | geyser
976 | lakeside, lakeshore
977 | promontory, headland, head, foreland
978 | sandbar, sand bar
979 | seashore, coast, seacoast, sea-coast
980 | valley, vale
981 | volcano
982 | ballplayer, baseball player
983 | groom, bridegroom
984 | scuba diver
985 | rapeseed
986 | daisy
987 | yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum
988 | corn
989 | acorn
990 | hip, rose hip, rosehip
991 | buckeye, horse chestnut, conker
992 | coral fungus
993 | agaric
994 | gyromitra
995 | stinkhorn, carrion fungus
996 | earthstar
997 | hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa
998 | bolete
999 | ear, spike, capitulum
1000 | toilet tissue, toilet paper, bathroom tissue
--------------------------------------------------------------------------------
/IDKL/grad_cam/main_cnn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | from PIL import Image
5 | import matplotlib.pyplot as plt
6 | from torchvision import models
7 | from torchvision import transforms
8 | from utils import GradCAM, show_cam_on_image, center_crop_img
9 |
10 |
11 | def main():
12 | model = models.mobilenet_v3_large(pretrained=True)
13 | target_layers = [model.features[-1]]
14 |
15 | # model = models.vgg16(pretrained=True)
16 | # target_layers = [model.features]
17 |
18 | # model = models.resnet34(pretrained=True)
19 | # target_layers = [model.layer4]
20 |
21 | # model = models.regnet_y_800mf(pretrained=True)
22 | # target_layers = [model.trunk_output]
23 |
24 | # model = models.efficientnet_b0(pretrained=True)
25 | # target_layers = [model.features]
26 |
27 | data_transform = transforms.Compose([transforms.ToTensor(),
28 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
29 | # load image
30 | img_path = "both.png"
31 | assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
32 | img = Image.open(img_path).convert('RGB')
33 | img = np.array(img, dtype=np.uint8)
34 | # img = center_crop_img(img, 224)
35 |
36 | # [C, H, W]
37 | img_tensor = data_transform(img)
38 | # expand batch dimension
39 | # [C, H, W] -> [N, C, H, W]
40 | input_tensor = torch.unsqueeze(img_tensor, dim=0)
41 |
42 | cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
43 | target_category = 281 # tabby, tabby cat
44 | # target_category = 254 # pug, pug-dog
45 |
46 | grayscale_cam = cam(input_tensor=input_tensor, target_category=target_category)
47 |
48 | grayscale_cam = grayscale_cam[0, :]
49 | visualization = show_cam_on_image(img.astype(dtype=np.float32) / 255.,
50 | grayscale_cam,
51 | use_rgb=True)
52 | plt.imshow(visualization)
53 | plt.show()
54 |
55 |
56 | if __name__ == '__main__':
57 | main()
58 |
--------------------------------------------------------------------------------
/IDKL/grad_cam/main_swin.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import numpy as np
4 | import torch
5 | from PIL import Image
6 | import matplotlib.pyplot as plt
7 | from torchvision import transforms
8 | from utils import GradCAM, show_cam_on_image, center_crop_img
9 | from swin_model import swin_base_patch4_window7_224
10 |
11 |
12 | class ResizeTransform:
13 | def __init__(self, im_h: int, im_w: int):
14 | self.height = self.feature_size(im_h)
15 | self.width = self.feature_size(im_w)
16 |
17 | @staticmethod
18 | def feature_size(s):
19 | s = math.ceil(s / 4) # PatchEmbed
20 | s = math.ceil(s / 2) # PatchMerging1
21 | s = math.ceil(s / 2) # PatchMerging2
22 | s = math.ceil(s / 2) # PatchMerging3
23 | return s
24 |
25 | def __call__(self, x):
26 | result = x.reshape(x.size(0),
27 | self.height,
28 | self.width,
29 | x.size(2))
30 |
31 | # Bring the channels to the first dimension,
32 | # like in CNNs.
33 | # [batch_size, H, W, C] -> [batch, C, H, W]
34 | result = result.permute(0, 3, 1, 2)
35 |
36 | return result
37 |
38 |
39 | def main():
40 | # 注意输入的图片必须是32的整数倍
41 | # 否则由于padding的原因会出现注意力飘逸的问题
42 | img_size = 224
43 | assert img_size % 32 == 0
44 |
45 | model = swin_base_patch4_window7_224()
46 | # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth
47 | weights_path = "./swin_base_patch4_window7_224.pth"
48 | model.load_state_dict(torch.load(weights_path, map_location="cpu")["model"], strict=False)
49 |
50 | target_layers = [model.norm]
51 |
52 | data_transform = transforms.Compose([transforms.ToTensor(),
53 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
54 | # load image
55 | img_path = "both.png"
56 | assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
57 | img = Image.open(img_path).convert('RGB')
58 | img = np.array(img, dtype=np.uint8)
59 | img = center_crop_img(img, img_size)
60 |
61 | # [C, H, W]
62 | img_tensor = data_transform(img)
63 | # expand batch dimension
64 | # [C, H, W] -> [N, C, H, W]
65 | input_tensor = torch.unsqueeze(img_tensor, dim=0)
66 |
67 | cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False,
68 | reshape_transform=ResizeTransform(im_h=img_size, im_w=img_size))
69 | target_category = 281 # tabby, tabby cat
70 | # target_category = 254 # pug, pug-dog
71 |
72 | grayscale_cam = cam(input_tensor=input_tensor, target_category=target_category)
73 |
74 | grayscale_cam = grayscale_cam[0, :]
75 | visualization = show_cam_on_image(img / 255., grayscale_cam, use_rgb=True)
76 | plt.imshow(visualization)
77 | plt.show()
78 |
79 |
80 | if __name__ == '__main__':
81 | main()
82 |
--------------------------------------------------------------------------------
/IDKL/grad_cam/main_vit.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | from PIL import Image
5 | import matplotlib.pyplot as plt
6 | from torchvision import transforms
7 | from utils import GradCAM, show_cam_on_image, center_crop_img
8 | from vit_model import vit_base_patch16_224
9 |
10 |
11 | class ReshapeTransform:
12 | def __init__(self, model):
13 | input_size = model.patch_embed.img_size
14 | patch_size = model.patch_embed.patch_size
15 | self.h = input_size[0] // patch_size[0]
16 | self.w = input_size[1] // patch_size[1]
17 |
18 | def __call__(self, x):
19 | # remove cls token and reshape
20 | # [batch_size, num_tokens, token_dim]
21 | result = x[:, 1:, :].reshape(x.size(0),
22 | self.h,
23 | self.w,
24 | x.size(2))
25 |
26 | # Bring the channels to the first dimension,
27 | # like in CNNs.
28 | # [batch_size, H, W, C] -> [batch, C, H, W]
29 | result = result.permute(0, 3, 1, 2)
30 | return result
31 |
32 |
33 | def main():
34 | model = vit_base_patch16_224()
35 | # 链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA 密码: eu9f
36 | weights_path = "./vit_base_patch16_224.pth"
37 | model.load_state_dict(torch.load(weights_path, map_location="cpu"))
38 | # Since the final classification is done on the class token computed in the last attention block,
39 | # the output will not be affected by the 14x14 channels in the last layer.
40 | # The gradient of the output with respect to them, will be 0!
41 | # We should chose any layer before the final attention block.
42 | target_layers = [model.blocks[-1].norm1]
43 |
44 | data_transform = transforms.Compose([transforms.ToTensor(),
45 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
46 | # load image
47 | img_path = "both.png"
48 | assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
49 | img = Image.open(img_path).convert('RGB')
50 | img = np.array(img, dtype=np.uint8)
51 | img = center_crop_img(img, 224)
52 | # [C, H, W]
53 | img_tensor = data_transform(img)
54 | # expand batch dimension
55 | # [C, H, W] -> [N, C, H, W]
56 | input_tensor = torch.unsqueeze(img_tensor, dim=0)
57 |
58 | cam = GradCAM(model=model,
59 | target_layers=target_layers,
60 | use_cuda=False,
61 | reshape_transform=ReshapeTransform(model))
62 | target_category = 281 # tabby, tabby cat
63 | # target_category = 254 # pug, pug-dog
64 |
65 | grayscale_cam = cam(input_tensor=input_tensor, target_category=target_category)
66 |
67 | grayscale_cam = grayscale_cam[0, :]
68 | visualization = show_cam_on_image(img / 255., grayscale_cam, use_rgb=True)
69 | plt.imshow(visualization)
70 | plt.show()
71 |
72 |
73 | if __name__ == '__main__':
74 | main()
75 |
--------------------------------------------------------------------------------
/IDKL/grad_cam/utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 |
5 | class ActivationsAndGradients:
6 | """ Class for extracting activations and
7 | registering gradients from targeted intermediate layers """
8 |
9 | def __init__(self, model, target_layers, reshape_transform):
10 | self.model = model
11 | self.gradients = []
12 | self.activations = []
13 | self.reshape_transform = reshape_transform
14 | self.handles = []
15 | for target_layer in target_layers:
16 | self.handles.append(
17 | target_layer.register_forward_hook(
18 | self.save_activation))
19 | # Backward compatibility with older pytorch versions:
20 | if hasattr(target_layer, 'register_full_backward_hook'):
21 | self.handles.append(
22 | target_layer.register_full_backward_hook(
23 | self.save_gradient))
24 | else:
25 | self.handles.append(
26 | target_layer.register_backward_hook(
27 | self.save_gradient))
28 |
29 | def save_activation(self, module, input, output):
30 | activation = output
31 | if self.reshape_transform is not None:
32 | activation = self.reshape_transform(activation)
33 | self.activations.append(activation.cpu().detach())
34 |
35 | def save_gradient(self, module, grad_input, grad_output):
36 | # Gradients are computed in reverse order
37 | grad = grad_output[0]
38 | if self.reshape_transform is not None:
39 | grad = self.reshape_transform(grad)
40 | self.gradients = [grad.cpu().detach()] + self.gradients
41 |
42 | def __call__(self, x):
43 | self.gradients = []
44 | self.activations = []
45 | return self.model(x)
46 |
47 | def release(self):
48 | for handle in self.handles:
49 | handle.remove()
50 |
51 |
52 | class GradCAM:
53 | def __init__(self,
54 | model,
55 | target_layers,
56 | reshape_transform=None,
57 | use_cuda=False):
58 | self.model = model.eval()
59 | self.target_layers = target_layers
60 | self.reshape_transform = reshape_transform
61 | self.cuda = use_cuda
62 | if self.cuda:
63 | self.model = model.cuda()
64 | self.activations_and_grads = ActivationsAndGradients(
65 | self.model, target_layers, reshape_transform)
66 |
67 | """ Get a vector of weights for every channel in the target layer.
68 | Methods that return weights channels,
69 | will typically need to only implement this function. """
70 |
71 | @staticmethod
72 | def get_cam_weights(grads):
73 | return np.mean(grads, axis=(2, 3), keepdims=True)
74 |
75 | @staticmethod
76 | def get_loss(output, target_category):
77 | loss = 0
78 | for i in range(len(target_category)):
79 | loss = loss + output[i, target_category[i]]
80 | return loss
81 |
82 | def get_cam_image(self, activations, grads):
83 | weights = self.get_cam_weights(grads)
84 | weighted_activations = weights * activations
85 | cam = weighted_activations.sum(axis=1)
86 |
87 | return cam
88 |
89 | @staticmethod
90 | def get_target_width_height(input_tensor):
91 | width, height = input_tensor.size(-1), input_tensor.size(-2)
92 | return width, height
93 |
94 | def compute_cam_per_layer(self, input_tensor):
95 | activations_list = [a.cpu().data.numpy()
96 | for a in self.activations_and_grads.activations]
97 | grads_list = [g.cpu().data.numpy()
98 | for g in self.activations_and_grads.gradients]
99 | target_size = self.get_target_width_height(input_tensor)
100 |
101 | cam_per_target_layer = []
102 | # Loop over the saliency image from every layer
103 |
104 | for layer_activations, layer_grads in zip(activations_list, grads_list):
105 | cam = self.get_cam_image(layer_activations, layer_grads)
106 | #cam = cam*2-cam.mean()
107 | cam[cam < 0] = 0 # works like mute the min-max scale in the function of scale_cam_image
108 | #cam = cam**0.6
109 |
110 | #cam = np.exp(cam)-np.exp()
111 |
112 | scaled = self.scale_cam_image(cam, target_size)
113 | cam_per_target_layer.append(scaled[:, None, :])
114 |
115 | return cam_per_target_layer
116 |
117 | def aggregate_multi_layers(self, cam_per_target_layer):
118 | cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1)
119 | cam_per_target_layer = np.maximum(cam_per_target_layer, 0)
120 | result = np.mean(cam_per_target_layer, axis=1)
121 | return self.scale_cam_image(result)
122 |
123 | @staticmethod
124 | def scale_cam_image(cam, target_size=None):
125 | result = []
126 | for img in cam:
127 | img = img - np.min(img)
128 | img = img / (1e-7 + np.max(img))
129 | if target_size is not None:
130 | img = cv2.resize(img, target_size)
131 | result.append(img)
132 | result = np.float32(result)
133 |
134 | return result
135 |
136 | def __call__(self, input_tensor, target_category=None):
137 |
138 | if self.cuda:
139 | input_tensor = input_tensor.cuda()
140 |
141 | # 正向传播得到网络输出logits(未经过softmax)
142 | output = self.activations_and_grads(input_tensor)
143 | if isinstance(target_category, int):
144 | target_category = [target_category] * input_tensor.size(0)
145 |
146 | if target_category is None:
147 | target_category = np.argmax(output.cpu().data.numpy(), axis=-1)
148 | print(f"category id: {target_category}")
149 | else:
150 | assert (len(target_category) == input_tensor.size(0))
151 |
152 | self.model.zero_grad()
153 | loss = self.get_loss(output, target_category)
154 | loss.backward(retain_graph=True)
155 |
156 | # In most of the saliency attribution papers, the saliency is
157 | # computed with a single target layer.
158 | # Commonly it is the last convolutional layer.
159 | # Here we support passing a list with multiple target layers.
160 | # It will compute the saliency image for every image,
161 | # and then aggregate them (with a default mean aggregation).
162 | # This gives you more flexibility in case you just want to
163 | # use all conv layers for example, all Batchnorm layers,
164 | # or something else.
165 | cam_per_layer = self.compute_cam_per_layer(input_tensor)
166 | return self.aggregate_multi_layers(cam_per_layer)
167 |
168 | def __del__(self):
169 | self.activations_and_grads.release()
170 |
171 | def __enter__(self):
172 | return self
173 |
174 | def __exit__(self, exc_type, exc_value, exc_tb):
175 | self.activations_and_grads.release()
176 | if isinstance(exc_value, IndexError):
177 | # Handle IndexError here...
178 | print(
179 | f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}")
180 | return True
181 |
182 |
183 | def show_cam_on_image(img: np.ndarray,
184 | mask: np.ndarray,
185 | use_rgb: bool = False,
186 | colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
187 | """ This function overlays the cam mask on the image as an heatmap.
188 | By default the heatmap is in BGR format.
189 |
190 | :param img: The base image in RGB or BGR format.
191 | :param mask: The cam mask.
192 | :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
193 | :param colormap: The OpenCV colormap to be used.
194 | :returns: The default image with the cam overlay.
195 | """
196 |
197 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
198 | if use_rgb:
199 | heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
200 | heatmap = np.float32(heatmap) / 255
201 |
202 | if np.max(img) > 1:
203 | raise Exception(
204 | "The input image should np.float32 in the range [0, 1]")
205 |
206 | cam = heatmap * 0.4 + img * 0.6 #heatmap + img
207 | cam = cam / np.max(cam)
208 | return np.uint8(255 * cam)
209 |
210 |
211 | def center_crop_img(img: np.ndarray, size: int):
212 | h, w, c = img.shape
213 |
214 | if w == h == size:
215 | return img
216 |
217 | if w < h:
218 | ratio = size / w
219 | new_w = size
220 | new_h = int(h * ratio)
221 | else:
222 | ratio = size / h
223 | new_h = size
224 | new_w = int(w * ratio)
225 |
226 | img = cv2.resize(img, dsize=(new_w, new_h))
227 |
228 | if new_w == size:
229 | h = (new_h - size) // 2
230 | img = img[h: h+size]
231 | else:
232 | w = (new_w - size) // 2
233 | img = img[:, w: w+size]
234 |
235 | return img
236 |
--------------------------------------------------------------------------------
/IDKL/grad_cam/vit_model.py:
--------------------------------------------------------------------------------
1 | """
2 | original code from rwightman:
3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
4 | """
5 | from functools import partial
6 | from collections import OrderedDict
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 |
12 | def drop_path(x, drop_prob: float = 0., training: bool = False):
13 | """
14 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
15 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
16 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
17 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
18 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
19 | 'survival rate' as the argument.
20 | """
21 | if drop_prob == 0. or not training:
22 | return x
23 | keep_prob = 1 - drop_prob
24 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
25 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
26 | random_tensor.floor_() # binarize
27 | output = x.div(keep_prob) * random_tensor
28 | return output
29 |
30 |
31 | class DropPath(nn.Module):
32 | """
33 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
34 | """
35 | def __init__(self, drop_prob=None):
36 | super(DropPath, self).__init__()
37 | self.drop_prob = drop_prob
38 |
39 | def forward(self, x):
40 | return drop_path(x, self.drop_prob, self.training)
41 |
42 |
43 | class PatchEmbed(nn.Module):
44 | """
45 | 2D Image to Patch Embedding
46 | """
47 | def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
48 | super().__init__()
49 | img_size = (img_size, img_size)
50 | patch_size = (patch_size, patch_size)
51 | self.img_size = img_size
52 | self.patch_size = patch_size
53 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
54 | self.num_patches = self.grid_size[0] * self.grid_size[1]
55 |
56 | self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
57 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
58 |
59 | def forward(self, x):
60 | B, C, H, W = x.shape
61 | assert H == self.img_size[0] and W == self.img_size[1], \
62 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
63 |
64 | # flatten: [B, C, H, W] -> [B, C, HW]
65 | # transpose: [B, C, HW] -> [B, HW, C]
66 | x = self.proj(x).flatten(2).transpose(1, 2)
67 | x = self.norm(x)
68 | return x
69 |
70 |
71 | class Attention(nn.Module):
72 | def __init__(self,
73 | dim, # 输入token的dim
74 | num_heads=8,
75 | qkv_bias=False,
76 | qk_scale=None,
77 | attn_drop_ratio=0.,
78 | proj_drop_ratio=0.):
79 | super(Attention, self).__init__()
80 | self.num_heads = num_heads
81 | head_dim = dim // num_heads
82 | self.scale = qk_scale or head_dim ** -0.5
83 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
84 | self.attn_drop = nn.Dropout(attn_drop_ratio)
85 | self.proj = nn.Linear(dim, dim)
86 | self.proj_drop = nn.Dropout(proj_drop_ratio)
87 |
88 | def forward(self, x):
89 | # [batch_size, num_patches + 1, total_embed_dim]
90 | B, N, C = x.shape
91 |
92 | # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
93 | # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
94 | # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
95 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
96 | # [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
97 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
98 |
99 | # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
100 | # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
101 | attn = (q @ k.transpose(-2, -1)) * self.scale
102 | attn = attn.softmax(dim=-1)
103 | attn = self.attn_drop(attn)
104 |
105 | # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
106 | # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
107 | # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
108 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
109 | x = self.proj(x)
110 | x = self.proj_drop(x)
111 | return x
112 |
113 |
114 | class Mlp(nn.Module):
115 | """
116 | MLP as used in Vision Transformer, MLP-Mixer and related networks
117 | """
118 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
119 | super().__init__()
120 | out_features = out_features or in_features
121 | hidden_features = hidden_features or in_features
122 | self.fc1 = nn.Linear(in_features, hidden_features)
123 | self.act = act_layer()
124 | self.fc2 = nn.Linear(hidden_features, out_features)
125 | self.drop = nn.Dropout(drop)
126 |
127 | def forward(self, x):
128 | x = self.fc1(x)
129 | x = self.act(x)
130 | x = self.drop(x)
131 | x = self.fc2(x)
132 | x = self.drop(x)
133 | return x
134 |
135 |
136 | class Block(nn.Module):
137 | def __init__(self,
138 | dim,
139 | num_heads,
140 | mlp_ratio=4.,
141 | qkv_bias=False,
142 | qk_scale=None,
143 | drop_ratio=0.,
144 | attn_drop_ratio=0.,
145 | drop_path_ratio=0.,
146 | act_layer=nn.GELU,
147 | norm_layer=nn.LayerNorm):
148 | super(Block, self).__init__()
149 | self.norm1 = norm_layer(dim)
150 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
151 | attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
152 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
153 | self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
154 | self.norm2 = norm_layer(dim)
155 | mlp_hidden_dim = int(dim * mlp_ratio)
156 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)
157 |
158 | def forward(self, x):
159 | x = x + self.drop_path(self.attn(self.norm1(x)))
160 | x = x + self.drop_path(self.mlp(self.norm2(x)))
161 | return x
162 |
163 |
164 | class VisionTransformer(nn.Module):
165 | def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
166 | embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
167 | qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
168 | attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
169 | act_layer=None):
170 | """
171 | Args:
172 | img_size (int, tuple): input image size
173 | patch_size (int, tuple): patch size
174 | in_c (int): number of input channels
175 | num_classes (int): number of classes for classification head
176 | embed_dim (int): embedding dimension
177 | depth (int): depth of transformer
178 | num_heads (int): number of attention heads
179 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim
180 | qkv_bias (bool): enable bias for qkv if True
181 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set
182 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
183 | distilled (bool): model includes a distillation token and head as in DeiT models
184 | drop_ratio (float): dropout rate
185 | attn_drop_ratio (float): attention dropout rate
186 | drop_path_ratio (float): stochastic depth rate
187 | embed_layer (nn.Module): patch embedding layer
188 | norm_layer: (nn.Module): normalization layer
189 | """
190 | super(VisionTransformer, self).__init__()
191 | self.num_classes = num_classes
192 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
193 | self.num_tokens = 2 if distilled else 1
194 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
195 | act_layer = act_layer or nn.GELU
196 |
197 | self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
198 | num_patches = self.patch_embed.num_patches
199 |
200 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
201 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
202 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
203 | self.pos_drop = nn.Dropout(p=drop_ratio)
204 |
205 | dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule
206 | self.blocks = nn.Sequential(*[
207 | Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
208 | drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
209 | norm_layer=norm_layer, act_layer=act_layer)
210 | for i in range(depth)
211 | ])
212 | self.norm = norm_layer(embed_dim)
213 |
214 | # Representation layer
215 | if representation_size and not distilled:
216 | self.has_logits = True
217 | self.num_features = representation_size
218 | self.pre_logits = nn.Sequential(OrderedDict([
219 | ("fc", nn.Linear(embed_dim, representation_size)),
220 | ("act", nn.Tanh())
221 | ]))
222 | else:
223 | self.has_logits = False
224 | self.pre_logits = nn.Identity()
225 |
226 | # Classifier head(s)
227 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
228 | self.head_dist = None
229 | if distilled:
230 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
231 |
232 | # Weight init
233 | nn.init.trunc_normal_(self.pos_embed, std=0.02)
234 | if self.dist_token is not None:
235 | nn.init.trunc_normal_(self.dist_token, std=0.02)
236 |
237 | nn.init.trunc_normal_(self.cls_token, std=0.02)
238 | self.apply(_init_vit_weights)
239 |
240 | def forward_features(self, x):
241 | # [B, C, H, W] -> [B, num_patches, embed_dim]
242 | x = self.patch_embed(x) # [B, 196, 768]
243 | # [1, 1, 768] -> [B, 1, 768]
244 | cls_token = self.cls_token.expand(x.shape[0], -1, -1)
245 | if self.dist_token is None:
246 | x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]
247 | else:
248 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
249 |
250 | x = self.pos_drop(x + self.pos_embed)
251 | x = self.blocks(x)
252 | x = self.norm(x)
253 | if self.dist_token is None:
254 | return self.pre_logits(x[:, 0])
255 | else:
256 | return x[:, 0], x[:, 1]
257 |
258 | def forward(self, x):
259 | x = self.forward_features(x)
260 | if self.head_dist is not None:
261 | x, x_dist = self.head(x[0]), self.head_dist(x[1])
262 | if self.training and not torch.jit.is_scripting():
263 | # during inference, return the average of both classifier predictions
264 | return x, x_dist
265 | else:
266 | return (x + x_dist) / 2
267 | else:
268 | x = self.head(x)
269 | return x
270 |
271 |
272 | def _init_vit_weights(m):
273 | """
274 | ViT weight initialization
275 | :param m: module
276 | """
277 | if isinstance(m, nn.Linear):
278 | nn.init.trunc_normal_(m.weight, std=.01)
279 | if m.bias is not None:
280 | nn.init.zeros_(m.bias)
281 | elif isinstance(m, nn.Conv2d):
282 | nn.init.kaiming_normal_(m.weight, mode="fan_out")
283 | if m.bias is not None:
284 | nn.init.zeros_(m.bias)
285 | elif isinstance(m, nn.LayerNorm):
286 | nn.init.zeros_(m.bias)
287 | nn.init.ones_(m.weight)
288 |
289 |
290 | def vit_base_patch16_224(num_classes: int = 1000):
291 | """
292 | ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
293 | ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
294 | weights ported from official Google JAX impl:
295 | 链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA 密码: eu9f
296 | """
297 | model = VisionTransformer(img_size=224,
298 | patch_size=16,
299 | embed_dim=768,
300 | depth=12,
301 | num_heads=12,
302 | representation_size=None,
303 | num_classes=num_classes)
304 | return model
305 |
306 |
307 | def vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
308 | """
309 | ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
310 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
311 | weights ported from official Google JAX impl:
312 | https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth
313 | """
314 | model = VisionTransformer(img_size=224,
315 | patch_size=16,
316 | embed_dim=768,
317 | depth=12,
318 | num_heads=12,
319 | representation_size=768 if has_logits else None,
320 | num_classes=num_classes)
321 | return model
322 |
323 |
324 | def vit_base_patch32_224(num_classes: int = 1000):
325 | """
326 | ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
327 | ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
328 | weights ported from official Google JAX impl:
329 | 链接: https://pan.baidu.com/s/1hCv0U8pQomwAtHBYc4hmZg 密码: s5hl
330 | """
331 | model = VisionTransformer(img_size=224,
332 | patch_size=32,
333 | embed_dim=768,
334 | depth=12,
335 | num_heads=12,
336 | representation_size=None,
337 | num_classes=num_classes)
338 | return model
339 |
340 |
341 | def vit_base_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
342 | """
343 | ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
344 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
345 | weights ported from official Google JAX impl:
346 | https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth
347 | """
348 | model = VisionTransformer(img_size=224,
349 | patch_size=32,
350 | embed_dim=768,
351 | depth=12,
352 | num_heads=12,
353 | representation_size=768 if has_logits else None,
354 | num_classes=num_classes)
355 | return model
356 |
357 |
358 | def vit_large_patch16_224(num_classes: int = 1000):
359 | """
360 | ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
361 | ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
362 | weights ported from official Google JAX impl:
363 | 链接: https://pan.baidu.com/s/1cxBgZJJ6qUWPSBNcE4TdRQ 密码: qqt8
364 | """
365 | model = VisionTransformer(img_size=224,
366 | patch_size=16,
367 | embed_dim=1024,
368 | depth=24,
369 | num_heads=16,
370 | representation_size=None,
371 | num_classes=num_classes)
372 | return model
373 |
374 |
375 | def vit_large_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
376 | """
377 | ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
378 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
379 | weights ported from official Google JAX impl:
380 | https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth
381 | """
382 | model = VisionTransformer(img_size=224,
383 | patch_size=16,
384 | embed_dim=1024,
385 | depth=24,
386 | num_heads=16,
387 | representation_size=1024 if has_logits else None,
388 | num_classes=num_classes)
389 | return model
390 |
391 |
392 | def vit_large_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
393 | """
394 | ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
395 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
396 | weights ported from official Google JAX impl:
397 | https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth
398 | """
399 | model = VisionTransformer(img_size=224,
400 | patch_size=32,
401 | embed_dim=1024,
402 | depth=24,
403 | num_heads=16,
404 | representation_size=1024 if has_logits else None,
405 | num_classes=num_classes)
406 | return model
407 |
408 |
409 | def vit_huge_patch14_224_in21k(num_classes: int = 21843, has_logits: bool = True):
410 | """
411 | ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
412 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
413 | NOTE: converted weights not currently available, too large for github release hosting.
414 | """
415 | model = VisionTransformer(img_size=224,
416 | patch_size=14,
417 | embed_dim=1280,
418 | depth=32,
419 | num_heads=16,
420 | representation_size=1280 if has_logits else None,
421 | num_classes=num_classes)
422 | return model
423 |
--------------------------------------------------------------------------------
/IDKL/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from layers.loss.am_softmax import AMSoftmaxLoss
2 | from layers.loss.center_loss import CenterLoss
3 | from layers.loss.triplet_loss import TripletLoss
4 | from layers.loss.rerank_loss import RerankLoss
5 | from layers.loss.local_center_loss import CenterTripletLoss
6 | from layers.module.norm_linear import NormalizeLinear
7 | from layers.module.reverse_grad import ReverseGrad
8 | from layers.loss.JSD import js_div
9 | from layers.module.CBAM import cbam
10 | from layers.module.NonLocal import NonLocalBlockND
11 |
12 |
13 | __all__ = ['RerankLoss','CenterLoss', 'CenterTripletLoss', 'AMSoftmaxLoss', 'TripletLoss', 'NormalizeLinear', 'js_div', 'cbam', 'NonLocalBlockND']
--------------------------------------------------------------------------------
/IDKL/layers/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/layers/loss/JSD.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | # import torch.softmax as softmax
3 | from torch.nn import functional as F
4 |
5 | class js_div:
6 | def __init__(self):
7 | self.KLDivLoss = nn.KLDivLoss(reduction='batchmean')
8 |
9 | def __call__(self, p_output, q_output, get_softmax=True):
10 | """
11 | Function that measures JS divergence between target and output logits:
12 | """
13 | if get_softmax:
14 | p_output = F.softmax(p_output, 1)
15 | q_output = F.softmax(q_output, 1)
16 | log_mean_output = ((p_output + q_output) / 2).log()
17 | return (self.KLDivLoss(log_mean_output, p_output) + self.KLDivLoss(log_mean_output, q_output))/2
--------------------------------------------------------------------------------
/IDKL/layers/loss/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/loss/__init__.py
--------------------------------------------------------------------------------
/IDKL/layers/loss/__pycache__/JSD.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/loss/__pycache__/JSD.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/layers/loss/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/loss/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/layers/loss/__pycache__/am_softmax.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/loss/__pycache__/am_softmax.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/layers/loss/__pycache__/center_loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/loss/__pycache__/center_loss.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/layers/loss/__pycache__/crossquad_loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/loss/__pycache__/crossquad_loss.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/layers/loss/__pycache__/crosstriplet_loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/loss/__pycache__/crosstriplet_loss.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/layers/loss/__pycache__/local_center_loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/loss/__pycache__/local_center_loss.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/layers/loss/__pycache__/mixtriplet_loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/loss/__pycache__/mixtriplet_loss.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/layers/loss/__pycache__/trapezoid_loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/loss/__pycache__/trapezoid_loss.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/layers/loss/__pycache__/triplet_loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/loss/__pycache__/triplet_loss.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/layers/loss/am_softmax.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class AMSoftmaxLoss(nn.Module):
8 | def __init__(self, scale, margin, weight=None, ignore_index=-100, reduction='mean'):
9 | super(AMSoftmaxLoss, self).__init__()
10 | self.weight = weight
11 | self.ignore_index = ignore_index
12 | self.reduction = reduction
13 | self.scale = scale
14 | self.margin = margin
15 |
16 | def forward(self, x, y):
17 | y_onehot = torch.zeros_like(x, device=x.device)
18 | y_onehot.scatter_(1, y.data.view(-1, 1), self.margin)
19 |
20 | out = self.scale * (x - y_onehot)
21 | loss = F.cross_entropy(out, y, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction)
22 |
23 | return loss
24 |
--------------------------------------------------------------------------------
/IDKL/layers/loss/center_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class CenterLoss(nn.Module):
6 | """Center loss.
7 |
8 | Reference:
9 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
10 |
11 | Args:
12 | num_classes (int): number of classes.
13 | feat_dim (int): feature dimension.
14 | """
15 |
16 | def __init__(self, num_classes, feat_dim, reduction='mean'):
17 | super(CenterLoss, self).__init__()
18 | self.num_classes = num_classes
19 | self.feat_dim = feat_dim
20 | self.reduction = reduction
21 |
22 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
23 |
24 | def forward(self, x, labels):
25 | """
26 | Args:
27 | x: feature matrix with shape (batch_size, feat_dim).
28 | labels: ground truth labels with shape (batch_size).
29 | """
30 | batch_size = x.size(0)
31 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
32 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
33 | distmat.addmm_(1, -2, x, self.centers.t())
34 |
35 | classes = torch.arange(self.num_classes).to(device=x.device, dtype=torch.long)
36 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
37 | mask = labels.eq(classes.expand(batch_size, self.num_classes))
38 |
39 | loss = distmat * mask.float()
40 |
41 | if self.reduction == 'mean':
42 | loss = loss.mean()
43 | elif self.reduction == 'sum':
44 | loss = loss.sum()
45 |
46 | return loss
47 |
--------------------------------------------------------------------------------
/IDKL/layers/loss/local_center_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | class CenterTripletLoss(nn.Module):
5 | def __init__(self, k_size, margin=0):
6 | super(CenterTripletLoss, self).__init__()
7 | self.margin = margin
8 | self.k_size = k_size
9 | self.ranking_loss = nn.MarginRankingLoss(margin=margin)
10 |
11 | def forward(self, inputs, targets):
12 | n = inputs.size(0)
13 |
14 | # Come to centers
15 | centers = []
16 | for i in range(n):
17 | centers.append(inputs[targets == targets[i]].mean(0))
18 | centers = torch.stack(centers)
19 |
20 | dist_pc = (inputs - centers)**2
21 | dist_pc = dist_pc.sum(1)
22 | dist_pc = dist_pc.sqrt()
23 |
24 | # Compute pairwise distance, replace by the official when merged
25 | dist = torch.pow(centers, 2).sum(dim=1, keepdim=True).expand(n, n)
26 | dist = dist + dist.t()
27 | dist.addmm_(1, -2, centers, centers.t())
28 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
29 |
30 | # For each anchor, find the hardest positive and negative
31 | mask = targets.expand(n, n).eq(targets.expand(n, n).t())
32 | dist_an, dist_ap = [], []
33 | for i in range(0, n, self.k_size):
34 | dist_an.append( (self.margin - dist[i][mask[i] == 0]).clamp(min=0.0).mean() )
35 | dist_an = torch.stack(dist_an)
36 |
37 | # Compute ranking hinge loss
38 | y = dist_an.data.new()
39 | y.resize_as_(dist_an.data)
40 | y.fill_(1)
41 | loss = dist_pc.mean() + dist_an.mean()
42 | return loss, dist_pc.mean(), dist_an.mean()
43 |
--------------------------------------------------------------------------------
/IDKL/layers/loss/rerank_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from utils.rerank import pairwise_distance
4 |
5 | def intersect1d(tensor1, tensor2):
6 | return torch.unique(torch.cat([tensor1[tensor1 == val] for val in tensor2]))
7 |
8 | # def rerank_vc(feat1, feat2, k1=20, k2=6, lambda_value=0.3, eval_type=True): #q_feat, g_feat ############代码结果不知正确与否,但没有增加显存了,aini
9 | # feats = torch.cat([feat1, feat2], 0)
10 | #
11 | # dist = torch.cdist(feats, feats)
12 | # original_dist = dist.clone()
13 | # all_num = original_dist.shape[0]
14 | # original_dist = (original_dist / original_dist.max(dim=0, keepdim=True).values).transpose(0, 1)
15 | #
16 | # V = torch.zeros_like(original_dist)
17 | #
18 | # query_num = feat1.size(0)
19 | # if eval_type:
20 | # max_val = dist.max()
21 | # dist = torch.cat((dist[:, :query_num], max_val.expand_as(dist[:, query_num:])), dim=1)
22 | # initial_rank = torch.argsort(dist, dim=1)
23 | #
24 | # for i in range(all_num):
25 | # forward_k_neigh_index = initial_rank[i, :k1 + 1]
26 | # backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1]
27 | # fi = (backward_k_neigh_index == i).nonzero(as_tuple=True)[0]
28 | # k_reciprocal_index = forward_k_neigh_index[fi]
29 | # k_reciprocal_expansion_index = k_reciprocal_index
30 | #
31 | # for j in k_reciprocal_index:
32 | # candidate = j
33 | # candidate_forward_k_neigh_index = initial_rank[candidate, :int(round(k1 / 2)) + 1]
34 | # candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, :int(round(k1 / 2)) + 1]
35 | # fi_candidate = (candidate_backward_k_neigh_index == candidate).nonzero(as_tuple=True)[0]
36 | # candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate]
37 | # if len(intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len(
38 | # candidate_k_reciprocal_index):
39 | # k_reciprocal_expansion_index = torch.unique(
40 | # torch.cat([k_reciprocal_expansion_index, candidate_k_reciprocal_index], dim=0))
41 | #
42 | # weight = torch.exp(-original_dist[i, k_reciprocal_expansion_index])
43 | # V[i, k_reciprocal_expansion_index] = weight / torch.sum(weight)
44 | #
45 | # original_dist = original_dist[:query_num, ]
46 | # if k2 != 1:
47 | # V_qe = torch.zeros_like(V)
48 | # for i in range(all_num):
49 | # V_qe[i, :] = torch.mean(V[initial_rank[i, :k2], :], dim=0)
50 | # V = V_qe
51 | #
52 | # invIndex = []
53 | # for i in range(all_num):
54 | # invIndex.append((V[:, i] != 0).nonzero(as_tuple=True)[0])
55 | #
56 | # jaccard_dist = torch.zeros_like(original_dist)
57 | #
58 | # for i in range(query_num):
59 | # temp_min = torch.zeros([1, all_num]).cuda()
60 | # indNonZero = (V[i, :] != 0).nonzero(as_tuple=True)[0]
61 | # indImages = [invIndex[ind] for ind in indNonZero]
62 | # for j, val in enumerate(indNonZero):
63 | # temp_min[0, indImages[j]] += torch.minimum(V[i, val], V[indImages[j], val])
64 | #
65 | # jaccard_dist[i] = 1 - temp_min / (2 - temp_min)
66 | #
67 | # final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value
68 | # final_dist = final_dist[:query_num, query_num:]
69 | #
70 | # return final_dist
71 |
72 | def rerank_dist(feat1, feat2, k1=20, k2=6, lambda_value=0.3, eval_type=True): #q_feat, g_feat
73 |
74 | #with torch.no_grad():
75 | feats = torch.cat([feat1, feat2], 0) #######
76 | dist = pairwise_distance(feats, feats)
77 | original_dist = dist.clone() # .detach() # .clone()
78 | # import pdb
79 | # pdb.set_trace()
80 | all_num = original_dist.shape[0]
81 |
82 | #original_dist = original_dist / torch.max(original_dist, dim=0).values
83 |
84 | original_dist = torch.transpose(original_dist, 0,1) #.transpose(0, 1)
85 | V = torch.zeros_like(original_dist) # .half()
86 |
87 |
88 | query_num = feat1.size(0)
89 |
90 | #with torch.no_grad():
91 | if eval_type:
92 | # dist[:, query_num:] = dist.max()罪魁祸首
93 | max_val = dist.max()
94 | dist = torch.cat((dist[:, :query_num], max_val.expand_as(dist[:, query_num:])), dim=1)
95 | initial_rank = torch.argsort(dist, dim=1)
96 | # import pdb
97 | # pdb.set_trace()
98 |
99 |
100 |
101 | for i in range(all_num):
102 | forward_k_neigh_index = initial_rank[i, :k1 + 1]
103 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1]
104 | fi = torch.where(backward_k_neigh_index == i)[0]
105 | k_reciprocal_index = forward_k_neigh_index[fi]
106 | k_reciprocal_expansion_index = k_reciprocal_index
107 |
108 | for j in k_reciprocal_index:
109 | candidate = j.item()
110 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(round(k1 / 2)) + 1]
111 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,
112 | :int(round(k1 / 2)) + 1]
113 | # import pdb
114 | # pdb.set_trace()
115 | fi_candidate = torch.where(candidate_backward_k_neigh_index == candidate)[0]
116 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate]
117 |
118 | if len(intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len(
119 | candidate_k_reciprocal_index):
120 | k_reciprocal_expansion_index = torch.unique(
121 | torch.cat([k_reciprocal_expansion_index, candidate_k_reciprocal_index], 0))
122 |
123 | weight = torch.exp(-original_dist[i, k_reciprocal_expansion_index])
124 | V[i, k_reciprocal_expansion_index] = (weight / torch.sum(weight)) # .half()
125 |
126 |
127 | original_dist = original_dist[:query_num, ]
128 | # print('before')
129 | # objgraph.show_growth(limit=3)
130 |
131 | if k2 != 1:
132 | V_qe = torch.zeros_like(V) # .half()
133 | for i in range(all_num):
134 | V_qe[i, :] = torch.mean(V[initial_rank[i, :k2], :], dim=0)
135 | V = V_qe
136 |
137 | invIndex = []
138 | for i in range(all_num):
139 | invIndex.append(torch.where(V[:, i] != 0)[0])
140 |
141 | jaccard_dist = torch.zeros_like(original_dist) # .half()
142 |
143 | # print('after')
144 | # objgraph.show_growth(limit=3)
145 |
146 | # with torch.no_grad():
147 | # for i in range(query_num):
148 | # temp_min = torch.zeros([1, all_num], device="cuda")
149 | # indNonZero = torch.where(V[i, :] != 0)[0]
150 | # indImages = [invIndex[ind] for ind in indNonZero]
151 | # for j in range(len(indNonZero)):
152 | # temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + torch.min(V[i, indNonZero[j]],
153 | # V[indImages[j], indNonZero[j]])
154 | # jaccard_dist[i] = 1 - temp_min / (2 - temp_min)
155 |
156 | for i in range(query_num):
157 | temp_min = torch.zeros([1, all_num], device="cuda") # .half()
158 | indNonZero = torch.where(V[i, :] != 0)[0]
159 | indImages = [invIndex[ind] for ind in indNonZero]
160 | for j in range(len(indNonZero)):
161 | temp_min[0, indImages[j]] += torch.min(V[i, indNonZero[j]], V[indImages[j], indNonZero[j]])
162 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min)
163 |
164 | # print('before')
165 | # objgraph.show_growth(limit=3)
166 | # print("Before:", torch.cuda.memory_allocated())
167 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value
168 | # print("After:", torch.cuda.memory_allocated())
169 | # import pdb
170 | # pdb.set_trace() ####jaccard_dist有细微差异和原来的rerank对比
171 | # del temp_min, jaccard_dist, V, original_dist,forward_k_neigh_index,backward_k_neigh_index,\
172 | # k_reciprocal_expansion_index,V_qe,candidate_k_reciprocal_index,candidate_backward_k_neigh_index,\
173 | # candidate_forward_k_neigh_index,k_reciprocal_index,fi,fi_candidate
174 | # torch.cuda.empty_cache()
175 | final_dist = final_dist[:query_num, query_num:]
176 | # import pdb
177 | # pdb.set_trace()
178 | # del original_dist, dist
179 | # torch.cuda.empty_cache()
180 | return final_dist
181 |
182 |
183 | class RerankLoss(nn.Module):
184 | def __init__(self, margin=0.03):
185 | super(RerankLoss, self).__init__()
186 | self.margin = margin
187 | self.ranking_loss = nn.MarginRankingLoss(margin=margin)
188 |
189 | def forward(self, inputs, targets):
190 | #def forward(self, inputs1, inputs2, targets):
191 |
192 | #n = inputs1.size(0)
193 | n = inputs.size(0)
194 | dist = rerank_dist(inputs, inputs)
195 | #dist = rerank_dist(inputs1, inputs2)
196 |
197 | mask = targets.expand(n, n).eq(targets.expand(n, n).t())
198 | dist_ap, dist_an = [], []
199 | for i in range(n):
200 | dist_ap.append(dist[i][mask[i]].max())
201 | dist_an.append(dist[i][mask[i] == 0].min())
202 | dist_ap = torch.stack(dist_ap)
203 | dist_an = torch.stack(dist_an)
204 |
205 | # Compute ranking hinge loss
206 | # y = dist_an.data.new()
207 | # y.resize_as_(dist_an.data)
208 | # y.fill_(1)
209 | y = torch.ones_like(dist_an)
210 | loss = self.ranking_loss(dist_an, dist_ap, y)
211 | #prec = dist_an.data > dist_ap.data
212 | #length = torch.sqrt((inputs * inputs).sum(1)).mean()
213 | return loss, dist,dist_ap, dist_an
--------------------------------------------------------------------------------
/IDKL/layers/loss/triplet_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | class TripletLoss(nn.Module):
5 | def __init__(self, margin=0):
6 | super(TripletLoss, self).__init__()
7 | self.margin = margin
8 | self.ranking_loss = nn.MarginRankingLoss(margin=margin)
9 |
10 | def forward(self, inputs, targets):
11 | n = inputs.size(0)
12 | # Compute pairwise distance, replace by the official when merged
13 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
14 | dist = dist + dist.t()
15 | dist.addmm_(1, -2, inputs, inputs.t())
16 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
17 |
18 | # For each anchor, find the hardest positive and negative
19 | mask = targets.expand(n, n).eq(targets.expand(n, n).t())
20 | dist_ap, dist_an = [], []
21 | for i in range(n):
22 | dist_ap.append(dist[i][mask[i]].max())
23 | dist_an.append(dist[i][mask[i] == 0].min())
24 | dist_ap = torch.stack(dist_ap)
25 | dist_an = torch.stack(dist_an)
26 |
27 | # Compute ranking hinge loss
28 | # y = dist_an.data.new()
29 | # y.resize_as_(dist_an.data)
30 | # y.fill_(1)
31 | y = torch.ones_like(dist_an)
32 | loss = self.ranking_loss(dist_an, dist_ap, y)
33 | prec = dist_an.data > dist_ap.data
34 | length = torch.sqrt((inputs * inputs).sum(1)).mean()
35 | return loss, dist,dist_ap, dist_an
36 |
--------------------------------------------------------------------------------
/IDKL/layers/module/CBAM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class ChannelAttention(nn.Module):
6 | def __init__(self, in_planes, ratio=16):
7 | super(ChannelAttention, self).__init__()
8 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
9 | self.max_pool = nn.AdaptiveMaxPool2d(1)
10 |
11 | self.sharedMLP = nn.Sequential(
12 | nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
13 | nn.ReLU(),
14 | nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False))
15 | self.sigmoid = nn.Sigmoid()
16 |
17 | def forward(self, x):
18 | avgout = self.sharedMLP(self.avg_pool(x))
19 | maxout = self.sharedMLP(self.max_pool(x))
20 | return self.sigmoid(avgout + maxout)
21 |
22 |
23 | class SpatialAttention(nn.Module):
24 | def __init__(self, kernel_size=7):
25 | super(SpatialAttention, self).__init__()
26 | assert kernel_size in (3,7), "kernel size must be 3 or 7"
27 | padding = 3 if kernel_size == 7 else 1
28 |
29 | self.conv = nn.Conv2d(2,1,kernel_size, padding=padding, bias=False)
30 | self.sigmoid = nn.Sigmoid()
31 |
32 | def forward(self, x):
33 | avgout = torch.mean(x, dim=1, keepdim=True)
34 | maxout, _ = torch.max(x, dim=1, keepdim=True)
35 | x = torch.cat([avgout, maxout], dim=1)
36 | x = self.conv(x)
37 | return self.sigmoid(x)
38 |
39 |
40 | class cbam(nn.Module):
41 | def __init__(self, planes):
42 | super(cbam, self).__init__()
43 | self.ca = ChannelAttention(planes)
44 | self.sa = SpatialAttention()
45 |
46 | def forward(self, x):
47 | x = self.ca(x) * x
48 | x = self.sa(x) * x
49 | return x
--------------------------------------------------------------------------------
/IDKL/layers/module/NonLocal.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 |
6 | class NonLocalBlockND(nn.Module):
7 | """
8 | 调用过程
9 | NONLocalBlock2D(in_channels=32),
10 | super(NONLocalBlock2D, self).__init__(in_channels,
11 | inter_channels=inter_channels,
12 | dimension=2, sub_sample=sub_sample,
13 | bn_layer=bn_layer)
14 | """
15 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
16 | super(NonLocalBlockND, self).__init__()
17 |
18 | assert dimension in [1, 2, 3]
19 |
20 | self.dimension = dimension
21 | self.sub_sample = sub_sample
22 |
23 | self.in_channels = in_channels
24 | self.inter_channels = inter_channels
25 |
26 | if self.inter_channels is None:
27 | self.inter_channels = in_channels // 2
28 | if self.inter_channels == 0:
29 | self.inter_channels = 1
30 |
31 | if dimension == 3:
32 | conv_nd = nn.Conv3d
33 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
34 | bn = nn.BatchNorm3d
35 | elif dimension == 2:
36 | conv_nd = nn.Conv2d
37 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
38 | bn = nn.BatchNorm2d
39 | else:
40 | conv_nd = nn.Conv1d
41 | max_pool_layer = nn.MaxPool1d(kernel_size=(2))
42 | bn = nn.BatchNorm1d
43 |
44 | self.g = conv_nd(in_channels=self.in_channels,
45 | out_channels=self.inter_channels,
46 | kernel_size=1,
47 | stride=1,
48 | padding=0)
49 |
50 | if bn_layer:
51 | self.W = nn.Sequential(
52 | conv_nd(in_channels=self.inter_channels,
53 | out_channels=self.in_channels,
54 | kernel_size=1,
55 | stride=1,
56 | padding=0), bn(self.in_channels))
57 | nn.init.constant_(self.W[1].weight, 0)
58 | nn.init.constant_(self.W[1].bias, 0)
59 | else:
60 | self.W = conv_nd(in_channels=self.inter_channels,
61 | out_channels=self.in_channels,
62 | kernel_size=1,
63 | stride=1,
64 | padding=0)
65 | nn.init.constant_(self.W.weight, 0)
66 | nn.init.constant_(self.W.bias, 0)
67 |
68 | self.theta = conv_nd(in_channels=self.in_channels,
69 | out_channels=self.inter_channels,
70 | kernel_size=1,
71 | stride=1,
72 | padding=0)
73 | self.phi = conv_nd(in_channels=self.in_channels,
74 | out_channels=self.inter_channels,
75 | kernel_size=1,
76 | stride=1,
77 | padding=0)
78 |
79 | if sub_sample:
80 | self.g = nn.Sequential(self.g, max_pool_layer)
81 | self.phi = nn.Sequential(self.phi, max_pool_layer)
82 |
83 | def forward(self, x):
84 | '''
85 | :param x: (b, c, h, w)
86 | :return:
87 | '''
88 |
89 | batch_size = x.size(0)
90 |
91 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)#[bs, c, w*h]
92 | g_x = g_x.permute(0, 2, 1)
93 |
94 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
95 | theta_x = theta_x.permute(0, 2, 1)
96 |
97 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
98 |
99 | f = torch.matmul(theta_x, phi_x)
100 |
101 | # print(f.shape)
102 |
103 | f_div_C = F.softmax(f, dim=-1)
104 |
105 | y = torch.matmul(f_div_C, g_x)
106 | y = y.permute(0, 2, 1).contiguous()
107 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
108 | W_y = self.W(y)
109 | z = W_y + x
110 | return z
--------------------------------------------------------------------------------
/IDKL/layers/module/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/module/__init__.py
--------------------------------------------------------------------------------
/IDKL/layers/module/__pycache__/CBAM.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/module/__pycache__/CBAM.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/layers/module/__pycache__/NonLocal.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/module/__pycache__/NonLocal.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/layers/module/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/module/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/layers/module/__pycache__/norm_linear.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/module/__pycache__/norm_linear.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/layers/module/__pycache__/reverse_grad.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/layers/module/__pycache__/reverse_grad.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/layers/module/norm_linear.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.init as init
5 | import torch.nn.functional as F
6 |
7 |
8 | class NormalizeLinear(nn.Module):
9 | def __init__(self, in_features, num_class):
10 | super(NormalizeLinear, self).__init__()
11 | self.weight = nn.Parameter(torch.Tensor(num_class, in_features))
12 | self.reset_parameters()
13 |
14 | def reset_parameters(self):
15 | init.kaiming_uniform_(self.weight, a=math.sqrt(5))
16 |
17 | def forward(self, x):
18 | w = F.normalize(self.weight.float(), p=2, dim=1)
19 | return F.linear(x.float(), w)
20 |
--------------------------------------------------------------------------------
/IDKL/layers/module/reverse_grad.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.autograd import Function
3 |
4 |
5 | class ReverseGradFunction(Function):
6 |
7 | @staticmethod
8 | def forward(ctx, data, alpha=1.0):
9 | ctx.alpha = alpha
10 | return data
11 |
12 | @staticmethod
13 | def backward(ctx, grad_outputs):
14 | grad = None
15 |
16 | if ctx.needs_input_grad[0]:
17 | grad = -ctx.alpha * grad_outputs
18 |
19 | return grad, None
20 |
21 |
22 | class ReverseGrad(nn.Module):
23 | def __init__(self):
24 | super(ReverseGrad, self).__init__()
25 |
26 | def forward(self, x, alpha=1.0):
27 | return ReverseGradFunction.apply(x, alpha)
28 |
--------------------------------------------------------------------------------
/IDKL/models/__pycache__/baseline.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/models/__pycache__/baseline.cpython-36.pyc
--------------------------------------------------------------------------------
/IDKL/models/__pycache__/baseline.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/models/__pycache__/baseline.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/models/__pycache__/resnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/models/__pycache__/resnet.cpython-36.pyc
--------------------------------------------------------------------------------
/IDKL/models/__pycache__/resnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/models/__pycache__/resnet.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/models/baseline.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn import init
5 | from torch.nn import functional as F
6 | from torch.nn import Parameter
7 | import numpy as np
8 |
9 | import cv2
10 | from layers.module.reverse_grad import ReverseGrad
11 | from models.resnet import resnet50, embed_net, convDiscrimination, Discrimination
12 | from utils.calc_acc import calc_acc
13 |
14 | from layers import TripletLoss, RerankLoss
15 | from layers import CenterTripletLoss
16 | from layers import CenterLoss
17 | from layers import cbam
18 | from layers import NonLocalBlockND
19 | from utils.rerank import re_ranking, pairwise_distance
20 |
21 | def intersect1d(tensor1, tensor2):
22 | return torch.unique(torch.cat([tensor1[tensor1 == val] for val in tensor2]))
23 |
24 | def spearman_loss(dist_matrix, rerank_matrix):
25 |
26 | sorted_idx_dist = torch.argsort(dist_matrix, dim=1)
27 | sorted_idx_rerank = torch.argsort(rerank_matrix, dim=1)
28 |
29 | rank_corr = 0
30 | n = dist_matrix.size(1)
31 | for i in range(dist_matrix.size(0)):
32 | diff = sorted_idx_dist[i] - sorted_idx_rerank[i]
33 | rank_corr += 1 - (6 * torch.sum(diff * diff) / (n * (n**2 - 1)))
34 |
35 | rank_corr /= dist_matrix.size(0)
36 |
37 | return 1 - rank_corr
38 |
39 |
40 | def Fb_dt(feat, labels):
41 | feat_dt = feat
42 | n_ft = feat_dt.size(0)
43 | dist_f = torch.pow(feat_dt, 2).sum(dim=1, keepdim=True).expand(n_ft, n_ft)
44 | dist_f = dist_f + dist_f.t()
45 | dist_f.addmm_(1, -2, feat_dt, feat_dt.t())
46 | dist_f = dist_f.clamp(min=1e-12).sqrt()
47 | mask_ft = labels.expand(n_ft, n_ft).eq(labels.expand(n_ft, n_ft).t())
48 | mask_ft_1 = torch.ones(n_ft, n_ft, dtype=bool)
49 | for i in range(n_ft):
50 | mask_ft_1[i, i] = 0
51 | mask_ft_2 = []
52 | for i in range(n_ft):
53 |
54 | mask_ft_2.append(mask_ft[i][mask_ft_1[i]])
55 | mask_ft_2 = torch.stack(mask_ft_2)
56 | dist_f_2 = []
57 | for i in range(n_ft):
58 |
59 | dist_f_2.append(dist_f[i][mask_ft_1[i]])
60 | dist_f_2 = torch.stack(dist_f_2)
61 | dist_f_2 = F.softmax(-(dist_f_2 - 1), 1)
62 | cN_ft = (mask_ft_2[0] == True).sum()
63 | f_d_ap = []
64 | for i in range(n_ft):
65 |
66 | f_d_ap.append(dist_f_2[i][mask_ft_2[i]])
67 | f_d_ap = torch.stack(f_d_ap).flatten()
68 | loss_f_d_ap = []
69 | xs_ft = 1
70 | m_ft = f_d_ap.size(0)
71 | for i in range(m_ft):
72 | loss_f_d_ap.append(
73 | -xs_ft * (1 / cN_ft) * torch.log(xs_ft * cN_ft * f_d_ap[i]))
74 | loss_f_d_ap = torch.stack(loss_f_d_ap).clamp(max=1e+3).sum() / n_ft
75 | return loss_f_d_ap
76 |
77 |
78 | def gem(x, p=3, eps=1e-6):
79 | return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1. / p)
80 |
81 | def gem_p(x):
82 | ss = gem(x).squeeze() # Gem池化
83 | ss= ss.view(ss.size(0), -1) # Gem池化
84 | return ss
85 | def pairwise_dist(x, y):
86 | # Compute pairwise distance of vectors
87 | xx = (x**2).sum(dim=1, keepdim=True)
88 | yy = (y**2).sum(dim=1, keepdim=True).t()
89 | dist = xx + yy - 2.0 * torch.mm(x, y.t())
90 | dist = dist.clamp(min=1e-6).sqrt() # for numerical stability
91 | return dist
92 |
93 | def kl_soft_dist(feat1,feat2):
94 | n_st = feat1.size(0)
95 | dist_st = pairwise_dist(feat1, feat2)
96 | mask_st_1 = torch.ones(n_st, n_st, dtype=bool)
97 | for i in range(n_st): # 将同一类样本中自己与自己的距离舍弃
98 | mask_st_1[i, i] = 0
99 | dist_st_2 = []
100 | for i in range(n_st):
101 | dist_st_2.append(dist_st[i][mask_st_1[i]])
102 | dist_st_2 = torch.stack(dist_st_2)
103 | return dist_st_2
104 |
105 |
106 | def Bg_kl(logits1, logits2):####输入:(60,206),(60,206)
107 | KL = nn.KLDivLoss(reduction='batchmean')
108 | kl_loss_12 = KL(F.log_softmax(logits1, 1), F.softmax(logits2, 1))
109 | kl_loss_21 = KL(F.log_softmax(logits2, 1), F.softmax(logits1, 1))
110 | bg_loss_kl = kl_loss_12 + kl_loss_21
111 | return kl_loss_12, bg_loss_kl
112 | def Sm_kl(logits1, logits2, labels):
113 | KL = nn.KLDivLoss(reduction='batchmean')
114 | m_kl = torch.div((labels == labels[0]).sum(), 2, rounding_mode='floor')
115 | v_logits_s = logits1.split(m_kl, 0)
116 | i_logits_s = logits2.split(m_kl, 0)
117 | sm_v_logits = torch.cat(v_logits_s, 1) # .t() # 5,206*12->206*12,5
118 | sm_i_logits = torch.cat(i_logits_s, 1) # .t()
119 | sm_kl_loss_vi = KL(F.log_softmax(sm_v_logits, 1), F.softmax(sm_i_logits, 1))
120 | sm_kl_loss_iv = KL(F.log_softmax(sm_i_logits, 1), F.softmax(sm_v_logits, 1))
121 | sm_kl_loss = sm_kl_loss_vi + sm_kl_loss_iv
122 | return sm_kl_loss_vi, sm_kl_loss
123 |
124 |
125 | def samplewise_entropy(logits):
126 | probabilities = F.softmax(logits, dim=1)
127 | log_probabilities = F.log_softmax(logits, dim=1)
128 | entropies = -torch.sum(probabilities * log_probabilities, dim=1)
129 | return entropies
130 |
131 |
132 | def entropy_margin_loss(logits1, logits2, margin):
133 | entropy1 = samplewise_entropy(logits1)
134 | entropy2 = samplewise_entropy(logits2)
135 | losses = torch.exp(F.relu(entropy2 - entropy1 + margin)) - 1
136 | return losses.mean()
137 |
138 |
139 | def compute_centroid_distance(features, labels, modalities):
140 | """
141 | 计算每个类别不同模态的中心特征的距离。
142 |
143 | 参数:
144 | features -- 特征矩阵,形状为(B, C)。
145 | labels -- 类别标签,形状为(B,)。
146 | modalities -- 模态标签,形状为(B,)。
147 |
148 | 返回:
149 | distances -- 每个类别模态中心距离的列表。
150 | """
151 | unique_labels = torch.unique(labels)
152 | distances = []
153 | for label in unique_labels:
154 | # 分别获取当前类别下的两种模态的特征
155 | features_modality_0 = features[(labels == label) & (modalities == 0)]
156 | features_modality_1 = features[(labels == label) & (modalities == 1)]
157 |
158 | # 计算中心特征
159 | centroid_modality_0 = features_modality_0.mean(dim=0)
160 | centroid_modality_1 = features_modality_1.mean(dim=0)
161 |
162 | # 计算两个中心特征之间的距离,这里使用欧氏距离
163 | distance = F.pairwise_distance(centroid_modality_0.unsqueeze(0), centroid_modality_1.unsqueeze(0))
164 | distances.append(distance)
165 |
166 |
167 | return torch.stack(distances)
168 |
169 |
170 | def modal_centroid_loss(F1, F2, labels, modalities, margin):
171 | """
172 | 计算损失函数,要求F2中每个类别不同模态的中心距离比F1更小,并施加一个margin。
173 |
174 | 参数:
175 | F1 -- 第一组特征,形状为(B, C)。
176 | F2 -- 第二组特征,经过网络结构优化,形状为(B, C)。
177 | labels -- 类别标签,形状为(B,)。
178 | modalities -- 模态标签,形状为(B,)。
179 | margin -- 施加的margin值。
180 |
181 | 返回:
182 | loss -- 计算的损失值。
183 | """
184 | # 计算F1和F2的中心距离
185 | distances_F1 = compute_centroid_distance(F1, labels, modalities)
186 | distances_F2 = compute_centroid_distance(F2, labels, modalities)
187 |
188 | # 计算带margin的损失
189 | losses = F.relu(distances_F2 - distances_F1 + margin)
190 |
191 | # 返回损失的平均值
192 | return losses.mean()
193 | class Baseline(nn.Module):
194 | def __init__(self, num_classes=None, drop_last_stride=False, decompose=False, **kwargs):
195 | super(Baseline, self).__init__()
196 |
197 | self.drop_last_stride = drop_last_stride
198 | self.decompose = decompose
199 | self.backbone = embed_net(drop_last_stride=drop_last_stride, decompose=decompose)
200 |
201 | self.base_dim = 2048
202 | self.dim = 0
203 | self.part_num = kwargs.get('num_parts', 0)
204 |
205 |
206 | print("output feat length:{}".format(self.base_dim + self.dim * self.part_num))
207 | self.bn_neck = nn.BatchNorm1d(self.base_dim + self.dim * self.part_num)
208 | nn.init.constant_(self.bn_neck.bias, 0)
209 | self.bn_neck.bias.requires_grad_(False)
210 | self.bn_neck_sp = nn.BatchNorm1d(self.base_dim + self.dim * self.part_num)
211 | nn.init.constant_(self.bn_neck_sp.bias, 0)
212 | self.bn_neck_sp.bias.requires_grad_(False)
213 |
214 | if kwargs.get('eval', False):
215 | return
216 |
217 | self.classification = kwargs.get('classification', False)
218 | self.triplet = kwargs.get('triplet', False)
219 | self.center_cluster = kwargs.get('center_cluster', False)
220 | self.center_loss = kwargs.get('center', False)
221 | self.margin = kwargs.get('margin', 0.3)
222 | self.CSA1 = kwargs.get('bg_kl', False)
223 | self.CSA2 = kwargs.get('sm_kl', False)
224 | self.TGSA = kwargs.get('distalign', False)
225 | self.IP = kwargs.get('IP', False)
226 | self.fb_dt = kwargs.get('fb_dt', False)
227 |
228 | if self.decompose:
229 | self.classifier = nn.Linear(self.base_dim + self.dim * self.part_num, num_classes, bias=False)
230 | self.classifier_sp = nn.Linear(self.base_dim, num_classes, bias=False)
231 | self.D_special = Discrimination()
232 | self.C_sp_f = nn.Linear(self.base_dim, num_classes, bias=False)
233 |
234 | self.D_shared_pseu = Discrimination(2048)
235 | self.grl = ReverseGrad()
236 |
237 | else:
238 | self.classifier = nn.Linear(self.base_dim + self.dim * self.part_num, num_classes, bias=False)
239 | if self.classification:
240 | self.id_loss = nn.CrossEntropyLoss(ignore_index=-1)
241 | if self.triplet:
242 | self.triplet_loss = TripletLoss(margin=self.margin)
243 | self.rerank_loss = RerankLoss(margin=0.7)
244 | if self.center_cluster:
245 | k_size = kwargs.get('k_size', 8)
246 | self.center_cluster_loss = CenterTripletLoss(k_size=k_size, margin=self.margin)
247 | if self.center_loss:
248 | self.center_loss = CenterLoss(num_classes, self.base_dim + self.dim * self.part_num)
249 |
250 | def forward(self, inputs, labels=None, **kwargs):
251 |
252 | cam_ids = kwargs.get('cam_ids')
253 | sub = (cam_ids == 3) + (cam_ids == 6)
254 | #epoch = kwargs.get('epoch')
255 | # CNN
256 | sh_feat, sh_pl, sp_pl, sp_IN,sp_IN_p,x_sp_f,x_sp_f_p = self.backbone(inputs)
257 |
258 |
259 | feats = sh_pl
260 |
261 | if not self.training:
262 | if feats.size(0) == 2048:
263 | feats = self.bn_neck(feats.permute(1, 0))
264 | logits = self.classifier(feats)
265 | return logits # feats #
266 |
267 |
268 | else:
269 | feats = self.bn_neck(
270 | feats)
271 | return feats
272 |
273 | else:
274 | return self.train_forward(feats, sp_pl, labels,
275 | sub, sp_IN,sp_IN_p,x_sp_f,x_sp_f_p, **kwargs)
276 |
277 |
278 |
279 | def train_forward(self, feat, sp_pl, labels,
280 | sub, sp_IN,sp_IN_p,x_sp_f,x_sp_f_p, **kwargs):
281 | epoch = kwargs.get('epoch')
282 | metric = {}
283 | loss = 0
284 |
285 | if self.triplet:
286 |
287 | triplet_loss, dist, sh_ap, sh_an = self.triplet_loss(feat.float(), labels)
288 | triplet_loss_im, _, sp_ap, sp_an = self.triplet_loss(sp_pl.float(), labels)
289 | trip_loss = triplet_loss + triplet_loss_im
290 | loss += trip_loss
291 | metric.update({'tri': trip_loss.data})
292 |
293 |
294 | bb = 120 #90
295 | if self.TGSA:
296 |
297 | sf_sp_dist_v = kl_soft_dist(sp_pl[sub == 0], sp_pl[sub == 0])
298 | sf_sp_dist_i = kl_soft_dist(sp_pl[sub == 1], sp_pl[sub == 1])
299 | sf_sh_dist_v = kl_soft_dist(feat[sub == 0], feat[sub == 0])
300 | sf_sh_dist_i = kl_soft_dist(feat[sub == 1], feat[sub == 1])
301 | half_B0 = feat[sub == 0].shape[0] // 2
302 | feat_half0 = feat[sub == 0][:half_B0]
303 | half_B1 = feat[sub == 1].shape[0] // 2
304 | feat_half1 = feat[sub == 1][:half_B1]
305 | feat_cross = torch.cat((feat_half0, feat_half1), dim=0)
306 | sf_sh_dist_vi = kl_soft_dist(feat_cross, feat_cross)
307 |
308 |
309 |
310 | _, kl_inter_v = Bg_kl(sf_sh_dist_v, sf_sp_dist_v)
311 | _, kl_inter_i = Bg_kl(sf_sh_dist_i, sf_sp_dist_i)
312 |
313 |
314 | _, kl_intra1 = Bg_kl(sf_sh_dist_v, sf_sh_dist_i)
315 | _, kl_intra2 = Bg_kl(sf_sh_dist_v, sf_sh_dist_vi)
316 | _, kl_intra3 = Bg_kl(sf_sh_dist_vi, sf_sh_dist_i)
317 |
318 | kl_intra = kl_intra1 + kl_intra2 + kl_intra3
319 |
320 |
321 |
322 | if feat.size(0) == bb:
323 | soft_dt = kl_intra + (kl_inter_v + kl_inter_i) * 0.6
324 |
325 |
326 | else:
327 | soft_dt = (kl_intra1 + kl_inter_v + kl_inter_i) * 0.1
328 |
329 | loss += soft_dt
330 | metric.update({'soft_dt': soft_dt.data})
331 |
332 | if self.center_loss:
333 | center_loss = self.center_loss(feat.float(), labels)
334 | loss += center_loss
335 | metric.update({'cen': center_loss.data})
336 |
337 | if self.center_cluster:
338 | center_cluster_loss, _, _ = self.center_cluster_loss(feat.float(), labels)
339 | loss += center_cluster_loss
340 | metric.update({'cc': center_cluster_loss.data})
341 |
342 |
343 | if self.fb_dt:
344 | loss_f_d_ap = Fb_dt(feat, labels)
345 | loss_Fb_im = Fb_dt(sp_pl, labels)
346 | fb_loss = loss_f_d_ap + loss_Fb_im
347 | loss += fb_loss
348 |
349 | metric.update({'f_dt': fb_loss.data})
350 |
351 | feat = self.bn_neck(feat)
352 | sp_pl = self.bn_neck_sp(sp_pl)
353 | sub_nb = sub + 0 ##模态标签
354 |
355 | if self.IP:
356 | ################
357 | ################
358 | l_F = self.C_sp_f(gem_p(x_sp_f))
359 | l_F_p = self.C_sp_f(gem_p(x_sp_f_p))
360 | loss_F = entropy_margin_loss(l_F, l_F_p, 0)
361 | loss_m_IN = modal_centroid_loss(gem_p(sp_IN), gem_p(sp_IN_p), labels, sub, 0)
362 |
363 | loss += 0.1 * (loss_F + loss_m_IN)
364 | metric.update({'IN_p': loss_m_IN.data})
365 | metric.update({'F_p': loss_F.data})
366 |
367 | ################
368 | ################
369 |
370 | if self.decompose:
371 | logits_sp = self.classifier_sp(sp_pl) # self.bn_neck_un(sp_pl)
372 | loss_id_sp = self.id_loss(logits_sp.float(), labels)
373 |
374 |
375 | sp_logits = self.D_special(sp_pl)
376 | unad_loss_b = self.id_loss(sp_logits.float(), sub_nb)
377 | unad_loss = unad_loss_b
378 |
379 |
380 | pseu_sh_logits = self.D_shared_pseu(feat)
381 | p_sub = sub_nb.chunk(2)[0].repeat_interleave(2)
382 | pp_sub = torch.roll(p_sub, -1)
383 | pseu_loss = self.id_loss(pseu_sh_logits.float(), pp_sub)
384 |
385 | loss += loss_id_sp + unad_loss + pseu_loss
386 |
387 | metric.update({'unad': unad_loss.data})
388 | metric.update({'id_pl': loss_id_sp.data})
389 |
390 | metric.update({'pse': pseu_loss.data})
391 |
392 |
393 |
394 |
395 | if self.classification:
396 | logits = self.classifier(feat)
397 | if self.CSA1:
398 |
399 | _, inter_bg_v = Bg_kl(logits[sub == 0], logits_sp[sub == 0])
400 | _, inter_bg_i = Bg_kl(logits[sub == 1], logits_sp[sub == 1])
401 |
402 | _, intra_bg = Bg_kl(logits[sub == 0], logits[sub == 1])
403 |
404 |
405 | if feat.size(0) == bb:
406 | bg_loss = intra_bg + (inter_bg_v + inter_bg_i) * 0.8 # intra_bg + (inter_bg_v + inter_bg_i) * 0.7
407 |
408 | else:
409 | bg_loss = intra_bg + (inter_bg_v + inter_bg_i) * 0.3
410 | loss += bg_loss
411 | metric.update({'bg_kl': bg_loss.data})
412 |
413 | if self.CSA2:
414 | _, inter_Sm_v = Sm_kl(logits[sub == 0], logits_sp[sub == 0], labels)
415 | _, inter_Sm_i = Sm_kl(logits[sub == 1], logits_sp[sub == 1], labels)
416 | inter_Sm = inter_Sm_v + inter_Sm_i
417 | _, intra_Sm = Sm_kl(logits[sub == 0], logits[sub == 1], labels)
418 |
419 | if feat.size(0) == bb:
420 | sm_kl_loss = intra_Sm + inter_Sm * 0.8
421 |
422 | else:
423 | sm_kl_loss = intra_Sm + inter_Sm * 0.3
424 | loss += sm_kl_loss
425 | metric.update({'sm_kl': sm_kl_loss.data})
426 | cls_loss = self.id_loss(logits.float(), labels)
427 | loss += cls_loss
428 | metric.update({'acc': calc_acc(logits.data, labels), 'ce': cls_loss.data})
429 |
430 | return loss, metric
431 |
--------------------------------------------------------------------------------
/IDKL/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch.nn import functional as F
3 | # from torchvision.models.utils import load_state_dict_from_url #原来
4 | from torch.hub import load_state_dict_from_url
5 |
6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']
8 |
9 | model_urls = {
10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
15 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
16 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
17 | }
18 |
19 |
20 | class convDiscrimination(nn.Module):
21 | def __init__(self, dim=512):
22 | super(convDiscrimination, self).__init__()
23 | self.conv1 = conv3x3(dim, 512, stride=2)
24 | self.bn1 = nn.BatchNorm2d(512)
25 | self.conv2 = conv3x3(512, 128, stride=2)
26 | self.bn2 = nn.BatchNorm2d(128)
27 | self.conv3 = conv3x3(128, 128, stride=2)
28 | self.bn3 = nn.BatchNorm2d(128)
29 | self.fc = nn.Linear(128, 2)
30 |
31 | def forward(self, x):
32 | x = F.dropout(F.relu(self.bn1(self.conv1(x))), training=self.training)
33 | x = F.dropout(F.relu(self.bn2(self.conv2(x))), training=self.training)
34 | x = F.dropout(F.relu(self.bn3(self.conv3(x))), training=self.training)
35 | x = F.avg_pool2d(x, (x.size(2), x.size(3)))
36 | x = x.view(-1, 128)
37 | x = self.fc(x)
38 | return x
39 |
40 |
41 | class Discrimination(nn.Module):
42 | def __init__(self, dim=2048):
43 | super(Discrimination, self).__init__()
44 | self.fc1 = nn.Linear(dim, 100)
45 | self.bn1 = nn.BatchNorm1d(100)
46 | self.fc2 = nn.Linear(100, 100)
47 | self.bn2 = nn.BatchNorm1d(100)
48 | self.fc3 = nn.Linear(100, 2)
49 |
50 | def forward(self, x):
51 | x = F.dropout(F.relu(self.bn1(self.fc1(x))), training=self.training)
52 | x = F.dropout(F.relu(self.bn2(self.fc2(x))), training=self.training)
53 | x = self.fc3(x)
54 | return x
55 |
56 |
57 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
58 | """3x3 convolution with padding"""
59 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
60 | padding=dilation, groups=groups, bias=False, dilation=dilation)
61 |
62 |
63 | def conv1x1(in_planes, out_planes, stride=1):
64 | """1x1 convolution"""
65 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
66 |
67 |
68 | class MAM(nn.Module):
69 | def __init__(self, dim, r=16):
70 | super(MAM, self).__init__()
71 |
72 | self.channel_attention = nn.Sequential(
73 | nn.Conv2d(dim, dim // r, kernel_size=1, bias=False),
74 | nn.ReLU(inplace=True),
75 | nn.Conv2d(dim // r, dim, kernel_size=1, bias=False),
76 | nn.Sigmoid()
77 | )
78 | self.IN = nn.InstanceNorm2d(dim, track_running_stats=False)
79 |
80 | def forward(self, x):
81 | pooled = F.avg_pool2d(x, x.size()[2:])
82 | mask = self.channel_attention(pooled)
83 | x = x * mask + self.IN(x) * (1 - mask)
84 |
85 | return x
86 |
87 |
88 | class BasicBlock(nn.Module):
89 | expansion = 1
90 |
91 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
92 | base_width=64, dilation=1, norm_layer=None):
93 | super(BasicBlock, self).__init__()
94 | if norm_layer is None:
95 | norm_layer = nn.BatchNorm2d
96 | if groups != 1 or base_width != 64:
97 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
98 | if dilation > 1:
99 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
100 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
101 | self.conv1 = conv3x3(inplanes, planes, stride)
102 | self.bn1 = norm_layer(planes)
103 | self.relu = nn.ReLU(inplace=True)
104 | self.conv2 = conv3x3(planes, planes)
105 | self.bn2 = norm_layer(planes)
106 | self.downsample = downsample
107 | self.stride = stride
108 |
109 | def forward(self, x):
110 | identity = x
111 |
112 | out = self.conv1(x)
113 | out = self.bn1(out)
114 | out = self.relu(out)
115 |
116 | out = self.conv2(out)
117 | out = self.bn2(out)
118 |
119 | if self.downsample is not None:
120 | identity = self.downsample(x)
121 |
122 | out += identity
123 | out = self.relu(out)
124 |
125 | return out
126 |
127 |
128 | class Bottleneck(nn.Module):
129 | expansion = 4
130 |
131 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
132 | base_width=64, dilation=1, norm_layer=None):
133 | super(Bottleneck, self).__init__()
134 | if norm_layer is None:
135 | norm_layer = nn.BatchNorm2d
136 | width = int(planes * (base_width / 64.)) * groups
137 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
138 | self.conv1 = conv1x1(inplanes, width)
139 | self.bn1 = norm_layer(width)
140 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
141 | self.bn2 = norm_layer(width)
142 | self.conv3 = conv1x1(width, planes * self.expansion)
143 | self.bn3 = norm_layer(planes * self.expansion)
144 | self.relu = nn.ReLU(inplace=True)
145 | self.downsample = downsample
146 | self.stride = stride
147 |
148 | def forward(self, x):
149 | identity = x
150 |
151 | out = self.conv1(x)
152 | out = self.bn1(out)
153 | out = self.relu(out)
154 |
155 | out = self.conv2(out)
156 | out = self.bn2(out)
157 | out = self.relu(out)
158 |
159 | out = self.conv3(out)
160 | out = self.bn3(out)
161 |
162 | if self.downsample is not None:
163 | identity = self.downsample(x)
164 |
165 | out += identity
166 | out = self.relu(out)
167 |
168 | return out
169 |
170 |
171 | class ResNet(nn.Module):
172 |
173 | def __init__(self, block, layers, zero_init_residual=False, modality_attention=0,
174 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
175 | norm_layer=None, drop_last_stride=False):
176 | super(ResNet, self).__init__()
177 | if norm_layer is None:
178 | norm_layer = nn.BatchNorm2d
179 | self._norm_layer = norm_layer
180 |
181 | self.inplanes = 64
182 | self.dilation = 1
183 | if replace_stride_with_dilation is None:
184 | # each element in the tuple indicates if we should replace
185 | # the 2x2 stride with a dilated convolution instead
186 | replace_stride_with_dilation = [False, False, False]
187 | if len(replace_stride_with_dilation) != 3:
188 | raise ValueError("replace_stride_with_dilation should be None "
189 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
190 | self.groups = groups
191 | self.base_width = width_per_group
192 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
193 | bias=False)
194 | self.bn1 = norm_layer(self.inplanes)
195 | self.relu = nn.ReLU(inplace=True)
196 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
197 | self.layer1 = self._make_layer(block, 64, layers[0])
198 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
199 | dilate=replace_stride_with_dilation[0])
200 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
201 | dilate=replace_stride_with_dilation[1])
202 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1 if drop_last_stride else 2,
203 | dilate=replace_stride_with_dilation[2])
204 |
205 |
206 | for m in self.modules():
207 | if isinstance(m, nn.Conv2d):
208 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
209 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
210 | nn.init.constant_(m.weight, 1)
211 | nn.init.constant_(m.bias, 0)
212 |
213 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
214 | norm_layer = self._norm_layer
215 | downsample = None
216 | previous_dilation = self.dilation
217 | if dilate:
218 | self.dilation *= stride
219 | stride = 1
220 | if stride != 1 or self.inplanes != planes * block.expansion:
221 | downsample = nn.Sequential(
222 | conv1x1(self.inplanes, planes * block.expansion, stride),
223 | norm_layer(planes * block.expansion),
224 | )
225 |
226 | layers = []
227 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
228 | self.base_width, previous_dilation, norm_layer))
229 | self.inplanes = planes * block.expansion
230 | for _ in range(1, blocks):
231 | layers.append(block(self.inplanes, planes, groups=self.groups,
232 | base_width=self.base_width, dilation=self.dilation,
233 | norm_layer=norm_layer))
234 |
235 | return nn.Sequential(*layers)
236 |
237 | def forward(self, x):
238 | x = self.conv1(x)
239 | x = self.bn1(x)
240 | x = self.relu(x)
241 | x = self.maxpool(x)
242 |
243 | x = self.layer1(x)
244 | x = self.layer2(x)
245 | x = self.layer3(x)
246 |
247 | x = self.layer4(x)
248 |
249 | return x
250 |
251 |
252 | class Shared_module_fr(nn.Module):
253 | def __init__(self, drop_last_stride, modality_attention):
254 | super(Shared_module_fr, self).__init__()
255 |
256 | model_sh_fr = resnet50(pretrained=True, drop_last_stride=drop_last_stride,
257 | modality_attention=modality_attention)
258 | # avg pooling to global pooling
259 | self.model_sh_fr = model_sh_fr
260 |
261 | def forward(self, x):
262 | x = self.model_sh_fr.conv1(x)
263 | x = self.model_sh_fr.bn1(x)
264 | x = self.model_sh_fr.relu(x)
265 | x = self.model_sh_fr.maxpool(x)
266 | x = self.model_sh_fr.layer1(x)
267 | x = self.model_sh_fr.layer2(x)
268 | # x = self.model_sh_fr.layer3(x)
269 | return x
270 |
271 |
272 | class Special_module(nn.Module):
273 | def __init__(self, drop_last_stride, modality_attention):
274 | super(Special_module, self).__init__()
275 |
276 | special_module = resnet50(pretrained=True, drop_last_stride=drop_last_stride,
277 | )
278 |
279 | self.special_module = special_module
280 |
281 | def forward(self, x):
282 | # x = self.special_module.layer2(x)
283 | x = self.special_module.layer3(x)
284 | x = self.special_module.layer4(x)
285 | return x
286 |
287 |
288 | class Shared_module_bh(nn.Module):
289 | def __init__(self, drop_last_stride, modality_attention):
290 | super(Shared_module_bh, self).__init__()
291 |
292 | model_sh_bh = resnet50(pretrained=True, drop_last_stride=drop_last_stride,) # model_sh_fr model_sh_bh
293 |
294 | self.model_sh_bh = model_sh_bh # self.model_sh_bh = model_sh_bh #self.model_sh_fr = model_sh_fr
295 |
296 | def forward(self, x):
297 | # x = self.model_sh_bh.layer2(x)
298 | x_sh3 = self.model_sh_bh.layer3(x) # self.model_sh_fr self.model_sh_bh
299 | x_sh4 = self.model_sh_bh.layer4(x_sh3) # self.model_sh_fr self.model_sh_bh
300 | return x_sh3, x_sh4
301 |
302 |
303 | class Mask(nn.Module):
304 | def __init__(self, dim, r=16):
305 | super(Mask, self).__init__()
306 |
307 | self.channel_attention = nn.Sequential(
308 | nn.Conv2d(dim, dim // r, kernel_size=1, bias=False),
309 | nn.ReLU(inplace=True),
310 | nn.Conv2d(dim // r, dim, kernel_size=1, bias=False),
311 | nn.Sigmoid()
312 | )
313 |
314 | def forward(self, x):
315 | mask = self.channel_attention(x)
316 | return mask
317 |
318 |
319 | class special_att(nn.Module):
320 | def __init__(self, dim, r=16):
321 | super(special_att, self).__init__()
322 |
323 | self.channel_attention = nn.Sequential(
324 | nn.Conv2d(dim, dim // r, kernel_size=1, bias=False),
325 | nn.ReLU(inplace=True),
326 | nn.Conv2d(dim // r, dim, kernel_size=1, bias=False),
327 | nn.Sigmoid()
328 | )
329 | self.IN = nn.InstanceNorm2d(dim, track_running_stats=False) #self.IN = nn.InstanceNorm2d(dim, track_running_stats=True, affine=True)
330 |
331 | def forward(self, x):
332 | x_IN = self.IN(x)
333 | x_R = x - x_IN
334 | pooled = gem(x_R)
335 | mask = self.channel_attention(pooled)
336 | x_sp = x_R * mask + x_IN # x
337 |
338 | return x_sp, x_IN
339 |
340 |
341 | def gem(x, p=3, eps=1e-6):
342 | return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1. / p)
343 |
344 |
345 | class embed_net(nn.Module):
346 | def __init__(self, drop_last_stride, decompose=False):
347 | super(embed_net, self).__init__()
348 |
349 | self.shared_module_fr = Shared_module_fr(drop_last_stride=drop_last_stride,
350 | )
351 | self.shared_module_bh = Shared_module_bh(drop_last_stride=drop_last_stride,
352 | )
353 |
354 | self.special = Special_module(drop_last_stride=drop_last_stride)
355 |
356 | self.decompose = decompose
357 | self.IN = nn.InstanceNorm2d(2048, track_running_stats=True, affine=True)
358 | if decompose:
359 | self.special_att = special_att(2048)
360 | self.mask1 = Mask(2048)
361 | self.mask2 = Mask(2048)
362 |
363 | def forward(self, x):
364 | x2 = self.shared_module_fr(x)
365 | x3, x_sh = self.shared_module_bh(x2) # bchw
366 |
367 | sh_pl = gem(x_sh).squeeze() # Gem池化
368 | sh_pl = sh_pl.view(sh_pl.size(0), -1) # Gem池化
369 |
370 | if self.decompose:
371 | ######special structure
372 |
373 | x_sp_f = self.special(x2)
374 | sp_IN = self.IN(x_sp_f)
375 | m_IN = self.mask1(sp_IN)
376 | m_F = self.mask2(x_sp_f)
377 | sp_IN_p = m_IN * sp_IN
378 | x_sp_f_p = m_F * x_sp_f
379 | x_sp = m_IN * x_sp_f_p + m_F * sp_IN_p
380 |
381 | sp_pl = gem(x_sp).squeeze() # Gem池化
382 | sp_pl = sp_pl.view(sp_pl.size(0), -1) # Gem池化
383 |
384 |
385 | return x_sh, sh_pl, sp_pl,sp_IN,sp_IN_p,x_sp_f,x_sp_f_p
386 |
387 |
388 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
389 | model = ResNet(block, layers, **kwargs)
390 | if pretrained:
391 | state_dict = load_state_dict_from_url(model_urls[arch],
392 | progress=progress)
393 | model.load_state_dict(state_dict, strict=False)
394 | return model
395 |
396 |
397 | def resnet18(pretrained=False, progress=True, **kwargs):
398 | """Constructs a ResNet-18 model.
399 |
400 | Args:
401 | pretrained (bool): If True, returns a model pre-trained on ImageNet
402 | progress (bool): If True, displays a progress bar of the download to stderr
403 | """
404 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
405 | **kwargs)
406 |
407 |
408 | def resnet34(pretrained=False, progress=True, **kwargs):
409 | """Constructs a ResNet-34 model.
410 |
411 | Args:
412 | pretrained (bool): If True, returns a model pre-trained on ImageNet
413 | progress (bool): If True, displays a progress bar of the download to stderr
414 | """
415 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
416 | **kwargs)
417 |
418 |
419 | def resnet50(pretrained=False, progress=True, **kwargs):
420 | """Constructs a ResNet-50 model.
421 |
422 | Args:
423 | pretrained (bool): If True, returns a model pre-trained on ImageNet
424 | progress (bool): If True, displays a progress bar of the download to stderr
425 | """
426 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
427 | **kwargs)
428 |
429 |
430 | def resnet101(pretrained=False, progress=True, **kwargs):
431 | """Constructs a ResNet-101 model.
432 |
433 | Args:
434 | pretrained (bool): If True, returns a model pre-trained on ImageNet
435 | progress (bool): If True, displays a progress bar of the download to stderr
436 | """
437 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
438 | **kwargs)
439 |
440 |
441 | def resnet152(pretrained=False, progress=True, **kwargs):
442 | """Constructs a ResNet-152 model.
443 |
444 | Args:
445 | pretrained (bool): If True, returns a model pre-trained on ImageNet
446 | progress (bool): If True, displays a progress bar of the download to stderr
447 | """
448 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
449 | **kwargs)
450 |
451 |
452 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
453 | """Constructs a ResNeXt-50 32x4d model.
454 |
455 | Args:
456 | pretrained (bool): If True, returns a model pre-trained on ImageNet
457 | progress (bool): If True, displays a progress bar of the download to stderr
458 | """
459 | kwargs['groups'] = 32
460 | kwargs['width_per_group'] = 4
461 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
462 | pretrained, progress, **kwargs)
463 |
464 |
465 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
466 | """Constructs a ResNeXt-101 32x8d model.
467 |
468 | Args:
469 | pretrained (bool): If True, returns a model pre-trained on ImageNet
470 | progress (bool): If True, displays a progress bar of the download to stderr
471 | """
472 | kwargs['groups'] = 32
473 | kwargs['width_per_group'] = 8
474 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
475 | pretrained, progress, **kwargs)
476 |
--------------------------------------------------------------------------------
/IDKL/train.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import pprint
4 |
5 | import torch
6 | import yaml
7 | from apex import amp
8 | from torch import optim
9 |
10 | from data import get_test_loader
11 | from data import get_train_loader
12 | from engine import get_trainer
13 | from models.baseline import Baseline
14 | # from torchstat import stat
15 |
16 | # from WarmUpLR import WarmUpStepLR
17 |
18 | def train(cfg):
19 | # set logger
20 | log_dir = os.path.join("logs/", cfg.dataset, cfg.prefix)
21 | if not os.path.isdir(log_dir):
22 | os.makedirs(log_dir, exist_ok=True)
23 |
24 | logging.basicConfig(format="%(asctime)s %(message)s",
25 | filename=log_dir + "/" + "log.txt",
26 | filemode="w")
27 |
28 | logger = logging.getLogger()
29 | logger.setLevel(logging.INFO)
30 | stream_handler = logging.StreamHandler()
31 | stream_handler.setLevel(logging.INFO)
32 | logger.addHandler(stream_handler)
33 |
34 | logger.info(pprint.pformat(cfg))
35 |
36 | # training data loader
37 | train_loader = get_train_loader(dataset=cfg.dataset,
38 | root=cfg.data_root,
39 | sample_method=cfg.sample_method,
40 | batch_size=cfg.batch_size,
41 | p_size=cfg.p_size,
42 | k_size=cfg.k_size,
43 | random_flip=cfg.random_flip,
44 | random_crop=cfg.random_crop,
45 | random_erase=cfg.random_erase,
46 | color_jitter=cfg.color_jitter,
47 | padding=cfg.padding,
48 | image_size=cfg.image_size,
49 | num_workers=8)
50 |
51 | # evaluation data loader
52 | gallery_loader, query_loader = None, None
53 | if cfg.eval_interval > 0:
54 | if True == False: # tsne
55 | query_loader = get_train_loader(dataset=cfg.dataset,
56 | root=cfg.data_root,
57 | sample_method=cfg.sample_method,
58 | batch_size=cfg.batch_size,
59 | p_size=cfg.p_size,
60 | k_size=cfg.k_size,
61 | # random_flip=cfg.random_flip,
62 | # random_crop=cfg.random_crop,
63 | # random_erase=cfg.random_erase,
64 | # color_jitter=cfg.color_jitter,
65 | # padding=cfg.padding,
66 | image_size=cfg.image_size,
67 | num_workers=8)
68 | gallery_loader = query_loader
69 |
70 | else:
71 | gallery_loader, query_loader = get_test_loader(dataset=cfg.dataset,
72 | root=cfg.data_root,
73 | batch_size=64,
74 | image_size=cfg.image_size,
75 | num_workers=4)
76 |
77 |
78 |
79 | # model
80 | model = Baseline(num_classes=cfg.num_id,
81 | pattern_attention=cfg.pattern_attention,
82 | modality_attention=cfg.modality_attention,
83 | mutual_learning=cfg.mutual_learning,
84 | decompose=cfg.decompose,
85 | drop_last_stride=cfg.drop_last_stride,
86 | triplet=cfg.triplet,
87 | k_size=cfg.k_size,
88 | center_cluster=cfg.center_cluster,
89 | center=cfg.center,
90 | margin=cfg.margin,
91 | num_parts=cfg.num_parts,
92 | weight_KL=cfg.weight_KL,
93 | weight_sid=cfg.weight_sid,
94 | weight_sep=cfg.weight_sep,
95 | update_rate=cfg.update_rate,
96 | classification=cfg.classification,
97 | bg_kl=cfg.bg_kl,
98 | sm_kl=cfg.sm_kl,
99 | fb_dt=cfg.fb_dt,
100 | IP=cfg.IP,
101 | distalign=cfg.distalign)
102 |
103 | def get_parameter_number(net):
104 | total_num = sum(p.numel() for p in net.parameters())
105 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
106 | return {'Total': total_num, 'Trainable': trainable_num}
107 |
108 | print(get_parameter_number(model))
109 |
110 | model.to(device)
111 |
112 | # optimizer
113 | assert cfg.optimizer in ['adam', 'sgd']
114 | if cfg.optimizer == 'adam':
115 | #optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd)
116 | optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd)
117 | else:
118 | optimizer = optim.SGD(model.parameters(), lr=cfg.lr, momentum=0.9, weight_decay=cfg.wd)
119 |
120 | # convert model for mixed precision training
121 | model, optimizer = amp.initialize(model, optimizer, enabled=cfg.fp16, opt_level="O1")
122 | if cfg.center:
123 | model.center_loss.centers = model.center_loss.centers.float()
124 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
125 | milestones=cfg.lr_step,
126 | gamma=0.1)
127 |
128 |
129 |
130 | if cfg.resume:
131 | checkpoint = torch.load(cfg.resume)
132 | model.load_state_dict(checkpoint)
133 |
134 | # stat(model,(3,224,224))
135 | # import pdb
136 | # pdb.set_trace()
137 |
138 | # engine
139 | checkpoint_dir = os.path.join("checkpoints", cfg.dataset, cfg.prefix)
140 | engine = get_trainer(dataset=cfg.dataset,
141 | model=model,
142 | optimizer=optimizer,
143 | lr_scheduler=lr_scheduler,
144 | logger=logger,
145 | non_blocking=True,
146 | log_period=cfg.log_period,
147 | save_dir=checkpoint_dir,
148 | prefix=cfg.prefix,
149 | eval_interval=cfg.eval_interval,
150 | start_eval=cfg.start_eval,
151 | gallery_loader=gallery_loader,
152 | query_loader=query_loader,
153 | rerank=cfg.rerank)
154 |
155 | # training
156 | engine.run(train_loader, max_epochs=cfg.num_epoch)
157 |
158 |
159 | if __name__ == '__main__':
160 | import argparse
161 | import random
162 | import numpy as np
163 | from configs.default import strategy_cfg
164 | from configs.default import dataset_cfg
165 |
166 | parser = argparse.ArgumentParser()
167 | parser.add_argument("--cfg", type=str, default="configs/softmax.yml")
168 | ################
169 | parser.add_argument('--gpu', default='0', type=str,
170 | help='gpu device ids for CUDA_VISIBLE_DEVICES')
171 | ####################
172 | args = parser.parse_args()
173 |
174 | ######################
175 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
176 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
177 | ################
178 |
179 | # set random seed
180 | seed = 1
181 | random.seed(seed)
182 | np.random.RandomState(seed)
183 | np.random.seed(seed)
184 | torch.manual_seed(seed)
185 | torch.cuda.manual_seed(seed)
186 |
187 | # enable cudnn backend
188 | torch.backends.cudnn.benchmark = True
189 | # torch.backends.cudnn.benchmark = False
190 | # torch.backends.cudnn.deterministic = True
191 |
192 | # load configuration
193 | customized_cfg = yaml.load(open(args.cfg, "r"), Loader=yaml.SafeLoader)
194 |
195 | cfg = strategy_cfg
196 | cfg.merge_from_file(args.cfg)
197 |
198 | dataset_cfg = dataset_cfg.get(cfg.dataset)
199 |
200 | for k, v in dataset_cfg.items():
201 | cfg[k] = v
202 |
203 | if cfg.sample_method == 'identity_uniform' or 'identity_random': #'identity_uniform' or 'identity_random'
204 | cfg.batch_size = cfg.p_size * cfg.k_size
205 |
206 | cfg.freeze()
207 |
208 | train(cfg)
209 |
--------------------------------------------------------------------------------
/IDKL/utils/__pycache__/calc_acc.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/utils/__pycache__/calc_acc.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/utils/__pycache__/eval_regdb.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/utils/__pycache__/eval_regdb.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/utils/__pycache__/eval_sysu.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/utils/__pycache__/eval_sysu.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/utils/__pycache__/neighbor.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/utils/__pycache__/neighbor.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/utils/__pycache__/rerank.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1KK077/IDKL/8676b620098798662296df17765b6d0bd62c3d4f/IDKL/utils/__pycache__/rerank.cpython-37.pyc
--------------------------------------------------------------------------------
/IDKL/utils/calc_acc.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def calc_acc(logits, label, ignore_index=-100, mode="multiclass"):
5 | if mode == "binary":
6 | indices = torch.round(logits).type(label.type())
7 | elif mode == "multiclass":
8 | indices = torch.max(logits, dim=1)[1]
9 |
10 | if label.size() == logits.size():
11 | ignore = 1 - torch.round(label.sum(dim=1))
12 | label = torch.max(label, dim=1)[1]
13 | else:
14 | ignore = torch.eq(label, ignore_index).view(-1)
15 |
16 | correct = torch.eq(indices, label).view(-1)
17 | num_correct = torch.sum(correct)
18 | num_examples = logits.shape[0] - ignore.sum()
19 |
20 | return num_correct.float() / num_examples.float()
21 |
--------------------------------------------------------------------------------
/IDKL/utils/eval_llcm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import numpy as np
4 | import torch
5 | from sklearn.preprocessing import normalize
6 | from torch.nn import functional as F
7 | from .rerank import re_ranking, pairwise_distance
8 |
9 |
10 | def get_gallery_names(perm, cams, ids, trial_id, num_shots=1):
11 | names = []
12 | for cam in cams:
13 | cam_perm = perm[cam - 1][0].squeeze()
14 | for i in ids:
15 | instance_id = cam_perm[i - 1][trial_id][:num_shots]
16 | names.extend(['cam{}/{:0>4d}/{:0>4d}'.format(cam, i, ins) for ins in instance_id.tolist()])
17 |
18 | return names
19 |
20 |
21 | def get_unique(array):
22 | _, idx = np.unique(array, return_index=True)
23 | return array[np.sort(idx)]
24 |
25 |
26 | def get_cmc(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids):
27 | gallery_unique_count = get_unique(gallery_ids).shape[0]
28 | match_counter = np.zeros((gallery_unique_count,))
29 |
30 | result = gallery_ids[sorted_indices]
31 | cam_locations_result = gallery_cam_ids[sorted_indices]
32 |
33 | valid_probe_sample_count = 0
34 |
35 | for probe_index in range(sorted_indices.shape[0]):
36 | # remove gallery samples from the same camera of the probe
37 | result_i = result[probe_index, :]
38 | result_i[np.equal(cam_locations_result[probe_index], query_cam_ids[probe_index])] = -1
39 |
40 | # remove the -1 entries from the label result
41 | result_i = np.array([i for i in result_i if i != -1])
42 |
43 | # remove duplicated id in "stable" manner
44 | result_i_unique = get_unique(result_i)
45 |
46 | # match for probe i
47 | match_i = np.equal(result_i_unique, query_ids[probe_index])
48 |
49 | if np.sum(match_i) != 0: # if there is true matching in gallery
50 | valid_probe_sample_count += 1
51 | match_counter += match_i
52 |
53 | rank = match_counter / valid_probe_sample_count
54 | cmc = np.cumsum(rank)
55 | return cmc
56 |
57 |
58 | def get_mAP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids):
59 | result = gallery_ids[sorted_indices]
60 | cam_locations_result = gallery_cam_ids[sorted_indices]
61 |
62 | valid_probe_sample_count = 0
63 | avg_precision_sum = 0
64 |
65 | for probe_index in range(sorted_indices.shape[0]):
66 | # remove gallery samples from the same camera of the probe
67 | result_i = result[probe_index, :]
68 | result_i[cam_locations_result[probe_index, :] == query_cam_ids[probe_index]] = -1
69 |
70 | # remove the -1 entries from the label result
71 | result_i = np.array([i for i in result_i if i != -1])
72 |
73 | # match for probe i
74 | match_i = result_i == query_ids[probe_index]
75 | true_match_count = np.sum(match_i)
76 |
77 | if true_match_count != 0: # if there is true matching in gallery
78 | valid_probe_sample_count += 1
79 | true_match_rank = np.where(match_i)[0]
80 |
81 | ap = np.mean(np.arange(1, true_match_count + 1) / (true_match_rank + 1))
82 | avg_precision_sum += ap
83 |
84 | mAP = avg_precision_sum / valid_probe_sample_count
85 | return mAP
86 |
87 |
88 | # def eval_llcm(query_feats, q_pids, q_camids, gallery_feats, g_pids, g_camids, max_rank=20, rerank=False):
89 | # """Evaluation with sysu metric
90 | # Key: for each query identity, its gallery images from the same camera view are discarded. "Following the original setting in ite dataset"
91 | # """
92 | # ptr = 0
93 | # query_feat = np.zeros((nquery, 2048))
94 | # query_feat_att = np.zeros((nquery, 2048))
95 | # with torch.no_grad():
96 | # for batch_idx, (input, label) in enumerate(query_loader):
97 | # batch_num = input.size(0)
98 | # input = Variable(input.cuda())
99 | # feat, feat_att = net(input, input, test_mode[1])
100 | # query_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy()
101 | # query_feat_att[ptr:ptr + batch_num, :] = feat_att.detach().cpu().numpy()
102 | # ptr = ptr + batch_num
103 | # distmat = -np.matmul(query_feats.cpu().numpy(), np.transpose(gallery_feats.cpu().numpy()))
104 | # num_q, num_g = distmat.shape
105 | # if num_g < max_rank:
106 | # max_rank = num_g
107 | # print("Note: number of gallery samples is quite small, got {}".format(num_g))
108 | # indices = np.argsort(distmat, axis=1)
109 | # pred_label = g_pids[indices]
110 | # matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
111 | #
112 | # # compute cmc curve for each query
113 | # new_all_cmc = []
114 | # all_cmc = []
115 | # all_AP = []
116 | # all_INP = []
117 | # num_valid_q = 0. # number of valid query
118 | # for q_idx in range(num_q):
119 | # # get query pid and camid
120 | # q_pid = q_pids[q_idx]
121 | # q_camid = q_camids[q_idx]
122 | #
123 | # # remove gallery samples that have the same pid and camid with query
124 | #
125 | # order = indices[q_idx]
126 | # remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
127 | # keep = np.invert(remove)
128 | #
129 | # # compute cmc curve
130 | # # the cmc calculation is different from standard protocol
131 | # # we follow the protocol of the author's released code
132 | # new_cmc = pred_label[q_idx][keep]
133 | # new_index = np.unique(new_cmc, return_index=True)[1]
134 | #
135 | # new_cmc = [new_cmc[index] for index in sorted(new_index)]
136 | #
137 | # new_match = (new_cmc == q_pid).astype(np.int32)
138 | # new_cmc = new_match.cumsum()
139 | # new_all_cmc.append(new_cmc[:max_rank])
140 | #
141 | # orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
142 | # if not np.any(orig_cmc):
143 | # # this condition is true when query identity does not appear in gallery
144 | # continue
145 | #
146 | # cmc = orig_cmc.cumsum()
147 | #
148 | # # compute mINP
149 | # # refernece Deep Learning for Person Re-identification: A Survey and Outlook
150 | # pos_idx = np.where(orig_cmc == 1)
151 | # pos_max_idx = np.max(pos_idx)
152 | # inp = cmc[pos_max_idx] / (pos_max_idx + 1.0)
153 | # all_INP.append(inp)
154 | #
155 | # cmc[cmc > 1] = 1
156 | #
157 | # all_cmc.append(cmc[:max_rank])
158 | # num_valid_q += 1.
159 | #
160 | # # compute average precision
161 | # # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
162 | # num_rel = orig_cmc.sum()
163 | # tmp_cmc = orig_cmc.cumsum()
164 | # tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
165 | # tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
166 | # AP = tmp_cmc.sum() / num_rel
167 | # all_AP.append(AP)
168 | #
169 | # assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
170 | #
171 | # all_cmc = np.asarray(all_cmc).astype(np.float32)
172 | # all_cmc = all_cmc.sum(0) / num_valid_q # standard CMC
173 | #
174 | # new_all_cmc = np.asarray(new_all_cmc).astype(np.float32)
175 | # new_all_cmc = new_all_cmc.sum(0) / num_valid_q
176 | # mAP = np.mean(all_AP)
177 | # mINP = np.mean(all_INP)
178 | # return new_all_cmc, mAP, mINP
179 |
180 |
181 | def eval_llcm(query_feats, query_ids, query_cam_ids, gallery_feats, gallery_ids, gallery_cam_ids, gallery_img_paths, rerank=False):
182 | # gallery_feats = F.normalize(gallery_feats, dim=1)
183 | # query_feats = F.normalize(query_feats, dim=1)
184 |
185 | if rerank:
186 | dist_mat = re_ranking(query_feats, gallery_feats, eval_type=False)
187 | else:
188 | dist_mat = pairwise_distance(query_feats, gallery_feats)
189 | # dist_mat = -torch.mm(query_feats, gallery_feats.t())
190 |
191 | sorted_indices = np.argsort(dist_mat, axis=1)
192 |
193 | mAP = get_mAP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids)
194 | cmc = get_cmc(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids)
195 |
196 | r1 = cmc[0]
197 | r5 = cmc[4]
198 | r10 = cmc[9]
199 | r20 = cmc[19]
200 |
201 | r1 = r1 * 100
202 | r5 = r5 * 100
203 | r10 = r10 * 100
204 | r20 = r20 * 100
205 | mAP = mAP * 100
206 |
207 | perf = 'r1 precision = {:.2f} , r10 precision = {:.2f} , r20 precision = {:.2f}, mAP = {:.2f}'
208 | logging.info(perf.format(r1, r10, r20, mAP))
209 |
210 | return mAP, r1, r5, r10, r20
211 |
--------------------------------------------------------------------------------
/IDKL/utils/eval_regdb.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import numpy as np
4 | import torch
5 | from sklearn.preprocessing import normalize
6 | from torch.nn import functional as F
7 | from .rerank import re_ranking, pairwise_distance
8 |
9 |
10 | def get_gallery_names(perm, cams, ids, trial_id, num_shots=1):
11 | names = []
12 | for cam in cams:
13 | cam_perm = perm[cam - 1][0].squeeze()
14 | for i in ids:
15 | instance_id = cam_perm[i - 1][trial_id][:num_shots]
16 | names.extend(['cam{}/{:0>4d}/{:0>4d}'.format(cam, i, ins) for ins in instance_id.tolist()])
17 |
18 | return names
19 |
20 |
21 | def get_unique(array):
22 | _, idx = np.unique(array, return_index=True)
23 | return array[np.sort(idx)]
24 |
25 |
26 | def get_cmc(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids):
27 | gallery_unique_count = get_unique(gallery_ids).shape[0]
28 | match_counter = np.zeros((gallery_unique_count,))
29 |
30 | result = gallery_ids[sorted_indices]
31 | cam_locations_result = gallery_cam_ids[sorted_indices]
32 |
33 | valid_probe_sample_count = 0
34 |
35 | for probe_index in range(sorted_indices.shape[0]):
36 | # remove gallery samples from the same camera of the probe
37 | result_i = result[probe_index, :]
38 | result_i[np.equal(cam_locations_result[probe_index], query_cam_ids[probe_index])] = -1
39 |
40 | # remove the -1 entries from the label result
41 | result_i = np.array([i for i in result_i if i != -1])
42 |
43 | # remove duplicated id in "stable" manner
44 | result_i_unique = get_unique(result_i)
45 |
46 | # match for probe i
47 | match_i = np.equal(result_i_unique, query_ids[probe_index])
48 |
49 | if np.sum(match_i) != 0: # if there is true matching in gallery
50 | valid_probe_sample_count += 1
51 | match_counter += match_i
52 |
53 | rank = match_counter / valid_probe_sample_count
54 | cmc = np.cumsum(rank)
55 | return cmc
56 |
57 |
58 | def get_mAP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids):
59 | result = gallery_ids[sorted_indices]
60 | cam_locations_result = gallery_cam_ids[sorted_indices]
61 |
62 | valid_probe_sample_count = 0
63 | avg_precision_sum = 0
64 |
65 | for probe_index in range(sorted_indices.shape[0]):
66 | # remove gallery samples from the same camera of the probe
67 | result_i = result[probe_index, :]
68 | result_i[cam_locations_result[probe_index, :] == query_cam_ids[probe_index]] = -1
69 |
70 | # remove the -1 entries from the label result
71 | result_i = np.array([i for i in result_i if i != -1])
72 |
73 | # match for probe i
74 | match_i = result_i == query_ids[probe_index]
75 | true_match_count = np.sum(match_i)
76 |
77 | if true_match_count != 0: # if there is true matching in gallery
78 | valid_probe_sample_count += 1
79 | true_match_rank = np.where(match_i)[0]
80 |
81 | ap = np.mean(np.arange(1, true_match_count + 1) / (true_match_rank + 1))
82 | avg_precision_sum += ap
83 |
84 | mAP = avg_precision_sum / valid_probe_sample_count
85 | return mAP
86 |
87 | def eval_regdb(query_feats, query_ids, query_cam_ids, gallery_feats, gallery_ids, gallery_cam_ids, gallery_img_paths, rerank=False):
88 | # gallery_feats = F.normalize(gallery_feats, dim=1)
89 | # query_feats = F.normalize(query_feats, dim=1)
90 |
91 | if rerank:
92 | dist_mat = re_ranking(query_feats, gallery_feats, eval_type=False)
93 | else:
94 | dist_mat = pairwise_distance(query_feats, gallery_feats)
95 | # dist_mat = -torch.mm(query_feats, gallery_feats.t())
96 |
97 | sorted_indices = np.argsort(dist_mat, axis=1)
98 |
99 | mAP = get_mAP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids)
100 | cmc = get_cmc(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids)
101 |
102 | r1 = cmc[0]
103 | r5 = cmc[4]
104 | r10 = cmc[9]
105 | r20 = cmc[19]
106 |
107 | r1 = r1 * 100
108 | r5 = r5 * 100
109 | r10 = r10 * 100
110 | r20 = r20 * 100
111 | mAP = mAP * 100
112 |
113 | perf = 'r1 precision = {:.2f} , r10 precision = {:.2f} , r20 precision = {:.2f}, mAP = {:.2f}'
114 | logging.info(perf.format(r1, r10, r20, mAP))
115 |
116 | return mAP, r1, r5, r10, r20
117 |
--------------------------------------------------------------------------------
/IDKL/utils/eval_sysu.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import torch
4 | import numpy as np
5 | from sklearn.preprocessing import normalize
6 | from .rerank import re_ranking, pairwise_distance
7 | from torch.nn import functional as F
8 |
9 |
10 | def get_gallery_names(perm, cams, ids, trial_id, num_shots=1):
11 | names = []
12 | for cam in cams:
13 | cam_perm = perm[cam - 1][0].squeeze()
14 | for i in ids:
15 | instance_id = cam_perm[i - 1][trial_id][:num_shots]
16 | names.extend(['cam{}/{:0>4d}/{:0>4d}'.format(cam, i, ins) for ins in instance_id.tolist()])
17 |
18 | return names
19 |
20 |
21 | def get_unique(array):
22 | _, idx = np.unique(array, return_index=True)
23 | return array[np.sort(idx)]
24 |
25 |
26 | def get_cmc(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids):
27 | gallery_unique_count = get_unique(gallery_ids).shape[0]
28 | match_counter = np.zeros((gallery_unique_count,))
29 |
30 | result = gallery_ids[sorted_indices]
31 | cam_locations_result = gallery_cam_ids[sorted_indices]
32 |
33 | valid_probe_sample_count = 0
34 |
35 | for probe_index in range(sorted_indices.shape[0]):
36 | # remove gallery samples from the same camera of the probe
37 | result_i = result[probe_index, :]
38 | result_i[np.equal(cam_locations_result[probe_index], query_cam_ids[probe_index])] = -1
39 |
40 | # remove the -1 entries from the label result
41 | result_i = np.array([i for i in result_i if i != -1])
42 |
43 | # remove duplicated id in "stable" manner
44 | result_i_unique = get_unique(result_i)
45 |
46 | # match for probe i
47 | match_i = np.equal(result_i_unique, query_ids[probe_index])
48 |
49 | if np.sum(match_i) != 0: # if there is true matching in gallery
50 | valid_probe_sample_count += 1
51 | match_counter += match_i
52 |
53 | rank = match_counter / valid_probe_sample_count
54 | cmc = np.cumsum(rank)
55 | return cmc
56 |
57 |
58 | def get_mAP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids):
59 | result = gallery_ids[sorted_indices]
60 | cam_locations_result = gallery_cam_ids[sorted_indices]
61 |
62 | valid_probe_sample_count = 0
63 | avg_precision_sum = 0
64 |
65 | for probe_index in range(sorted_indices.shape[0]):
66 | # remove gallery samples from the same camera of the probe
67 | result_i = result[probe_index, :]
68 | result_i[cam_locations_result[probe_index, :] == query_cam_ids[probe_index]] = -1
69 |
70 | # remove the -1 entries from the label result
71 | result_i = np.array([i for i in result_i if i != -1])
72 |
73 | # match for probe i
74 | match_i = result_i == query_ids[probe_index]
75 | true_match_count = np.sum(match_i)
76 |
77 | if true_match_count != 0: # if there is true matching in gallery
78 | valid_probe_sample_count += 1
79 | true_match_rank = np.where(match_i)[0]
80 |
81 | ap = np.mean(np.arange(1, true_match_count + 1) / (true_match_rank + 1))
82 | avg_precision_sum += ap
83 |
84 | mAP = avg_precision_sum / valid_probe_sample_count
85 | return mAP
86 |
87 |
88 | def eval_sysu(query_feats, query_ids, query_cam_ids, gallery_feats, gallery_ids, gallery_cam_ids, gallery_img_paths,
89 | perm, mode='all', num_shots=1, num_trials=10, rerank=False):
90 | assert mode in ['indoor', 'all']
91 |
92 | gallery_cams = [1, 2] if mode == 'indoor' else [1, 2, 4, 5]
93 |
94 | # cam2 and cam3 are in the same location
95 | query_cam_ids[np.equal(query_cam_ids, 3)] = 2
96 | query_feats = F.normalize(query_feats, dim=1)
97 |
98 | gallery_indices = np.in1d(gallery_cam_ids, gallery_cams)
99 |
100 | gallery_feats = gallery_feats[gallery_indices]
101 | gallery_feats = F.normalize(gallery_feats, dim=1)
102 | gallery_cam_ids = gallery_cam_ids[gallery_indices]
103 | gallery_ids = gallery_ids[gallery_indices]
104 | gallery_img_paths = gallery_img_paths[gallery_indices]
105 | gallery_names = np.array(['/'.join(os.path.splitext(path)[0].split('/')[-3:]) for path in gallery_img_paths])
106 |
107 | gallery_id_set = np.unique(gallery_ids)
108 |
109 | mAP, r1, r5, r10, r20 = 0, 0, 0, 0, 0
110 | for t in range(num_trials):
111 | names = get_gallery_names(perm, gallery_cams, gallery_id_set, t, num_shots)
112 | flag = np.in1d(gallery_names, names)
113 |
114 | g_feat = gallery_feats[flag]
115 | g_ids = gallery_ids[flag]
116 | g_cam_ids = gallery_cam_ids[flag]
117 |
118 | if rerank:
119 | dist_mat = re_ranking(query_feats, g_feat)
120 | else:
121 | dist_mat = pairwise_distance(query_feats, g_feat)
122 | # dist_mat = -torch.mm(query_feats, g_feat.permute(1,0))
123 |
124 | sorted_indices = np.argsort(dist_mat, axis=1)
125 |
126 | mAP += get_mAP(sorted_indices, query_ids, query_cam_ids, g_ids, g_cam_ids)
127 | cmc = get_cmc(sorted_indices, query_ids, query_cam_ids, g_ids, g_cam_ids)
128 |
129 | r1 += cmc[0]
130 | r5 += cmc[4]
131 | r10 += cmc[9]
132 | r20 += cmc[19]
133 |
134 | r1 = r1 / num_trials * 100
135 | r5 = r5 / num_trials * 100
136 | r10 = r10 / num_trials * 100
137 | r20 = r20 / num_trials * 100
138 | mAP = mAP / num_trials * 100
139 |
140 | perf = '{} num-shot:{} r1 precision = {:.2f} , r10 precision = {:.2f} , r20 precision = {:.2f}, mAP = {:.2f}'
141 | logging.info(perf.format(mode, num_shots, r1, r10, r20, mAP))
142 |
143 | return mAP, r1, r5, r10, r20
144 |
--------------------------------------------------------------------------------
/IDKL/utils/rerank.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | def k_reciprocal_neigh( initial_rank, i, k1):
5 | forward_k_neigh_index = initial_rank[i,:k1+1]
6 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1]
7 | fi = np.where(backward_k_neigh_index==i)[0]
8 | return forward_k_neigh_index[fi]
9 |
10 | def pairwise_distance(query_features, gallery_features):
11 | x = query_features
12 | y = gallery_features
13 | m, n = x.size(0), y.size(0)
14 | x = x.view(m, -1)
15 | y = y.view(n, -1)
16 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \
17 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
18 | dist.addmm_(1, -2, x, y.t())
19 | return dist
20 |
21 | def re_ranking(q_feat, g_feat, k1=20, k2=6, lambda_value=0.3, eval_type=True):
22 | # The following naming, e.g. gallery_num, is different from outer scope.
23 | # Don't care about it.
24 | feats = torch.cat([q_feat, g_feat], 0)
25 | dist = pairwise_distance(feats, feats)
26 | original_dist = dist.detach().cpu().numpy()
27 | all_num = original_dist.shape[0]
28 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0))
29 | V = np.zeros_like(original_dist).astype(np.float16)
30 |
31 | query_num = q_feat.size(0)
32 | all_num = original_dist.shape[0]
33 | if eval_type:
34 | dist[:, query_num:] = dist.max()
35 | dist = dist.detach().cpu().numpy()
36 | initial_rank = np.argsort(dist).astype(np.int32)
37 |
38 | # print("start re-ranking")
39 | for i in range(all_num):
40 | # k-reciprocal neighbors
41 | forward_k_neigh_index = initial_rank[i, :k1 + 1]
42 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1]
43 | fi = np.where(backward_k_neigh_index == i)[0]
44 | k_reciprocal_index = forward_k_neigh_index[fi]
45 | k_reciprocal_expansion_index = k_reciprocal_index
46 | for j in range(len(k_reciprocal_index)):
47 | candidate = k_reciprocal_index[j]
48 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1]
49 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,
50 | :int(np.around(k1 / 2)) + 1]
51 | # import pdb
52 | # pdb.set_trace()
53 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0]
54 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate]
55 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len(
56 | candidate_k_reciprocal_index):
57 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index)
58 |
59 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
60 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index])
61 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight)
62 | original_dist = original_dist[:query_num, ]
63 | if k2 != 1:
64 | V_qe = np.zeros_like(V, dtype=np.float16)
65 | for i in range(all_num):
66 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0)
67 | V = V_qe
68 | del V_qe
69 | del initial_rank
70 | invIndex = []
71 | for i in range(all_num):
72 | invIndex.append(np.where(V[:, i] != 0)[0])
73 |
74 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16)
75 |
76 |
77 | for i in range(query_num):
78 | temp_min = np.zeros(shape=[1, all_num], dtype=np.float16)
79 | indNonZero = np.where(V[i, :] != 0)[0]
80 | indImages = []
81 | indImages = [invIndex[ind] for ind in indNonZero]
82 | for j in range(len(indNonZero)):
83 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]],
84 | V[indImages[j], indNonZero[j]])
85 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min)
86 |
87 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value
88 | del original_dist
89 | del V
90 | del jaccard_dist
91 | final_dist = final_dist[:query_num, query_num:]
92 | return final_dist
--------------------------------------------------------------------------------
/IDKL/utils/tsne.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | import scipy.io as sio
5 | import matplotlib as mpl
6 |
7 | mpl.use('AGG')
8 | import matplotlib.pyplot as plt
9 | from sklearn.manifold import TSNE
10 |
11 | if __name__ == '__main__':
12 | test_ids = [
13 | 6, 10, 17, 21, 24, 25, 27, 28, 31, 34, 36, 37, 40, 41, 42, 43, 44, 45, 49, 50, 51, 54, 63, 69, 75, 80, 81, 82,
14 | 83, 84, 85, 86, 87, 88, 89, 90, 93, 102, 104, 105, 106, 108, 112, 116, 117, 122, 125, 129, 130, 134, 138, 139,
15 | 150, 152, 162, 166, 167, 170, 172, 176, 185, 190, 192, 202, 204, 207, 210, 215, 223, 229, 232, 237, 252, 253,
16 | 257, 259, 263, 266, 269, 272, 273, 274, 275, 282, 285, 291, 300, 301, 302, 303, 307, 312, 315, 318, 331, 333
17 | ]
18 | random.seed(0)
19 | tsne = TSNE(n_components=2, init='pca')
20 | selected_ids = random.sample(test_ids, 20)
21 | plt.figure(figsize=(5, 5))
22 |
23 | # features without dual path
24 | q_mat_path = 'features/sysu/query-sysu-test-nodual-nore-adam-16x8-grey_model_150.mat'
25 | g_mat_path = 'features/sysu/gallery-sysu-test-nodual-nore-adam-16x8-grey_model_150.mat'
26 |
27 | mat = sio.loadmat(q_mat_path)
28 | q_feats = mat["feat"]
29 | q_ids = mat["ids"].squeeze()
30 | flag = np.in1d(q_ids, selected_ids)
31 | q_feats = q_feats[flag]
32 |
33 | mat = sio.loadmat(g_mat_path)
34 | g_feats = mat["feat"]
35 | g_ids = mat["ids"].squeeze()
36 | flag = np.in1d(g_ids, selected_ids)
37 | g_feats = g_feats[flag]
38 |
39 | embed = tsne.fit_transform(np.concatenate([q_feats, g_feats], axis=0))
40 | c = ['r'] * q_feats.shape[0] + ['b'] * g_feats.shape[0]
41 | # plt.subplot(1, 2, 1)
42 | plt.scatter(embed[:, 0], embed[:, 1], c=c)
43 |
44 | # # features with dual path
45 | # q_mat_path = 'features/sysu/query-sysu-test-dual-nore-separatelayer12-0.05_model_30.mat'
46 | # g_mat_path = 'features/sysu/gallery-sysu-test-dual-nore-separatelayer12-0.05_model_30.mat'
47 | #
48 | # mat = sio.loadmat(q_mat_path)
49 | # q_feats = mat["feat"]
50 | # q_ids = mat["ids"].squeeze()
51 | # flag = np.in1d(q_ids, selected_ids)
52 | # q_feats = q_feats[flag]
53 | #
54 | # mat = sio.loadmat(g_mat_path)
55 | # g_feats = mat["feat"]
56 | # g_ids = mat["ids"].squeeze()
57 | # flag = np.in1d(g_ids, selected_ids)
58 | # g_feats = g_feats[flag]
59 | #
60 | # embed = tsne.fit_transform(np.concatenate([q_feats, g_feats], axis=0))
61 | # c = ['r'] * q_feats.shape[0] + ['b'] * g_feats.shape[0]
62 | # plt.subplot(1, 2, 2)
63 | # plt.scatter(embed[:, 0], embed[:, 1], c=c)
64 |
65 | plt.tight_layout()
66 | plt.savefig('tsne-adv-layer2-separate-l2.jpg')
67 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [CVPR2024]IDKL: Implicit Discriminative Knowledge Learning for Visible-Infrared Person Re-Identification. (https://arxiv.org/abs/2403.11708)
2 |
3 | ## Environmental requirements:
4 |
5 | PyTorch == 1.10.0
6 |
7 | ignite == 0.2.1
8 |
9 | torchvision == 0.11.2
10 |
11 | apex == 0.1
12 |
13 | **Training:**
14 |
15 | To train the model, you can use following command:
16 |
17 | SYSU-MM01:
18 | ```Shell
19 | python train.py --cfg ./configs/SYSU.yml
20 | ```
21 |
22 | RegDB:
23 | ```Shell
24 | python train.py --cfg ./configs/RegDB.yml
25 | ```
26 |
27 | RegDB:
28 | ```Shell
29 | python train.py --cfg ./configs/RegDB.yml
30 | ```
31 |
32 |
33 |
--------------------------------------------------------------------------------