├── ChannelAug.py ├── GAT.py ├── HyperGraphs.py ├── README.md ├── Transformer.py ├── data_loader.py ├── data_manager.py ├── eval_metrics.py ├── loss.py ├── model.py ├── model_sle.py ├── model_sle_hsl.py ├── model_sle_hsl_cfl.py ├── pre_process_sysu.py ├── resnet.py ├── test_sle.py ├── test_sle_hsl.py ├── test_sle_hsl_cfl.py ├── train_hos_net.py ├── train_sle.py ├── train_sle_hsl.py ├── train_sle_hsl_cfl.py ├── utils.py └── whitening.py /ChannelAug.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | 5 | #from PIL import Image 6 | import random 7 | import math 8 | #import numpy as np 9 | #import torch 10 | 11 | 12 | class ChannelAdap(object): 13 | """ Adaptive selects a channel or two channels. 14 | Args: 15 | probability: The probability that the Random Erasing operation will be performed. 16 | sl: Minimum proportion of erased area against input image. 17 | sh: Maximum proportion of erased area against input image. 18 | r1: Minimum aspect ratio of erased area. 19 | mean: Erasing value. 20 | """ 21 | 22 | def __init__(self, probability = 0.5): 23 | self.probability = probability 24 | 25 | 26 | def __call__(self, img): 27 | 28 | # if random.uniform(0, 1) > self.probability: 29 | # return img 30 | 31 | idx = random.randint(0, 3) 32 | 33 | if idx ==0: 34 | # random select R Channel 35 | img[1, :,:] = img[0,:,:] 36 | img[2, :,:] = img[0,:,:] 37 | elif idx ==1: 38 | # random select B Channel 39 | img[0, :,:] = img[1,:,:] 40 | img[2, :,:] = img[1,:,:] 41 | elif idx ==2: 42 | # random select G Channel 43 | img[0, :,:] = img[2,:,:] 44 | img[1, :,:] = img[2,:,:] 45 | else: 46 | img = img 47 | 48 | return img 49 | 50 | 51 | class ChannelAdapGray(object): 52 | """ Adaptive selects a channel or two channels. 53 | Args: 54 | probability: The probability that the Random Erasing operation will be performed. 55 | sl: Minimum proportion of erased area against input image. 56 | sh: Maximum proportion of erased area against input image. 57 | r1: Minimum aspect ratio of erased area. 58 | mean: Erasing value. 59 | """ 60 | 61 | def __init__(self, probability = 0.5): 62 | self.probability = probability 63 | 64 | 65 | def __call__(self, img): 66 | 67 | # if random.uniform(0, 1) > self.probability: 68 | # return img 69 | 70 | idx = random.randint(0, 3) 71 | 72 | if idx ==0: 73 | # random select R Channel 74 | img[1, :,:] = img[0,:,:] 75 | img[2, :,:] = img[0,:,:] 76 | elif idx ==1: 77 | # random select B Channel 78 | img[0, :,:] = img[1,:,:] 79 | img[2, :,:] = img[1,:,:] 80 | elif idx ==2: 81 | # random select G Channel 82 | img[0, :,:] = img[2,:,:] 83 | img[1, :,:] = img[2,:,:] 84 | else: 85 | if random.uniform(0, 1) > self.probability: 86 | # return img 87 | img = img 88 | else: 89 | tmp_img = 0.2989 * img[0,:,:] + 0.5870 * img[1,:,:] + 0.1140 * img[2,:,:] 90 | img[0,:,:] = tmp_img 91 | img[1,:,:] = tmp_img 92 | img[2,:,:] = tmp_img 93 | return img 94 | 95 | class ChannelRandomErasing(object): 96 | """ Randomly selects a rectangle region in an image and erases its pixels. 97 | 'Random Erasing Data Augmentation' by Zhong et al. 98 | See https://arxiv.org/pdf/1708.04896.pdf 99 | Args: 100 | probability: The probability that the Random Erasing operation will be performed. 101 | sl: Minimum proportion of erased area against input image. 102 | sh: Maximum proportion of erased area against input image. 103 | r1: Minimum aspect ratio of erased area. 104 | mean: Erasing value. 105 | """ 106 | 107 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]): 108 | 109 | self.probability = probability 110 | self.mean = mean 111 | self.sl = sl 112 | self.sh = sh 113 | self.r1 = r1 114 | 115 | def __call__(self, img): 116 | 117 | if random.uniform(0, 1) > self.probability: 118 | return img 119 | 120 | for attempt in range(100): 121 | area = img.size()[1] * img.size()[2] 122 | 123 | target_area = random.uniform(self.sl, self.sh) * area 124 | aspect_ratio = random.uniform(self.r1, 1/self.r1) 125 | 126 | h = int(round(math.sqrt(target_area * aspect_ratio))) 127 | w = int(round(math.sqrt(target_area / aspect_ratio))) 128 | 129 | if w < img.size()[2] and h < img.size()[1]: 130 | x1 = random.randint(0, img.size()[1] - h) 131 | y1 = random.randint(0, img.size()[2] - w) 132 | if img.size()[0] == 3: 133 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 134 | img[1, x1:x1+h, y1:y1+w] = self.mean[1] 135 | img[2, x1:x1+h, y1:y1+w] = self.mean[2] 136 | else: 137 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 138 | return img 139 | 140 | return img -------------------------------------------------------------------------------- /GAT.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | # References : https://github.com/ohhhyeahhh/SiamGAT 4 | # -------------------------------------------------------- 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch 8 | 9 | class Graph_Attention_Union(nn.Module): 10 | def __init__(self, in_channel, out_channel, meanw): 11 | super(Graph_Attention_Union, self).__init__() 12 | self.meanw = meanw 13 | 14 | # search region nodes linear transformation 15 | self.support = nn.Conv2d(in_channel, in_channel, 1, 1) 16 | 17 | # target template nodes linear transformation 18 | self.query = nn.Conv2d(in_channel, in_channel, 1, 1) 19 | 20 | # linear transformation for message passing 21 | self.g = nn.Sequential( 22 | nn.Conv2d(in_channel, in_channel, 1, 1), 23 | nn.BatchNorm2d(in_channel), 24 | nn.ReLU(inplace=True), 25 | ) 26 | 27 | 28 | self.init_weights() 29 | 30 | def init_weights(self): 31 | for n, m in self.named_modules(): 32 | if isinstance(m, nn.Conv2d): 33 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 34 | elif isinstance(m, nn.BatchNorm2d): 35 | nn.init.ones_(m.weight) 36 | nn.init.zeros_(m.bias) 37 | 38 | def forward(self, zf, xf): 39 | # linear transformation 40 | 41 | xf_trans = self.query(xf) 42 | zf_trans = self.support(zf) 43 | 44 | # linear transformation for message passing 45 | # xf_g = self.g(xf) 46 | zf_g = self.g(zf) 47 | 48 | # calculate similarity 49 | shape_x = xf_trans.shape 50 | shape_z = zf_trans.shape 51 | 52 | zf_trans_plain = zf_trans.view(-1, shape_z[1], shape_z[2] * shape_z[3]) 53 | zf_g_plain = zf_g.view(-1, shape_z[1], shape_z[2] * shape_z[3]).permute(0, 2, 1) 54 | xf_trans_plain = xf_trans.view(-1, shape_x[1], shape_x[2] * shape_x[3]).permute(0, 2, 1) 55 | 56 | similar = torch.matmul(xf_trans_plain, zf_trans_plain) 57 | similar = F.softmax(similar, dim=2) 58 | if self.meanw != 0.0: 59 | mean_ = torch.mean(similar, dim=[2], keepdim=True) 60 | 61 | similar = torch.where(similar > self.meanw*mean_, similar, 0) 62 | 63 | embedding = torch.matmul(similar, zf_g_plain).permute(0, 2, 1) 64 | embedding = embedding.view(-1, shape_x[1], shape_x[2], shape_x[3]) 65 | 66 | 67 | return embedding 68 | 69 | 70 | 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /HyperGraphs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | # References :https://github.com/GouravWadhwa/Hypergraphs-Image-Inpainting 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class HypergraphConv(nn.Module): 10 | def __init__( 11 | self, 12 | in_features=1024, 13 | out_features=1024, 14 | features_height=18, 15 | features_width=9, 16 | edges=256, 17 | filters=128, 18 | apply_bias=True, 19 | theta1 = 0.0 20 | ): 21 | super().__init__() 22 | 23 | self.in_features = in_features 24 | self.out_features = out_features 25 | self.features_height = features_height 26 | self.features_width = features_width 27 | self.vertices = self.features_height * self.features_width 28 | self.edges = edges 29 | self.apply_bias = apply_bias 30 | self.filters = filters 31 | self.theta1 = theta1 32 | 33 | self.phi_conv = nn.Conv2d(self.in_features, self.filters, kernel_size=1, stride=1, padding=0) 34 | self.A_conv = nn.Conv2d(self.in_features, self.filters, kernel_size=1, stride=1, padding=0) 35 | self.M_conv = nn.Conv2d(self.in_features, self.edges, kernel_size=7, stride=1, padding=3) 36 | 37 | self.weight_2 = nn.Parameter(torch.empty(self.in_features, self.out_features)) 38 | nn.init.xavier_normal_(self.weight_2) 39 | 40 | if apply_bias: 41 | self.bias_2 = nn.Parameter(torch.empty(1, self.out_features)) 42 | nn.init.xavier_normal_(self.bias_2) 43 | 44 | def forward(self, x): 45 | phi = self.phi_conv(x) 46 | phi = torch.permute(phi, (0, 2, 3, 1)).contiguous() 47 | phi = phi.view(-1, self.vertices, self.filters) 48 | 49 | A = F.avg_pool2d(x, kernel_size=(self.features_height, self.features_width)) 50 | A = self.A_conv(A) 51 | A = torch.permute(A, (0, 2, 3, 1)).contiguous() 52 | 53 | A = torch.diag_embed(A.squeeze()) # checked 54 | 55 | M = self.M_conv(x) 56 | M = torch.permute(M, (0, 2, 3, 1)).contiguous() 57 | M = M.view(-1, self.vertices, self.edges) 58 | 59 | 60 | H = torch.matmul(phi, torch.matmul(A, torch.matmul(phi.transpose(1, 2), M))) 61 | H = torch.abs(H) 62 | 63 | if self.theta1 != 0.0: 64 | mean_H = self.theta1*torch.mean(H,dim=[1,2],keepdim=True) 65 | H = torch.where(H < mean_H, 0.0, H) 66 | D = H.sum(dim=2) 67 | D_H = torch.mul(torch.unsqueeze(torch.pow(D + 1e-10, -0.5), dim=-1), H) 68 | B = H.sum(dim=1) 69 | B = torch.diag_embed(torch.pow(B + 1e-10, -1)) 70 | x_ = torch.permute(x, (0, 2, 3, 1)).contiguous() 71 | features = x_.view(-1, self.vertices, self.in_features) 72 | 73 | out = features - torch.matmul(D_H, torch.matmul(B, torch.matmul(D_H.transpose(1, 2), features))) 74 | out = torch.matmul(out, self.weight_2) 75 | 76 | if self.apply_bias: 77 | out = out + self.bias_2 78 | out = torch.permute(out, (0, 2, 1)).contiguous() 79 | out = out.view(-1, self.out_features, self.features_height, self.features_width) 80 | 81 | return out -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # High-Order Structure Based Middle-Feature Learning for Visible-Infrared Person Re-Identification (AAAI 2024) 2 | 3 | 4 | 5 | 6 | ### 1. Prepare the datasets. 7 | 8 | 9 | - (1) SYSU-MM01 Dataset [1]: The SYSU-MM01 dataset can be downloaded from this [website](http://isee.sysu.edu.cn/project/RGBIRReID.htm). 10 | 11 | - run `python pre_process_sysu.py` to pepare the dataset, the training data will be stored in ".npy" format. 12 | 13 | 14 | - (2) RegDB Dataset [2]: The RegDB dataset can be downloaded from this [website](http://dm.dongguk.edu/link.html) by submitting a copyright form. 15 | 16 | - (Named: "Dongguk Body-based Person Recognition Database (DBPerson-Recog-DB1)" on their website). 17 | 18 | - A private download link can be requested via sending an email to mangye16@gmail.com. 19 | 20 | - (3) LLCM Dataset [3]: The LLCM dataset can be downloaded from this [website](https://github.com/ZYK100/LLCM) by submitting a copyright form. 21 | - Please send a signed dataset release agreement copy to zhangyk@stu.xmu.edu.cn 22 | 23 | ### 2. Training and Testing. 24 | Before train the hos-net, you can download the baseline ckpt from [CAJ](https://drive.google.com/drive/folders/107vztbRqim8-oQAuEBh_9S8oTKdhuV2j?usp=sharing) and put it in `./baseline/` The results might be better by finetuning the hyper-parameters. 25 | 26 | 27 | ```bash 28 | python train_hos_net.py 29 | ``` 30 | 31 | ckpt and log can be seen in `./sle_ckpt/`, `./sle_hsl_ckpt/`, and `./sle_hsl_cfl_ckpt/` 32 | 33 | 34 | ### 3. Citation 35 | 36 | Please kindly cite this paper in your publications if it helps your research: 37 | ``` 38 | @inproceedings{qiu2024high, 39 | title={High-Order Structure Based Middle-Feature Learning for Visible-Infrared Person Re-Identification}, 40 | author={Qiu, Liuxiang and Chen, Si and Yan, Yan and Xue, Jing-Hao and Wang, Da-Han and Zhu, Shunzhi}, 41 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 42 | volume={38}, 43 | number={5}, 44 | pages={4596--4604}, 45 | year={2024} 46 | } 47 | ``` 48 | 49 | ### 4. References. 50 | 51 | [1] A. Wu, W.-s. Zheng, H.-X. Yu, S. Gong, and J. Lai. Rgb-infrared crossmodality person re-identification. In IEEE International Conference on Computer Vision (ICCV), pages 5380–5389, 2017. 52 | 53 | [2] D. T. Nguyen, H. G. Hong, K. W. Kim, and K. R. Park. Person recognition system based on a combination of body images from visible light and thermal cameras. Sensors, 17(3):605, 2017. 54 | 55 | [3] Y. Zhang, H. Wang. Diverse embedding expansion network and low-light cross-modality benchmark for visible-infrared person re-identification. In IEEE Computer Vision and Pattern Pecognition (CVPR), pages 2153-2162, 2023. 56 | 57 | ### Questions 58 | 59 | Q1: How can we get the baseline checkpoints (e.g., [CAJ-SYSU](https://drive.google.com/drive/folders/107vztbRqim8-oQAuEBh_9S8oTKdhuV2j?usp=sharing))? 60 | 61 | A1: You can train the CAJ to get the base checkpoint ([CAJ-Code](https://github.com/mangye16/Cross-Modal-Re-ID-baseline/tree/master/ICCV21_CAJ)). 62 | 63 | ### Contact 64 | If you have any questions, please feel free to contact us. E-mail: liuxiangqiu007@gmail.com 65 | -------------------------------------------------------------------------------- /Transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from timm.models.vision_transformer import Block 18 | 19 | 20 | class ViT(nn.Module): 21 | """ VisionTransformer backbone 22 | """ 23 | 24 | def __init__(self, img_size=18 * 9, embed_dim=2048, depth=2, num_heads=4, 25 | mlp_ratio=4., norm_layer=nn.LayerNorm): 26 | super().__init__() 27 | 28 | # -------------------------------------------------------------------------- 29 | # MAE encoder specifics 30 | self.pos_embed = nn.Parameter(torch.zeros(1, img_size, embed_dim)) 31 | 32 | self.blocks = nn.ModuleList([ 33 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 34 | for i in range(depth)]) 35 | self.norm_1 = norm_layer(embed_dim) 36 | self.norm = norm_layer(embed_dim) 37 | 38 | self.initialize_weights() 39 | 40 | def initialize_weights(self): 41 | # initialization 42 | # initialize (and freeze) pos_embed by sin-cos embedding 43 | 44 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 45 | torch.nn.init.normal_(self.pos_embed, std=.02) 46 | 47 | # initialize nn.Linear and nn.LayerNorm 48 | self.apply(self._init_weights) 49 | 50 | def _init_weights(self, m): 51 | if isinstance(m, nn.Linear): 52 | # we use xavier_uniform following official JAX ViT: 53 | torch.nn.init.xavier_uniform_(m.weight) 54 | if isinstance(m, nn.Linear) and m.bias is not None: 55 | nn.init.constant_(m.bias, 0) 56 | elif isinstance(m, nn.LayerNorm): 57 | nn.init.constant_(m.bias, 0) 58 | nn.init.constant_(m.weight, 1.0) 59 | 60 | def forward(self, x): 61 | # embed patches 62 | b, c, h, w = x.shape 63 | x = self.norm_1(x.reshape(b, c, h * w).transpose(1, 2)) 64 | x = x + self.pos_embed 65 | for blk in self.blocks: 66 | x = blk(x) 67 | x = self.norm(x) 68 | x = x.transpose(1, 2).reshape(b, c, h, w) 69 | 70 | return x 71 | 72 | 73 | 74 | 75 | 76 | # model = ViT(img_size=18 * 9, embed_dim=1024) 77 | # x = torch.rand(3, 1024, 18, 9) 78 | # print(model(x).shape) 79 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch.utils.data as data 4 | from ChannelAug import ChannelAdap, ChannelAdapGray, ChannelRandomErasing 5 | import torchvision.transforms as transforms 6 | import random 7 | import math 8 | 9 | class ChannelExchange(object): 10 | """ Adaptive selects a channel or two channels. 11 | Args: 12 | probability: The probability that the Random Erasing operation will be performed. 13 | sl: Minimum proportion of erased area against input image. 14 | sh: Maximum proportion of erased area against input image. 15 | r1: Minimum aspect ratio of erased area. 16 | mean: Erasing value. 17 | """ 18 | 19 | def __init__(self, gray = 2): 20 | self.gray = gray 21 | 22 | def __call__(self, img): 23 | 24 | idx = random.randint(0, self.gray) 25 | 26 | if idx ==0: 27 | # random select R Channel 28 | img[1, :,:] = img[0,:,:] 29 | img[2, :,:] = img[0,:,:] 30 | elif idx ==1: 31 | # random select B Channel 32 | img[0, :,:] = img[1,:,:] 33 | img[2, :,:] = img[1,:,:] 34 | elif idx ==2: 35 | # random select G Channel 36 | img[0, :,:] = img[2,:,:] 37 | img[1, :,:] = img[2,:,:] 38 | else: 39 | tmp_img = 0.2989 * img[0,:,:] + 0.5870 * img[1,:,:] + 0.1140 * img[2,:,:] 40 | img[0,:,:] = tmp_img 41 | img[1,:,:] = tmp_img 42 | img[2,:,:] = tmp_img 43 | return img 44 | 45 | 46 | 47 | class SYSUData(data.Dataset): 48 | def __init__(self, data_dir, transform=None, colorIndex = None, thermalIndex = None): 49 | 50 | # Load training images (path) and labels 51 | train_color_image = np.load(data_dir + 'train_rgb_resized_img.npy') 52 | self.train_color_label = np.load(data_dir + 'train_rgb_resized_label.npy') 53 | 54 | train_thermal_image = np.load(data_dir + 'train_ir_resized_img.npy') 55 | self.train_thermal_label = np.load(data_dir + 'train_ir_resized_label.npy') 56 | 57 | # BGR to RGB 58 | self.train_color_image = train_color_image 59 | self.train_thermal_image = train_thermal_image 60 | self.transform = transform 61 | self.cIndex = colorIndex 62 | self.tIndex = thermalIndex 63 | 64 | 65 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 66 | self.transform_thermal = transforms.Compose( [ 67 | transforms.ToPILImage(), 68 | transforms.Pad(10), 69 | transforms.RandomCrop((288, 144)), 70 | transforms.RandomHorizontalFlip(), 71 | transforms.ToTensor(), 72 | normalize, 73 | ChannelRandomErasing(probability = 0.5), 74 | ChannelAdapGray(probability =0.5)]) 75 | 76 | self.transform_color = transforms.Compose( [ 77 | transforms.ToPILImage(), 78 | transforms.Pad(10), 79 | transforms.RandomCrop((288, 144)), 80 | transforms.RandomHorizontalFlip(), 81 | # transforms.RandomGrayscale(p = 0.1), 82 | transforms.ToTensor(), 83 | normalize, 84 | ChannelRandomErasing(probability = 0.5)]) 85 | 86 | self.transform_color1 = transforms.Compose( [ 87 | transforms.ToPILImage(), 88 | transforms.Pad(10), 89 | transforms.RandomCrop((288, 144)), 90 | transforms.RandomHorizontalFlip(), 91 | transforms.ToTensor(), 92 | normalize, 93 | ChannelRandomErasing(probability = 0.5), 94 | ChannelExchange(gray = 2)]) 95 | 96 | def __getitem__(self, index): 97 | 98 | img1, target1 = self.train_color_image[self.cIndex[index]], self.train_color_label[self.cIndex[index]] 99 | img2, target2 = self.train_thermal_image[self.tIndex[index]], self.train_thermal_label[self.tIndex[index]] 100 | 101 | img1_0 = self.transform_color(img1) 102 | img1_1 = self.transform_color1(img1) 103 | img2 = self.transform_thermal(img2) 104 | 105 | return img1_0, img1_1, img2, target1, target2 106 | 107 | def __len__(self): 108 | return len(self.train_color_label) 109 | 110 | 111 | class RegDBData(data.Dataset): 112 | def __init__(self, data_dir, trial, transform=None, colorIndex = None, thermalIndex = None): 113 | # Load training images (path) and labels 114 | train_color_list = data_dir + 'idx/train_visible_{}'.format(trial)+ '.txt' 115 | train_thermal_list = data_dir + 'idx/train_thermal_{}'.format(trial)+ '.txt' 116 | 117 | color_img_file, train_color_label = load_data(train_color_list) 118 | thermal_img_file, train_thermal_label = load_data(train_thermal_list) 119 | 120 | train_color_image = [] 121 | for i in range(len(color_img_file)): 122 | 123 | img = Image.open(data_dir+ color_img_file[i]) 124 | img = img.resize((144, 288), Image.ANTIALIAS) 125 | pix_array = np.array(img) 126 | train_color_image.append(pix_array) 127 | train_color_image = np.array(train_color_image) 128 | 129 | train_thermal_image = [] 130 | for i in range(len(thermal_img_file)): 131 | img = Image.open(data_dir+ thermal_img_file[i]) 132 | img = img.resize((144, 288), Image.ANTIALIAS) 133 | pix_array = np.array(img) 134 | train_thermal_image.append(pix_array) 135 | train_thermal_image = np.array(train_thermal_image) 136 | 137 | # BGR to RGB 138 | self.train_color_image = train_color_image 139 | self.train_color_label = train_color_label 140 | 141 | # BGR to RGB 142 | self.train_thermal_image = train_thermal_image 143 | self.train_thermal_label = train_thermal_label 144 | 145 | self.transform = transform 146 | self.cIndex = colorIndex 147 | self.tIndex = thermalIndex 148 | 149 | 150 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 151 | self.transform_thermal = transforms.Compose( [ 152 | transforms.ToPILImage(), 153 | transforms.Pad(10), 154 | transforms.RandomCrop((288, 144)), 155 | transforms.RandomHorizontalFlip(), 156 | transforms.ToTensor(), 157 | normalize, 158 | ChannelRandomErasing(probability = 0.5), 159 | ChannelAdapGray(probability =0.5)]) 160 | 161 | self.transform_color = transforms.Compose( [ 162 | transforms.ToPILImage(), 163 | transforms.Pad(10), 164 | transforms.RandomCrop((288, 144)), 165 | transforms.RandomHorizontalFlip(), 166 | # transforms.RandomGrayscale(p = 0.1), 167 | transforms.ToTensor(), 168 | normalize, 169 | ChannelRandomErasing(probability = 0.5)]) 170 | 171 | self.transform_color1 = transforms.Compose( [ 172 | transforms.ToPILImage(), 173 | transforms.Pad(10), 174 | transforms.RandomCrop((288, 144)), 175 | transforms.RandomHorizontalFlip(), 176 | transforms.ToTensor(), 177 | normalize, 178 | ChannelRandomErasing(probability = 0.5), 179 | ChannelExchange(gray = 2)]) 180 | 181 | def __getitem__(self, index): 182 | 183 | img1, target1 = self.train_color_image[self.cIndex[index]], self.train_color_label[self.cIndex[index]] 184 | img2, target2 = self.train_thermal_image[self.tIndex[index]], self.train_thermal_label[self.tIndex[index]] 185 | 186 | img1_0 = self.transform_color(img1) 187 | img1_1 = self.transform_color1(img1) 188 | img2 = self.transform_thermal(img2) 189 | 190 | return img1_0, img1_1, img2, target1, target2 191 | 192 | def __len__(self): 193 | return len(self.train_color_label) 194 | 195 | 196 | class LLCMData(data.Dataset): 197 | def __init__(self, data_dir, trial, transform=None, colorIndex=None, thermalIndex=None): 198 | # Load training images (path) and labels 199 | train_color_list = data_dir + 'idx/train_vis.txt' 200 | train_thermal_list = data_dir + 'idx/train_nir.txt' 201 | 202 | color_img_file, train_color_label = load_data(train_color_list) 203 | thermal_img_file, train_thermal_label = load_data(train_thermal_list) 204 | 205 | train_color_image = [] 206 | for i in range(len(color_img_file)): 207 | img = Image.open(data_dir + color_img_file[i]) 208 | img = img.resize((144, 288), Image.ANTIALIAS) 209 | pix_array = np.array(img) 210 | train_color_image.append(pix_array) 211 | train_color_image = np.array(train_color_image) 212 | 213 | train_thermal_image = [] 214 | for i in range(len(thermal_img_file)): 215 | img = Image.open(data_dir + thermal_img_file[i]) 216 | img = img.resize((144, 288), Image.ANTIALIAS) 217 | pix_array = np.array(img) 218 | train_thermal_image.append(pix_array) 219 | # print(pix_array.shape) 220 | train_thermal_image = np.array(train_thermal_image) 221 | 222 | # BGR to RGB 223 | self.train_color_image = train_color_image 224 | self.train_color_label = train_color_label 225 | 226 | # BGR to RGB 227 | self.train_thermal_image = train_thermal_image 228 | self.train_thermal_label = train_thermal_label 229 | 230 | self.transform = transform 231 | self.cIndex = colorIndex 232 | self.tIndex = thermalIndex 233 | 234 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 235 | self.transform_thermal = transforms.Compose([ 236 | transforms.ToPILImage(), 237 | transforms.Pad(10), 238 | transforms.RandomCrop((288, 144)), 239 | transforms.RandomHorizontalFlip(), 240 | transforms.ToTensor(), 241 | normalize, 242 | ChannelRandomErasing(probability=0.5), 243 | ChannelAdapGray(probability=0.5)]) 244 | 245 | self.transform_color = transforms.Compose([ 246 | transforms.ToPILImage(), 247 | transforms.Pad(10), 248 | transforms.RandomCrop((288, 144)), 249 | transforms.RandomHorizontalFlip(), 250 | # transforms.RandomGrayscale(p = 0.1), 251 | transforms.ToTensor(), 252 | normalize, 253 | ChannelRandomErasing(probability=0.5)]) 254 | 255 | self.transform_color1 = transforms.Compose([ 256 | transforms.ToPILImage(), 257 | transforms.Pad(10), 258 | transforms.RandomCrop((288, 144)), 259 | transforms.RandomHorizontalFlip(), 260 | transforms.ToTensor(), 261 | normalize, 262 | ChannelRandomErasing(probability=0.5), 263 | ChannelExchange(gray=2)]) 264 | 265 | def __getitem__(self, index): 266 | 267 | img1, target1 = self.train_color_image[self.cIndex[index]], self.train_color_label[self.cIndex[index]] 268 | img2, target2 = self.train_thermal_image[self.tIndex[index]], self.train_thermal_label[self.tIndex[index]] 269 | 270 | img1_0 = self.transform_color(img1) 271 | img1_1 = self.transform_color1(img1) 272 | img2 = self.transform_thermal(img2) 273 | 274 | return img1_0, img1_1, img2, target1, target2 275 | 276 | def __len__(self): 277 | return len(self.train_color_label) 278 | 279 | 280 | 281 | 282 | class TestData(data.Dataset): 283 | def __init__(self, test_img_file, test_label, transform=None, img_size = (144,288)): 284 | 285 | test_image = [] 286 | for i in range(len(test_img_file)): 287 | img = Image.open(test_img_file[i]) 288 | img = img.resize((img_size[0], img_size[1]), Image.ANTIALIAS) 289 | pix_array = np.array(img) 290 | test_image.append(pix_array) 291 | test_image = np.array(test_image) 292 | self.test_image = test_image 293 | self.test_label = test_label 294 | self.transform = transform 295 | 296 | def __getitem__(self, index): 297 | img1, target1 = self.test_image[index], self.test_label[index] 298 | img1 = self.transform(img1) 299 | return img1, target1 300 | 301 | def __len__(self): 302 | return len(self.test_image) 303 | 304 | class TestDataOld(data.Dataset): 305 | def __init__(self, data_dir, test_img_file, test_label, transform=None, img_size = (144,288)): 306 | 307 | test_image = [] 308 | for i in range(len(test_img_file)): 309 | img = Image.open(data_dir + test_img_file[i]) 310 | img = img.resize((img_size[0], img_size[1]), Image.ANTIALIAS) 311 | pix_array = np.array(img) 312 | test_image.append(pix_array) 313 | test_image = np.array(test_image) 314 | self.test_image = test_image 315 | self.test_label = test_label 316 | self.transform = transform 317 | 318 | def __getitem__(self, index): 319 | img1, target1 = self.test_image[index], self.test_label[index] 320 | img1 = self.transform(img1) 321 | return img1, target1 322 | 323 | def __len__(self): 324 | return len(self.test_image) 325 | def load_data(input_data_path ): 326 | with open(input_data_path) as f: 327 | data_file_list = open(input_data_path, 'rt').read().splitlines() 328 | # Get full list of image and labels 329 | file_image = [s.split(' ')[0] for s in data_file_list] 330 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 331 | 332 | return file_image, file_label -------------------------------------------------------------------------------- /data_manager.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | import numpy as np 4 | import random 5 | 6 | 7 | def process_query_llcm(data_path, mode=1, relabel=False): 8 | if mode == 1: 9 | cameras = ['test_vis/cam1', 'test_vis/cam2', 'test_vis/cam3', 'test_vis/cam4', 'test_vis/cam5', 'test_vis/cam6', 10 | 'test_vis/cam7', 'test_vis/cam8', 'test_vis/cam9'] 11 | elif mode == 2: 12 | cameras = ['test_nir/cam1', 'test_nir/cam2', 'test_nir/cam4', 'test_nir/cam5', 'test_nir/cam6', 'test_nir/cam7', 13 | 'test_nir/cam8', 'test_nir/cam9'] 14 | 15 | file_path = os.path.join(data_path, 'idx/test_id.txt') 16 | files_rgb = [] 17 | files_ir = [] 18 | 19 | with open(file_path, 'r') as file: 20 | ids = file.read().splitlines() 21 | ids = [int(y) for y in ids[0].split(',')] 22 | ids = ["%04d" % x for x in ids] 23 | 24 | for id in sorted(ids): 25 | for cam in cameras: 26 | img_dir = os.path.join(data_path, cam, id) 27 | if os.path.isdir(img_dir): 28 | new_files = sorted([img_dir + '/' + i for i in os.listdir(img_dir)]) 29 | files_ir.extend(new_files) 30 | query_img = [] 31 | query_id = [] 32 | query_cam = [] 33 | for img_path in files_ir: 34 | camid, pid = int(img_path.split('cam')[1][0]), int(img_path.split('cam')[1][2:6]) 35 | query_img.append(img_path) 36 | query_id.append(pid) 37 | query_cam.append(camid) 38 | return query_img, np.array(query_id), np.array(query_cam) 39 | 40 | 41 | def process_gallery_llcm(data_path, mode=1, trial=0, relabel=False): 42 | random.seed(trial) 43 | 44 | if mode == 1: 45 | cameras = ['test_vis/cam1', 'test_vis/cam2', 'test_vis/cam3', 'test_vis/cam4', 'test_vis/cam5', 'test_vis/cam6', 46 | 'test_vis/cam7', 'test_vis/cam8', 'test_vis/cam9'] 47 | elif mode == 2: 48 | cameras = ['test_nir/cam1', 'test_nir/cam2', 'test_nir/cam4', 'test_nir/cam5', 'test_nir/cam6', 'test_nir/cam7', 49 | 'test_nir/cam8', 'test_nir/cam9'] 50 | 51 | file_path = os.path.join(data_path, 'idx/test_id.txt') 52 | files_rgb = [] 53 | with open(file_path, 'r') as file: 54 | ids = file.read().splitlines() 55 | ids = [int(y) for y in ids[0].split(',')] 56 | ids = ["%04d" % x for x in ids] 57 | 58 | for id in sorted(ids): 59 | for cam in cameras: 60 | img_dir = os.path.join(data_path, cam, id) 61 | if os.path.isdir(img_dir): 62 | new_files = sorted([img_dir + '/' + i for i in os.listdir(img_dir)]) 63 | files_rgb.append(random.choice(new_files)) 64 | gall_img = [] 65 | gall_id = [] 66 | gall_cam = [] 67 | for img_path in files_rgb: 68 | camid, pid = int(img_path.split('cam')[1][0]), int(img_path.split('cam')[1][2:6]) 69 | gall_img.append(img_path) 70 | gall_id.append(pid) 71 | gall_cam.append(camid) 72 | return gall_img, np.array(gall_id), np.array(gall_cam) 73 | 74 | 75 | 76 | def process_query_sysu(data_path, mode = 'all', relabel=False): 77 | if mode== 'all': 78 | ir_cameras = ['cam3','cam6'] 79 | elif mode =='indoor': 80 | ir_cameras = ['cam3','cam6'] 81 | 82 | file_path = os.path.join(data_path,'exp/test_id.txt') 83 | files_rgb = [] 84 | files_ir = [] 85 | 86 | with open(file_path, 'r') as file: 87 | ids = file.read().splitlines() 88 | ids = [int(y) for y in ids[0].split(',')] 89 | ids = ["%04d" % x for x in ids] 90 | 91 | for id in sorted(ids): 92 | for cam in ir_cameras: 93 | img_dir = os.path.join(data_path,cam,id) 94 | if os.path.isdir(img_dir): 95 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 96 | files_ir.extend(new_files) 97 | query_img = [] 98 | query_id = [] 99 | query_cam = [] 100 | for img_path in files_ir: 101 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 102 | query_img.append(img_path) 103 | query_id.append(pid) 104 | query_cam.append(camid) 105 | return query_img, np.array(query_id), np.array(query_cam) 106 | 107 | def process_gallery_sysu(data_path, mode = 'all', trial = 0, relabel=False): 108 | 109 | random.seed(trial) 110 | 111 | if mode== 'all': 112 | rgb_cameras = ['cam1','cam2','cam4','cam5'] 113 | elif mode =='indoor': 114 | rgb_cameras = ['cam1','cam2'] 115 | 116 | file_path = os.path.join(data_path,'exp/test_id.txt') 117 | files_rgb = [] 118 | with open(file_path, 'r') as file: 119 | ids = file.read().splitlines() 120 | ids = [int(y) for y in ids[0].split(',')] 121 | ids = ["%04d" % x for x in ids] 122 | 123 | for id in sorted(ids): 124 | for cam in rgb_cameras: 125 | img_dir = os.path.join(data_path,cam,id) 126 | if os.path.isdir(img_dir): 127 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 128 | files_rgb.append(random.choice(new_files)) 129 | gall_img = [] 130 | gall_id = [] 131 | gall_cam = [] 132 | for img_path in files_rgb: 133 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 134 | gall_img.append(img_path) 135 | gall_id.append(pid) 136 | gall_cam.append(camid) 137 | return gall_img, np.array(gall_id), np.array(gall_cam) 138 | 139 | 140 | 141 | 142 | def process_test_regdb(img_dir, trial = 1, modal = 'visible'): 143 | if modal=='visible': 144 | input_data_path = img_dir + 'idx/test_visible_{}'.format(trial) + '.txt' 145 | elif modal=='thermal': 146 | input_data_path = img_dir + 'idx/test_thermal_{}'.format(trial) + '.txt' 147 | 148 | with open(input_data_path) as f: 149 | data_file_list = open(input_data_path, 'rt').read().splitlines() 150 | # Get full list of image and labels 151 | file_image = [img_dir + '/' + s.split(' ')[0] for s in data_file_list] 152 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 153 | 154 | return file_image, np.array(file_label) -------------------------------------------------------------------------------- /eval_metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import numpy as np 3 | """Cross-Modality ReID""" 4 | import pdb 5 | 6 | 7 | def eval_llcm(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=20): 8 | """Evaluation with sysu metric 9 | Key: for each query identity, its gallery images from the same camera view are discarded. "Following the original setting in ite dataset" 10 | """ 11 | num_q, num_g = distmat.shape 12 | if num_g < max_rank: 13 | max_rank = num_g 14 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 15 | indices = np.argsort(distmat, axis=1) 16 | pred_label = g_pids[indices] 17 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 18 | 19 | # compute cmc curve for each query 20 | new_all_cmc = [] 21 | all_cmc = [] 22 | all_AP = [] 23 | all_INP = [] 24 | num_valid_q = 0. # number of valid query 25 | for q_idx in range(num_q): 26 | # get query pid and camid 27 | q_pid = q_pids[q_idx] 28 | q_camid = q_camids[q_idx] 29 | 30 | # remove gallery samples that have the same pid and camid with query 31 | 32 | order = indices[q_idx] 33 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 34 | keep = np.invert(remove) 35 | 36 | # compute cmc curve 37 | # the cmc calculation is different from standard protocol 38 | # we follow the protocol of the author's released code 39 | new_cmc = pred_label[q_idx][keep] 40 | new_index = np.unique(new_cmc, return_index=True)[1] 41 | 42 | new_cmc = [new_cmc[index] for index in sorted(new_index)] 43 | 44 | new_match = (new_cmc == q_pid).astype(np.int32) 45 | new_cmc = new_match.cumsum() 46 | new_all_cmc.append(new_cmc[:max_rank]) 47 | 48 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 49 | if not np.any(orig_cmc): 50 | # this condition is true when query identity does not appear in gallery 51 | continue 52 | 53 | cmc = orig_cmc.cumsum() 54 | 55 | # compute mINP 56 | # refernece Deep Learning for Person Re-identification: A Survey and Outlook 57 | pos_idx = np.where(orig_cmc == 1) 58 | pos_max_idx = np.max(pos_idx) 59 | inp = cmc[pos_max_idx] / (pos_max_idx + 1.0) 60 | all_INP.append(inp) 61 | 62 | cmc[cmc > 1] = 1 63 | 64 | all_cmc.append(cmc[:max_rank]) 65 | num_valid_q += 1. 66 | 67 | # compute average precision 68 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 69 | num_rel = orig_cmc.sum() 70 | tmp_cmc = orig_cmc.cumsum() 71 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 72 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 73 | AP = tmp_cmc.sum() / num_rel 74 | all_AP.append(AP) 75 | 76 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 77 | 78 | all_cmc = np.asarray(all_cmc).astype(np.float32) 79 | all_cmc = all_cmc.sum(0) / num_valid_q # standard CMC 80 | 81 | new_all_cmc = np.asarray(new_all_cmc).astype(np.float32) 82 | new_all_cmc = new_all_cmc.sum(0) / num_valid_q 83 | mAP = np.mean(all_AP) 84 | mINP = np.mean(all_INP) 85 | return new_all_cmc, mAP, mINP 86 | 87 | 88 | 89 | 90 | 91 | def eval_sysu(distmat, q_pids, g_pids, q_camids, g_camids, max_rank = 20): 92 | """Evaluation with sysu metric 93 | Key: for each query identity, its gallery images from the same camera view are discarded. "Following the original setting in ite dataset" 94 | """ 95 | num_q, num_g = distmat.shape 96 | if num_g < max_rank: 97 | max_rank = num_g 98 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 99 | indices = np.argsort(distmat, axis=1) 100 | pred_label = g_pids[indices] 101 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 102 | 103 | # compute cmc curve for each query 104 | new_all_cmc = [] 105 | all_cmc = [] 106 | all_AP = [] 107 | all_INP = [] 108 | num_valid_q = 0. # number of valid query 109 | for q_idx in range(num_q): 110 | # get query pid and camid 111 | q_pid = q_pids[q_idx] 112 | q_camid = q_camids[q_idx] 113 | 114 | # remove gallery samples that have the same pid and camid with query 115 | order = indices[q_idx] 116 | remove = (q_camid == 3) & (g_camids[order] == 2) 117 | keep = np.invert(remove) 118 | 119 | # compute cmc curve 120 | # the cmc calculation is different from standard protocol 121 | # we follow the protocol of the author's released code 122 | new_cmc = pred_label[q_idx][keep] 123 | new_index = np.unique(new_cmc, return_index=True)[1] 124 | new_cmc = [new_cmc[index] for index in sorted(new_index)] 125 | 126 | new_match = (new_cmc == q_pid).astype(np.int32) 127 | new_cmc = new_match.cumsum() 128 | new_all_cmc.append(new_cmc[:max_rank]) 129 | 130 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 131 | if not np.any(orig_cmc): 132 | # this condition is true when query identity does not appear in gallery 133 | continue 134 | 135 | cmc = orig_cmc.cumsum() 136 | 137 | # compute mINP 138 | # refernece Deep Learning for Person Re-identification: A Survey and Outlook 139 | pos_idx = np.where(orig_cmc == 1) 140 | pos_max_idx = np.max(pos_idx) 141 | inp = cmc[pos_max_idx]/ (pos_max_idx + 1.0) 142 | all_INP.append(inp) 143 | 144 | cmc[cmc > 1] = 1 145 | 146 | all_cmc.append(cmc[:max_rank]) 147 | num_valid_q += 1. 148 | 149 | # compute average precision 150 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 151 | num_rel = orig_cmc.sum() 152 | tmp_cmc = orig_cmc.cumsum() 153 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 154 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 155 | AP = tmp_cmc.sum() / num_rel 156 | all_AP.append(AP) 157 | 158 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 159 | 160 | all_cmc = np.asarray(all_cmc).astype(np.float32) 161 | all_cmc = all_cmc.sum(0) / num_valid_q # standard CMC 162 | 163 | new_all_cmc = np.asarray(new_all_cmc).astype(np.float32) 164 | new_all_cmc = new_all_cmc.sum(0) / num_valid_q 165 | mAP = np.mean(all_AP) 166 | mINP = np.mean(all_INP) 167 | return new_all_cmc, mAP, mINP 168 | 169 | 170 | 171 | def eval_regdb(distmat, q_pids, g_pids, max_rank = 20): 172 | num_q, num_g = distmat.shape 173 | if num_g < max_rank: 174 | max_rank = num_g 175 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 176 | indices = np.argsort(distmat, axis=1) 177 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 178 | 179 | # compute cmc curve for each query 180 | all_cmc = [] 181 | all_AP = [] 182 | all_INP = [] 183 | num_valid_q = 0. # number of valid query 184 | 185 | # only two cameras 186 | q_camids = np.ones(num_q).astype(np.int32) 187 | g_camids = 2* np.ones(num_g).astype(np.int32) 188 | 189 | for q_idx in range(num_q): 190 | # get query pid and camid 191 | q_pid = q_pids[q_idx] 192 | q_camid = q_camids[q_idx] 193 | 194 | # remove gallery samples that have the same pid and camid with query 195 | order = indices[q_idx] 196 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 197 | keep = np.invert(remove) 198 | 199 | # compute cmc curve 200 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 201 | if not np.any(raw_cmc): 202 | # this condition is true when query identity does not appear in gallery 203 | continue 204 | 205 | cmc = raw_cmc.cumsum() 206 | 207 | # compute mINP 208 | # refernece Deep Learning for Person Re-identification: A Survey and Outlook 209 | pos_idx = np.where(raw_cmc == 1) 210 | pos_max_idx = np.max(pos_idx) 211 | inp = cmc[pos_max_idx]/ (pos_max_idx + 1.0) 212 | all_INP.append(inp) 213 | 214 | cmc[cmc > 1] = 1 215 | 216 | all_cmc.append(cmc[:max_rank]) 217 | num_valid_q += 1. 218 | 219 | # compute average precision 220 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 221 | num_rel = raw_cmc.sum() 222 | tmp_cmc = raw_cmc.cumsum() 223 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 224 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 225 | AP = tmp_cmc.sum() / num_rel 226 | all_AP.append(AP) 227 | 228 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 229 | 230 | all_cmc = np.asarray(all_cmc).astype(np.float32) 231 | all_cmc = all_cmc.sum(0) / num_valid_q 232 | mAP = np.mean(all_AP) 233 | mINP = np.mean(all_INP) 234 | return all_cmc, mAP, mINP -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | 8 | 9 | class OriTripletLoss(nn.Module): 10 | """Triplet loss with hard positive/negative mining. 11 | 12 | Reference: 13 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 14 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. 15 | 16 | Args: 17 | - margin (float): margin for triplet. 18 | """ 19 | 20 | def __init__(self, batch_size, margin=0.3): 21 | super(OriTripletLoss, self).__init__() 22 | self.margin = margin 23 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 24 | 25 | def forward(self, inputs, targets): 26 | """ 27 | Args: 28 | - inputs: feature matrix with shape (batch_size, feat_dim) 29 | - targets: ground truth labels with shape (num_classes) 30 | """ 31 | n = inputs.size(0) 32 | 33 | # Compute pairwise distance, replace by the official when merged 34 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 35 | dist = dist + dist.t() 36 | dist.addmm_(1, -2, inputs, inputs.t()) 37 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 38 | 39 | # For each anchor, find the hardest positive and negative 40 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 41 | dist_ap, dist_an = [], [] 42 | for i in range(n): 43 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) 44 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) 45 | dist_ap = torch.cat(dist_ap) 46 | dist_an = torch.cat(dist_an) 47 | 48 | # Compute ranking hinge loss 49 | y = torch.ones_like(dist_an) 50 | loss = self.ranking_loss(dist_an, dist_ap, y) 51 | 52 | # compute accuracy 53 | correct = torch.ge(dist_an, dist_ap).sum().item() 54 | return loss, correct 55 | 56 | # Adaptive weights 57 | 58 | 59 | def softmax_weights(dist, mask): 60 | max_v = torch.max(dist * mask, dim=1, keepdim=True)[0] 61 | diff = dist - max_v 62 | Z = torch.sum(torch.exp(diff) * mask, dim=1, keepdim=True) + 1e-6 # avoid division by zero 63 | W = torch.exp(diff) * mask / Z 64 | return W 65 | 66 | 67 | def normalize(x, axis=-1): 68 | """Normalizing to unit length along the specified dimension. 69 | Args: 70 | x: pytorch Variable 71 | Returns: 72 | x: pytorch Variable, same shape as input 73 | """ 74 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 75 | return x 76 | 77 | 78 | class TripletLoss_WRT(nn.Module): 79 | """Weighted Regularized Triplet'.""" 80 | 81 | def __init__(self): 82 | super(TripletLoss_WRT, self).__init__() 83 | self.ranking_loss = nn.SoftMarginLoss() 84 | 85 | def forward(self, inputs, targets, normalize_feature=False): 86 | if normalize_feature: 87 | inputs = normalize(inputs, axis=-1) 88 | dist_mat = pdist_torch(inputs, inputs) 89 | 90 | N = dist_mat.size(0) 91 | # shape [N, N] 92 | is_pos = targets.expand(N, N).eq(targets.expand(N, N).t()).float() 93 | is_neg = targets.expand(N, N).ne(targets.expand(N, N).t()).float() 94 | 95 | # `dist_ap` means distance(anchor, positive) 96 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 97 | dist_ap = dist_mat * is_pos 98 | dist_an = dist_mat * is_neg 99 | 100 | weights_ap = softmax_weights(dist_ap, is_pos) 101 | weights_an = softmax_weights(-dist_an, is_neg) 102 | furthest_positive = torch.sum(dist_ap * weights_ap, dim=1) 103 | closest_negative = torch.sum(dist_an * weights_an, dim=1) 104 | 105 | y = furthest_positive.new().resize_as_(furthest_positive).fill_(1) 106 | loss = self.ranking_loss(closest_negative - furthest_positive, y) 107 | 108 | # compute accuracy 109 | correct = torch.ge(closest_negative, furthest_positive).sum().item() 110 | return loss, correct 111 | 112 | 113 | class TripletLoss_ADP(nn.Module): 114 | """Weighted Regularized Triplet'.""" 115 | 116 | def __init__(self, alpha=1, gamma=1, square=0): 117 | super(TripletLoss_ADP, self).__init__() 118 | self.ranking_loss = nn.SoftMarginLoss() 119 | self.alpha = alpha 120 | self.gamma = gamma 121 | self.square = square 122 | 123 | def forward(self, inputs, targets, normalize_feature=False): 124 | if normalize_feature: 125 | inputs = normalize(inputs, axis=-1) 126 | dist_mat = pdist_torch(inputs, inputs) 127 | 128 | N = dist_mat.size(0) 129 | # shape [N, N] 130 | is_pos = targets.expand(N, N).eq(targets.expand(N, N).t()).float() 131 | is_neg = targets.expand(N, N).ne(targets.expand(N, N).t()).float() 132 | 133 | # `dist_ap` means distance(anchor, positive) 134 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 135 | dist_ap = dist_mat * is_pos 136 | dist_an = dist_mat * is_neg 137 | 138 | weights_ap = softmax_weights(dist_ap * self.alpha, is_pos) 139 | weights_an = softmax_weights(-dist_an * self.alpha, is_neg) 140 | furthest_positive = torch.sum(dist_ap * weights_ap, dim=1) 141 | closest_negative = torch.sum(dist_an * weights_an, dim=1) 142 | 143 | # ranking_loss = nn.SoftMarginLoss(reduction = 'none') 144 | # loss1 = ranking_loss(closest_negative - furthest_positive, y) 145 | 146 | # squared difference 147 | if self.square == 0: 148 | y = furthest_positive.new().resize_as_(furthest_positive).fill_(1) 149 | loss = self.ranking_loss(self.gamma * (closest_negative - furthest_positive), y) 150 | else: 151 | diff_pow = torch.pow(furthest_positive - closest_negative, 2) * self.gamma 152 | diff_pow = torch.clamp_max(diff_pow, max=44) 153 | 154 | # Compute ranking hinge loss 155 | y1 = (furthest_positive > closest_negative).float() 156 | y2 = y1 - 1 157 | y = -(y1 + y2) 158 | 159 | loss = self.ranking_loss(diff_pow, y) 160 | 161 | # loss = self.ranking_loss(self.gamma*(closest_negative - furthest_positive), y) 162 | 163 | # compute accuracy 164 | correct = torch.ge(closest_negative, furthest_positive).sum().item() 165 | return loss, correct 166 | 167 | 168 | class KLDivLoss(nn.Module): 169 | def __init__(self): 170 | super(KLDivLoss, self).__init__() 171 | 172 | def forward(self, pred, label): 173 | # pred: 2D matrix (batch_size, num_classes) 174 | # label: 1D vector indicating class number 175 | T = 3 176 | 177 | predict = F.log_softmax(pred / T, dim=1) 178 | target_data = F.softmax(label / T, dim=1) 179 | target_data = target_data + 10 ** (-7) 180 | target = Variable(target_data.data.cuda(), requires_grad=False) 181 | loss = T * T * ((target * (target.log() - predict)).sum(1).sum() / target.size()[0]) 182 | return loss 183 | 184 | 185 | def pdist_torch(emb1, emb2): 186 | ''' 187 | compute the eucilidean distance matrix between embeddings1 and embeddings2 188 | using gpu 189 | ''' 190 | m, n = emb1.shape[0], emb2.shape[0] 191 | emb1_pow = torch.pow(emb1, 2).sum(dim=1, keepdim=True).expand(m, n) 192 | emb2_pow = torch.pow(emb2, 2).sum(dim=1, keepdim=True).expand(n, m).t() 193 | dist_mtx = emb1_pow + emb2_pow 194 | dist_mtx = dist_mtx.addmm_(1, -2, emb1, emb2.t()) 195 | # dist_mtx = dist_mtx.clamp(min = 1e-12) 196 | dist_mtx = dist_mtx.clamp(min=1e-12).sqrt() 197 | return dist_mtx 198 | 199 | 200 | def pdist_np(emb1, emb2): 201 | ''' 202 | compute the eucilidean distance matrix between embeddings1 and embeddings2 203 | using cpu 204 | ''' 205 | m, n = emb1.shape[0], emb2.shape[0] 206 | emb1_pow = np.square(emb1).sum(axis=1)[..., np.newaxis] 207 | emb2_pow = np.square(emb2).sum(axis=1)[np.newaxis, ...] 208 | dist_mtx = -2 * np.matmul(emb1, emb2.T) + emb1_pow + emb2_pow 209 | # dist_mtx = np.sqrt(dist_mtx.clip(min = 1e-12)) 210 | return dist_mtx 211 | 212 | 213 | 214 | 215 | class MRIC(nn.Module): 216 | def __init__(self): 217 | super(MRIC, self).__init__() 218 | self.adaptive = nn.Softmax(dim=0) 219 | 220 | def Adaptive_Identity(self, x): 221 | sim = torch.sum(torch.mm(x, x.T), dim=1) 222 | return self.adaptive(sim).view(4, -1) 223 | 224 | def forward(self, x1, x2): 225 | x1 = F.normalize(x1) 226 | x2 = F.normalize(x2) 227 | b = x2.shape[0] // 4 228 | listx1 = [] 229 | listx2 = [] 230 | for i in range(b): 231 | listx1.append( 232 | torch.sum(self.Adaptive_Identity(x1[i * 4:(i + 1) * 4]) * x1[i * 4:(i + 1) * 4], dim=0).view(1, -1)) 233 | listx2.append( 234 | torch.sum(self.Adaptive_Identity(x2[i * 4:(i + 1) * 4]) * x2[i * 4:(i + 1) * 4], dim=0).view(1, -1)) 235 | 236 | x1 = torch.cat(listx1, dim=0) 237 | x2 = torch.cat(listx2, dim=0) 238 | 239 | x1 = F.normalize(x1) 240 | x2 = F.normalize(x2) 241 | center_loss = ((x1 - x2).norm(dim=1, keepdim=True)).mean() 242 | 243 | sim = torch.mm(x1, x2.T) 244 | labels = torch.arange(b).cuda() 245 | loss_t = F.cross_entropy(sim, labels) 246 | loss_i = F.cross_entropy(sim.T, labels) 247 | return (loss_t + loss_i) + center_loss, x1, x2 248 | 249 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from resnet import resnet50, resnet18 5 | 6 | class Normalize(nn.Module): 7 | def __init__(self, power=2): 8 | super(Normalize, self).__init__() 9 | self.power = power 10 | 11 | def forward(self, x): 12 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 13 | out = x.div(norm) 14 | return out 15 | 16 | class Non_local(nn.Module): 17 | def __init__(self, in_channels, reduc_ratio=2): 18 | super(Non_local, self).__init__() 19 | 20 | self.in_channels = in_channels 21 | self.inter_channels = reduc_ratio//reduc_ratio 22 | 23 | self.g = nn.Sequential( 24 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, 25 | padding=0), 26 | ) 27 | 28 | self.W = nn.Sequential( 29 | nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, 30 | kernel_size=1, stride=1, padding=0), 31 | nn.BatchNorm2d(self.in_channels), 32 | ) 33 | nn.init.constant_(self.W[1].weight, 0.0) 34 | nn.init.constant_(self.W[1].bias, 0.0) 35 | 36 | 37 | 38 | self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 39 | kernel_size=1, stride=1, padding=0) 40 | 41 | self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 42 | kernel_size=1, stride=1, padding=0) 43 | 44 | def forward(self, x): 45 | ''' 46 | :param x: (b, c, t, h, w) 47 | :return: 48 | ''' 49 | 50 | batch_size = x.size(0) 51 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 52 | g_x = g_x.permute(0, 2, 1) 53 | 54 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 55 | theta_x = theta_x.permute(0, 2, 1) 56 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 57 | f = torch.matmul(theta_x, phi_x) 58 | N = f.size(-1) 59 | # f_div_C = torch.nn.functional.softmax(f, dim=-1) 60 | f_div_C = f / N 61 | 62 | y = torch.matmul(f_div_C, g_x) 63 | y = y.permute(0, 2, 1).contiguous() 64 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 65 | W_y = self.W(y) 66 | z = W_y + x 67 | 68 | return z 69 | 70 | 71 | # ##################################################################### 72 | def weights_init_kaiming(m): 73 | classname = m.__class__.__name__ 74 | # print(classname) 75 | if classname.find('Conv') != -1: 76 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 77 | elif classname.find('Linear') != -1: 78 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 79 | init.zeros_(m.bias.data) 80 | elif classname.find('BatchNorm1d') != -1: 81 | init.normal_(m.weight.data, 1.0, 0.01) 82 | init.zeros_(m.bias.data) 83 | 84 | def weights_init_classifier(m): 85 | classname = m.__class__.__name__ 86 | if classname.find('Linear') != -1: 87 | init.normal_(m.weight.data, 0, 0.001) 88 | if m.bias: 89 | init.zeros_(m.bias.data) 90 | 91 | 92 | 93 | class visible_module(nn.Module): 94 | def __init__(self, arch='resnet50'): 95 | super(visible_module, self).__init__() 96 | 97 | model_v = resnet50(pretrained=True, 98 | last_conv_stride=1, last_conv_dilation=1) 99 | # avg pooling to global pooling 100 | self.visible = model_v 101 | 102 | def forward(self, x): 103 | x = self.visible.conv1(x) 104 | x = self.visible.bn1(x) 105 | x = self.visible.relu(x) 106 | x = self.visible.maxpool(x) 107 | return x 108 | 109 | 110 | class thermal_module(nn.Module): 111 | def __init__(self, arch='resnet50'): 112 | super(thermal_module, self).__init__() 113 | 114 | model_t = resnet50(pretrained=True, 115 | last_conv_stride=1, last_conv_dilation=1) 116 | # avg pooling to global pooling 117 | self.thermal = model_t 118 | 119 | def forward(self, x): 120 | x = self.thermal.conv1(x) 121 | x = self.thermal.bn1(x) 122 | x = self.thermal.relu(x) 123 | x = self.thermal.maxpool(x) 124 | return x 125 | 126 | 127 | class base_resnet(nn.Module): 128 | def __init__(self, arch='resnet50'): 129 | super(base_resnet, self).__init__() 130 | 131 | model_base = resnet50(pretrained=True, 132 | last_conv_stride=1, last_conv_dilation=1) 133 | # avg pooling to global pooling 134 | model_base.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 135 | self.base = model_base 136 | 137 | def forward(self, x): 138 | x = self.base.layer1(x) 139 | x = self.base.layer2(x) 140 | x = self.base.layer3(x) 141 | x = self.base.layer4(x) 142 | return x 143 | 144 | 145 | class embed_net(nn.Module): 146 | def __init__(self, class_num, no_local= 'on', gm_pool = 'on', arch='resnet50'): 147 | super(embed_net, self).__init__() 148 | 149 | self.thermal_module = thermal_module(arch=arch) 150 | self.visible_module = visible_module(arch=arch) 151 | self.base_resnet = base_resnet(arch=arch) 152 | self.non_local = no_local 153 | if self.non_local =='on': 154 | layers=[3, 4, 6, 3] 155 | non_layers=[0,2,3,0] 156 | self.NL_1 = nn.ModuleList( 157 | [Non_local(256) for i in range(non_layers[0])]) 158 | self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])]) 159 | self.NL_2 = nn.ModuleList( 160 | [Non_local(512) for i in range(non_layers[1])]) 161 | self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])]) 162 | self.NL_3 = nn.ModuleList( 163 | [Non_local(1024) for i in range(non_layers[2])]) 164 | self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])]) 165 | self.NL_4 = nn.ModuleList( 166 | [Non_local(2048) for i in range(non_layers[3])]) 167 | self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])]) 168 | 169 | 170 | pool_dim = 2048 171 | self.l2norm = Normalize(2) 172 | self.bottleneck = nn.BatchNorm1d(pool_dim) 173 | self.bottleneck.bias.requires_grad_(False) # no shift 174 | 175 | self.classifier = nn.Linear(pool_dim, class_num, bias=False) 176 | 177 | self.bottleneck.apply(weights_init_kaiming) 178 | self.classifier.apply(weights_init_classifier) 179 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 180 | self.gm_pool = gm_pool 181 | 182 | def forward(self, x1, x2, modal=0): 183 | if modal == 0: 184 | x1 = self.visible_module(x1) 185 | x2 = self.thermal_module(x2) 186 | x = torch.cat((x1, x2), 0) 187 | elif modal == 1: 188 | x = self.visible_module(x1) 189 | elif modal == 2: 190 | x = self.thermal_module(x2) 191 | 192 | # shared block 193 | if self.non_local == 'on': 194 | NL1_counter = 0 195 | if len(self.NL_1_idx) == 0: self.NL_1_idx = [-1] 196 | for i in range(len(self.base_resnet.base.layer1)): 197 | x = self.base_resnet.base.layer1[i](x) 198 | if i == self.NL_1_idx[NL1_counter]: 199 | _, C, H, W = x.shape 200 | x = self.NL_1[NL1_counter](x) 201 | NL1_counter += 1 202 | # Layer 2 203 | NL2_counter = 0 204 | if len(self.NL_2_idx) == 0: self.NL_2_idx = [-1] 205 | for i in range(len(self.base_resnet.base.layer2)): 206 | x = self.base_resnet.base.layer2[i](x) 207 | if i == self.NL_2_idx[NL2_counter]: 208 | _, C, H, W = x.shape 209 | x = self.NL_2[NL2_counter](x) 210 | NL2_counter += 1 211 | # Layer 3 212 | NL3_counter = 0 213 | if len(self.NL_3_idx) == 0: self.NL_3_idx = [-1] 214 | for i in range(len(self.base_resnet.base.layer3)): 215 | x = self.base_resnet.base.layer3[i](x) 216 | if i == self.NL_3_idx[NL3_counter]: 217 | _, C, H, W = x.shape 218 | x = self.NL_3[NL3_counter](x) 219 | NL3_counter += 1 220 | # Layer 4 221 | NL4_counter = 0 222 | if len(self.NL_4_idx) == 0: self.NL_4_idx = [-1] 223 | for i in range(len(self.base_resnet.base.layer4)): 224 | x = self.base_resnet.base.layer4[i](x) 225 | if i == self.NL_4_idx[NL4_counter]: 226 | _, C, H, W = x.shape 227 | x = self.NL_4[NL4_counter](x) 228 | NL4_counter += 1 229 | else: 230 | x = self.base_resnet(x) 231 | if self.gm_pool == 'on': 232 | b, c, h, w = x.shape 233 | x = x.view(b, c, -1) 234 | p = 3.0 235 | x_pool = (torch.mean(x**p, dim=-1) + 1e-12)**(1/p) 236 | else: 237 | x_pool = self.avgpool(x) 238 | x_pool = x_pool.view(x_pool.size(0), x_pool.size(1)) 239 | 240 | feat = self.bottleneck(x_pool) 241 | 242 | if self.training: 243 | return x_pool, self.classifier(feat) 244 | else: 245 | return self.l2norm(x_pool), self.l2norm(feat) 246 | 247 | 248 | class embed_net_vis(nn.Module): 249 | def __init__(self, class_num, no_local= 'on', gm_pool = 'on', arch='resnet50'): 250 | super(embed_net_vis, self).__init__() 251 | 252 | self.thermal_module = thermal_module(arch=arch) 253 | self.visible_module = visible_module(arch=arch) 254 | self.base_resnet = base_resnet(arch=arch) 255 | self.non_local = no_local 256 | if self.non_local =='on': 257 | layers=[3, 4, 6, 3] 258 | non_layers=[0,2,3,0] 259 | self.NL_1 = nn.ModuleList( 260 | [Non_local(256) for i in range(non_layers[0])]) 261 | self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])]) 262 | self.NL_2 = nn.ModuleList( 263 | [Non_local(512) for i in range(non_layers[1])]) 264 | self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])]) 265 | self.NL_3 = nn.ModuleList( 266 | [Non_local(1024) for i in range(non_layers[2])]) 267 | self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])]) 268 | self.NL_4 = nn.ModuleList( 269 | [Non_local(2048) for i in range(non_layers[3])]) 270 | self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])]) 271 | 272 | 273 | pool_dim = 2048 274 | self.l2norm = Normalize(2) 275 | self.bottleneck = nn.BatchNorm1d(pool_dim) 276 | self.bottleneck.bias.requires_grad_(False) # no shift 277 | 278 | self.classifier = nn.Linear(pool_dim, class_num, bias=False) 279 | 280 | self.bottleneck.apply(weights_init_kaiming) 281 | self.classifier.apply(weights_init_classifier) 282 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 283 | self.gm_pool = gm_pool 284 | 285 | def forward(self, x1): 286 | 287 | x = self.thermal_module(x1) 288 | 289 | # shared block 290 | if self.non_local == 'on': 291 | NL1_counter = 0 292 | if len(self.NL_1_idx) == 0: self.NL_1_idx = [-1] 293 | for i in range(len(self.base_resnet.base.layer1)): 294 | x = self.base_resnet.base.layer1[i](x) 295 | if i == self.NL_1_idx[NL1_counter]: 296 | _, C, H, W = x.shape 297 | x = self.NL_1[NL1_counter](x) 298 | NL1_counter += 1 299 | # Layer 2 300 | NL2_counter = 0 301 | if len(self.NL_2_idx) == 0: self.NL_2_idx = [-1] 302 | for i in range(len(self.base_resnet.base.layer2)): 303 | x = self.base_resnet.base.layer2[i](x) 304 | if i == self.NL_2_idx[NL2_counter]: 305 | _, C, H, W = x.shape 306 | x = self.NL_2[NL2_counter](x) 307 | NL2_counter += 1 308 | # Layer 3 309 | NL3_counter = 0 310 | if len(self.NL_3_idx) == 0: self.NL_3_idx = [-1] 311 | for i in range(len(self.base_resnet.base.layer3)): 312 | x = self.base_resnet.base.layer3[i](x) 313 | if i == self.NL_3_idx[NL3_counter]: 314 | _, C, H, W = x.shape 315 | x = self.NL_3[NL3_counter](x) 316 | NL3_counter += 1 317 | # Layer 4 318 | NL4_counter = 0 319 | if len(self.NL_4_idx) == 0: self.NL_4_idx = [-1] 320 | for i in range(len(self.base_resnet.base.layer4)): 321 | x = self.base_resnet.base.layer4[i](x) 322 | if i == self.NL_4_idx[NL4_counter]: 323 | _, C, H, W = x.shape 324 | x = self.NL_4[NL4_counter](x) 325 | NL4_counter += 1 326 | else: 327 | x = self.base_resnet(x) 328 | if self.gm_pool == 'on': 329 | b, c, h, w = x.shape 330 | x = x.view(b, c, -1) 331 | p = 3.0 332 | x_pool = (torch.mean(x**p, dim=-1) + 1e-12)**(1/p) 333 | else: 334 | x_pool = self.avgpool(x) 335 | x_pool = x_pool.view(x_pool.size(0), x_pool.size(1)) 336 | 337 | feat = self.bottleneck(x_pool) 338 | 339 | if self.training: 340 | return x_pool, self.classifier(feat) 341 | else: 342 | return self.l2norm(x_pool) 343 | -------------------------------------------------------------------------------- /model_sle.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | import sys,os 4 | sys.path.append(os.path.dirname(__file__) + os.sep + '../') 5 | sys.path.append("..") 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import init 9 | import torch.nn.functional as F 10 | from resnet import resnet50 11 | from Transformer import ViT 12 | 13 | class Normalize(nn.Module): 14 | def __init__(self, power=2): 15 | super(Normalize, self).__init__() 16 | self.power = power 17 | 18 | def forward(self, x): 19 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 20 | out = x.div(norm) 21 | return out 22 | 23 | 24 | class Bottleneck12(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None): 28 | super(Bottleneck12, self).__init__() 29 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 32 | padding=1, bias=False) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.conv3 = nn.Conv2d(planes, planes, kernel_size=1, bias=False) 35 | self.bn3 = nn.BatchNorm2d(planes) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.downsample = downsample 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | residual = x 42 | 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | out = self.relu(out) 50 | 51 | out = self.conv3(out) 52 | out = self.bn3(out) 53 | 54 | if self.downsample is not None: 55 | residual = self.downsample(x) 56 | 57 | out += residual 58 | out = self.relu(out) 59 | 60 | return out 61 | 62 | 63 | class Non_local(nn.Module): 64 | def __init__(self, in_channels, reduc_ratio=2): 65 | super(Non_local, self).__init__() 66 | 67 | self.in_channels = in_channels 68 | self.inter_channels = reduc_ratio // reduc_ratio 69 | 70 | self.g = nn.Sequential( 71 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, 72 | padding=0), 73 | ) 74 | 75 | self.W = nn.Sequential( 76 | nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, 77 | kernel_size=1, stride=1, padding=0), 78 | nn.BatchNorm2d(self.in_channels), 79 | ) 80 | nn.init.constant_(self.W[1].weight, 0.0) 81 | nn.init.constant_(self.W[1].bias, 0.0) 82 | 83 | self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 84 | kernel_size=1, stride=1, padding=0) 85 | 86 | self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 87 | kernel_size=1, stride=1, padding=0) 88 | 89 | def forward(self, x): 90 | ''' 91 | :param x: (b, c, t, h, w) 92 | :return: 93 | ''' 94 | 95 | batch_size = x.size(0) 96 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 97 | g_x = g_x.permute(0, 2, 1) 98 | 99 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 100 | theta_x = theta_x.permute(0, 2, 1) 101 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 102 | f = torch.matmul(theta_x, phi_x) 103 | N = f.size(-1) 104 | # f_div_C = torch.nn.functional.softmax(f, dim=-1) 105 | f_div_C = f / N 106 | 107 | y = torch.matmul(f_div_C, g_x) 108 | y = y.permute(0, 2, 1).contiguous() 109 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 110 | W_y = self.W(y) 111 | z = W_y + x 112 | 113 | return z 114 | 115 | 116 | # ##################################################################### 117 | def weights_init_kaiming(m): 118 | classname = m.__class__.__name__ 119 | # print(classname) 120 | if classname.find('Conv') != -1: 121 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 122 | elif classname.find('Linear') != -1: 123 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 124 | init.zeros_(m.bias.data) 125 | elif classname.find('BatchNorm1d') != -1: 126 | init.normal_(m.weight.data, 1.0, 0.01) 127 | init.zeros_(m.bias.data) 128 | 129 | 130 | def weights_init_classifier(m): 131 | classname = m.__class__.__name__ 132 | if classname.find('Linear') != -1: 133 | init.normal_(m.weight.data, 0, 0.001) 134 | if m.bias: 135 | init.zeros_(m.bias.data) 136 | 137 | 138 | class visible_module(nn.Module): 139 | def __init__(self, arch='resnet50'): 140 | super(visible_module, self).__init__() 141 | 142 | model_v = resnet50(pretrained=True, 143 | last_conv_stride=1, last_conv_dilation=1) 144 | # avg pooling to global pooling 145 | self.visible = model_v 146 | 147 | def forward(self, x): 148 | x = self.visible.conv1(x) 149 | x = self.visible.bn1(x) 150 | x = self.visible.relu(x) 151 | x = self.visible.maxpool(x) 152 | return x 153 | 154 | 155 | class thermal_module(nn.Module): 156 | def __init__(self, arch='resnet50'): 157 | super(thermal_module, self).__init__() 158 | 159 | model_t = resnet50(pretrained=True, 160 | last_conv_stride=1, last_conv_dilation=1) 161 | # avg pooling to global pooling 162 | self.thermal = model_t 163 | 164 | def forward(self, x): 165 | x = self.thermal.conv1(x) 166 | x = self.thermal.bn1(x) 167 | x = self.thermal.relu(x) 168 | x = self.thermal.maxpool(x) 169 | return x 170 | 171 | 172 | class base_resnet(nn.Module): 173 | def __init__(self, arch='resnet50'): 174 | super(base_resnet, self).__init__() 175 | 176 | model_base = resnet50(pretrained=True, 177 | last_conv_stride=1, last_conv_dilation=1) 178 | # avg pooling to global pooling 179 | model_base.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 180 | self.base = model_base 181 | 182 | def forward(self, x): 183 | x = self.base.layer1(x) 184 | x = self.base.layer2(x) 185 | x = self.base.layer3(x) 186 | x = self.base.layer4(x) 187 | return x 188 | 189 | 190 | class embed_net(nn.Module): 191 | def __init__(self, class_num, no_local='on', gm_pool='on', arch='resnet50', dataset="sysu", plearn=0,stage=-1, depth=-1, head=-1): 192 | super(embed_net, self).__init__() 193 | self.thermal_module = thermal_module(arch=arch) 194 | self.visible_module = visible_module(arch=arch) 195 | self.base_resnet = base_resnet(arch=arch) 196 | self.non_local = no_local 197 | if self.non_local == 'on': 198 | layers = [3, 4, 6, 3] 199 | non_layers = [0, 2, 3, 0] 200 | self.NL_1 = nn.ModuleList( 201 | [Non_local(256) for i in range(non_layers[0])]) 202 | self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])]) 203 | self.NL_2 = nn.ModuleList( 204 | [Non_local(512) for i in range(non_layers[1])]) 205 | self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])]) 206 | self.NL_3 = nn.ModuleList( 207 | [Non_local(1024) for i in range(non_layers[2])]) 208 | self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])]) 209 | self.NL_4 = nn.ModuleList( 210 | [Non_local(2048) for i in range(non_layers[3])]) 211 | self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])]) 212 | 213 | pool_dim = 2048 214 | self.l2norm = Normalize(2) 215 | self.bottleneck = nn.BatchNorm1d(pool_dim) 216 | self.bottleneck.bias.requires_grad_(False) # no shift 217 | 218 | self.classifier = nn.Linear(pool_dim, class_num, bias=False) 219 | 220 | self.bottleneck.apply(weights_init_kaiming) 221 | self.classifier.apply(weights_init_classifier) 222 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 223 | self.gm_pool = gm_pool 224 | 225 | self.num_stripes = 6 226 | local_conv_out_channels = 256 227 | 228 | self.local_conv_list = nn.ModuleList() 229 | for _ in range(self.num_stripes): 230 | conv = nn.Conv2d(pool_dim, local_conv_out_channels, 1) 231 | conv.apply(weights_init_kaiming) 232 | self.local_conv_list.append(nn.Sequential( 233 | conv, 234 | nn.BatchNorm2d(local_conv_out_channels), 235 | nn.ReLU(inplace=True) 236 | )) 237 | 238 | self.fc_list = nn.ModuleList() 239 | for _ in range(self.num_stripes): 240 | fc = nn.Linear(local_conv_out_channels, class_num) 241 | init.normal_(fc.weight, std=0.001) 242 | init.constant_(fc.bias, 0) 243 | self.fc_list.append(fc) 244 | if plearn == 1: # better or worse 245 | self.p = nn.Parameter(torch.ones(1) * 3.0) 246 | if dataset == 'sysu': 247 | self.p1 = nn.Parameter(torch.ones(1) * 3.0) 248 | else: 249 | self.p1 = nn.Parameter(torch.ones(1) * 10.0) 250 | elif plearn == 0: 251 | self.p = 3.0 252 | if dataset == 'sysu': 253 | self.p1 = 3.0 254 | else: 255 | self.p1 = 10.0 256 | self.stage = stage 257 | self.depth = depth 258 | self.head = head 259 | self.cnn23 = Bottleneck12(inplanes=1024, planes=1024) 260 | if self.stage == 23: 261 | self.vit = ViT(img_size=18 * 9, embed_dim=1024, depth=self.depth, num_heads=self.head) 262 | 263 | 264 | 265 | def forward(self, x1, x2, modal=0): 266 | if modal == 0: 267 | x1 = self.visible_module(x1) 268 | x2 = self.thermal_module(x2) 269 | x = torch.cat((x1, x2), 0) 270 | 271 | elif modal == 1: 272 | x = self.visible_module(x1) 273 | elif modal == 2: 274 | x = self.thermal_module(x2) 275 | 276 | # shared block 277 | if self.non_local == 'on': 278 | NL1_counter = 0 279 | if len(self.NL_1_idx) == 0: self.NL_1_idx = [-1] 280 | for i in range(len(self.base_resnet.base.layer1)): 281 | x = self.base_resnet.base.layer1[i](x) 282 | if i == self.NL_1_idx[NL1_counter]: 283 | _, C, H, W = x.shape 284 | x = self.NL_1[NL1_counter](x) 285 | NL1_counter += 1 286 | # Layer 2 287 | NL2_counter = 0 288 | if len(self.NL_2_idx) == 0: self.NL_2_idx = [-1] 289 | for i in range(len(self.base_resnet.base.layer2)): 290 | x = self.base_resnet.base.layer2[i](x) 291 | if i == self.NL_2_idx[NL2_counter]: 292 | _, C, H, W = x.shape 293 | x = self.NL_2[NL2_counter](x) 294 | NL2_counter += 1 295 | # Layer 3 296 | NL3_counter = 0 297 | if len(self.NL_3_idx) == 0: self.NL_3_idx = [-1] 298 | for i in range(len(self.base_resnet.base.layer3)): 299 | x = self.base_resnet.base.layer3[i](x) 300 | if i == self.NL_3_idx[NL3_counter]: 301 | _, C, H, W = x.shape 302 | x = self.NL_3[NL3_counter](x) 303 | NL3_counter += 1 304 | # sle 305 | out_23 = x 306 | x = self.cnn23(x) 307 | if self.stage == 23 and self.training: 308 | out_23_shape = out_23.shape[0]//3 309 | temp = torch.cat((out_23[0:out_23_shape], out_23[2*out_23_shape:3*out_23_shape]), dim=0) 310 | x_vit = self.vit(temp) 311 | 312 | x = torch.cat((x, x_vit), dim=0) 313 | # end sle 314 | # Layer 4 315 | NL4_counter = 0 316 | if len(self.NL_4_idx) == 0: self.NL_4_idx = [-1] 317 | for i in range(len(self.base_resnet.base.layer4)): 318 | x = self.base_resnet.base.layer4[i](x) 319 | if i == self.NL_4_idx[NL4_counter]: 320 | _, C, H, W = x.shape 321 | x = self.NL_4[NL4_counter](x) 322 | NL4_counter += 1 323 | else: 324 | pass 325 | 326 | # for partial feature 327 | feat = x 328 | assert feat.size(2) % self.num_stripes == 0 329 | stripe_h = int(feat.size(2) / self.num_stripes) 330 | local_feat_list = [] 331 | logits_list = [] 332 | for i in range(self.num_stripes): 333 | if self.gm_pool == 'on': 334 | # gm pool 335 | local_feat = feat[:, :, i * stripe_h: (i + 1) * stripe_h, :] 336 | b, c, h, w = local_feat.shape 337 | local_feat = local_feat.view(b, c, -1) 338 | local_feat = (torch.mean(local_feat ** self.p1, dim=-1) + 1e-12) ** (1 / self.p1) 339 | else: 340 | 341 | local_feat = F.max_pool2d(feat[:, :, i * stripe_h: (i + 1) * stripe_h, :], 342 | (stripe_h, feat.size(-1))) 343 | local_feat = self.local_conv_list[i](local_feat.view(feat.size(0), feat.size(1), 1, 1)) 344 | 345 | # shape [N, c] 346 | local_feat = local_feat.view(local_feat.size(0), -1) 347 | local_feat_list.append(local_feat) 348 | 349 | if hasattr(self, 'fc_list'): 350 | logits_list.append(self.fc_list[i](local_feat)) 351 | 352 | feat_all = [lf for lf in local_feat_list] 353 | feat_all = torch.cat(feat_all, dim=1) 354 | 355 | # for global feature 356 | if self.gm_pool == 'on': 357 | b, c, h, w = x.shape 358 | x = x.view(b, c, -1) 359 | x_pool = (torch.mean(x ** self.p, dim=-1) + 1e-12) ** (1 / self.p) 360 | else: 361 | x_pool = self.avgpool(x) 362 | x_pool = x_pool.view(x_pool.size(0), x_pool.size(1)) 363 | feat = self.bottleneck(x_pool) 364 | if self.training: 365 | return x_pool, self.classifier(feat), local_feat_list, logits_list, feat_all 366 | else: 367 | x_pool_1 = torch.cat((x_pool, feat_all), dim=1) 368 | feat_1 = torch.cat((feat, feat_all), dim=1) 369 | return self.l2norm(x_pool_1), self.l2norm(feat_1) 370 | -------------------------------------------------------------------------------- /model_sle_hsl.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | import sys,os 4 | sys.path.append(os.path.dirname(__file__) + os.sep + '../') 5 | sys.path.append("..") 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import init 9 | import torch.nn.functional as F 10 | from resnet import resnet50 11 | from Transformer import ViT 12 | from HyperGraphs import HypergraphConv 13 | import whitening 14 | 15 | class whitening_scale_shift(nn.Module): 16 | def __init__(self, planes, group_size, affine=True): 17 | super(whitening_scale_shift, self).__init__() 18 | self.planes = planes 19 | self.group_size = group_size 20 | self.affine = affine 21 | 22 | self.wh = whitening.WTransform2d(self.planes, 23 | self.group_size) 24 | if self.affine: 25 | self.gamma = nn.Parameter(torch.ones(self.planes, 1, 1)) 26 | self.beta = nn.Parameter(torch.zeros(self.planes, 1, 1)) 27 | 28 | 29 | def forward(self, x): 30 | return self.wh(x) * self.gamma + self.beta + x 31 | 32 | class Normalize(nn.Module): 33 | def __init__(self, power=2): 34 | super(Normalize, self).__init__() 35 | self.power = power 36 | 37 | def forward(self, x): 38 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 39 | out = x.div(norm) 40 | return out 41 | 42 | class Bottleneck12(nn.Module): 43 | expansion = 1 44 | 45 | def __init__(self, inplanes, planes, stride=1, downsample=None): 46 | super(Bottleneck12, self).__init__() 47 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(planes) 49 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 50 | padding=1, bias=False) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | self.conv3 = nn.Conv2d(planes, planes, kernel_size=1, bias=False) 53 | self.bn3 = nn.BatchNorm2d(planes) 54 | self.relu = nn.ReLU(inplace=True) 55 | self.downsample = downsample 56 | self.stride = stride 57 | 58 | def forward(self, x): 59 | residual = x 60 | 61 | out = self.conv1(x) 62 | out = self.bn1(out) 63 | out = self.relu(out) 64 | 65 | out = self.conv2(out) 66 | out = self.bn2(out) 67 | out = self.relu(out) 68 | 69 | out = self.conv3(out) 70 | out = self.bn3(out) 71 | 72 | if self.downsample is not None: 73 | residual = self.downsample(x) 74 | 75 | out += residual 76 | out = self.relu(out) 77 | 78 | return out 79 | 80 | class Non_local(nn.Module): 81 | def __init__(self, in_channels, reduc_ratio=2): 82 | super(Non_local, self).__init__() 83 | 84 | self.in_channels = in_channels 85 | self.inter_channels = reduc_ratio // reduc_ratio 86 | 87 | self.g = nn.Sequential( 88 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, 89 | padding=0), 90 | ) 91 | 92 | self.W = nn.Sequential( 93 | nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, 94 | kernel_size=1, stride=1, padding=0), 95 | nn.BatchNorm2d(self.in_channels), 96 | ) 97 | nn.init.constant_(self.W[1].weight, 0.0) 98 | nn.init.constant_(self.W[1].bias, 0.0) 99 | 100 | self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 101 | kernel_size=1, stride=1, padding=0) 102 | 103 | self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 104 | kernel_size=1, stride=1, padding=0) 105 | 106 | def forward(self, x): 107 | ''' 108 | :param x: (b, c, t, h, w) 109 | :return: 110 | ''' 111 | 112 | batch_size = x.size(0) 113 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 114 | g_x = g_x.permute(0, 2, 1) 115 | 116 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 117 | theta_x = theta_x.permute(0, 2, 1) 118 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 119 | f = torch.matmul(theta_x, phi_x) 120 | N = f.size(-1) 121 | # f_div_C = torch.nn.functional.softmax(f, dim=-1) 122 | f_div_C = f / N 123 | 124 | y = torch.matmul(f_div_C, g_x) 125 | y = y.permute(0, 2, 1).contiguous() 126 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 127 | W_y = self.W(y) 128 | z = W_y + x 129 | 130 | return z 131 | 132 | 133 | # ##################################################################### 134 | def weights_init_kaiming(m): 135 | classname = m.__class__.__name__ 136 | # print(classname) 137 | if classname.find('Conv') != -1: 138 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 139 | elif classname.find('Linear') != -1: 140 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 141 | init.zeros_(m.bias.data) 142 | elif classname.find('BatchNorm1d') != -1: 143 | init.normal_(m.weight.data, 1.0, 0.01) 144 | init.zeros_(m.bias.data) 145 | 146 | 147 | def weights_init_classifier(m): 148 | classname = m.__class__.__name__ 149 | if classname.find('Linear') != -1: 150 | init.normal_(m.weight.data, 0, 0.001) 151 | if m.bias: 152 | init.zeros_(m.bias.data) 153 | 154 | 155 | class visible_module(nn.Module): 156 | def __init__(self, arch='resnet50'): 157 | super(visible_module, self).__init__() 158 | 159 | model_v = resnet50(pretrained=True, 160 | last_conv_stride=1, last_conv_dilation=1) 161 | # avg pooling to global pooling 162 | self.visible = model_v 163 | 164 | def forward(self, x): 165 | x = self.visible.conv1(x) 166 | x = self.visible.bn1(x) 167 | x = self.visible.relu(x) 168 | x = self.visible.maxpool(x) 169 | return x 170 | 171 | 172 | class thermal_module(nn.Module): 173 | def __init__(self, arch='resnet50'): 174 | super(thermal_module, self).__init__() 175 | 176 | model_t = resnet50(pretrained=True, 177 | last_conv_stride=1, last_conv_dilation=1) 178 | # avg pooling to global pooling 179 | self.thermal = model_t 180 | 181 | def forward(self, x): 182 | x = self.thermal.conv1(x) 183 | x = self.thermal.bn1(x) 184 | x = self.thermal.relu(x) 185 | x = self.thermal.maxpool(x) 186 | return x 187 | 188 | 189 | class base_resnet(nn.Module): 190 | def __init__(self, arch='resnet50'): 191 | super(base_resnet, self).__init__() 192 | 193 | model_base = resnet50(pretrained=True, 194 | last_conv_stride=1, last_conv_dilation=1) 195 | # avg pooling to global pooling 196 | model_base.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 197 | self.base = model_base 198 | 199 | def forward(self, x): 200 | x = self.base.layer1(x) 201 | x = self.base.layer2(x) 202 | x = self.base.layer3(x) 203 | x = self.base.layer4(x) 204 | return x 205 | 206 | 207 | class embed_net(nn.Module): 208 | def __init__(self, class_num, no_local='on', gm_pool='on', arch='resnet50', dataset="sysu", plearn=0, stage=23, depth=-1, head=-1, graphw=-1, theta1=0.0): 209 | super(embed_net, self).__init__() 210 | self.thermal_module = thermal_module(arch=arch) 211 | self.visible_module = visible_module(arch=arch) 212 | self.base_resnet = base_resnet(arch=arch) 213 | self.non_local = no_local 214 | if self.non_local == 'on': 215 | layers = [3, 4, 6, 3] 216 | non_layers = [0, 2, 3, 0] 217 | self.NL_1 = nn.ModuleList( 218 | [Non_local(256) for i in range(non_layers[0])]) 219 | self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])]) 220 | self.NL_2 = nn.ModuleList( 221 | [Non_local(512) for i in range(non_layers[1])]) 222 | self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])]) 223 | self.NL_3 = nn.ModuleList( 224 | [Non_local(1024) for i in range(non_layers[2])]) 225 | self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])]) 226 | self.NL_4 = nn.ModuleList( 227 | [Non_local(2048) for i in range(non_layers[3])]) 228 | self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])]) 229 | 230 | pool_dim = 2048 231 | self.l2norm = Normalize(2) 232 | self.bottleneck = nn.BatchNorm1d(pool_dim) 233 | self.bottleneck.bias.requires_grad_(False) # no shift 234 | 235 | self.classifier = nn.Linear(pool_dim, class_num, bias=False) 236 | 237 | self.bottleneck.apply(weights_init_kaiming) 238 | self.classifier.apply(weights_init_classifier) 239 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 240 | self.gm_pool = gm_pool 241 | 242 | self.num_stripes = 6 243 | local_conv_out_channels = 256 244 | 245 | self.local_conv_list = nn.ModuleList() 246 | for _ in range(self.num_stripes): 247 | conv = nn.Conv2d(pool_dim, local_conv_out_channels, 1) 248 | conv.apply(weights_init_kaiming) 249 | self.local_conv_list.append(nn.Sequential( 250 | conv, 251 | nn.BatchNorm2d(local_conv_out_channels), 252 | nn.ReLU(inplace=True) 253 | )) 254 | 255 | self.fc_list = nn.ModuleList() 256 | for _ in range(self.num_stripes): 257 | fc = nn.Linear(local_conv_out_channels, class_num) 258 | init.normal_(fc.weight, std=0.001) 259 | init.constant_(fc.bias, 0) 260 | self.fc_list.append(fc) 261 | if plearn == 1: # better or worse 262 | self.p = nn.Parameter(torch.ones(1) * 3.0) 263 | if dataset == 'sysu': 264 | self.p1 = nn.Parameter(torch.ones(1) * 3.0) 265 | else: 266 | self.p1 = nn.Parameter(torch.ones(1) * 10.0) 267 | elif plearn == 0: 268 | self.p = 3.0 269 | if dataset == 'sysu': 270 | self.p1 = 3.0 271 | else: 272 | self.p1 = 10.0 273 | self.stage = stage 274 | self.depth = depth 275 | self.head = head 276 | self.graphw = graphw 277 | self.theta1 = theta1 278 | self.cnn23 = Bottleneck12(inplanes=1024, planes=1024) 279 | if self.stage == 23: 280 | self.vit = ViT(img_size=18 * 9, embed_dim=1024, depth=self.depth, num_heads=self.head) 281 | self.hypergraph = HypergraphConv(theta1=self.theta1) 282 | 283 | self.whiten_o = whitening_scale_shift(1024, 1) 284 | 285 | 286 | def forward(self, x1, x2, modal=0): 287 | if modal == 0: 288 | x1 = self.visible_module(x1) 289 | x2 = self.thermal_module(x2) 290 | x = torch.cat((x1, x2), 0) 291 | 292 | elif modal == 1: 293 | x = self.visible_module(x1) 294 | elif modal == 2: 295 | x = self.thermal_module(x2) 296 | 297 | # shared block 298 | if self.non_local == 'on': 299 | NL1_counter = 0 300 | if len(self.NL_1_idx) == 0: self.NL_1_idx = [-1] 301 | for i in range(len(self.base_resnet.base.layer1)): 302 | x = self.base_resnet.base.layer1[i](x) 303 | if i == self.NL_1_idx[NL1_counter]: 304 | _, C, H, W = x.shape 305 | x = self.NL_1[NL1_counter](x) 306 | NL1_counter += 1 307 | # Layer 2 308 | NL2_counter = 0 309 | if len(self.NL_2_idx) == 0: self.NL_2_idx = [-1] 310 | for i in range(len(self.base_resnet.base.layer2)): 311 | x = self.base_resnet.base.layer2[i](x) 312 | if i == self.NL_2_idx[NL2_counter]: 313 | _, C, H, W = x.shape 314 | x = self.NL_2[NL2_counter](x) 315 | NL2_counter += 1 316 | # Layer 3 317 | NL3_counter = 0 318 | if len(self.NL_3_idx) == 0: self.NL_3_idx = [-1] 319 | for i in range(len(self.base_resnet.base.layer3)): 320 | x = self.base_resnet.base.layer3[i](x) 321 | if i == self.NL_3_idx[NL3_counter]: 322 | _, C, H, W = x.shape 323 | x = self.NL_3[NL3_counter](x) 324 | NL3_counter += 1 325 | out_23=x 326 | x = self.cnn23(x) 327 | if self.stage == 23 and self.training: 328 | out_23_shape = out_23.shape[0]//3 329 | temp = torch.cat((out_23[0:out_23_shape],out_23[2*out_23_shape:3*out_23_shape]), dim=0) 330 | x_vit = self.vit(temp) 331 | 332 | x = torch.cat((x, x_vit), dim=0) 333 | 334 | x = x + self.graphw * self.hypergraph(self.whiten_o(x)) 335 | 336 | # Layer 4 337 | NL4_counter = 0 338 | if len(self.NL_4_idx) == 0: self.NL_4_idx = [-1] 339 | for i in range(len(self.base_resnet.base.layer4)): 340 | x = self.base_resnet.base.layer4[i](x) 341 | if i == self.NL_4_idx[NL4_counter]: 342 | _, C, H, W = x.shape 343 | x = self.NL_4[NL4_counter](x) 344 | NL4_counter += 1 345 | else: 346 | pass 347 | 348 | feat = x 349 | assert feat.size(2) % self.num_stripes == 0 350 | stripe_h = int(feat.size(2) / self.num_stripes) 351 | local_feat_list = [] 352 | logits_list = [] 353 | for i in range(self.num_stripes): 354 | # shape [N, C, 1, 1] 355 | 356 | # average pool 357 | # local_feat = F.avg_pool2d(feat[:, :, i * stripe_h: (i + 1) * stripe_h, :],(stripe_h, feat.size(-1))) 358 | if self.gm_pool == 'on': 359 | # gm pool 360 | local_feat = feat[:, :, i * stripe_h: (i + 1) * stripe_h, :] 361 | b, c, h, w = local_feat.shape 362 | local_feat = local_feat.view(b, c, -1) 363 | local_feat = (torch.mean(local_feat ** self.p1, dim=-1) + 1e-12) ** (1 / self.p1) 364 | else: 365 | # average pool 366 | # local_feat = F.avg_pool2d(feat[:, :, i * stripe_h: (i + 1) * stripe_h, :],(stripe_h, feat.size(-1))) 367 | local_feat = F.max_pool2d(feat[:, :, i * stripe_h: (i + 1) * stripe_h, :], 368 | (stripe_h, feat.size(-1))) 369 | 370 | # shape [N, c, 1, 1] 371 | local_feat = self.local_conv_list[i](local_feat.view(feat.size(0), feat.size(1), 1, 1)) 372 | 373 | # shape [N, c] 374 | local_feat = local_feat.view(local_feat.size(0), -1) 375 | local_feat_list.append(local_feat) 376 | 377 | if hasattr(self, 'fc_list'): 378 | logits_list.append(self.fc_list[i](local_feat)) 379 | 380 | feat_all = [lf for lf in local_feat_list] 381 | feat_all = torch.cat(feat_all, dim=1) 382 | 383 | ## golable 384 | if self.gm_pool == 'on': 385 | b, c, h, w = x.shape 386 | x = x.view(b, c, -1) 387 | x_pool = (torch.mean(x ** self.p, dim=-1) + 1e-12) ** (1 / self.p) 388 | else: 389 | x_pool = self.avgpool(x) 390 | x_pool = x_pool.view(x_pool.size(0), x_pool.size(1)) 391 | feat = self.bottleneck(x_pool) 392 | if self.training: 393 | return x_pool, self.classifier(feat), local_feat_list, logits_list, feat_all 394 | else: 395 | x_pool_1 = torch.cat((x_pool, feat_all), dim=1) 396 | feat_1 = torch.cat((feat, feat_all), dim=1) 397 | return self.l2norm(x_pool_1), self.l2norm(feat_1) 398 | -------------------------------------------------------------------------------- /model_sle_hsl_cfl.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | 4 | import sys, os 5 | sys.path.append(os.path.dirname(__file__) + os.sep + '../') 6 | sys.path.append("..") 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn import init 11 | import torch.nn.functional as F 12 | from resnet import resnet50 13 | from Transformer import ViT 14 | from HyperGraphs import HypergraphConv 15 | from GAT import Graph_Attention_Union 16 | import whitening 17 | 18 | class whitening_scale_shift(nn.Module): 19 | def __init__(self, planes, group_size, affine=True): 20 | super(whitening_scale_shift, self).__init__() 21 | self.planes = planes 22 | self.group_size = group_size 23 | self.affine = affine 24 | 25 | self.wh = whitening.WTransform2d(self.planes, 26 | self.group_size) 27 | if self.affine: 28 | self.gamma = nn.Parameter(torch.ones(self.planes, 1, 1)) 29 | self.beta = nn.Parameter(torch.zeros(self.planes, 1, 1)) 30 | 31 | def forward(self, x): 32 | out = self.wh(x) 33 | if self.affine: 34 | out = out * self.gamma + self.beta + x 35 | return out 36 | 37 | class Bottleneck12(nn.Module): 38 | expansion = 1 39 | 40 | def __init__(self, inplanes, planes, stride=1, downsample=None): 41 | super(Bottleneck12, self).__init__() 42 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 43 | self.bn1 = nn.BatchNorm2d(planes) 44 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 45 | padding=1, bias=False) 46 | self.bn2 = nn.BatchNorm2d(planes) 47 | self.conv3 = nn.Conv2d(planes, planes, kernel_size=1, bias=False) 48 | self.bn3 = nn.BatchNorm2d(planes) 49 | self.relu = nn.ReLU(inplace=True) 50 | self.downsample = downsample 51 | self.stride = stride 52 | 53 | def forward(self, x): 54 | residual = x 55 | 56 | out = self.conv1(x) 57 | out = self.bn1(out) 58 | out = self.relu(out) 59 | 60 | out = self.conv2(out) 61 | out = self.bn2(out) 62 | out = self.relu(out) 63 | 64 | out = self.conv3(out) 65 | out = self.bn3(out) 66 | 67 | if self.downsample is not None: 68 | residual = self.downsample(x) 69 | 70 | out += residual 71 | out = self.relu(out) 72 | 73 | return out 74 | 75 | class Normalize(nn.Module): 76 | def __init__(self, power=2): 77 | super(Normalize, self).__init__() 78 | self.power = power 79 | 80 | def forward(self, x): 81 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 82 | out = x.div(norm) 83 | return out 84 | 85 | 86 | class Non_local(nn.Module): 87 | def __init__(self, in_channels, reduc_ratio=2): 88 | super(Non_local, self).__init__() 89 | 90 | self.in_channels = in_channels 91 | self.inter_channels = reduc_ratio // reduc_ratio 92 | 93 | self.g = nn.Sequential( 94 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, 95 | padding=0), 96 | ) 97 | 98 | self.W = nn.Sequential( 99 | nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, 100 | kernel_size=1, stride=1, padding=0), 101 | nn.BatchNorm2d(self.in_channels), 102 | ) 103 | nn.init.constant_(self.W[1].weight, 0.0) 104 | nn.init.constant_(self.W[1].bias, 0.0) 105 | 106 | self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 107 | kernel_size=1, stride=1, padding=0) 108 | 109 | self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 110 | kernel_size=1, stride=1, padding=0) 111 | 112 | def forward(self, x): 113 | ''' 114 | :param x: (b, c, t, h, w) 115 | :return: 116 | ''' 117 | 118 | batch_size = x.size(0) 119 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 120 | g_x = g_x.permute(0, 2, 1) 121 | 122 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 123 | theta_x = theta_x.permute(0, 2, 1) 124 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 125 | f = torch.matmul(theta_x, phi_x) 126 | N = f.size(-1) 127 | # f_div_C = torch.nn.functional.softmax(f, dim=-1) 128 | f_div_C = f / N 129 | 130 | y = torch.matmul(f_div_C, g_x) 131 | y = y.permute(0, 2, 1).contiguous() 132 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 133 | W_y = self.W(y) 134 | z = W_y + x 135 | 136 | return z 137 | 138 | 139 | # ##################################################################### 140 | def weights_init_kaiming(m): 141 | classname = m.__class__.__name__ 142 | # print(classname) 143 | if classname.find('Conv') != -1: 144 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 145 | elif classname.find('Linear') != -1: 146 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 147 | init.zeros_(m.bias.data) 148 | elif classname.find('BatchNorm1d') != -1: 149 | init.normal_(m.weight.data, 1.0, 0.01) 150 | init.zeros_(m.bias.data) 151 | 152 | 153 | def weights_init_classifier(m): 154 | classname = m.__class__.__name__ 155 | if classname.find('Linear') != -1: 156 | init.normal_(m.weight.data, 0, 0.001) 157 | if m.bias: 158 | init.zeros_(m.bias.data) 159 | 160 | 161 | class visible_module(nn.Module): 162 | def __init__(self, arch='resnet50'): 163 | super(visible_module, self).__init__() 164 | 165 | model_v = resnet50(pretrained=True, 166 | last_conv_stride=1, last_conv_dilation=1) 167 | # avg pooling to global pooling 168 | self.visible = model_v 169 | 170 | def forward(self, x): 171 | x = self.visible.conv1(x) 172 | x = self.visible.bn1(x) 173 | x = self.visible.relu(x) 174 | x = self.visible.maxpool(x) 175 | return x 176 | 177 | 178 | class thermal_module(nn.Module): 179 | def __init__(self, arch='resnet50'): 180 | super(thermal_module, self).__init__() 181 | 182 | model_t = resnet50(pretrained=True, 183 | last_conv_stride=1, last_conv_dilation=1) 184 | # avg pooling to global pooling 185 | self.thermal = model_t 186 | 187 | def forward(self, x): 188 | x = self.thermal.conv1(x) 189 | x = self.thermal.bn1(x) 190 | x = self.thermal.relu(x) 191 | x = self.thermal.maxpool(x) 192 | return x 193 | 194 | 195 | class base_resnet(nn.Module): 196 | def __init__(self, arch='resnet50'): 197 | super(base_resnet, self).__init__() 198 | 199 | model_base = resnet50(pretrained=True, 200 | last_conv_stride=1, last_conv_dilation=1) 201 | # avg pooling to global pooling 202 | model_base.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 203 | self.base = model_base 204 | 205 | def forward(self, x): 206 | x = self.base.layer1(x) 207 | x = self.base.layer2(x) 208 | x = self.base.layer3(x) 209 | x = self.base.layer4(x) 210 | return x 211 | 212 | 213 | class embed_net(nn.Module): 214 | def __init__(self, class_num, no_local='on', gm_pool='on', arch='resnet50', dataset="sysu", plearn=0, stage=23, 215 | depth=-1, head=-1, graphw=-1, gatw=-1.0, meanw=1.3, whiten=0, theta1=0.0, lambda1=1.3,edge=256): 216 | super(embed_net, self).__init__() 217 | self.thermal_module = thermal_module(arch=arch) 218 | self.visible_module = visible_module(arch=arch) 219 | self.base_resnet = base_resnet(arch=arch) 220 | self.non_local = no_local 221 | if self.non_local == 'on': 222 | layers = [3, 4, 6, 3] 223 | non_layers = [0, 2, 3, 0] 224 | self.NL_1 = nn.ModuleList( 225 | [Non_local(256) for i in range(non_layers[0])]) 226 | self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])]) 227 | self.NL_2 = nn.ModuleList( 228 | [Non_local(512) for i in range(non_layers[1])]) 229 | self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])]) 230 | self.NL_3 = nn.ModuleList( 231 | [Non_local(1024) for i in range(non_layers[2])]) 232 | self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])]) 233 | self.NL_4 = nn.ModuleList( 234 | [Non_local(2048) for i in range(non_layers[3])]) 235 | self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])]) 236 | 237 | pool_dim = 2048 238 | self.l2norm = Normalize(2) 239 | self.bottleneck = nn.BatchNorm1d(pool_dim) 240 | self.bottleneck.bias.requires_grad_(False) # no shift 241 | 242 | self.classifier = nn.Linear(pool_dim, class_num, bias=False) 243 | 244 | self.bottleneck.apply(weights_init_kaiming) 245 | self.classifier.apply(weights_init_classifier) 246 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 247 | self.gm_pool = gm_pool 248 | 249 | self.num_stripes = 6 250 | local_conv_out_channels = 256 251 | 252 | self.local_conv_list = nn.ModuleList() 253 | for _ in range(self.num_stripes): 254 | conv = nn.Conv2d(pool_dim, local_conv_out_channels, 1) 255 | conv.apply(weights_init_kaiming) 256 | self.local_conv_list.append(nn.Sequential( 257 | conv, 258 | nn.BatchNorm2d(local_conv_out_channels), 259 | nn.ReLU(inplace=True) 260 | )) 261 | 262 | self.fc_list = nn.ModuleList() 263 | for _ in range(self.num_stripes): 264 | fc = nn.Linear(local_conv_out_channels, class_num) 265 | init.normal_(fc.weight, std=0.001) 266 | init.constant_(fc.bias, 0) 267 | self.fc_list.append(fc) 268 | if plearn == 1:# better or worse 269 | self.p = nn.Parameter(torch.ones(1) * 3.0) 270 | if dataset == 'sysu': 271 | self.p1 = self.p 272 | else: 273 | self.p1 = nn.Parameter(torch.ones(1) * 10.0) 274 | elif plearn == 0: 275 | self.p = 3.0 276 | if dataset == 'sysu': 277 | self.p1 = 3.0 278 | else: 279 | self.p1 = 10.0 280 | self.stage = stage 281 | self.depth = depth 282 | self.head = head 283 | self.graphw = graphw 284 | self.whiten = whiten 285 | self.gatw = gatw 286 | self.cnn23 = Bottleneck12(inplanes=1024, planes=1024) 287 | self.lambda1 = lambda1 288 | if self.stage == 23: 289 | self.vit = ViT(img_size=18 * 9, embed_dim=1024, depth=self.depth, num_heads=self.head) 290 | self.hypergraph = HypergraphConv(theta1=theta1,edges=edge) 291 | self.soft = nn.Softmax(dim=0) 292 | 293 | self.gat = Graph_Attention_Union(1024, 1024, meanw=self.lambda1) 294 | if self.whiten == 1: 295 | self.whiten_o = whitening_scale_shift(1024, 1) 296 | print(self.whiten) 297 | 298 | self.fi = nn.Sequential( 299 | nn.Conv2d(1024 * 3, 1024, 1, 1), 300 | nn.BatchNorm2d(1024), 301 | nn.ReLU(inplace=True), 302 | ) 303 | self.init_weights() 304 | 305 | def init_weights(self): 306 | for n, m in self.named_modules(): 307 | if isinstance(m, nn.Conv2d): 308 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 309 | elif isinstance(m, nn.BatchNorm2d): 310 | nn.init.ones_(m.weight) 311 | nn.init.zeros_(m.bias) 312 | 313 | def forward(self, x1, x2, modal=0): 314 | if modal == 0: 315 | x1 = self.visible_module(x1) 316 | x2 = self.thermal_module(x2) 317 | x = torch.cat((x1, x2), 0) 318 | 319 | elif modal == 1: 320 | x = self.visible_module(x1) 321 | elif modal == 2: 322 | x = self.thermal_module(x2) 323 | 324 | # shared block 325 | if self.non_local == 'on': 326 | NL1_counter = 0 327 | if len(self.NL_1_idx) == 0: self.NL_1_idx = [-1] 328 | for i in range(len(self.base_resnet.base.layer1)): 329 | x = self.base_resnet.base.layer1[i](x) 330 | if i == self.NL_1_idx[NL1_counter]: 331 | _, C, H, W = x.shape 332 | x = self.NL_1[NL1_counter](x) 333 | NL1_counter += 1 334 | # Layer 2 335 | NL2_counter = 0 336 | if len(self.NL_2_idx) == 0: self.NL_2_idx = [-1] 337 | for i in range(len(self.base_resnet.base.layer2)): 338 | x = self.base_resnet.base.layer2[i](x) 339 | if i == self.NL_2_idx[NL2_counter]: 340 | _, C, H, W = x.shape 341 | x = self.NL_2[NL2_counter](x) 342 | NL2_counter += 1 343 | # Layer 3 344 | NL3_counter = 0 345 | if len(self.NL_3_idx) == 0: self.NL_3_idx = [-1] 346 | for i in range(len(self.base_resnet.base.layer3)): 347 | x = self.base_resnet.base.layer3[i](x) 348 | if i == self.NL_3_idx[NL3_counter]: 349 | _, C, H, W = x.shape 350 | x = self.NL_3[NL3_counter](x) 351 | NL3_counter += 1 352 | # SLE 353 | out_23 = x 354 | x = self.cnn23(x) 355 | if self.stage == 23 and self.training: 356 | out_23_shape = out_23.shape[0]//3 357 | temp = torch.cat((out_23[0:out_23_shape],out_23[2*out_23_shape:3*out_23_shape]), dim=0) 358 | x_vit = self.vit(temp) 359 | 360 | x = torch.cat((x, x_vit), dim=0) 361 | # HSL 362 | if self.whiten == 1: 363 | x = x + self.graphw * self.hypergraph(self.whiten_o(x)) 364 | else: 365 | x = x + self.graphw * self.hypergraph(x) 366 | # CFL 367 | if self.training: 368 | x_shape = x.shape[0] // 5 369 | CNN_RGB = x[0 * x_shape:1 * x_shape] 370 | CNN_IR = x[2 * x_shape:3 * x_shape] 371 | ViT_RGB = x[3 * x_shape:4 * x_shape] 372 | ViT_IR = x[4 * x_shape:5 * x_shape] 373 | 374 | GAT_CNN_RGB_ALL = self.gat(CNN_IR, CNN_RGB) + self.gat(ViT_RGB, CNN_RGB) + self.gat(ViT_IR, CNN_RGB) 375 | 376 | GAT_CNN_IR_ALL = self.gat(CNN_RGB, CNN_IR) + self.gat(ViT_RGB, CNN_IR) + self.gat(ViT_IR, CNN_IR) 377 | GAT_ViT_RGB_ALL = self.gat(CNN_RGB, ViT_RGB) + self.gat(CNN_IR, ViT_RGB) + self.gat(ViT_IR, ViT_RGB) 378 | GAT_ViT_IR_ALL = self.gat(CNN_RGB, ViT_IR) + self.gat(CNN_IR, ViT_IR) + self.gat(ViT_RGB, ViT_IR) 379 | 380 | x = torch.cat((x, self.gatw * GAT_CNN_RGB_ALL + CNN_RGB, self.gatw * GAT_CNN_IR_ALL + CNN_IR, 381 | self.gatw * GAT_ViT_RGB_ALL + ViT_RGB, self.gatw * GAT_ViT_IR_ALL + ViT_IR), dim=0) 382 | 383 | 384 | 385 | # Layer 4 386 | NL4_counter = 0 387 | if len(self.NL_4_idx) == 0: self.NL_4_idx = [-1] 388 | for i in range(len(self.base_resnet.base.layer4)): 389 | x = self.base_resnet.base.layer4[i](x) 390 | if i == self.NL_4_idx[NL4_counter]: 391 | _, C, H, W = x.shape 392 | x = self.NL_4[NL4_counter](x) 393 | NL4_counter += 1 394 | else: 395 | pass 396 | 397 | feat = x 398 | assert feat.size(2) % self.num_stripes == 0 399 | stripe_h = int(feat.size(2) / self.num_stripes) 400 | local_feat_list = [] 401 | logits_list = [] 402 | for i in range(self.num_stripes): 403 | # shape [N, C, 1, 1] 404 | 405 | # average pool 406 | # local_feat = F.avg_pool2d(feat[:, :, i * stripe_h: (i + 1) * stripe_h, :],(stripe_h, feat.size(-1))) 407 | if self.gm_pool == 'on': 408 | # gm pool 409 | local_feat = feat[:, :, i * stripe_h: (i + 1) * stripe_h, :] 410 | b, c, h, w = local_feat.shape 411 | local_feat = local_feat.view(b, c, -1) 412 | local_feat = (torch.mean(local_feat ** self.p1, dim=-1) + 1e-12) ** (1 / self.p1) 413 | else: 414 | # average pool 415 | # local_feat = F.avg_pool2d(feat[:, :, i * stripe_h: (i + 1) * stripe_h, :],(stripe_h, feat.size(-1))) 416 | local_feat = F.max_pool2d(feat[:, :, i * stripe_h: (i + 1) * stripe_h, :], 417 | (stripe_h, feat.size(-1))) 418 | 419 | # shape [N, c, 1, 1] 420 | local_feat = self.local_conv_list[i](local_feat.view(feat.size(0), feat.size(1), 1, 1)) 421 | 422 | # shape [N, c] 423 | local_feat = local_feat.view(local_feat.size(0), -1) 424 | local_feat_list.append(local_feat) 425 | 426 | if hasattr(self, 'fc_list'): 427 | logits_list.append(self.fc_list[i](local_feat)) 428 | 429 | feat_all = [lf for lf in local_feat_list] 430 | feat_all = torch.cat(feat_all, dim=1) 431 | 432 | ## golable 433 | if self.gm_pool == 'on': 434 | b, c, h, w = x.shape 435 | x = x.view(b, c, -1) 436 | x_pool = (torch.mean(x ** self.p, dim=-1) + 1e-12) ** (1 / self.p) 437 | else: 438 | x_pool = self.avgpool(x) 439 | x_pool = x_pool.view(x_pool.size(0), x_pool.size(1)) 440 | feat = self.bottleneck(x_pool) 441 | if self.training: 442 | temp = torch.cat((feat, feat_all), dim=1) 443 | return x_pool, self.classifier(feat), local_feat_list, logits_list, feat_all, temp 444 | else: 445 | x_pool_1 = torch.cat((x_pool, feat_all), dim=1) 446 | feat_1 = torch.cat((feat, feat_all), dim=1) 447 | return self.l2norm(x_pool_1), self.l2norm(feat_1) 448 | 449 | 450 | if __name__ == '__main__': 451 | x1 = torch.Tensor(8, 3, 288, 144).cuda() 452 | x2 = torch.Tensor(4, 3, 288, 144).cuda() 453 | model = embed_net(class_num=100, no_local='on', gm_pool='on', stage=23, head=4, depth=2).cuda() 454 | model(x1, x2) 455 | -------------------------------------------------------------------------------- /pre_process_sysu.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import pdb 4 | import os 5 | 6 | data_path = '/home/datasets/prml/computervision/re-id/sysu-mm01/ori_data' 7 | 8 | rgb_cameras = ['cam1','cam2','cam4','cam5'] 9 | ir_cameras = ['cam3','cam6'] 10 | 11 | # load id info 12 | file_path_train = os.path.join(data_path,'exp/train_id.txt') 13 | file_path_val = os.path.join(data_path,'exp/val_id.txt') 14 | with open(file_path_train, 'r') as file: 15 | ids = file.read().splitlines() 16 | ids = [int(y) for y in ids[0].split(',')] 17 | id_train = ["%04d" % x for x in ids] 18 | 19 | with open(file_path_val, 'r') as file: 20 | ids = file.read().splitlines() 21 | ids = [int(y) for y in ids[0].split(',')] 22 | id_val = ["%04d" % x for x in ids] 23 | 24 | # combine train and val split 25 | id_train.extend(id_val) 26 | 27 | files_rgb = [] 28 | files_ir = [] 29 | for id in sorted(id_train): 30 | for cam in rgb_cameras: 31 | img_dir = os.path.join(data_path,cam,id) 32 | if os.path.isdir(img_dir): 33 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 34 | files_rgb.extend(new_files) 35 | 36 | for cam in ir_cameras: 37 | img_dir = os.path.join(data_path,cam,id) 38 | if os.path.isdir(img_dir): 39 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 40 | files_ir.extend(new_files) 41 | 42 | # relabel 43 | pid_container = set() 44 | for img_path in files_ir: 45 | pid = int(img_path[-13:-9]) 46 | pid_container.add(pid) 47 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 48 | fix_image_width = 144 49 | fix_image_height = 288 50 | def read_imgs(train_image): 51 | train_img = [] 52 | train_label = [] 53 | for img_path in train_image: 54 | # img 55 | img = Image.open(img_path) 56 | img = img.resize((fix_image_width, fix_image_height), Image.ANTIALIAS) 57 | pix_array = np.array(img) 58 | 59 | train_img.append(pix_array) 60 | 61 | # label 62 | pid = int(img_path[-13:-9]) 63 | pid = pid2label[pid] 64 | train_label.append(pid) 65 | return np.array(train_img), np.array(train_label) 66 | 67 | # rgb imges 68 | train_img, train_label = read_imgs(files_rgb) 69 | np.save(data_path + 'train_rgb_resized_img.npy', train_img) 70 | np.save(data_path + 'train_rgb_resized_label.npy', train_label) 71 | 72 | # ir imges 73 | train_img, train_label = read_imgs(files_ir) 74 | np.save(data_path + 'train_ir_resized_img.npy', train_img) 75 | np.save(data_path + 'train_ir_resized_label.npy', train_label) 76 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152'] 7 | 8 | model_urls = { 9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1, dilation=1): 18 | """3x3 convolution with padding""" 19 | # original padding is 1; original dilation is 1 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=dilation, bias=False, dilation=dilation) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride, dilation) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | # original padding is 1; original dilation is 1 64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=dilation, bias=False, dilation=dilation) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 67 | self.bn3 = nn.BatchNorm2d(planes * 4) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class ResNet(nn.Module): 96 | 97 | def __init__(self, block, layers, last_conv_stride=2, last_conv_dilation=1): 98 | 99 | self.inplanes = 64 100 | super(ResNet, self).__init__() 101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 102 | bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | self.layer1 = self._make_layer(block, 64, layers[0]) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_conv_stride, dilation=last_conv_dilation) 110 | 111 | for m in self.modules(): 112 | if isinstance(m, nn.Conv2d): 113 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 114 | m.weight.data.normal_(0, math.sqrt(2. / n)) 115 | elif isinstance(m, nn.BatchNorm2d): 116 | m.weight.data.fill_(1) 117 | m.bias.data.zero_() 118 | 119 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 120 | downsample = None 121 | if stride != 1 or self.inplanes != planes * block.expansion: 122 | downsample = nn.Sequential( 123 | nn.Conv2d(self.inplanes, planes * block.expansion, 124 | kernel_size=1, stride=stride, bias=False), 125 | nn.BatchNorm2d(planes * block.expansion), 126 | ) 127 | 128 | layers = [] 129 | layers.append(block(self.inplanes, planes, stride, downsample, dilation)) 130 | self.inplanes = planes * block.expansion 131 | for i in range(1, blocks): 132 | layers.append(block(self.inplanes, planes)) 133 | 134 | return nn.Sequential(*layers) 135 | 136 | def forward(self, x): 137 | x = self.conv1(x) 138 | x = self.bn1(x) 139 | x = self.relu(x) 140 | x = self.maxpool(x) 141 | 142 | x = self.layer1(x) 143 | x = self.layer2(x) 144 | x = self.layer3(x) 145 | x = self.layer4(x) 146 | 147 | return x 148 | 149 | 150 | def remove_fc(state_dict): 151 | """Remove the fc layer parameters from state_dict.""" 152 | # for key, value in state_dict.items(): 153 | for key, value in list(state_dict.items()): 154 | if key.startswith('fc.'): 155 | del state_dict[key] 156 | return state_dict 157 | 158 | 159 | def resnet18(pretrained=False, **kwargs): 160 | """Constructs a ResNet-18 model. 161 | Args: 162 | pretrained (bool): If True, returns a model pre-trained on ImageNet 163 | """ 164 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 165 | if pretrained: 166 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet18']))) 167 | return model 168 | 169 | 170 | def resnet34(pretrained=False, **kwargs): 171 | """Constructs a ResNet-34 model. 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | """ 175 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 176 | if pretrained: 177 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet34']))) 178 | return model 179 | 180 | 181 | def resnet50(pretrained=False, **kwargs): 182 | """Constructs a ResNet-50 model. 183 | Args: 184 | pretrained (bool): If True, returns a model pre-trained on ImageNet 185 | """ 186 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 187 | if pretrained: 188 | # model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet50']))) 189 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet50']))) 190 | return model 191 | 192 | 193 | def resnet101(pretrained=False, **kwargs): 194 | """Constructs a ResNet-101 model. 195 | Args: 196 | pretrained (bool): If True, returns a model pre-trained on ImageNet 197 | """ 198 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 199 | if pretrained: 200 | model.load_state_dict( 201 | remove_fc(model_zoo.load_url(model_urls['resnet101']))) 202 | return model 203 | 204 | 205 | def resnet152(pretrained=False, **kwargs): 206 | """Constructs a ResNet-152 model. 207 | Args: 208 | pretrained (bool): If True, returns a model pre-trained on ImageNet 209 | """ 210 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 211 | if pretrained: 212 | model.load_state_dict( 213 | remove_fc(model_zoo.load_url(model_urls['resnet152']))) 214 | return model -------------------------------------------------------------------------------- /test_sle.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import time 4 | import torch.nn as nn 5 | import torch.backends.cudnn as cudnn 6 | from torch.autograd import Variable 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | from data_loader import TestData 10 | from data_manager import * 11 | from eval_metrics import eval_sysu, eval_regdb,eval_llcm 12 | from model_sle import embed_net as basline_pcb 13 | from utils import * 14 | 15 | parser = argparse.ArgumentParser(description='PyTorch Cross-Modality Training') 16 | parser.add_argument('--dataset', default='sysu', help='dataset name: regdb or sysu]') 17 | parser.add_argument('--lr', default=0.1 , type=float, help='learning rate, 0.00035 for adam') 18 | parser.add_argument('--optim', default='sgd', type=str, help='optimizer') 19 | parser.add_argument('--arch', default='resnet50', type=str, 20 | help='network baseline: resnet50') 21 | parser.add_argument('--resume', '-r', default='', type=str, 22 | help='resume from checkpoint') 23 | parser.add_argument('--test-only', action='store_true', help='test only') 24 | parser.add_argument('--model_path', default='save_model/', type=str, 25 | help='model save path') 26 | parser.add_argument('--save_epoch', default=20, type=int, 27 | metavar='s', help='save model every 10 epochs') 28 | parser.add_argument('--log_path', default='log/', type=str, 29 | help='log save path') 30 | parser.add_argument('--vis_log_path', default='log/vis_log/', type=str, 31 | help='log save path') 32 | parser.add_argument('--workers', default=4, type=int, metavar='N', 33 | help='number of data loading workers (default: 4)') 34 | parser.add_argument('--img_w', default=144, type=int, 35 | metavar='imgw', help='img width') 36 | parser.add_argument('--img_h', default=288, type=int, 37 | metavar='imgh', help='img height') 38 | parser.add_argument('--batch-size', default=8, type=int, 39 | metavar='B', help='training batch size') 40 | parser.add_argument('--test-batch', default=128, type=int, 41 | metavar='tb', help='testing batch size') 42 | parser.add_argument('--method', default='awg', type=str, 43 | metavar='m', help='method type: base or awg') 44 | parser.add_argument('--margin', default=0.3, type=float, 45 | metavar='margin', help='triplet loss margin') 46 | parser.add_argument('--num_pos', default=4, type=int, 47 | help='num of pos per identity in each modality') 48 | parser.add_argument('--trial', default=1, type=int, 49 | metavar='t', help='trial (only for RegDB dataset)') 50 | parser.add_argument('--seed', default=0, type=int, 51 | metavar='t', help='random seed') 52 | parser.add_argument('--gpu', default='0', type=str, 53 | help='gpu device ids for CUDA_VISIBLE_DEVICES') 54 | parser.add_argument('--mode', default='all', type=str, help='all or indoor for sysu') 55 | parser.add_argument('--tvsearch', action='store_true', help='whether thermal to visible search on RegDB') 56 | parser.add_argument('--depth', default=2, type=int, 57 | metavar='t', help='random seed') 58 | parser.add_argument('--head', default=4, type=int, 59 | metavar='t', help='random seed') 60 | args = parser.parse_args() 61 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 62 | 63 | dataset = args.dataset 64 | if dataset == 'sysu': 65 | data_path = '/SSD_dataset/CMReID/SYSU-MM01/ori_data/' 66 | n_class = 395 67 | test_mode = [1, 2] 68 | elif dataset =='regdb': 69 | data_path = '/SSD_dataset/CMReID/RegDB/' 70 | n_class = 206 71 | test_mode = [2, 1] 72 | 73 | elif dataset == 'llcm': 74 | data_path = '/SSD_dataset/CMReID/LLCM/LLCM/' 75 | log_path = args.log_path + 'llcm_log/' 76 | n_class = 713 77 | test_mode = [1, 2] # [2, 1]: VIS to IR; [1, 2]: IR to VIS 78 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 79 | best_acc = 0 # best test accuracy 80 | start_epoch = 0 81 | pool_dim = 3584 82 | print('==> Building model..') 83 | if args.method =='base': 84 | net = basline_pcb(n_class, no_local='on', gm_pool='on', arch=args.arch, dataset=dataset, plearn=1, stage=23, depth=args.depth, head=args.head) 85 | else: 86 | net = basline_pcb(n_class, no_local='on', gm_pool='on', arch=args.arch, dataset=dataset, plearn=1, stage=23, depth=args.depth, head=args.head) 87 | net.to(device) 88 | cudnn.benchmark = True 89 | 90 | checkpoint_path = args.model_path 91 | 92 | if args.method =='id': 93 | criterion = nn.CrossEntropyLoss() 94 | criterion.to(device) 95 | 96 | print('==> Loading data..') 97 | # Data loading code 98 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 99 | 100 | transform_test = transforms.Compose([ 101 | transforms.ToPILImage(), 102 | transforms.Resize((args.img_h,args.img_w)), 103 | transforms.ToTensor(), 104 | normalize, 105 | ]) 106 | 107 | 108 | class ChannelRGB(object): 109 | """ Adaptive selects a channel or two channels. 110 | Args: 111 | probability: The probability that the Random Erasing operation will be performed. 112 | sl: Minimum proportion of erased area against input image. 113 | sh: Maximum proportion of erased area against input image. 114 | r1: Minimum aspect ratio of erased area. 115 | mean: Erasing value. 116 | """ 117 | 118 | def __init__(self, idx = 0): 119 | self.idx = idx 120 | 121 | 122 | def __call__(self, img): 123 | 124 | if self.idx ==0: 125 | # random select R Channel 126 | img[1, :,:] = img[0,:,:] 127 | img[2, :,:] = img[0,:,:] 128 | elif self.idx ==1: 129 | # random select G Channel 130 | img[0, :,:] = img[1,:,:] 131 | img[2, :,:] = img[1,:,:] 132 | elif self.idx ==2: 133 | # random select B Channel 134 | img[0, :,:] = img[2,:,:] 135 | img[1, :,:] = img[2,:,:] 136 | else: 137 | img = img 138 | 139 | return img 140 | 141 | transform_visible_list = [ 142 | transforms.ToPILImage(), 143 | transforms.Resize((args.img_h, args.img_w)), 144 | transforms.ToTensor(), 145 | normalize] 146 | 147 | 148 | 149 | # transform_visible_list = transform_visible_list + [ChannelRGB(idx=2)] 150 | 151 | transform_visible = transforms.Compose( transform_visible_list ) 152 | 153 | end = time.time() 154 | 155 | def fliplr(img): 156 | '''flip horizontal''' 157 | inv_idx = torch.arange(img.size(3)-1,-1,-1).long() # N x C x H x W 158 | img_flip = img.index_select(3,inv_idx) 159 | return img_flip 160 | 161 | def extract_gall_feat(gall_loader): 162 | net.eval() 163 | print ('Extracting Gallery Feature...') 164 | start = time.time() 165 | ptr = 0 166 | gall_feat_pool = np.zeros((ngall, pool_dim)) 167 | gall_feat_fc = np.zeros((ngall, pool_dim)) 168 | with torch.no_grad(): 169 | for batch_idx, (input, label ) in enumerate(gall_loader): 170 | batch_num = input.size(0) 171 | 172 | flip_input = fliplr(input) 173 | 174 | input = Variable(input.cuda()) 175 | feat_pool, feat_fc = net(input, input, test_mode[0]) 176 | 177 | flip_input = Variable(flip_input.cuda()) 178 | feat_pool_1, feat_fc_1 = net(flip_input, flip_input, test_mode[0]) 179 | 180 | feature_pool = (feat_pool.detach() + feat_pool_1.detach())/2 181 | feature_fc = (feat_fc.detach() + feat_fc_1.detach())/2 182 | 183 | # feature_pool = feat_pool.detach() 184 | # feature_fc = feat_fc.detach() 185 | 186 | fnorm_pool = torch.norm(feature_pool, p=2, dim=1, keepdim=True) 187 | feature_pool = feature_pool.div(fnorm_pool.expand_as(feature_pool)) 188 | 189 | fnorm_fc = torch.norm(feature_fc, p=2, dim=1, keepdim=True) 190 | feature_fc = feature_fc.div(fnorm_fc.expand_as(feature_fc)) 191 | 192 | gall_feat_pool[ptr:ptr+batch_num,: ] = feature_pool.cpu().numpy() 193 | gall_feat_fc[ptr:ptr+batch_num,: ] = feature_fc.cpu().numpy() 194 | ptr = ptr + batch_num 195 | print('Extracting Time:\t {:.3f}'.format(time.time()-start)) 196 | return gall_feat_pool, gall_feat_fc 197 | 198 | def extract_query_feat(query_loader): 199 | net.eval() 200 | print ('Extracting Query Feature...') 201 | start = time.time() 202 | ptr = 0 203 | query_feat_pool = np.zeros((nquery, pool_dim)) 204 | query_feat_fc = np.zeros((nquery, pool_dim)) 205 | with torch.no_grad(): 206 | for batch_idx, (input, label ) in enumerate(query_loader): 207 | batch_num = input.size(0) 208 | flip_input = fliplr(input) 209 | 210 | input = Variable(input.cuda()) 211 | feat_pool, feat_fc = net(input, input, test_mode[1]) 212 | 213 | flip_input = Variable(flip_input.cuda()) 214 | feat_pool_1, feat_fc_1 = net(flip_input, flip_input, test_mode[1]) 215 | 216 | feature_pool = (feat_pool.detach() + feat_pool_1.detach())/2 217 | feature_fc = (feat_fc.detach() + feat_fc_1.detach())/2 218 | 219 | # feature_pool = feat_pool.detach() 220 | # feature_fc = feat_fc.detach() 221 | 222 | 223 | fnorm_pool = torch.norm(feature_pool, p=2, dim=1, keepdim=True) 224 | feature_pool = feature_pool.div(fnorm_pool.expand_as(feature_pool)) 225 | 226 | fnorm_fc = torch.norm(feature_fc, p=2, dim=1, keepdim=True) 227 | feature_fc = feature_fc.div(fnorm_fc.expand_as(feature_fc)) 228 | 229 | query_feat_pool[ptr:ptr+batch_num,: ] = feature_pool.cpu().numpy() 230 | query_feat_fc[ptr:ptr+batch_num,: ] = feature_fc.cpu().numpy() 231 | 232 | ptr = ptr + batch_num 233 | print('Extracting Time:\t {:.3f}'.format(time.time()-start)) 234 | return query_feat_pool, query_feat_fc 235 | 236 | 237 | if dataset == 'sysu': 238 | print('==> Resuming from checkpoint..') 239 | model_path = "./sle_ckpt/sysu/save_model/sysu_adp_joint_co_nog_ch_nog_sq1_aug_G_erase_0.5_p4_n8_lr_0.1_seed_0_localtri_0.0_otri_1_stage_23_depth_2_head_4_pha_1.0_resume_1_49_best.t" 240 | if os.path.isfile(model_path): 241 | print('==> loading checkpoint {}'.format(model_path)) 242 | checkpoint = torch.load(model_path) 243 | net.load_state_dict(checkpoint['net']) 244 | print('==> loaded checkpoint {} (epoch {})' 245 | .format(args.resume, checkpoint['epoch'])) 246 | assert os.path.isfile(model_path) 247 | 248 | 249 | # testing set 250 | query_img, query_label, query_cam = process_query_sysu(data_path, mode=args.mode) 251 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode, trial=0) 252 | 253 | nquery = len(query_label) 254 | ngall = len(gall_label) 255 | print("Dataset statistics:") 256 | print(" ------------------------------") 257 | print(" subset | # ids | # images") 258 | print(" ------------------------------") 259 | print(" query | {:5d} | {:8d}".format(len(np.unique(query_label)), nquery)) 260 | print(" gallery | {:5d} | {:8d}".format(len(np.unique(gall_label)), ngall)) 261 | print(" ------------------------------") 262 | 263 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 264 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=4) 265 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 266 | localtime = time.asctime(time.localtime(time.time())) 267 | print("本地时间为 :", localtime) 268 | query_feat_pool, query_feat_fc = extract_query_feat(query_loader) 269 | for trial in range(10): 270 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode, trial=trial) 271 | 272 | trial_gallset = TestData(gall_img, gall_label, transform=transform_visible, img_size=(args.img_w, args.img_h)) 273 | trial_gall_loader = data.DataLoader(trial_gallset, batch_size=args.test_batch, shuffle=False, num_workers=4) 274 | 275 | gall_feat_pool, gall_feat_fc = extract_gall_feat(trial_gall_loader) 276 | 277 | # pool5 feature 278 | distmat_pool = np.matmul(query_feat_pool, np.transpose(gall_feat_pool)) 279 | cmc_pool, mAP_pool, mINP_pool = eval_sysu(-distmat_pool, query_label, gall_label, query_cam, gall_cam) 280 | 281 | # fc feature 282 | distmat = np.matmul(query_feat_fc, np.transpose(gall_feat_fc)) 283 | cmc, mAP, mINP = eval_sysu(-distmat, query_label, gall_label, query_cam, gall_cam) 284 | if trial == 0: 285 | all_cmc = cmc 286 | all_mAP = mAP 287 | all_mINP = mINP 288 | all_cmc_pool = cmc_pool 289 | all_mAP_pool = mAP_pool 290 | all_mINP_pool = mINP_pool 291 | else: 292 | all_cmc = all_cmc + cmc 293 | all_mAP = all_mAP + mAP 294 | all_mINP = all_mINP + mINP 295 | all_cmc_pool = all_cmc_pool + cmc_pool 296 | all_mAP_pool = all_mAP_pool + mAP_pool 297 | all_mINP_pool = all_mINP_pool + mINP_pool 298 | 299 | print('Test Trial: {}'.format(trial)) 300 | print( 301 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 302 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 303 | print( 304 | 'POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 305 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) 306 | 307 | if dataset == 'llcm': 308 | print('==> Resuming from checkpoint..') 309 | model_path = "" 310 | if os.path.isfile(model_path): 311 | print('==> loading checkpoint {}'.format(args.resume)) 312 | checkpoint = torch.load(model_path) 313 | net.load_state_dict(checkpoint['net']) 314 | print('==> loaded checkpoint {} (epoch {})' 315 | .format(args.resume, checkpoint['epoch'])) 316 | assert os.path.isfile(model_path) 317 | 318 | # testing set 319 | query_img, query_label, query_cam = process_query_llcm(data_path, mode=test_mode[1]) 320 | gall_img, gall_label, gall_cam = process_gallery_llcm(data_path, mode=test_mode[0], trial=0) 321 | 322 | nquery = len(query_label) 323 | ngall = len(gall_label) 324 | print("Dataset statistics:") 325 | print(" ------------------------------") 326 | print(" subset | # ids | # images") 327 | print(" ------------------------------") 328 | print(" query | {:5d} | {:8d}".format(len(np.unique(query_label)), nquery)) 329 | print(" gallery | {:5d} | {:8d}".format(len(np.unique(gall_label)), ngall)) 330 | print(" ------------------------------") 331 | 332 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 333 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=4) 334 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 335 | localtime = time.asctime(time.localtime(time.time())) 336 | print("本地时间为 :", localtime) 337 | query_feat_pool, query_feat_fc = extract_query_feat(query_loader) 338 | for trial in range(10): 339 | gall_img, gall_label, gall_cam = process_gallery_llcm(data_path, mode=test_mode[0], trial=trial) 340 | 341 | trial_gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 342 | trial_gall_loader = data.DataLoader(trial_gallset, batch_size=args.test_batch, shuffle=False, num_workers=4) 343 | 344 | gall_feat_pool, gall_feat_fc = extract_gall_feat(trial_gall_loader) 345 | 346 | # pool5 feature 347 | distmat_pool = np.matmul(query_feat_pool, np.transpose(gall_feat_pool)) 348 | cmc_pool, mAP_pool, mINP_pool = eval_llcm(-distmat_pool, query_label, gall_label, query_cam, gall_cam) 349 | 350 | # fc feature 351 | distmat = np.matmul(query_feat_fc, np.transpose(gall_feat_fc)) 352 | cmc, mAP, mINP = eval_llcm(-distmat, query_label, gall_label, query_cam, gall_cam) 353 | if trial == 0: 354 | all_cmc = cmc 355 | all_mAP = mAP 356 | all_mINP = mINP 357 | all_cmc_pool = cmc_pool 358 | all_mAP_pool = mAP_pool 359 | all_mINP_pool = mINP_pool 360 | else: 361 | all_cmc = all_cmc + cmc 362 | all_mAP = all_mAP + mAP 363 | all_mINP = all_mINP + mINP 364 | all_cmc_pool = all_cmc_pool + cmc_pool 365 | all_mAP_pool = all_mAP_pool + mAP_pool 366 | all_mINP_pool = all_mINP_pool + mINP_pool 367 | 368 | print('Test Trial: {}'.format(trial)) 369 | print( 370 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 371 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 372 | print( 373 | 'POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 374 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) 375 | 376 | elif dataset == 'regdb': 377 | 378 | for trial in range(10): 379 | test_trial = trial +1 380 | # model_path = checkpoint_path + args.resume 381 | model_path = ''.format(test_trial) 382 | 383 | 384 | if os.path.isfile(model_path): 385 | 386 | checkpoint = torch.load(model_path) 387 | net.load_state_dict(checkpoint['net']) 388 | print('==> loaded checkpoint {} (epoch {})' 389 | .format(model_path, checkpoint['epoch'])) 390 | assert os.path.isfile(model_path) 391 | 392 | # testing set 393 | query_img, query_label = process_test_regdb(data_path, trial=test_trial, modal='visible') 394 | query_img, query_label = process_test_regdb(data_path, trial=test_trial, modal='visible') 395 | gall_img, gall_label = process_test_regdb(data_path, trial=test_trial, modal='thermal') 396 | 397 | gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 398 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 399 | 400 | nquery = len(query_label) 401 | ngall = len(gall_label) 402 | 403 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 404 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=4) 405 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 406 | 407 | localtime = time.asctime(time.localtime(time.time())) 408 | print("本地时间为 :", localtime) 409 | query_feat_pool, query_feat_fc = extract_query_feat(query_loader) 410 | gall_feat_pool, gall_feat_fc = extract_gall_feat(gall_loader) 411 | 412 | if args.tvsearch: 413 | # pool5 feature 414 | distmat_pool = np.matmul(gall_feat_pool, np.transpose(query_feat_pool)) 415 | cmc_pool, mAP_pool, mINP_pool = eval_regdb(-distmat_pool, gall_label, query_label) 416 | 417 | # fc feature 418 | distmat = np.matmul(gall_feat_fc , np.transpose(query_feat_fc)) 419 | cmc, mAP, mINP = eval_regdb(-distmat,gall_label, query_label ) 420 | else: 421 | # pool5 feature 422 | distmat_pool = np.matmul(query_feat_pool, np.transpose(gall_feat_pool)) 423 | cmc_pool, mAP_pool, mINP_pool = eval_regdb(-distmat_pool, query_label, gall_label) 424 | 425 | # fc feature 426 | distmat = np.matmul(query_feat_fc, np.transpose(gall_feat_fc)) 427 | cmc, mAP, mINP = eval_regdb(-distmat, query_label, gall_label) 428 | 429 | 430 | if trial == 0: 431 | all_cmc = cmc 432 | all_mAP = mAP 433 | all_mINP = mINP 434 | all_cmc_pool = cmc_pool 435 | all_mAP_pool = mAP_pool 436 | all_mINP_pool = mINP_pool 437 | else: 438 | all_cmc = all_cmc + cmc 439 | all_mAP = all_mAP + mAP 440 | all_mINP = all_mINP + mINP 441 | all_cmc_pool = all_cmc_pool + cmc_pool 442 | all_mAP_pool = all_mAP_pool + mAP_pool 443 | all_mINP_pool = all_mINP_pool + mINP_pool 444 | 445 | print('Test Trial: {}'.format(trial+1)) 446 | print( 447 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 448 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 449 | print( 450 | 'POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 451 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) 452 | 453 | cmc = all_cmc / 10 454 | mAP = all_mAP / 10 455 | mINP = all_mINP / 10 456 | 457 | cmc_pool = all_cmc_pool / 10 458 | mAP_pool = all_mAP_pool / 10 459 | mINP_pool = all_mINP_pool / 10 460 | print('All Average:') 461 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 462 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 463 | print('POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 464 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) 465 | localtime = time.asctime(time.localtime(time.time())) 466 | print("本地时间为 :", localtime) -------------------------------------------------------------------------------- /test_sle_hsl.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import time 4 | import torch.nn as nn 5 | import torch.backends.cudnn as cudnn 6 | from torch.autograd import Variable 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | from data_loader import SYSUData, RegDBData, TestData 10 | from data_manager import * 11 | from eval_metrics import eval_sysu, eval_regdb, eval_llcm 12 | from model_sle_hsl import embed_net 13 | from utils import * 14 | import time 15 | parser = argparse.ArgumentParser(description='PyTorch Cross-Modality Training') 16 | parser.add_argument('--dataset', default='sysu', help='dataset name: regdb or sysu]') 17 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate, 0.00035 for adam') 18 | parser.add_argument('--optim', default='sgd', type=str, help='optimizer') 19 | parser.add_argument('--arch', default='resnet50', type=str, 20 | help='network baseline: resnet50') 21 | parser.add_argument('--resume', '-r', default='', type=str, 22 | help='resume from checkpoint') 23 | parser.add_argument('--test-only', action='store_true', help='test only') 24 | parser.add_argument('--model_path', default='save_model/', type=str, 25 | help='model save path') 26 | parser.add_argument('--save_epoch', default=20, type=int, 27 | metavar='s', help='save model every 10 epochs') 28 | parser.add_argument('--log_path', default='log/', type=str, 29 | help='log save path') 30 | parser.add_argument('--vis_log_path', default='log/vis_log/', type=str, 31 | help='log save path') 32 | parser.add_argument('--workers', default=4, type=int, metavar='N', 33 | help='number of data loading workers (default: 4)') 34 | parser.add_argument('--img_w', default=144, type=int, 35 | metavar='imgw', help='img width') 36 | parser.add_argument('--img_h', default=288, type=int, 37 | metavar='imgh', help='img height') 38 | parser.add_argument('--batch-size', default=8, type=int, 39 | metavar='B', help='training batch size') 40 | parser.add_argument('--test-batch', default=64, type=int, 41 | metavar='tb', help='testing batch size') 42 | parser.add_argument('--method', default='awg', type=str, 43 | metavar='m', help='method type: base or awg') 44 | parser.add_argument('--margin', default=0.3, type=float, 45 | metavar='margin', help='triplet loss margin') 46 | parser.add_argument('--num_pos', default=4, type=int, 47 | help='num of pos per identity in each modality') 48 | parser.add_argument('--trial', default=1, type=int, 49 | metavar='t', help='trial (only for RegDB dataset)') 50 | parser.add_argument('--seed', default=0, type=int, 51 | metavar='t', help='random seed') 52 | parser.add_argument('--gpu', default='0', type=str, 53 | help='gpu device ids for CUDA_VISIBLE_DEVICES') 54 | parser.add_argument('--mode', default='all', type=str, help='all or indoor for sysu') 55 | parser.add_argument('--tvsearch', action='store_true', help='whether thermal to visible search on RegDB') 56 | parser.add_argument('--stage', default=23, type=int, 57 | metavar='t', help='random seed') 58 | parser.add_argument('--depth', default=2, type=int, 59 | metavar='t', help='random seed') 60 | parser.add_argument('--head', default=4, type=int, 61 | metavar='t', help='random seed') 62 | parser.add_argument('--reduce', default=0, type=int, 63 | metavar='t', help='random seed') 64 | parser.add_argument('--graphw', default=1.0, type=float, 65 | metavar='t', help='random seed') 66 | parser.add_argument('--theta1', default=0.5, type=float, metavar='theta1', help='use theta1') 67 | parser.add_argument('--vt', default=0, type=int, 68 | metavar='t', help='random seed') 69 | 70 | args = parser.parse_args() 71 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 72 | 73 | dataset = args.dataset 74 | if dataset == 'sysu': 75 | data_path = '/SSD_dataset/CMReID/SYSU-MM01/ori_data/' 76 | n_class = 395 77 | test_mode = [1, 2] 78 | elif dataset == 'regdb': 79 | data_path = '/SSD_dataset/CMReID/RegDB/' 80 | n_class = 206 81 | test_mode = [2, 1] 82 | elif dataset == 'llcm': 83 | data_path = '/SSD_dataset/CMReID/LLCM/LLCM/' 84 | log_path = args.log_path + 'llcm_log/' 85 | n_class = 713 86 | if args.vt == 1: 87 | test_mode = [2, 1] #[2, 1]: VIS to IR; [1, 2]: IR to VIS 88 | else: 89 | test_mode = [1, 2] 90 | 91 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 92 | best_acc = 0 # best test accuracy 93 | start_epoch = 0 94 | pool_dim = 3584 95 | print('==> Building model..') 96 | if args.method == 'base': 97 | net = embed_net(n_class, no_local='on', gm_pool='on', arch=args.arch, dataset=dataset, plearn=1, 98 | stage=23, depth=args.depth, head=args.head, graphw=args.graphw, theta1=args.theta1) 99 | 100 | else: 101 | net = embed_net(n_class, no_local='on', gm_pool='on', arch=args.arch, dataset=dataset, plearn=1, 102 | stage=23, depth=args.depth, head=args.head, graphw=args.graphw, theta1=args.theta1) 103 | 104 | net.to(device) 105 | cudnn.benchmark = True 106 | 107 | checkpoint_path = args.model_path 108 | 109 | if args.method == 'id': 110 | criterion = nn.CrossEntropyLoss() 111 | criterion.to(device) 112 | 113 | print('==> Loading data..') 114 | # Data loading code 115 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 116 | 117 | transform_test = transforms.Compose([ 118 | transforms.ToPILImage(), 119 | transforms.Resize((args.img_h, args.img_w)), 120 | transforms.ToTensor(), 121 | normalize, 122 | ]) 123 | 124 | 125 | class ChannelRGB(object): 126 | """ Adaptive selects a channel or two channels. 127 | Args: 128 | probability: The probability that the Random Erasing operation will be performed. 129 | sl: Minimum proportion of erased area against input image. 130 | sh: Maximum proportion of erased area against input image. 131 | r1: Minimum aspect ratio of erased area. 132 | mean: Erasing value. 133 | """ 134 | 135 | def __init__(self, idx=0): 136 | self.idx = idx 137 | 138 | def __call__(self, img): 139 | 140 | if self.idx == 0: 141 | # random select R Channel 142 | img[1, :, :] = img[0, :, :] 143 | img[2, :, :] = img[0, :, :] 144 | elif self.idx == 1: 145 | # random select G Channel 146 | img[0, :, :] = img[1, :, :] 147 | img[2, :, :] = img[1, :, :] 148 | elif self.idx == 2: 149 | # random select B Channel 150 | img[0, :, :] = img[2, :, :] 151 | img[1, :, :] = img[2, :, :] 152 | else: 153 | img = img 154 | 155 | return img 156 | 157 | 158 | transform_visible_list = [ 159 | transforms.ToPILImage(), 160 | transforms.Resize((args.img_h, args.img_w)), 161 | transforms.ToTensor(), 162 | normalize] 163 | 164 | # transform_visible_list = transform_visible_list + [ChannelRGB(idx=2)] 165 | 166 | transform_visible = transforms.Compose(transform_visible_list) 167 | 168 | end = time.time() 169 | 170 | 171 | def fliplr(img): 172 | '''flip horizontal''' 173 | inv_idx = torch.arange(img.size(3) - 1, -1, -1).long() # N x C x H x W 174 | img_flip = img.index_select(3, inv_idx) 175 | return img_flip 176 | 177 | 178 | def extract_gall_feat(gall_loader): 179 | net.eval() 180 | print('Extracting Gallery Feature...') 181 | start = time.time() 182 | ptr = 0 183 | gall_feat_pool = np.zeros((ngall, pool_dim)) 184 | gall_feat_fc = np.zeros((ngall, pool_dim)) 185 | with torch.no_grad(): 186 | for batch_idx, (input, label) in enumerate(gall_loader): 187 | batch_num = input.size(0) 188 | 189 | flip_input = fliplr(input) 190 | 191 | input = Variable(input.cuda()) 192 | feat_pool, feat_fc = net(input, input, test_mode[0]) 193 | 194 | flip_input = Variable(flip_input.cuda()) 195 | feat_pool_1, feat_fc_1 = net(flip_input, flip_input, test_mode[0]) 196 | 197 | feature_pool = (feat_pool.detach() + feat_pool_1.detach()) / 2 198 | feature_fc = (feat_fc.detach() + feat_fc_1.detach()) / 2 199 | 200 | # feature_pool = feat_pool.detach() 201 | # feature_fc = feat_fc.detach() 202 | 203 | fnorm_pool = torch.norm(feature_pool, p=2, dim=1, keepdim=True) 204 | feature_pool = feature_pool.div(fnorm_pool.expand_as(feature_pool)) 205 | 206 | fnorm_fc = torch.norm(feature_fc, p=2, dim=1, keepdim=True) 207 | feature_fc = feature_fc.div(fnorm_fc.expand_as(feature_fc)) 208 | 209 | gall_feat_pool[ptr:ptr + batch_num, :] = feature_pool.cpu().numpy() 210 | gall_feat_fc[ptr:ptr + batch_num, :] = feature_fc.cpu().numpy() 211 | ptr = ptr + batch_num 212 | print('Extracting Time:\t {:.3f}'.format(time.time() - start)) 213 | return gall_feat_pool, gall_feat_fc 214 | 215 | 216 | def extract_query_feat(query_loader): 217 | net.eval() 218 | print('Extracting Query Feature...') 219 | start = time.time() 220 | ptr = 0 221 | query_feat_pool = np.zeros((nquery, pool_dim)) 222 | query_feat_fc = np.zeros((nquery, pool_dim)) 223 | with torch.no_grad(): 224 | for batch_idx, (input, label) in enumerate(query_loader): 225 | batch_num = input.size(0) 226 | flip_input = fliplr(input) 227 | 228 | input = Variable(input.cuda()) 229 | feat_pool, feat_fc = net(input, input, test_mode[1]) 230 | 231 | flip_input = Variable(flip_input.cuda()) 232 | feat_pool_1, feat_fc_1 = net(flip_input, flip_input, test_mode[1]) 233 | 234 | feature_pool = (feat_pool.detach() + feat_pool_1.detach()) / 2 235 | feature_fc = (feat_fc.detach() + feat_fc_1.detach()) / 2 236 | 237 | # feature_pool = feat_pool.detach() 238 | # feature_fc = feat_fc.detach() 239 | 240 | fnorm_pool = torch.norm(feature_pool, p=2, dim=1, keepdim=True) 241 | feature_pool = feature_pool.div(fnorm_pool.expand_as(feature_pool)) 242 | 243 | fnorm_fc = torch.norm(feature_fc, p=2, dim=1, keepdim=True) 244 | feature_fc = feature_fc.div(fnorm_fc.expand_as(feature_fc)) 245 | 246 | query_feat_pool[ptr:ptr + batch_num, :] = feature_pool.cpu().numpy() 247 | query_feat_fc[ptr:ptr + batch_num, :] = feature_fc.cpu().numpy() 248 | 249 | ptr = ptr + batch_num 250 | print('Extracting Time:\t {:.3f}'.format(time.time() - start)) 251 | return query_feat_pool, query_feat_fc 252 | 253 | if dataset == 'llcm': 254 | print('==> Resuming from checkpoint..') 255 | model_path = "" 256 | if os.path.isfile(model_path): 257 | print('==> loading checkpoint {}'.format(model_path)) 258 | checkpoint = torch.load(model_path) 259 | net.load_state_dict(checkpoint['net']) 260 | print('==> loaded checkpoint {} (epoch {})' 261 | .format(args.resume, checkpoint['epoch'])) 262 | assert os.path.isfile(model_path) 263 | 264 | # testing set 265 | query_img, query_label, query_cam = process_query_llcm(data_path, mode=test_mode[1]) 266 | gall_img, gall_label, gall_cam = process_gallery_llcm(data_path, mode=test_mode[0], trial=0) 267 | 268 | nquery = len(query_label) 269 | ngall = len(gall_label) 270 | print("Dataset statistics:") 271 | print(" ------------------------------") 272 | print(" subset | # ids | # images") 273 | print(" ------------------------------") 274 | print(" query | {:5d} | {:8d}".format(len(np.unique(query_label)), nquery)) 275 | print(" gallery | {:5d} | {:8d}".format(len(np.unique(gall_label)), ngall)) 276 | print(" ------------------------------") 277 | 278 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 279 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=4) 280 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 281 | 282 | query_feat_pool, query_feat_fc = extract_query_feat(query_loader) 283 | for trial in range(10): 284 | gall_img, gall_label, gall_cam = process_gallery_llcm(data_path, mode=test_mode[0], trial=trial) 285 | 286 | trial_gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 287 | trial_gall_loader = data.DataLoader(trial_gallset, batch_size=args.test_batch, shuffle=False, num_workers=4) 288 | 289 | gall_feat_pool, gall_feat_fc = extract_gall_feat(trial_gall_loader) 290 | 291 | # pool5 feature 292 | distmat_pool = np.matmul(query_feat_pool, np.transpose(gall_feat_pool)) 293 | cmc_pool, mAP_pool, mINP_pool = eval_llcm(-distmat_pool, query_label, gall_label, query_cam, gall_cam) 294 | 295 | # fc feature 296 | distmat = np.matmul(query_feat_fc, np.transpose(gall_feat_fc)) 297 | cmc, mAP, mINP = eval_llcm(-distmat, query_label, gall_label, query_cam, gall_cam) 298 | if trial == 0: 299 | all_cmc = cmc 300 | all_mAP = mAP 301 | all_mINP = mINP 302 | all_cmc_pool = cmc_pool 303 | all_mAP_pool = mAP_pool 304 | all_mINP_pool = mINP_pool 305 | else: 306 | all_cmc = all_cmc + cmc 307 | all_mAP = all_mAP + mAP 308 | all_mINP = all_mINP + mINP 309 | all_cmc_pool = all_cmc_pool + cmc_pool 310 | all_mAP_pool = all_mAP_pool + mAP_pool 311 | all_mINP_pool = all_mINP_pool + mINP_pool 312 | 313 | print('Test Trial: {}'.format(trial)) 314 | print( 315 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 316 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 317 | print( 318 | 'POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 319 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) 320 | 321 | 322 | 323 | 324 | if dataset == 'sysu': 325 | 326 | print('==> Resuming from checkpoint..') 327 | 328 | model_path = "./sle_hsl_ckpt/sysu/save_model/sysu_adp_joint_co_nog_ch_nog_sq1_aug_G_erase_0.5_p4_n8_lr_0.1_seed_0_localtri_0.0_otri_1_stage_23_depth_2_head_4_pha_1.0_graphw_1.0_theta1_0.5_best.t" 329 | if os.path.isfile(model_path): 330 | print('==> loading checkpoint {}'.format(model_path)) 331 | checkpoint = torch.load(model_path) 332 | net.load_state_dict(checkpoint['net'], strict=True) 333 | print('==> loaded checkpoint {} (epoch {})' 334 | .format(args.resume, checkpoint['epoch'])) 335 | assert os.path.isfile(model_path) 336 | # testing set 337 | query_img, query_label, query_cam = process_query_sysu(data_path, mode=args.mode) 338 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode, trial=0) 339 | 340 | nquery = len(query_label) 341 | ngall = len(gall_label) 342 | print("Dataset statistics:") 343 | print(" ------------------------------") 344 | print(" subset | # ids | # images") 345 | print(" ------------------------------") 346 | print(" query | {:5d} | {:8d}".format(len(np.unique(query_label)), nquery)) 347 | print(" gallery | {:5d} | {:8d}".format(len(np.unique(gall_label)), ngall)) 348 | print(" ------------------------------") 349 | 350 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 351 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=4) 352 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 353 | localtime = time.asctime(time.localtime(time.time())) 354 | print("本地时间为 :", localtime) 355 | query_feat_pool, query_feat_fc = extract_query_feat(query_loader) 356 | for trial in range(10): 357 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode, trial=trial) 358 | 359 | trial_gallset = TestData(gall_img, gall_label, transform=transform_visible, img_size=(args.img_w, args.img_h)) 360 | trial_gall_loader = data.DataLoader(trial_gallset, batch_size=args.test_batch, shuffle=False, num_workers=4) 361 | 362 | gall_feat_pool, gall_feat_fc = extract_gall_feat(trial_gall_loader) 363 | 364 | # pool5 feature 365 | distmat_pool = np.matmul(query_feat_pool, np.transpose(gall_feat_pool)) 366 | cmc_pool, mAP_pool, mINP_pool = eval_sysu(-distmat_pool, query_label, gall_label, query_cam, gall_cam) 367 | 368 | # fc feature 369 | distmat = np.matmul(query_feat_fc, np.transpose(gall_feat_fc)) 370 | cmc, mAP, mINP = eval_sysu(-distmat, query_label, gall_label, query_cam, gall_cam) 371 | if trial == 0: 372 | all_cmc = cmc 373 | all_mAP = mAP 374 | all_mINP = mINP 375 | all_cmc_pool = cmc_pool 376 | all_mAP_pool = mAP_pool 377 | all_mINP_pool = mINP_pool 378 | else: 379 | all_cmc = all_cmc + cmc 380 | all_mAP = all_mAP + mAP 381 | all_mINP = all_mINP + mINP 382 | all_cmc_pool = all_cmc_pool + cmc_pool 383 | all_mAP_pool = all_mAP_pool + mAP_pool 384 | all_mINP_pool = all_mINP_pool + mINP_pool 385 | 386 | print('Test Trial: {}'.format(trial)) 387 | print( 388 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 389 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 390 | print( 391 | 'POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 392 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) 393 | 394 | 395 | elif dataset == 'regdb': 396 | 397 | for trial in range(10): 398 | test_trial = trial + 1 399 | # model_path = checkpoint_path + args.resume 400 | model_path = ''.format(test_trial) 401 | if os.path.isfile(model_path): 402 | print('==> loading checkpoint {}'.format(model_path)) 403 | checkpoint = torch.load(model_path) 404 | net.load_state_dict(checkpoint['net']) 405 | assert os.path.isfile(model_path) 406 | 407 | # training set 408 | # trainset = RegDBData(data_path, test_trial, transform=transform_train) 409 | # # generate the idx of each person identity 410 | # color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 411 | 412 | # testing set 413 | query_img, query_label = process_test_regdb(data_path, trial=test_trial, modal='visible') 414 | gall_img, gall_label = process_test_regdb(data_path, trial=test_trial, modal='thermal') 415 | 416 | gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 417 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 418 | 419 | nquery = len(query_label) 420 | ngall = len(gall_label) 421 | 422 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 423 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=4) 424 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 425 | localtime = time.asctime(time.localtime(time.time())) 426 | print("本地时间为 :", localtime) 427 | query_feat_pool, query_feat_fc = extract_query_feat(query_loader) 428 | gall_feat_pool, gall_feat_fc = extract_gall_feat(gall_loader) 429 | 430 | if args.tvsearch: 431 | # pool5 feature 432 | distmat_pool = np.matmul(gall_feat_pool, np.transpose(query_feat_pool)) 433 | cmc_pool, mAP_pool, mINP_pool = eval_regdb(-distmat_pool, gall_label, query_label) 434 | 435 | # fc feature 436 | distmat = np.matmul(gall_feat_fc, np.transpose(query_feat_fc)) 437 | cmc, mAP, mINP = eval_regdb(-distmat, gall_label, query_label) 438 | else: 439 | # pool5 feature 440 | distmat_pool = np.matmul(query_feat_pool, np.transpose(gall_feat_pool)) 441 | cmc_pool, mAP_pool, mINP_pool = eval_regdb(-distmat_pool, query_label, gall_label) 442 | 443 | # fc feature 444 | distmat = np.matmul(query_feat_fc, np.transpose(gall_feat_fc)) 445 | cmc, mAP, mINP = eval_regdb(-distmat, query_label, gall_label) 446 | 447 | if trial == 0: 448 | all_cmc = cmc 449 | all_mAP = mAP 450 | all_mINP = mINP 451 | all_cmc_pool = cmc_pool 452 | all_mAP_pool = mAP_pool 453 | all_mINP_pool = mINP_pool 454 | else: 455 | all_cmc = all_cmc + cmc 456 | all_mAP = all_mAP + mAP 457 | all_mINP = all_mINP + mINP 458 | all_cmc_pool = all_cmc_pool + cmc_pool 459 | all_mAP_pool = all_mAP_pool + mAP_pool 460 | all_mINP_pool = all_mINP_pool + mINP_pool 461 | 462 | print('Test Trial: {}'.format(trial)) 463 | print( 464 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 465 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 466 | print( 467 | 'POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 468 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) 469 | 470 | cmc = all_cmc / 10 471 | mAP = all_mAP / 10 472 | mINP = all_mINP / 10 473 | 474 | cmc_pool = all_cmc_pool / 10 475 | mAP_pool = all_mAP_pool / 10 476 | mINP_pool = all_mINP_pool / 10 477 | print('All Average:') 478 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 479 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 480 | print('POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 481 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) 482 | localtime = time.asctime(time.localtime(time.time())) 483 | print("本地时间为 :", localtime) -------------------------------------------------------------------------------- /test_sle_hsl_cfl.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import time 4 | import torch.nn as nn 5 | import torch.backends.cudnn as cudnn 6 | from torch.autograd import Variable 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | from data_loader import SYSUData, RegDBData, TestData 10 | from data_manager import * 11 | from eval_metrics import eval_sysu, eval_regdb, eval_llcm 12 | from model_sle_hsl_cfl import embed_net 13 | 14 | 15 | from utils import * 16 | 17 | parser = argparse.ArgumentParser(description='PyTorch Cross-Modality Training') 18 | parser.add_argument('--dataset', default='sysu', help='dataset name: regdb or sysu]') 19 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate, 0.00035 for adam') 20 | parser.add_argument('--optim', default='sgd', type=str, help='optimizer') 21 | parser.add_argument('--arch', default='resnet50', type=str, 22 | help='network baseline: resnet50') 23 | parser.add_argument('--resume', '-r', default='', type=str, 24 | help='resume from checkpoint') 25 | parser.add_argument('--test-only', action='store_true', help='test only') 26 | parser.add_argument('--model_path', default='save_model/', type=str, 27 | help='model save path') 28 | parser.add_argument('--save_epoch', default=20, type=int, 29 | metavar='s', help='save model every 10 epochs') 30 | parser.add_argument('--log_path', default='log/', type=str, 31 | help='log save path') 32 | parser.add_argument('--vis_log_path', default='log/vis_log/', type=str, 33 | help='log save path') 34 | parser.add_argument('--workers', default=4, type=int, metavar='N', 35 | help='number of data loading workers (default: 4)') 36 | parser.add_argument('--img_w', default=144, type=int, 37 | metavar='imgw', help='img width') 38 | parser.add_argument('--img_h', default=288, type=int, 39 | metavar='imgh', help='img height') 40 | parser.add_argument('--batch-size', default=8, type=int, 41 | metavar='B', help='training batch size') 42 | parser.add_argument('--test-batch', default=64, type=int, 43 | metavar='tb', help='testing batch size') 44 | parser.add_argument('--method', default='awg', type=str, 45 | metavar='m', help='method type: base or awg') 46 | parser.add_argument('--margin', default=0.3, type=float, 47 | metavar='margin', help='triplet loss margin') 48 | parser.add_argument('--num_pos', default=4, type=int, 49 | help='num of pos per identity in each modality') 50 | parser.add_argument('--trial', default=1, type=int, 51 | metavar='t', help='trial (only for RegDB dataset)') 52 | parser.add_argument('--seed', default=0, type=int, 53 | metavar='t', help='random seed') 54 | parser.add_argument('--gpu', default='0', type=str, 55 | help='gpu device ids for CUDA_VISIBLE_DEVICES') 56 | parser.add_argument('--mode', default='all', type=str, help='all or indoor for sysu') 57 | parser.add_argument('--tvsearch', action='store_true', help='whether thermal to visible search on RegDB') 58 | parser.add_argument('--stage', default=23, type=int, 59 | metavar='t', help='random seed') 60 | parser.add_argument('--depth', default=2, type=int, 61 | metavar='t', help='random seed') 62 | parser.add_argument('--head', default=4, type=int, 63 | metavar='t', help='random seed') 64 | parser.add_argument('--reduce', default=0, type=int, 65 | metavar='t', help='random seed') 66 | parser.add_argument('--graphw', default=1.0, type=float, 67 | metavar='t', help='random seed') 68 | parser.add_argument('--gatw', default=1.0, type=float, 69 | metavar='t', help='random seed') 70 | parser.add_argument('--vt', default=0, type=int, 71 | metavar='t', help='random seed') 72 | parser.add_argument('--whiten', default=1, type=int, 73 | metavar='whiten', help='random seed') 74 | args = parser.parse_args() 75 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 76 | 77 | dataset = args.dataset 78 | if dataset == 'sysu': 79 | data_path = '/SSD_dataset/CMReID/SYSU-MM01/ori_data/' 80 | n_class = 395 81 | test_mode = [1, 2] 82 | elif dataset == 'regdb': 83 | data_path = '/SSD_dataset/CMReID/RegDB/' 84 | n_class = 206 85 | test_mode = [2, 1] 86 | elif dataset == 'llcm': 87 | data_path = '/SSD_dataset/CMReID/LLCM/LLCM/' 88 | log_path = args.log_path + 'llcm_log/' 89 | n_class = 713 90 | if args.vt == 1: 91 | test_mode = [2, 1] #[2, 1]: VIS to IR; [1, 2]: IR to VIS 92 | else: 93 | test_mode = [1, 2] 94 | 95 | 96 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 97 | best_acc = 0 # best test accuracy 98 | start_epoch = 0 99 | pool_dim = 3584 100 | print('==> Building model..') 101 | if args.method == 'base': 102 | net = embed_net(n_class, no_local='on', gm_pool='on', arch=args.arch, dataset=dataset, plearn=1, 103 | stage=23, depth=args.depth, head=args.head, graphw=args.graphw, gatw=args.gatw, whiten=args.whiten) 104 | 105 | else: 106 | net = embed_net(n_class, no_local='on', gm_pool='on', arch=args.arch, dataset=dataset, plearn=1, 107 | stage=23, depth=args.depth, head=args.head, graphw=args.graphw, gatw=args.gatw, whiten=args.whiten) 108 | 109 | net.to(device) 110 | cudnn.benchmark = True 111 | 112 | checkpoint_path = args.model_path 113 | 114 | if args.method == 'id': 115 | criterion = nn.CrossEntropyLoss() 116 | criterion.to(device) 117 | 118 | print('==> Loading data..') 119 | # Data loading code 120 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 121 | 122 | transform_test = transforms.Compose([ 123 | transforms.ToPILImage(), 124 | transforms.Resize((args.img_h, args.img_w)), 125 | transforms.ToTensor(), 126 | normalize, 127 | ]) 128 | 129 | 130 | class ChannelRGB(object): 131 | """ Adaptive selects a channel or two channels. 132 | Args: 133 | probability: The probability that the Random Erasing operation will be performed. 134 | sl: Minimum proportion of erased area against input image. 135 | sh: Maximum proportion of erased area against input image. 136 | r1: Minimum aspect ratio of erased area. 137 | mean: Erasing value. 138 | """ 139 | 140 | def __init__(self, idx=0): 141 | self.idx = idx 142 | 143 | def __call__(self, img): 144 | 145 | if self.idx == 0: 146 | # random select R Channel 147 | img[1, :, :] = img[0, :, :] 148 | img[2, :, :] = img[0, :, :] 149 | elif self.idx == 1: 150 | # random select G Channel 151 | img[0, :, :] = img[1, :, :] 152 | img[2, :, :] = img[1, :, :] 153 | elif self.idx == 2: 154 | # random select B Channel 155 | img[0, :, :] = img[2, :, :] 156 | img[1, :, :] = img[2, :, :] 157 | else: 158 | img = img 159 | 160 | return img 161 | 162 | 163 | transform_visible_list = [ 164 | transforms.ToPILImage(), 165 | transforms.Resize((args.img_h, args.img_w)), 166 | transforms.ToTensor(), 167 | normalize] 168 | 169 | # transform_visible_list = transform_visible_list + [ChannelRGB(idx=2)] 170 | 171 | transform_visible = transforms.Compose(transform_visible_list) 172 | 173 | end = time.time() 174 | 175 | 176 | def fliplr(img): 177 | '''flip horizontal''' 178 | inv_idx = torch.arange(img.size(3) - 1, -1, -1).long() # N x C x H x W 179 | img_flip = img.index_select(3, inv_idx) 180 | return img_flip 181 | 182 | 183 | def extract_gall_feat(gall_loader): 184 | net.eval() 185 | print('Extracting Gallery Feature...') 186 | start = time.time() 187 | ptr = 0 188 | gall_feat_pool = np.zeros((ngall, pool_dim)) 189 | gall_feat_fc = np.zeros((ngall, pool_dim)) 190 | with torch.no_grad(): 191 | for batch_idx, (input, label) in enumerate(gall_loader): 192 | batch_num = input.size(0) 193 | 194 | flip_input = fliplr(input) 195 | 196 | input = Variable(input.cuda()) 197 | feat_pool, feat_fc = net(input, input, test_mode[0]) 198 | 199 | flip_input = Variable(flip_input.cuda()) 200 | feat_pool_1, feat_fc_1 = net(flip_input, flip_input, test_mode[0]) 201 | 202 | feature_pool = (feat_pool.detach() + feat_pool_1.detach()) / 2 203 | feature_fc = (feat_fc.detach() + feat_fc_1.detach()) / 2 204 | 205 | # feature_pool = feat_pool.detach() 206 | # feature_fc = feat_fc.detach() 207 | 208 | fnorm_pool = torch.norm(feature_pool, p=2, dim=1, keepdim=True) 209 | feature_pool = feature_pool.div(fnorm_pool.expand_as(feature_pool)) 210 | 211 | fnorm_fc = torch.norm(feature_fc, p=2, dim=1, keepdim=True) 212 | feature_fc = feature_fc.div(fnorm_fc.expand_as(feature_fc)) 213 | 214 | gall_feat_pool[ptr:ptr + batch_num, :] = feature_pool.cpu().numpy() 215 | gall_feat_fc[ptr:ptr + batch_num, :] = feature_fc.cpu().numpy() 216 | ptr = ptr + batch_num 217 | print('Extracting Time:\t {:.3f}'.format(time.time() - start)) 218 | return gall_feat_pool, gall_feat_fc 219 | 220 | 221 | def extract_query_feat(query_loader): 222 | net.eval() 223 | print('Extracting Query Feature...') 224 | start = time.time() 225 | ptr = 0 226 | query_feat_pool = np.zeros((nquery, pool_dim)) 227 | query_feat_fc = np.zeros((nquery, pool_dim)) 228 | with torch.no_grad(): 229 | for batch_idx, (input, label) in enumerate(query_loader): 230 | batch_num = input.size(0) 231 | flip_input = fliplr(input) 232 | 233 | input = Variable(input.cuda()) 234 | feat_pool, feat_fc = net(input, input, test_mode[1]) 235 | 236 | flip_input = Variable(flip_input.cuda()) 237 | feat_pool_1, feat_fc_1 = net(flip_input, flip_input, test_mode[1]) 238 | 239 | feature_pool = (feat_pool.detach() + feat_pool_1.detach()) / 2 240 | feature_fc = (feat_fc.detach() + feat_fc_1.detach()) / 2 241 | 242 | # feature_pool = feat_pool.detach() 243 | # feature_fc = feat_fc.detach() 244 | 245 | fnorm_pool = torch.norm(feature_pool, p=2, dim=1, keepdim=True) 246 | feature_pool = feature_pool.div(fnorm_pool.expand_as(feature_pool)) 247 | 248 | fnorm_fc = torch.norm(feature_fc, p=2, dim=1, keepdim=True) 249 | feature_fc = feature_fc.div(fnorm_fc.expand_as(feature_fc)) 250 | 251 | query_feat_pool[ptr:ptr + batch_num, :] = feature_pool.cpu().numpy() 252 | query_feat_fc[ptr:ptr + batch_num, :] = feature_fc.cpu().numpy() 253 | 254 | ptr = ptr + batch_num 255 | print('Extracting Time:\t {:.3f}'.format(time.time() - start)) 256 | return query_feat_pool, query_feat_fc 257 | 258 | if dataset == 'llcm': 259 | print('==> Resuming from checkpoint..') 260 | model_path = "" 261 | if os.path.isfile(model_path): 262 | print('==> loading checkpoint {}'.format(args.resume)) 263 | checkpoint = torch.load(model_path) 264 | net.load_state_dict(checkpoint['net']) 265 | print('==> loaded checkpoint {} (epoch {})' 266 | .format(args.resume, checkpoint['epoch'])) 267 | else: 268 | print('==> no checkpoint found at {}'.format(args.resume)) 269 | 270 | # testing set 271 | query_img, query_label, query_cam = process_query_llcm(data_path, mode=test_mode[1]) 272 | gall_img, gall_label, gall_cam = process_gallery_llcm(data_path, mode=test_mode[0], trial=0) 273 | 274 | nquery = len(query_label) 275 | ngall = len(gall_label) 276 | print("Dataset statistics:") 277 | print(" ------------------------------") 278 | print(" subset | # ids | # images") 279 | print(" ------------------------------") 280 | print(" query | {:5d} | {:8d}".format(len(np.unique(query_label)), nquery)) 281 | print(" gallery | {:5d} | {:8d}".format(len(np.unique(gall_label)), ngall)) 282 | print(" ------------------------------") 283 | 284 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 285 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=4) 286 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 287 | 288 | query_feat_pool, query_feat_fc = extract_query_feat(query_loader) 289 | for trial in range(10): 290 | gall_img, gall_label, gall_cam = process_gallery_llcm(data_path, mode=test_mode[0], trial=trial) 291 | 292 | trial_gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 293 | trial_gall_loader = data.DataLoader(trial_gallset, batch_size=args.test_batch, shuffle=False, num_workers=4) 294 | 295 | gall_feat_pool, gall_feat_fc = extract_gall_feat(trial_gall_loader) 296 | 297 | # pool5 feature 298 | distmat_pool = np.matmul(query_feat_pool, np.transpose(gall_feat_pool)) 299 | cmc_pool, mAP_pool, mINP_pool = eval_llcm(-distmat_pool, query_label, gall_label, query_cam, gall_cam) 300 | 301 | # fc feature 302 | distmat = np.matmul(query_feat_fc, np.transpose(gall_feat_fc)) 303 | cmc, mAP, mINP = eval_llcm(-distmat, query_label, gall_label, query_cam, gall_cam) 304 | if trial == 0: 305 | all_cmc = cmc 306 | all_mAP = mAP 307 | all_mINP = mINP 308 | all_cmc_pool = cmc_pool 309 | all_mAP_pool = mAP_pool 310 | all_mINP_pool = mINP_pool 311 | else: 312 | all_cmc = all_cmc + cmc 313 | all_mAP = all_mAP + mAP 314 | all_mINP = all_mINP + mINP 315 | all_cmc_pool = all_cmc_pool + cmc_pool 316 | all_mAP_pool = all_mAP_pool + mAP_pool 317 | all_mINP_pool = all_mINP_pool + mINP_pool 318 | 319 | print('Test Trial: {}'.format(trial)) 320 | print( 321 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 322 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 323 | print( 324 | 'POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 325 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) 326 | 327 | 328 | 329 | 330 | if dataset == 'sysu': 331 | 332 | print('==> Resuming from checkpoint..') 333 | model_path = "./sle_hsl_cfl_ckpt/sysu/save_model/sysu_adp_joint_co_nog_ch_nog_sq1_aug_G_erase_0.5_p4_n8_lr_0.1_seed_0_otri_1_stage_23_depth_2_head_4_resume_1_graphw_1.0_gatw_1.0_mricT_3_whiten_1_theta_0.5_aicT_3_lambda1_1.3_best.t" 334 | 335 | if os.path.isfile(model_path): 336 | print('==> loading checkpoint {}'.format(model_path)) 337 | checkpoint = torch.load(model_path) 338 | net.load_state_dict(checkpoint['net'], strict=True) 339 | print('==> loaded checkpoint {} (epoch {})' 340 | .format(model_path, checkpoint['epoch'])) 341 | else: 342 | print('==> no checkpoint found at {}'.format(model_path)) 343 | 344 | # testing set 345 | query_img, query_label, query_cam = process_query_sysu(data_path, mode=args.mode) 346 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode, trial=0) 347 | 348 | nquery = len(query_label) 349 | ngall = len(gall_label) 350 | print("Dataset statistics:") 351 | print(" ------------------------------") 352 | print(" subset | # ids | # images") 353 | print(" ------------------------------") 354 | print(" query | {:5d} | {:8d}".format(len(np.unique(query_label)), nquery)) 355 | print(" gallery | {:5d} | {:8d}".format(len(np.unique(gall_label)), ngall)) 356 | print(" ------------------------------") 357 | 358 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 359 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=4) 360 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 361 | 362 | query_feat_pool, query_feat_fc = extract_query_feat(query_loader) 363 | for trial in range(10): 364 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode, trial=trial) 365 | 366 | trial_gallset = TestData(gall_img, gall_label, transform=transform_visible, img_size=(args.img_w, args.img_h)) 367 | trial_gall_loader = data.DataLoader(trial_gallset, batch_size=args.test_batch, shuffle=False, num_workers=4) 368 | 369 | gall_feat_pool, gall_feat_fc = extract_gall_feat(trial_gall_loader) 370 | 371 | # pool5 feature 372 | distmat_pool = np.matmul(query_feat_pool, np.transpose(gall_feat_pool)) 373 | cmc_pool, mAP_pool, mINP_pool = eval_sysu(-distmat_pool, query_label, gall_label, query_cam, gall_cam) 374 | 375 | # fc feature 376 | distmat = np.matmul(query_feat_fc, np.transpose(gall_feat_fc)) 377 | cmc, mAP, mINP = eval_sysu(-distmat, query_label, gall_label, query_cam, gall_cam) 378 | if trial == 0: 379 | all_cmc = cmc 380 | all_mAP = mAP 381 | all_mINP = mINP 382 | all_cmc_pool = cmc_pool 383 | all_mAP_pool = mAP_pool 384 | all_mINP_pool = mINP_pool 385 | else: 386 | all_cmc = all_cmc + cmc 387 | all_mAP = all_mAP + mAP 388 | all_mINP = all_mINP + mINP 389 | all_cmc_pool = all_cmc_pool + cmc_pool 390 | all_mAP_pool = all_mAP_pool + mAP_pool 391 | all_mINP_pool = all_mINP_pool + mINP_pool 392 | 393 | print('Test Trial: {}'.format(trial)) 394 | print( 395 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 396 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 397 | print( 398 | 'POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 399 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) 400 | 401 | 402 | elif dataset == 'regdb': 403 | 404 | for trial in range(10): 405 | test_trial = trial + 1 406 | model_path = ''.format( 407 | test_trial) 408 | if os.path.isfile(model_path): 409 | print('==> loading checkpoint {}'.format(model_path)) 410 | checkpoint = torch.load(model_path) 411 | net.load_state_dict(checkpoint['net']) 412 | 413 | 414 | # testing set 415 | query_img, query_label = process_test_regdb(data_path, trial=test_trial, modal='visible') 416 | gall_img, gall_label = process_test_regdb(data_path, trial=test_trial, modal='thermal') 417 | 418 | gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 419 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 420 | 421 | nquery = len(query_label) 422 | ngall = len(gall_label) 423 | 424 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 425 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=4) 426 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 427 | 428 | query_feat_pool, query_feat_fc = extract_query_feat(query_loader) 429 | gall_feat_pool, gall_feat_fc = extract_gall_feat(gall_loader) 430 | 431 | if args.tvsearch: 432 | # pool5 feature 433 | distmat_pool = np.matmul(gall_feat_pool, np.transpose(query_feat_pool)) 434 | cmc_pool, mAP_pool, mINP_pool = eval_regdb(-distmat_pool, gall_label, query_label) 435 | 436 | # fc feature 437 | distmat = np.matmul(gall_feat_fc, np.transpose(query_feat_fc)) 438 | cmc, mAP, mINP = eval_regdb(-distmat, gall_label, query_label) 439 | else: 440 | # pool5 feature 441 | distmat_pool = np.matmul(query_feat_pool, np.transpose(gall_feat_pool)) 442 | cmc_pool, mAP_pool, mINP_pool = eval_regdb(-distmat_pool, query_label, gall_label) 443 | 444 | # fc feature 445 | distmat = np.matmul(query_feat_fc, np.transpose(gall_feat_fc)) 446 | cmc, mAP, mINP = eval_regdb(-distmat, query_label, gall_label) 447 | 448 | if trial == 0: 449 | all_cmc = cmc 450 | all_mAP = mAP 451 | all_mINP = mINP 452 | all_cmc_pool = cmc_pool 453 | all_mAP_pool = mAP_pool 454 | all_mINP_pool = mINP_pool 455 | else: 456 | all_cmc = all_cmc + cmc 457 | all_mAP = all_mAP + mAP 458 | all_mINP = all_mINP + mINP 459 | all_cmc_pool = all_cmc_pool + cmc_pool 460 | all_mAP_pool = all_mAP_pool + mAP_pool 461 | all_mINP_pool = all_mINP_pool + mINP_pool 462 | 463 | print('Test Trial: {}'.format(trial)) 464 | print( 465 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 466 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 467 | print( 468 | 'POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 469 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) 470 | 471 | cmc = all_cmc / 10 472 | mAP = all_mAP / 10 473 | mINP = all_mINP / 10 474 | 475 | cmc_pool = all_cmc_pool / 10 476 | mAP_pool = all_mAP_pool / 10 477 | mINP_pool = all_mINP_pool / 10 478 | print('All Average:') 479 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 480 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 481 | print('POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 482 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) 483 | -------------------------------------------------------------------------------- /train_hos_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | ''' 4 | @Project :RemoteHOS 5 | @File :train_hos_net.py 6 | @Author :yauloucoeng 7 | @Date :2024/4/28 21:36 8 | ''' 9 | 10 | import os 11 | 12 | # step 1 train sle 13 | cmd = "python train_sle.py --dataset sysu --gpu 0" 14 | os.system(cmd) 15 | 16 | # step 2 train sle+hsl 17 | cmd = "python train_sle_hsl.py --dataset sysu --gpu 0" 18 | os.system(cmd) 19 | 20 | # step 3 train sle+hsl+cfl (final model, hos) 21 | cmd = "python train_sle_hsl_cfl.py --dataset sysu --gpu 0" 22 | os.system(cmd) 23 | 24 | # step 4 test sle+hsl+cfl 25 | cmd = "python test_sle_hsl_cfl.py --dataset sysu --gpu 0" 26 | os.system(cmd) 27 | 28 | -------------------------------------------------------------------------------- /train_sle.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import time 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.backends.cudnn as cudnn 7 | from torch.autograd import Variable 8 | import torch.utils.data as data 9 | import torchvision.transforms as transforms 10 | from data_loader import SYSUData, RegDBData, LLCMData, TestData 11 | from data_manager import * 12 | from eval_metrics import eval_sysu, eval_regdb, eval_llcm 13 | from model_sle import embed_net as basline_pcb 14 | from utils import * 15 | from loss import OriTripletLoss, KLDivLoss 16 | from tensorboardX import SummaryWriter 17 | from ChannelAug import ChannelAdapGray, ChannelRandomErasing 18 | 19 | parser = argparse.ArgumentParser(description='PyTorch Cross-Modality Training') 20 | parser.add_argument('--dataset', default='sysu', help='dataset name: regdb or sysu]') 21 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate, 0.00035 for adam') 22 | parser.add_argument('--optim', default='sgd', type=str, help='optimizer') 23 | parser.add_argument('--arch', default='resnet50', type=str, 24 | help='network baseline:resnet18 or resnet50') 25 | parser.add_argument('--resume', '-r', default=1, type=int, 26 | help='resume from checkpoint') 27 | parser.add_argument('--test-only', action='store_true', help='test only') 28 | parser.add_argument('--model_path', default='save_model/', type=str, 29 | help='model save path') 30 | parser.add_argument('--save_epoch', default=20, type=int, 31 | metavar='s', help='save model every 10 epochs') 32 | parser.add_argument('--log_path', default='log/', type=str, 33 | help='log save path') 34 | parser.add_argument('--vis_log_path', default='log/vis_log/', type=str, 35 | help='log save path') 36 | parser.add_argument('--workers', default=4, type=int, metavar='N', 37 | help='number of data loading workers (default: 4)') 38 | parser.add_argument('--img_w', default=144, type=int, 39 | metavar='imgw', help='img width') 40 | parser.add_argument('--img_h', default=288, type=int, 41 | metavar='imgh', help='img height') 42 | parser.add_argument('--batch-size', default=8, type=int, 43 | metavar='B', help='training batch size') 44 | parser.add_argument('--test-batch', default=64, type=int, 45 | metavar='tb', help='testing batch size') 46 | parser.add_argument('--method', default='adp', type=str, 47 | metavar='m', help='method type: base or agw, adp') 48 | parser.add_argument('--margin', default=0.3, type=float, 49 | metavar='margin', help='triplet loss margin') 50 | parser.add_argument('--num_pos', default=4, type=int, 51 | help='num of pos per identity in each modality') 52 | parser.add_argument('--trial', default=1, type=int, 53 | metavar='t', help='trial (only for RegDB dataset)') 54 | parser.add_argument('--seed', default=0, type=int, 55 | metavar='t', help='random seed') 56 | parser.add_argument('--gpu', default='0', type=str, 57 | help='gpu device ids for CUDA_VISIBLE_DEVICES') 58 | parser.add_argument('--mode', default='all', type=str, help='all or indoor') 59 | parser.add_argument('--augc', default=1, type=int, 60 | metavar='aug', help='use channel aug or not') 61 | parser.add_argument('--rande', default=0.5, type=float, 62 | metavar='ra', help='use random erasing or not and the probability') 63 | parser.add_argument('--kl', default=0, type=float, 64 | metavar='kl', help='use kl loss and the weight') 65 | parser.add_argument('--alpha', default=1, type=int, 66 | metavar='alpha', help='magnification for the hard mining') 67 | parser.add_argument('--gamma', default=1, type=int, 68 | metavar='gamma', help='gamma for the hard mining') 69 | parser.add_argument('--square', default=1, type=int, 70 | metavar='square', help='gamma for the hard mining') 71 | parser.add_argument('--otri', default=1, type=int, 72 | metavar='otri', help='otri for the hard mining') 73 | parser.add_argument('--pl', default=1, type=int, metavar='pl', help='pl for the model') 74 | parser.add_argument('--stage', default=23, type=int, metavar='stage', help='stage for the model') 75 | parser.add_argument('--depth', default=2, type=int, metavar='depth', help='depth for the model') 76 | parser.add_argument('--head', default=4, type=int, metavar='head', help='head for the model') 77 | parser.add_argument('--goballoss', default=1.0, type=float, metavar='gid', help='use global loss and the weight') 78 | parser.add_argument('--localloss', default=1.0, type=float, metavar='lid', help='use local loss and the weight') 79 | parser.add_argument('--gobaltri', default=1.0, type=float, metavar='gtri', help='use global tri loss and the weight') 80 | parser.add_argument('--localtri', default=0.0, type=float, metavar='ltri', help='use local tri loss and the weight') 81 | parser.add_argument('--pha', default=1.0, type=float, metavar='pha', help='use pha') 82 | args = parser.parse_args() 83 | 84 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 85 | set_seed(args.seed) 86 | dataset = args.dataset 87 | 88 | base_path = "./sle_ckpt/" 89 | 90 | if dataset == 'sysu': 91 | # path to your sysu-mm01 dataset 92 | data_path = '/SSD_dataset/CMReID/SYSU-MM01/ori_data/' 93 | log_path = base_path + dataset + "/" + args.model_path + '/' + args.log_path + 'sysu_log/' 94 | test_mode = [1, 2] # thermal to visible 95 | elif dataset == 'regdb': 96 | # path to your regdb dataset 97 | data_path = '/SSD_dataset/CMReID/RegDB/' 98 | log_path = base_path + dataset + "/" + args.model_path + '/' + args.log_path + 'regdb_log/' 99 | test_mode = [2, 1] # visible to thermal 100 | elif dataset == 'llcm': 101 | # path to your llcm dataset 102 | data_path = '/SSD_dataset/CMReID/LLCM/LLCM/' 103 | log_path = base_path + dataset + "/" + args.model_path + '/' + args.log_path + 'llcm_log/' 104 | test_mode = [2, 1] # [2, 1]: VIS to IR; [1, 2]: IR to VIS 105 | 106 | checkpoint_path = base_path + dataset + "/" + args.model_path + '/' 107 | 108 | if not os.path.isdir(log_path): 109 | os.makedirs(log_path) 110 | if not os.path.isdir(checkpoint_path): 111 | os.makedirs(checkpoint_path) 112 | if not os.path.isdir(args.vis_log_path): 113 | os.makedirs(args.vis_log_path) 114 | 115 | suffix = dataset 116 | if args.method == 'adp': 117 | suffix = suffix + '_{}_joint_co_nog_ch_nog_sq{}'.format(args.method, args.square) 118 | else: 119 | suffix = suffix + '_{}'.format(args.method) 120 | # suffix = suffix + '_KL_{}'.format(args.kl) 121 | if args.augc == 1: 122 | suffix = suffix + '_aug_G' 123 | if args.rande > 0: 124 | suffix = suffix + '_erase_{}'.format(args.rande) 125 | 126 | suffix = suffix + '_p{}_n{}_lr_{}_seed_{}_localtri_{}_otri_{}_stage_{}_depth_{}_head_{}_pha_{}'.format( 127 | args.num_pos, args.batch_size, 128 | args.lr, args.seed, args.localtri, args.otri, args.stage, args.depth, args.head, args.pha) 129 | 130 | if not args.optim == 'sgd': 131 | suffix = suffix + '_' + args.optim 132 | 133 | if dataset == 'regdb': 134 | suffix = suffix + '_trial_{}'.format(args.trial) 135 | 136 | sys.stdout = Logger(log_path + suffix + '_os.txt') 137 | 138 | vis_log_dir = args.vis_log_path + suffix + '/' 139 | 140 | if not os.path.isdir(vis_log_dir): 141 | os.makedirs(vis_log_dir) 142 | writer = SummaryWriter(vis_log_dir) 143 | print("==========\nArgs:{}\n==========".format(args)) 144 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 145 | best_acc = 0 # best test accuracy 146 | start_epoch = 0 147 | 148 | print('==> Loading data..') 149 | # Data loading code 150 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 151 | transform_train_list = [ 152 | transforms.ToPILImage(), 153 | transforms.Pad(10), 154 | transforms.RandomCrop((args.img_h, args.img_w)), 155 | transforms.RandomHorizontalFlip(), 156 | transforms.ToTensor(), 157 | normalize] 158 | 159 | transform_test = transforms.Compose([ 160 | transforms.ToPILImage(), 161 | transforms.Resize((args.img_h, args.img_w)), 162 | transforms.ToTensor(), 163 | normalize]) 164 | 165 | if args.rande > 0: 166 | transform_train_list = transform_train_list + [ChannelRandomErasing(probability=args.rande)] 167 | 168 | if args.augc == 1: 169 | # transform_train_list = transform_train_list + [ChannelAdap(probability =0.5)] 170 | transform_train_list = transform_train_list + [ChannelAdapGray(probability=0.5)] 171 | 172 | transform_train = transforms.Compose(transform_train_list) 173 | 174 | end = time.time() 175 | if dataset == 'sysu': 176 | # training set 177 | trainset = SYSUData(data_path, transform=transform_train) 178 | # generate the idx of each person identity 179 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 180 | 181 | # testing set 182 | query_img, query_label, query_cam = process_query_sysu(data_path, mode=args.mode) 183 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode, trial=0) 184 | 185 | elif dataset == 'regdb': 186 | # training set 187 | trainset = RegDBData(data_path, args.trial, transform=transform_train) 188 | # generate the idx of each person identity 189 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 190 | 191 | # testing set 192 | query_img, query_label = process_test_regdb(data_path, trial=args.trial, modal='visible') 193 | gall_img, gall_label = process_test_regdb(data_path, trial=args.trial, modal='thermal') 194 | 195 | elif dataset == 'llcm': 196 | # training set 197 | trainset = LLCMData(data_path, args.trial, transform=transform_train) 198 | # generate the idx of each person identity 199 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 200 | 201 | # testing set 202 | query_img, query_label, query_cam = process_query_llcm(data_path, mode=test_mode[1]) 203 | gall_img, gall_label, gall_cam = process_gallery_llcm(data_path, mode=test_mode[0], trial=0) 204 | 205 | gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 206 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 207 | 208 | # testing data loader 209 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 210 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 211 | 212 | n_class = len(np.unique(trainset.train_color_label)) 213 | nquery = len(query_label) 214 | ngall = len(gall_label) 215 | 216 | print('Dataset {} statistics:'.format(dataset)) 217 | print(' ------------------------------') 218 | print(' subset | # ids | # images') 219 | print(' ------------------------------') 220 | print(' visible | {:5d} | {:8d}'.format(n_class, len(trainset.train_color_label))) 221 | print(' thermal | {:5d} | {:8d}'.format(n_class, len(trainset.train_thermal_label))) 222 | print(' ------------------------------') 223 | print(' query | {:5d} | {:8d}'.format(len(np.unique(query_label)), nquery)) 224 | print(' gallery | {:5d} | {:8d}'.format(len(np.unique(gall_label)), ngall)) 225 | print(' ------------------------------') 226 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 227 | 228 | print('==> Building model..') 229 | 230 | net = basline_pcb(n_class, no_local='on', gm_pool='on', arch=args.arch, dataset=dataset, plearn=args.pl, 231 | stage=args.stage, depth=args.depth, head=args.head) 232 | 233 | net.to(device) 234 | cudnn.benchmark = True 235 | 236 | 237 | 238 | if dataset == 'sysu': 239 | model_path = "./baseline/sysu_adp_joint_co_nog_ch_nog_sq1_aug_G_erase_0.5_p4_n8_lr_0.1_seed_0_best.t" 240 | if args.dataset == 'regdb': 241 | model_path = './baseline/regdb_'.format(args.trial) 242 | if dataset == 'llcm': 243 | model_path = './baseline/llcm_' 244 | if os.path.isfile(model_path): 245 | print('==> loading checkpoint {}'.format(model_path)) 246 | checkpoint = torch.load(model_path) 247 | net.load_state_dict(checkpoint['net'], strict=False) 248 | print('==> loaded checkpoint {}' 249 | .format(model_path)) 250 | 251 | assert os.path.isfile(model_path) 252 | 253 | # define loss function 254 | criterion_id = nn.CrossEntropyLoss() 255 | if args.method == 'agw': 256 | pass 257 | elif args.method == 'adp': 258 | # criterion_tri = TripletLoss_ADP(alpha=args.alpha, gamma=args.gamma, square=args.square) 259 | # criterion_tri_l = TripletLoss_ADP(alpha=args.alpha, gamma=args.gamma, square=args.square) 260 | if args.otri == 1: 261 | loader_batch = args.batch_size * args.num_pos 262 | criterion_tri = OriTripletLoss(batch_size=loader_batch, margin=args.margin) 263 | criterion_tri_l = OriTripletLoss(batch_size=loader_batch, margin=args.margin) 264 | 265 | else: 266 | pass 267 | criterion_kl = KLDivLoss() 268 | criterion_id.to(device) 269 | criterion_tri_l.to(device) 270 | criterion_tri.to(device) 271 | criterion_kl.to(device) 272 | 273 | if args.optim == 'sgd': 274 | ignored_params = list(map(id, net.bottleneck.parameters())) \ 275 | + list(map(id, net.vit.parameters())) \ 276 | + list(map(id, net.classifier.parameters())) \ 277 | + list(map(id, net.local_conv_list.parameters())) \ 278 | + list(map(id, net.fc_list.parameters())) 279 | 280 | base_params = filter(lambda p: id(p) not in ignored_params, net.parameters()) 281 | 282 | optimizer = optim.SGD([ 283 | {'params': base_params, 'lr': 0.1 * args.lr}, 284 | {'params': net.bottleneck.parameters(), 'lr': args.lr}, 285 | {'params': net.vit.parameters(), 'lr': args.lr}, 286 | {'params': net.classifier.parameters(), 'lr': args.lr}, 287 | {'params': net.local_conv_list.parameters(), 'lr': args.lr}, 288 | {'params': net.fc_list.parameters(), 'lr': args.lr} 289 | ], 290 | weight_decay=5e-4, momentum=0.9, nesterov=True) 291 | 292 | 293 | # exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) 294 | def adjust_learning_rate(optimizer, epoch): 295 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 296 | if epoch < 10: 297 | lr = args.lr * (epoch + 1) / 10 298 | elif epoch >= 10 and epoch < 20: 299 | lr = args.lr 300 | elif epoch >= 20 and epoch < 50: 301 | lr = args.lr * 0.1 302 | elif epoch >= 50: 303 | lr = args.lr * 0.01 304 | 305 | optimizer.param_groups[0]['lr'] = 0.1 * lr 306 | for i in range(len(optimizer.param_groups) - 1): 307 | optimizer.param_groups[i + 1]['lr'] = lr 308 | 309 | return lr 310 | 311 | 312 | def train(epoch): 313 | current_lr = adjust_learning_rate(optimizer, epoch) 314 | train_loss = AverageMeter() 315 | id_loss_g = AverageMeter() 316 | id_loss_l = AverageMeter() 317 | tri_loss_g = AverageMeter() 318 | tri_loss_l = AverageMeter() 319 | data_time = AverageMeter() 320 | batch_time = AverageMeter() 321 | correct = 0 322 | total = 0 323 | 324 | # switch to train mode 325 | net.train() 326 | end = time.time() 327 | 328 | for batch_idx, (input10, input11, input2, label1, label2) in enumerate(trainloader): 329 | 330 | labels = torch.cat((label1, label1, label2), 0) 331 | labels_vit = torch.cat((label1, label2), 0) 332 | input2 = Variable(input2.cuda()) 333 | input10 = Variable(input10.cuda()) 334 | input11 = Variable(input11.cuda()) 335 | labels = Variable(labels.cuda()) 336 | labels_vit = Variable(labels_vit.cuda()) 337 | input1 = torch.cat((input10, input11,), 0) 338 | input2 = Variable(input2.cuda()) 339 | data_time.update(time.time() - end) 340 | 341 | featG, outG, feat, out0, feat_all = net(input1, input2) 342 | b_x = featG.shape[0] 343 | b_x = b_x // 5 344 | 345 | loss_id = args.pha * criterion_id(out0[0][0:b_x * 3], labels) + (2.0 - args.pha) * criterion_id( 346 | out0[0][b_x * 3:], labels_vit) 347 | for i in range(len(feat) - 1): 348 | loss_id_temp = args.pha * criterion_id(out0[i + 1][0:b_x * 3], labels) + (2.0 - args.pha) * criterion_id( 349 | out0[i + 1][b_x * 3:], labels_vit) 350 | loss_id += loss_id_temp 351 | 352 | loss_tri_l = args.pha * criterion_tri_l(feat_all[0:b_x * 3], labels)[0] + (2.0 - args.pha) * \ 353 | criterion_tri_l(feat_all[b_x * 3:], labels_vit)[0] 354 | 355 | loss_id_G = args.pha * criterion_id(outG[0:b_x * 3], labels) + (2.0 - args.pha) * criterion_id(outG[b_x * 3:], 356 | labels_vit) 357 | loss_tri, batch_acc = criterion_tri(featG[0:b_x * 3], labels) 358 | loss_tri = args.pha * loss_tri 359 | loss_tri += (2.0 - args.pha) * criterion_tri(featG[b_x * 3:], labels_vit)[0] 360 | correct += (batch_acc / 2) 361 | _, predicted = outG[0:b_x * 3].max(1) 362 | correct += (predicted.eq(labels).sum().item() / 2) 363 | 364 | correct += batch_acc 365 | loss_id = args.localloss * loss_id 366 | loss_id_G = args.goballoss * loss_id_G 367 | loss_tri = args.gobaltri * loss_tri 368 | loss_tri_l = args.localtri * loss_tri_l 369 | 370 | loss = loss_id + loss_id_G + loss_tri + loss_tri_l 371 | 372 | optimizer.zero_grad() 373 | loss.backward() 374 | optimizer.step() 375 | 376 | # update P 377 | train_loss.update(loss.item(), 2 * input1.size(0)) 378 | 379 | id_loss_l.update(loss_id.item(), 2 * input1.size(0)) 380 | id_loss_g.update(loss_id_G.item(), 2 * input1.size(0)) 381 | 382 | tri_loss_g.update(loss_tri.item(), 2 * input1.size(0)) 383 | tri_loss_l.update(loss_tri_l.item(), 2 * input1.size(0)) 384 | 385 | total += labels.size(0) 386 | 387 | # measure elapsed time 388 | batch_time.update(time.time() - end) 389 | end = time.time() 390 | if batch_idx % 50 == 0: 391 | print('Epoch: [{}][{}/{}] ' 392 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 393 | 'lr:{:.5f} ' 394 | 'Loss: {train_loss.val:.4f} ({train_loss.avg:.4f}) ' 395 | 'G-Loss: {id_loss_g.val:.4f} ({id_loss_g.avg:.4f}) ' 396 | 'L-Loss: {id_loss_l.val:.4f} ({id_loss_l.avg:.4f}) ' 397 | 'G-TLoss: {tri_loss_g.val:.4f} ({tri_loss_g.avg:.4f}) ' 398 | 'L-TLoss: {tri_loss_l.val:.4f} ({tri_loss_l.avg:.4f}) ' 399 | 'Accu: {:.2f}'.format( 400 | epoch, batch_idx, len(trainloader), 401 | current_lr, 402 | 100. * correct / total, batch_time=batch_time, 403 | train_loss=train_loss, 404 | id_loss_g=id_loss_g, 405 | id_loss_l=id_loss_l, 406 | tri_loss_g=tri_loss_g, 407 | tri_loss_l=tri_loss_l)) 408 | 409 | writer.add_scalar('total_loss', train_loss.avg, epoch) 410 | writer.add_scalar('id_loss_l', id_loss_l.avg, epoch) 411 | writer.add_scalar('id_loss_g', id_loss_g.avg, epoch) 412 | writer.add_scalar('tri_loss_g', tri_loss_g.avg, epoch) 413 | writer.add_scalar('tri_loss_l', tri_loss_l.avg, epoch) 414 | writer.add_scalar('lr', current_lr, epoch) 415 | 416 | 417 | def test(epoch): 418 | # switch to evaluation mode 419 | net.eval() 420 | print('Extracting Gallery Feature...') 421 | start = time.time() 422 | ptr = 0 423 | gall_feat = np.zeros((ngall, 3584)) 424 | gall_feat_att = np.zeros((ngall, 3584)) 425 | with torch.no_grad(): 426 | for batch_idx, (input, label) in enumerate(gall_loader): 427 | batch_num = input.size(0) 428 | input = Variable(input.cuda()) 429 | feat, feat_att = net(input, input, test_mode[0]) 430 | gall_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 431 | gall_feat_att[ptr:ptr + batch_num, :] = feat_att.detach().cpu().numpy() 432 | ptr = ptr + batch_num 433 | print('Extracting Time:\t {:.3f}'.format(time.time() - start)) 434 | 435 | # switch to evaluation 436 | net.eval() 437 | print('Extracting Query Feature...') 438 | start = time.time() 439 | ptr = 0 440 | query_feat = np.zeros((nquery, 3584)) 441 | query_feat_att = np.zeros((nquery, 3584)) 442 | with torch.no_grad(): 443 | for batch_idx, (input, label) in enumerate(query_loader): 444 | batch_num = input.size(0) 445 | input = Variable(input.cuda()) 446 | feat, feat_att = net(input, input, test_mode[1]) 447 | query_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 448 | query_feat_att[ptr:ptr + batch_num, :] = feat_att.detach().cpu().numpy() 449 | ptr = ptr + batch_num 450 | print('Extracting Time:\t {:.3f}'.format(time.time() - start)) 451 | 452 | start = time.time() 453 | # compute the similarity 454 | distmat = np.matmul(query_feat, np.transpose(gall_feat)) 455 | distmat_att = np.matmul(query_feat_att, np.transpose(gall_feat_att)) 456 | 457 | # evaluation 458 | if dataset == 'regdb': 459 | cmc, mAP, mINP = eval_regdb(-distmat, query_label, gall_label) 460 | cmc_att, mAP_att, mINP_att = eval_regdb(-distmat_att, query_label, gall_label) 461 | elif dataset == 'sysu': 462 | cmc, mAP, mINP = eval_sysu(-distmat, query_label, gall_label, query_cam, gall_cam) 463 | cmc_att, mAP_att, mINP_att = eval_sysu(-distmat_att, query_label, gall_label, query_cam, gall_cam) 464 | elif dataset == 'llcm': 465 | cmc, mAP, mINP = eval_llcm(-distmat, query_label, gall_label, query_cam, gall_cam) 466 | cmc_att, mAP_att, mINP_att = eval_llcm(-distmat_att, query_label, gall_label, query_cam, gall_cam) 467 | 468 | print('Evaluation Time:\t {:.3f}'.format(time.time() - start)) 469 | 470 | writer.add_scalar('rank1', cmc[0], epoch) 471 | writer.add_scalar('mAP', mAP, epoch) 472 | writer.add_scalar('mINP', mINP, epoch) 473 | writer.add_scalar('rank1_att', cmc_att[0], epoch) 474 | writer.add_scalar('mAP_att', mAP_att, epoch) 475 | writer.add_scalar('mINP_att', mINP_att, epoch) 476 | return cmc, mAP, mINP, cmc_att, mAP_att, mINP_att 477 | 478 | 479 | # training 480 | print('==> Start Training...') 481 | for epoch in range(start_epoch, 60): 482 | 483 | print('==> Preparing Data Loader...') 484 | # identity sampler 485 | sampler = IdentitySampler(trainset.train_color_label, \ 486 | trainset.train_thermal_label, color_pos, thermal_pos, args.num_pos, args.batch_size, 487 | epoch) 488 | 489 | trainset.cIndex = sampler.index1 # color index 490 | trainset.tIndex = sampler.index2 # thermal index 491 | print(epoch) 492 | print(trainset.cIndex) 493 | print(trainset.tIndex) 494 | 495 | loader_batch = args.batch_size * args.num_pos 496 | 497 | trainloader = data.DataLoader(trainset, batch_size=loader_batch, \ 498 | sampler=sampler, num_workers=args.workers, drop_last=True) 499 | 500 | # training 501 | train(epoch) 502 | 503 | if epoch >= 0: 504 | print('Test Epoch: {}'.format(epoch)) 505 | 506 | # testing 507 | cmc, mAP, mINP, cmc_att, mAP_att, mINP_att = test(epoch) 508 | # save model 509 | if cmc_att[0] + mAP_att > best_acc: # not the real best for sysu-mm01 510 | best_acc = cmc_att[0] + mAP_att 511 | best_epoch = epoch 512 | state = { 513 | 'net': net.state_dict(), 514 | 'cmc': cmc_att, 515 | 'mAP': mAP_att, 516 | 'mINP': mINP_att, 517 | 'epoch': epoch, 518 | } 519 | torch.save(state, checkpoint_path + suffix + '_best.t') 520 | 521 | print( 522 | 'POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 523 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 524 | print( 525 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 526 | cmc_att[0], cmc_att[4], cmc_att[9], cmc_att[19], mAP_att, mINP_att)) 527 | print('Best Epoch [{}]'.format(best_epoch)) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data.sampler import Sampler 4 | import sys 5 | import os.path as osp 6 | import torch 7 | import random 8 | def load_data(input_data_path ): 9 | with open(input_data_path) as f: 10 | data_file_list = open(input_data_path, 'rt').read().splitlines() 11 | # Get full list of color image and labels 12 | file_image = [s.split(' ')[0] for s in data_file_list] 13 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 14 | 15 | return file_image, file_label 16 | 17 | 18 | def GenIdx( train_color_label, train_thermal_label): 19 | color_pos = [] 20 | unique_label_color = np.unique(train_color_label) 21 | for i in range(len(unique_label_color)): 22 | tmp_pos = [k for k,v in enumerate(train_color_label) if v==unique_label_color[i]] 23 | color_pos.append(tmp_pos) 24 | 25 | thermal_pos = [] 26 | unique_label_thermal = np.unique(train_thermal_label) 27 | for i in range(len(unique_label_thermal)): 28 | tmp_pos = [k for k,v in enumerate(train_thermal_label) if v==unique_label_thermal[i]] 29 | thermal_pos.append(tmp_pos) 30 | return color_pos, thermal_pos 31 | 32 | def GenCamIdx(gall_img, gall_label, mode): 33 | if mode =='indoor': 34 | camIdx = [1,2] 35 | else: 36 | camIdx = [1,2,4,5] 37 | gall_cam = [] 38 | for i in range(len(gall_img)): 39 | gall_cam.append(int(gall_img[i][-10])) 40 | 41 | sample_pos = [] 42 | unique_label = np.unique(gall_label) 43 | for i in range(len(unique_label)): 44 | for j in range(len(camIdx)): 45 | id_pos = [k for k,v in enumerate(gall_label) if v==unique_label[i] and gall_cam[k]==camIdx[j]] 46 | if id_pos: 47 | sample_pos.append(id_pos) 48 | return sample_pos 49 | 50 | def ExtractCam(gall_img): 51 | gall_cam = [] 52 | for i in range(len(gall_img)): 53 | cam_id = int(gall_img[i][-10]) 54 | # if cam_id ==3: 55 | # cam_id = 2 56 | gall_cam.append(cam_id) 57 | 58 | return np.array(gall_cam) 59 | 60 | class IdentitySampler(Sampler): 61 | """Sample person identities evenly in each batch. 62 | Args: 63 | train_color_label, train_thermal_label: labels of two modalities 64 | color_pos, thermal_pos: positions of each identity 65 | batchSize: batch size 66 | """ 67 | 68 | def __init__(self, train_color_label, train_thermal_label, color_pos, thermal_pos, num_pos, batchSize, epoch): 69 | uni_label = np.unique(train_color_label) 70 | self.n_classes = len(uni_label) 71 | 72 | 73 | N = np.maximum(len(train_color_label), len(train_thermal_label)) 74 | for j in range(int(N/(batchSize*num_pos))+1): 75 | batch_idx = np.random.choice(uni_label, batchSize, replace = False) 76 | for i in range(batchSize): 77 | sample_color = np.random.choice(color_pos[batch_idx[i]], num_pos) 78 | sample_thermal = np.random.choice(thermal_pos[batch_idx[i]], num_pos) 79 | 80 | if j ==0 and i==0: 81 | index1= sample_color 82 | index2= sample_thermal 83 | else: 84 | index1 = np.hstack((index1, sample_color)) 85 | index2 = np.hstack((index2, sample_thermal)) 86 | 87 | self.index1 = index1 88 | self.index2 = index2 89 | self.N = N 90 | 91 | def __iter__(self): 92 | return iter(np.arange(len(self.index1))) 93 | 94 | def __len__(self): 95 | return self.N 96 | 97 | class AverageMeter(object): 98 | """Computes and stores the average and current value""" 99 | def __init__(self): 100 | self.reset() 101 | 102 | def reset(self): 103 | self.val = 0 104 | self.avg = 0 105 | self.sum = 0 106 | self.count = 0 107 | 108 | def update(self, val, n=1): 109 | self.val = val 110 | self.sum += val * n 111 | self.count += n 112 | self.avg = self.sum / self.count 113 | 114 | def mkdir_if_missing(directory): 115 | if not osp.exists(directory): 116 | try: 117 | os.makedirs(directory) 118 | except OSError as e: 119 | if e.errno != errno.EEXIST: 120 | raise 121 | class Logger(object): 122 | """ 123 | Write console output to external text file. 124 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 125 | """ 126 | def __init__(self, fpath=None): 127 | self.console = sys.stdout 128 | self.file = None 129 | if fpath is not None: 130 | mkdir_if_missing(osp.dirname(fpath)) 131 | self.file = open(fpath, 'w') 132 | 133 | def __del__(self): 134 | self.close() 135 | 136 | def __enter__(self): 137 | pass 138 | 139 | def __exit__(self, *args): 140 | self.close() 141 | 142 | def write(self, msg): 143 | self.console.write(msg) 144 | if self.file is not None: 145 | self.file.write(msg) 146 | 147 | def flush(self): 148 | self.console.flush() 149 | if self.file is not None: 150 | self.file.flush() 151 | os.fsync(self.file.fileno()) 152 | 153 | def close(self): 154 | self.console.close() 155 | if self.file is not None: 156 | self.file.close() 157 | 158 | def set_seed(seed, cuda=True): 159 | np.random.seed(seed) 160 | torch.manual_seed(seed) 161 | if cuda: 162 | torch.cuda.manual_seed(seed) 163 | 164 | def set_requires_grad(nets, requires_grad=False): 165 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 166 | Parameters: 167 | nets (network list) -- a list of networks 168 | requires_grad (bool) -- whether the networks require gradients or not 169 | """ 170 | if not isinstance(nets, list): 171 | nets = [nets] 172 | for net in nets: 173 | if net is not None: 174 | for param in net.parameters(): 175 | param.requires_grad = requires_grad -------------------------------------------------------------------------------- /whitening.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | # References:https://github.com/roysubhankar/dwt-domain-adaptation 4 | # -------------------------------------------------------- 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn.functional import conv2d 8 | 9 | class _Whitening(nn.Module): 10 | 11 | def __init__(self, num_features, group_size, momentum=0.1, eps=1e-3, alpha=1): 12 | super(_Whitening, self).__init__() 13 | self.num_features = num_features 14 | self.momentum = momentum 15 | self.eps = eps 16 | self.alpha = alpha 17 | self.group_size = min(self.num_features, group_size) 18 | self.num_groups = self.num_features // self.group_size 19 | self.register_buffer('running_mean', torch.zeros([1, self.num_features, 1, 1], out=torch.cuda.FloatTensor() if torch.cuda.is_available() else torch.FloatTensor())) 20 | self.register_buffer('running_variance', torch.ones([self.num_groups, self.group_size, self.group_size], out=torch.cuda.FloatTensor() if torch.cuda.is_available() else torch.FloatTensor())) 21 | # self.register_buffer('running_mean', torch.zeros([1, self.num_features, 1, 1], out=torch.cuda.FloatTensor() if torch.cuda.is_available() else torch.FloatTensor())) 22 | # self.register_buffer('running_variance', torch.ones([self.num_groups, self.group_size, self.group_size], out=torch.cuda.FloatTensor() if torch.cuda.is_available() else torch.FloatTensor())) 23 | """ 24 | if self.track_running_stats: 25 | self.register_buffer('running_mean', torch.zeros([1, self.num_features, 1, 1], out=torch.cuda.FloatTensor() if torch.cuda.is_available() else torch.FloatTensor())) 26 | self.register_buffer('running_variance', torch.zeros([self.num_groups, self.group_size, self.group_size], out=torch.cuda.FloatTensor() if torch.cuda.is_available() else torch.FloatTensor())) 27 | """ 28 | 29 | def _check_input_dim(self, input): 30 | raise NotImplementedError 31 | 32 | def _check_group_size(self): 33 | raise NotImplementedError 34 | 35 | def forward(self, x): 36 | self._check_input_dim(x) 37 | self._check_group_size() 38 | 39 | m = x.mean(0).view(self.num_features, -1).mean(-1).view(1, -1, 1, 1) 40 | if not self.training: # for inference 41 | m = self.running_mean 42 | xn = x - m 43 | 44 | T = xn.permute(1,0,2,3).contiguous().view(self.num_groups, self.group_size,-1) 45 | f_cov = torch.bmm(T, T.permute(0,2,1)) / T.shape[-1] 46 | f_cov_shrinked = (1-self.eps) * f_cov + self.eps * torch.eye(self.group_size, out=torch.cuda.FloatTensor() if torch.cuda.is_available() else torch.FloatTensor()).repeat(self.num_groups, 1, 1) 47 | 48 | if not self.training: # for inference 49 | f_cov_shrinked = (1-self.eps) * self.running_variance + self.eps * torch.eye(self.group_size, out=torch.cuda.FloatTensor() if torch.cuda.is_available() else torch.FloatTensor()).repeat(self.num_groups, 1, 1) 50 | inv_sqrt = torch.inverse(torch.cholesky(f_cov_shrinked)).contiguous().view(self.num_features, self.group_size, 1, 1) 51 | 52 | decorrelated = conv2d(xn, inv_sqrt, groups = self.num_groups) 53 | 54 | if self.training: 55 | self.running_mean = torch.add(self.momentum * m.detach(), (1 - self.momentum) * self.running_mean, out=self.running_mean) 56 | self.running_variance = torch.add(self.momentum * f_cov.detach(), (1 - self.momentum) * self.running_variance, out=self.running_variance) 57 | 58 | return decorrelated 59 | 60 | class WTransform2d(_Whitening): 61 | def _check_input_dim(self, input): 62 | if input.dim() != 4: 63 | raise ValueError('expected 4D input (got {}D input)'. format(input.dim())) 64 | 65 | def _check_group_size(self): 66 | if self.num_features % self.group_size != 0: 67 | raise ValueError('expected number of channels divisible by group_size (got {} group_size\ 68 | for {} number of features'.format(self.group_size, self.num_features)) 69 | --------------------------------------------------------------------------------