├── README.md ├── VD_project.py ├── data ├── __pycache__ │ ├── dataloader.cpython-36.pyc │ ├── dataloader.cpython-37.pyc │ ├── dataloader.cpython-38.pyc │ ├── dataset.cpython-36.pyc │ ├── dataset.cpython-37.pyc │ ├── dataset.cpython-38.pyc │ ├── random_erasing.cpython-36.pyc │ ├── random_erasing.cpython-37.pyc │ └── random_erasing.cpython-38.pyc ├── dataloader.py └── dataset.py ├── loss ├── Id_loss.py ├── RankingLoss.py ├── __pycache__ │ ├── Id_loss.cpython-36.pyc │ ├── Id_loss.cpython-37.pyc │ ├── Id_loss.cpython-38.pyc │ ├── RankingLoss.cpython-36.pyc │ ├── RankingLoss.cpython-37.pyc │ ├── RankingLoss.cpython-38.pyc │ └── loss.cpython-36.pyc └── loss.py ├── loss_TransREID ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── arcface.cpython-38.pyc │ ├── center_loss.cpython-38.pyc │ ├── make_loss.cpython-38.pyc │ ├── metric_learning.cpython-38.pyc │ ├── softmax_loss.cpython-38.pyc │ └── triplet_loss.cpython-38.pyc ├── arcface.py ├── center_loss.py ├── make_loss.py ├── metric_learning.py ├── softmax_loss.py └── triplet_loss.py ├── model ├── DETR_model.py ├── __pycache__ │ ├── DETR_model.cpython-37.pyc │ ├── DETR_model.cpython-38.pyc │ ├── model.cpython-36.pyc │ ├── model.cpython-37.pyc │ ├── model.cpython-38.pyc │ ├── text_feature_extract.cpython-36.pyc │ ├── text_feature_extract.cpython-37.pyc │ └── text_feature_extract.cpython-38.pyc ├── model.py └── text_feature_extract.py ├── model_TransREID ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── make_model.cpython-38.pyc ├── backbones │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── resnet.cpython-38.pyc │ │ └── vit_pytorch.cpython-38.pyc │ ├── resnet.py │ └── vit_pytorch.py └── make_model.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── transformer.cpython-37.pyc │ └── transformer.cpython-38.pyc ├── backbone.py ├── matcher.py ├── position_encoding.py ├── segmentation.py └── transformer.py ├── option ├── __pycache__ │ ├── options.cpython-36.pyc │ ├── options.cpython-37.pyc │ └── options.cpython-38.pyc └── options.py ├── processed_data_singledata_CUHK.py ├── processed_data_singledata_ICFG.py ├── random_erasing.py ├── read_json.py ├── reidtools.py ├── test_ICFG_my.py ├── test_during_train.py ├── train_mydecoder_pixelvit_txtimg_3_bert.py ├── utils ├── __pycache__ │ ├── random_erasing.cpython-38.pyc │ ├── read_write_data.cpython-36.pyc │ ├── read_write_data.cpython-37.pyc │ └── read_write_data.cpython-38.pyc └── read_write_data.py ├── utils_RVN ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── directory.cpython-38.pyc │ └── metric.cpython-38.pyc ├── directory.py ├── metric.py └── visualize.py ├── vit_pytorch ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── vit_pytorch.cpython-37.pyc │ └── vit_pytorch.cpython-38.pyc ├── distill.py ├── efficient.py ├── mpp.py ├── t2t.py ├── train_2module_guide.py └── vit_pytorch.py └── vit_pytorch_TransREID ├── __init__.py ├── distill.py ├── efficient.py ├── mpp.py ├── t2t.py └── vit_pytorch.py /README.md: -------------------------------------------------------------------------------- 1 | # Learning Granularity-Unified Representations for Text-to-Image Person Re-identification 2 | 3 | This is the codebase for our [ACM MM 2022 paper](https://arxiv.org/abs/2207.07802). 4 | ```bash 5 | datasets 6 | └── cuhkpedes 7 | ├── captions.json 8 | └── imgs 9 | ├── cam_a 10 | ├── cam_b 11 | ├── CUHK01 12 | ├── CUHK03 13 | ├── Market 14 | ├── test_query 15 | └── train_query 16 | └──icfgpedes 17 | ├── ICFG-PEDES.json 18 | └── ICFG_PEDES 19 | ├── test 20 | └── train 21 | 22 | ``` 23 | 24 | ### Download DeiT-small weights 25 | ```bash 26 | wget https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth 27 | ``` 28 | ### Process image and text datasets 29 | ```bash 30 | python processed_data_singledata_CUHK.py 31 | python processed_data_singledata_ICFG.py 32 | ``` 33 | 34 | 35 | ### Train 36 | ```bash 37 | python train_mydecoder_pixelvit_txtimg_3_bert.py 38 | ``` 39 | 40 | ## Citation 41 | If you find this project useful for your research, please use the following BibTeX entry. 42 | ``` 43 | @inproceedings{shao2022learning, 44 | title={Learning Granularity-Unified Representations for Text-to-Image Person Re-identification}, 45 | author={Shao, Zhiyin and Zhang, Xinyu and Fang, Meng and Lin, Zhifeng and Wang, Jian and Ding, Changxing}, 46 | booktitle={Proceedings of the 30th ACM International Conference on Multimedia}, 47 | year={2022} 48 | } 49 | ``` 50 | -------------------------------------------------------------------------------- /VD_project.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | # from ..utils import concat_all_gather 5 | import torch.distributed as dist 6 | # def concat_all_gather(tensor): 7 | # """ 8 | # Performs all_gather operation on the provided tensors. 9 | # *** Warning ***: torch.distributed.all_gather has no gradient. 10 | # """ 11 | # tensors_gather = [ 12 | # torch.ones_like(tensor) 13 | # for _ in range(torch.distributed.get_world_size()) 14 | # ] 15 | # torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 16 | # 17 | # output = torch.cat(tensors_gather, dim=0) 18 | # return output 19 | def ema_inplace(moving_avg, new, decay): 20 | moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay)) 21 | 22 | def ema_tensor_inplace(moving_avg, new, decay): 23 | new_out = torch.mul(new,1.0-decay) 24 | moving_avg.data.mul_(decay).add_(new_out.detach()) 25 | 26 | def sum_inplace(sum_data,new): 27 | sum_data.data.add_(new) 28 | 29 | def laplace_smoothing(x, n_categories, eps=1e-5): 30 | return (x + eps) / (x.sum() + n_categories * eps) 31 | 32 | def laplace_smoothing_dim(x, n_categories,dim=1, eps=1e-5): 33 | return (x + eps) / (x.sum(dim=dim,keepdim=True) + n_categories * eps) 34 | 35 | class SOHO_Pre_VD(nn.Module): 36 | def __init__(self,num_tokens,token_dim,decay=0.1,max_decay=0.99,eps=1e-5): 37 | super(SOHO_Pre_VD, self).__init__() 38 | self.token_dim = token_dim 39 | self.num_tokens = num_tokens 40 | embed = torch.randn(num_tokens, token_dim) 41 | self.register_buffer('embed', embed) 42 | nn.init.normal_(self.embed) 43 | self.register_buffer('cluster_size', torch.zeros(num_tokens)) 44 | self.register_buffer('cluster_sum', torch.zeros(num_tokens)) 45 | self.register_buffer('embed_avg', torch.zeros(num_tokens,token_dim)) 46 | 47 | self.decay = decay 48 | self.eps = eps 49 | self.curr_decay=self.decay 50 | self.max_decay=max_decay 51 | 52 | 53 | def set_decay_updates(self,num_update): 54 | self.curr_decay=min(self.decay*num_update,self.max_decay) 55 | 56 | def forward(self,inputs_flatten): 57 | 58 | distances = (torch.sum(inputs_flatten ** 2, dim=1, keepdim=True) 59 | + torch.sum(self.embed.data ** 2, dim=1) 60 | - 2 * torch.matmul(inputs_flatten, self.embed.data.t())) 61 | 62 | """ 63 | encoding_indices: Tensor containing the discrete encoding indices, ie 64 | which element of the quantized space each input element was mapped to. 65 | """ 66 | 67 | 68 | encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) 69 | encodings = torch.zeros(encoding_indices.shape[0],self.num_tokens, dtype=torch.float,device=inputs_flatten.device) 70 | encodings.scatter_(1, encoding_indices, 1) 71 | 72 | if self.training: 73 | 74 | tmp_sum = torch.sum(encodings,dim=0,keepdim=True) 75 | encoding_sum = torch.sum(tmp_sum, dim=0) 76 | 77 | sum_inplace(self.cluster_sum,encoding_sum) 78 | ema_tensor_inplace(self.cluster_size, encoding_sum, self.curr_decay) 79 | embed_sum_tmp = torch.matmul(encodings.t(), inputs_flatten) 80 | 81 | embed_sum = torch.sum(embed_sum_tmp.unsqueeze(dim=0),dim=0) 82 | ema_tensor_inplace(self.embed_avg, embed_sum, self.curr_decay) 83 | 84 | cluster_size = laplace_smoothing(self.cluster_size, self.num_tokens, self.eps) * self.cluster_size.sum() 85 | embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) 86 | 87 | # world_size = dist.get_world_size() 88 | # dist.all_reduce(embed_normalized.div_(world_size)) 89 | self.embed.data.copy_(embed_normalized) 90 | # print('embed') 91 | # print(self.embed) 92 | # print('encodings') 93 | # print(encodings) 94 | # print('embed') 95 | # print(self.embed) 96 | quantize = torch.matmul(encodings, self.embed) 97 | #quantize = inputs_flatten 98 | quantize = (quantize - inputs_flatten).detach() + inputs_flatten 99 | # print('quantize') 100 | # print(quantize) 101 | return quantize, encoding_indices 102 | 103 | # vq = SOHO_Pre_VD(3, 5, decay=0.1,max_decay=0.99) 104 | # inputs_flatten = torch.FloatTensor([[2,1,3,4,1],[6,1,7,1,3],[15,1,14,4,111]]) 105 | # inputs_flatten_norm = inputs_flatten/inputs_flatten.norm(p=2,dim=1,keepdim=True) 106 | # embed_norm = vq.embed/vq.embed.norm(p=2,dim=1,keepdim=True) 107 | # a = torch.matmul(inputs_flatten_norm,embed_norm.t()) 108 | # print(a) 109 | # print(vq.embed) 110 | # print(vq.embed_avg) 111 | # quantized_pt, indices = vq(inputs_flatten) 112 | # inputs_flatten_norm = inputs_flatten/inputs_flatten.norm(p=2,dim=1,keepdim=True) 113 | # embed_norm = vq.embed/vq.embed.norm(p=2,dim=1,keepdim=True) 114 | # a = torch.matmul(inputs_flatten_norm,embed_norm.t()) 115 | # print(a) 116 | # print(quantized_pt) 117 | # print(indices) -------------------------------------------------------------------------------- /data/__pycache__/dataloader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/data/__pycache__/dataloader.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataloader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/data/__pycache__/dataloader.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/data/__pycache__/dataloader.cpython-38.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/data/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/data/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/data/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /data/__pycache__/random_erasing.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/data/__pycache__/random_erasing.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/random_erasing.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/data/__pycache__/random_erasing.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/random_erasing.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/data/__pycache__/random_erasing.cpython-38.pyc -------------------------------------------------------------------------------- /data/dataloader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat., JUL. 20(th), 2019 at 16:51 4 | 5 | @author: zifyloo 6 | """ 7 | 8 | from torchvision import transforms 9 | from PIL import Image 10 | import torch 11 | from data.dataset import CUHKPEDEDataset, CUHKPEDE_img_dateset, CUHKPEDE_txt_dateset 12 | 13 | 14 | def get_dataloader(opt): 15 | """ 16 | tranforms the image, downloads the image with the id by data.DataLoader 17 | """ 18 | 19 | if opt.mode == 'train': 20 | transform_list = [ 21 | transforms.Resize((384, 128), interpolation=3), 22 | transforms.Pad(10), 23 | transforms.RandomCrop((384, 128)), 24 | transforms.RandomHorizontalFlip(), 25 | transforms.ToTensor(), 26 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 27 | ] 28 | tran = transforms.Compose(transform_list) 29 | 30 | dataset = CUHKPEDEDataset(opt, tran) 31 | 32 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, 33 | shuffle=True, drop_last=True, num_workers=3) 34 | print('{}-{} has {} pohtos'.format(opt.dataset, opt.mode, len(dataset))) 35 | 36 | return dataloader 37 | 38 | else: 39 | tran = transforms.Compose([ 40 | transforms.Resize((384, 128), interpolation=3), 41 | transforms.ToTensor(), 42 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 43 | ] 44 | ) 45 | 46 | img_dataset = CUHKPEDE_img_dateset(opt, tran) 47 | 48 | img_dataloader = torch.utils.data.DataLoader(img_dataset, batch_size=opt.batch_size, 49 | shuffle=False, drop_last=False, num_workers=3) 50 | 51 | txt_dataset = CUHKPEDE_txt_dateset(opt) 52 | 53 | txt_dataloader = torch.utils.data.DataLoader(txt_dataset, batch_size=opt.batch_size, 54 | shuffle=False, drop_last=False, num_workers=3) 55 | 56 | print('{}-{} has {} pohtos, {} text'.format(opt.dataset, opt.mode, len(img_dataset), len(txt_dataset))) 57 | 58 | return img_dataloader, txt_dataloader 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | make the CUHK-PEDE dataset 4 | 5 | Created on Fri., Aug. 1(st), 2019 at 10:42 6 | 7 | @author: zifyloo 8 | """ 9 | 10 | import torch 11 | import torch.utils.data as data 12 | import numpy as np 13 | from PIL import Image 14 | import os 15 | from utils.read_write_data import read_dict 16 | import cv2 17 | import torchvision.transforms.functional as F 18 | import random 19 | 20 | 21 | def fliplr(img, dim): 22 | """ 23 | flip horizontal 24 | :param img: 25 | :return: 26 | """ 27 | inv_idx = torch.arange(img.size(dim) - 1, -1, -1).long() # N x C x H x W 28 | img_flip = img.index_select(dim, inv_idx) 29 | return img_flip 30 | 31 | 32 | class CUHKPEDEDataset(data.Dataset): 33 | def __init__(self, opt, tran): 34 | 35 | self.opt = opt 36 | self.flip_flag = (self.opt.mode == 'train') 37 | data_save = read_dict(opt.pkl_root + opt.mode + '_save.pkl') 38 | print(data_save.keys()) 39 | if opt.dataset == 'CUHK-PEDES': 40 | self.img_path = [os.path.join(opt.dataroot, img_path) for img_path in data_save['img_path']] 41 | elif opt.dataset == 'MSMT-PEDES': 42 | self.img_path = [os.path.join('/home/zhiyin/ICFG_PEDES/ICFG_PEDES', img_path.split('/')[-3],img_path.split('/')[-2],img_path.split('/')[-1]) for img_path in data_save['img_path']] 43 | 44 | self.label = data_save['id'] 45 | if self.opt.wordtype == 'bert': 46 | self.caption_code = data_save['bert_caption_id'] 47 | elif self.opt.wordtype == 'lstm': 48 | self.caption_code = data_save['lstm_caption_id'] 49 | self.transform = tran 50 | 51 | self.num_data = len(self.img_path) 52 | 53 | def __getitem__(self, index): 54 | """ 55 | :param index: 56 | :return: image and its label 57 | """ 58 | 59 | image = Image.open(self.img_path[index]) 60 | image = self.transform(image) 61 | 62 | label = torch.from_numpy(np.array([self.label[index]])).long() 63 | 64 | caption_code, caption_length, caption_mask= self.caption_mask(self.caption_code[index]) 65 | 66 | return image, label, caption_code, caption_length, caption_mask 67 | 68 | def caption_mask(self, caption): 69 | caption_length = len(caption) 70 | caption = torch.from_numpy(np.array(caption)).view(-1).float() 71 | if caption_length < self.opt.caption_length_max: 72 | zero_padding = torch.zeros(self.opt.caption_length_max - caption_length) 73 | caption = torch.cat([caption, zero_padding], 0) 74 | else: 75 | caption = caption[:self.opt.caption_length_max] 76 | caption_length = self.opt.caption_length_max 77 | caption_mask = np.where(caption != 0, 1, 0) 78 | return caption, caption_length , caption_mask 79 | 80 | def __len__(self): 81 | return self.num_data 82 | 83 | 84 | class CUHKPEDE_img_dateset(data.Dataset): 85 | def __init__(self, opt, tran): 86 | 87 | self.opt = opt 88 | 89 | data_save = read_dict(opt.pkl_root + opt.mode + '_save.pkl') 90 | 91 | if opt.dataset == 'CUHK-PEDES': 92 | self.img_path = [os.path.join(opt.dataroot, img_path) for img_path in data_save['img_path']] 93 | elif opt.dataset == 'MSMT-PEDES': 94 | self.img_path = [os.path.join('/home/zhiyin/ICFG_PEDES/ICFG_PEDES', img_path.split('/')[-3],img_path.split('/')[-2],img_path.split('/')[-1]) for img_path in data_save['img_path']] 95 | 96 | self.label = data_save['id'] 97 | 98 | self.transform = tran 99 | 100 | self.num_data = len(self.img_path) 101 | 102 | def __getitem__(self, index): 103 | """ 104 | :param index: 105 | :return: image and its label 106 | """ 107 | 108 | image = Image.open(self.img_path[index]) 109 | image = self.transform(image) 110 | 111 | label = torch.from_numpy(np.array([self.label[index]])).long() 112 | 113 | return image, label 114 | 115 | def __len__(self): 116 | return self.num_data 117 | 118 | 119 | class CUHKPEDE_txt_dateset(data.Dataset): 120 | def __init__(self, opt): 121 | 122 | self.opt = opt 123 | 124 | data_save = read_dict(opt.pkl_root + opt.mode + '_save.pkl') 125 | 126 | self.label = data_save['caption_label'] 127 | if self.opt.wordtype == 'bert': 128 | self.caption_code = data_save['bert_caption_id'] 129 | elif self.opt.wordtype == 'lstm': 130 | self.caption_code = data_save['lstm_caption_id'] 131 | 132 | self.caption_matching_img_index = data_save['caption_matching_img_index'] 133 | 134 | self.num_data = len(self.caption_code) 135 | 136 | def __getitem__(self, index): 137 | """ 138 | :param index: 139 | :return: image and its label 140 | """ 141 | 142 | label = torch.from_numpy(np.array([self.label[index]])).long() 143 | 144 | caption_code, caption_length , caption_mask= self.caption_mask(self.caption_code[index]) 145 | caption_matching_img_index = self.caption_matching_img_index[index] 146 | return label, caption_code, caption_length, caption_mask,caption_matching_img_index 147 | 148 | def caption_mask(self, caption): 149 | caption_length = len(caption) 150 | caption = torch.from_numpy(np.array(caption)).view(-1).float() 151 | if caption_length < self.opt.caption_length_max: 152 | zero_padding = torch.zeros(self.opt.caption_length_max - caption_length) 153 | caption = torch.cat([caption, zero_padding], 0) 154 | else: 155 | caption = caption[:self.opt.caption_length_max] 156 | caption_length = self.opt.caption_length_max 157 | caption_mask = np.where(caption != 0, 1, 0) 158 | return caption, caption_length, caption_mask 159 | 160 | def __len__(self): 161 | return self.num_data 162 | 163 | -------------------------------------------------------------------------------- /loss/Id_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat., Aug. 17(rd), 2019 at 15:33 4 | 5 | @author: zifyloo 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn.parameter import Parameter 11 | from torch.nn import init 12 | 13 | 14 | def weights_init_classifier(m): 15 | classname = m.__class__.__name__ 16 | if classname.find('Linear') != -1: 17 | init.normal_(m.weight.data, std=0.001) 18 | init.constant_(m.bias.data, 0.0) 19 | 20 | 21 | class classifier(nn.Module): 22 | 23 | def __init__(self, input_dim, output_dim): 24 | super(classifier, self).__init__() 25 | 26 | self.block = nn.Linear(input_dim, output_dim) 27 | self.block.apply(weights_init_classifier) 28 | 29 | def forward(self, x): 30 | x = self.block(x) 31 | return x 32 | 33 | 34 | class Id_Loss(nn.Module): 35 | 36 | def __init__(self, opt): 37 | super(Id_Loss, self).__init__() 38 | 39 | self.opt = opt 40 | 41 | self.W = classifier(opt.feature_length, opt.class_num) 42 | # self.W_txt = classifier(opt.feature_length, opt.class_num) 43 | 44 | def calculate_IdLoss(self, image_embedding, text_embedding, label): 45 | 46 | label = label.view(label.size(0)) 47 | 48 | criterion = nn.CrossEntropyLoss(reduction='mean') 49 | 50 | score_i2t = self.W(image_embedding) 51 | score_t2i = self.W(text_embedding) 52 | Lipt_local = criterion(score_i2t, label) 53 | Ltpi_local = criterion(score_t2i, label) 54 | pred_i2t = torch.mean((torch.argmax(score_i2t, dim=1) == label).float()) 55 | pred_t2i = torch.mean((torch.argmax(score_t2i, dim=1) == label).float()) 56 | loss = (Lipt_local + Ltpi_local) 57 | 58 | return loss, pred_i2t, pred_t2i 59 | 60 | def forward(self, image_embedding, text_embedding, label): 61 | 62 | loss, pred_i2t, pred_t2i = self.calculate_IdLoss(image_embedding, text_embedding, label) 63 | 64 | return loss, pred_i2t, pred_t2i 65 | 66 | 67 | class Id_Loss_2(nn.Module): 68 | 69 | def __init__(self, opt): 70 | super(Id_Loss_2, self).__init__() 71 | 72 | self.opt = opt 73 | 74 | self.W = classifier(opt.feature_length, opt.class_num) 75 | # self.W_txt = classifier(opt.feature_length, opt.class_num) 76 | 77 | def calculate_IdLoss(self, image_embedding, text_embedding, label): 78 | 79 | label = label.view(label.size(0)) 80 | 81 | criterion = nn.CrossEntropyLoss(reduction='mean') 82 | 83 | score_i2t = self.W(image_embedding) 84 | score_t2i = self.W(text_embedding) 85 | Lipt_local = criterion(score_i2t, label) 86 | Ltpi_local = criterion(score_t2i, label) 87 | pred_i2t = torch.mean((torch.argmax(score_i2t, dim=1) == label).float()) 88 | pred_t2i = torch.mean((torch.argmax(score_t2i, dim=1) == label).float()) 89 | # loss = (Lipt_local + Ltpi_local) 90 | 91 | return Lipt_local, Ltpi_local,pred_i2t, pred_t2i 92 | 93 | def forward(self, image_embedding, text_embedding, label): 94 | 95 | Lipt_local, Ltpi_local, pred_i2t, pred_t2i = self.calculate_IdLoss(image_embedding, text_embedding, label) 96 | 97 | return Lipt_local, Ltpi_local, pred_i2t, pred_t2i 98 | 99 | 100 | class Id_Loss_3(nn.Module): 101 | 102 | def __init__(self, opt): 103 | super(Id_Loss_3, self).__init__() 104 | 105 | self.opt = opt 106 | 107 | self.W = classifier(opt.feature_length, opt.class_num) 108 | # self.W_txt = classifier(opt.feature_length, opt.class_num) 109 | 110 | def calculate_IdLoss(self, image_embedding, image_embedding_2 , text_embedding, label): 111 | 112 | label = label.view(label.size(0)) 113 | 114 | criterion = nn.CrossEntropyLoss(reduction='mean') 115 | 116 | score_i2t = self.W(image_embedding) 117 | score_i2t_MPN = self.W(image_embedding_2) 118 | score_t2i = self.W(text_embedding) 119 | Lipt_local = criterion(score_i2t, label) 120 | Ltpi_local = criterion(score_t2i, label) 121 | Lipt_local_MPN = criterion(score_i2t_MPN, label) 122 | pred_i2t = torch.mean((torch.argmax(score_i2t, dim=1) == label).float()) 123 | pred_t2i = torch.mean((torch.argmax(score_t2i, dim=1) == label).float()) 124 | pred_i2t_MPN = torch.mean((torch.argmax(score_i2t_MPN, dim=1) == label).float()) 125 | # loss = (Lipt_local + Ltpi_local) 126 | 127 | return Lipt_local, Ltpi_local,Lipt_local_MPN,pred_i2t, pred_t2i , pred_i2t_MPN 128 | 129 | def forward(self, image_embedding, image_embedding_2 , text_embedding, label): 130 | 131 | Lipt_local, Ltpi_local,Lipt_local_MPN,pred_i2t, pred_t2i , pred_i2t_MPN = self.calculate_IdLoss(image_embedding, image_embedding_2 , text_embedding, label) 132 | 133 | return Lipt_local, Ltpi_local,Lipt_local_MPN,pred_i2t, pred_t2i , pred_i2t_MPN 134 | # class Id_Loss_part(nn.Module): 135 | # 136 | # def __init__(self, opt): 137 | # super(Id_Loss_part, self).__init__() 138 | # 139 | # self.opt = opt 140 | # self.W = nn.ModuleList() 141 | # for _ in range(opt.num_query): 142 | # self.W.append(classifier(opt.d_model, opt.class_num)) 143 | # # self.W_txt = classifier(opt.feature_length, opt.class_num) 144 | # 145 | # def calculate_IdLoss(self, image_embedding, text_embedding, label): 146 | # 147 | # label = label.view(label.size(0)) 148 | # 149 | # criterion = nn.CrossEntropyLoss(reduction='mean') 150 | # 151 | # score_i2t = self.W(image_embedding) 152 | # score_t2i = self.W(text_embedding) 153 | # Lipt_local = criterion(score_i2t, label) 154 | # Ltpi_local = criterion(score_t2i, label) 155 | # pred_i2t = torch.mean((torch.argmax(score_i2t, dim=1) == label).float()) 156 | # pred_t2i = torch.mean((torch.argmax(score_t2i, dim=1) == label).float()) 157 | # loss = (Lipt_local + Ltpi_local) 158 | # 159 | # return loss, pred_i2t, pred_t2i 160 | # 161 | # def forward(self, image_embedding, text_embedding, label): 162 | # 163 | # loss, pred_i2t, pred_t2i = self.calculate_IdLoss(image_embedding, text_embedding, label) 164 | # 165 | # return loss, pred_i2t, pred_t2i -------------------------------------------------------------------------------- /loss/RankingLoss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat., Aug. 17(rd), 2019 at 15:41 4 | 5 | @author: zifyloo 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import numpy as np 11 | import torch.nn.functional as F 12 | 13 | 14 | def calculate_similarity_global(image_embedding, text_embedding): 15 | image_embedding_norm = image_embedding / (image_embedding.norm(dim=1, keepdim=True) + 1e-8) 16 | text_embedding_norm = text_embedding / (text_embedding.norm(dim=1, keepdim=True) + 1e-8) 17 | 18 | similarity = torch.mm(image_embedding_norm, text_embedding_norm.t()) 19 | 20 | return similarity 21 | 22 | 23 | class RankingLoss(nn.Module): 24 | 25 | def __init__(self, opt): 26 | super(RankingLoss, self).__init__() 27 | 28 | self.margin = opt.margin 29 | self.device = opt.device 30 | 31 | def semi_hard_negative(self, loss): 32 | negative_index = np.where(np.logical_and(loss < self.margin, loss > 0))[0] 33 | return np.random.choice(negative_index) if len(negative_index) > 0 else None 34 | 35 | def get_triplets(self, similarity, labels): 36 | similarity = similarity.cpu().data.numpy() 37 | 38 | labels = labels.cpu().data.numpy() 39 | triplets = [] 40 | 41 | for idx, label in enumerate(labels): # same class calculate together 42 | 43 | negative = np.where(labels != label)[0] 44 | 45 | ap_sim = similarity[idx, idx] 46 | # print(ap_combination_list.shape, ap_distances_list.shape) 47 | 48 | loss = similarity[idx, negative] - ap_sim + self.margin 49 | 50 | negetive_index = self.semi_hard_negative(loss) 51 | 52 | if negetive_index is not None: 53 | triplets.append([idx, idx, negative[negetive_index]]) 54 | 55 | if len(triplets) == 0: 56 | triplets.append([idx, idx, negative[0]]) 57 | 58 | triplets = np.array(triplets) 59 | 60 | return torch.LongTensor(triplets) 61 | 62 | def forward(self, similarity, label): 63 | 64 | image_triplets = self.get_triplets(similarity, label) 65 | text_triplets = self.get_triplets(similarity.t(), label) 66 | 67 | # print(image_triplets.size(), text_triplets.size()) 68 | image_anchor_loss = F.relu(self.margin 69 | - similarity[image_triplets[:, 0], image_triplets[:, 1]] 70 | + similarity[image_triplets[:, 0], image_triplets[:, 2]]) 71 | 72 | texy_anchor_loss = F.relu(self.margin 73 | - similarity[text_triplets[:, 0], text_triplets[:, 1]] 74 | + similarity[text_triplets[:, 0], text_triplets[:, 2]]) 75 | 76 | loss = torch.sum(image_anchor_loss) + torch.sum(texy_anchor_loss) 77 | # loss = CMPM_loss + CMPC_loss 78 | 79 | return loss 80 | 81 | 82 | """ 83 | # test code 84 | SEED = 0 85 | torch.manual_seed(SEED) 86 | np.random.seed(SEED) 87 | 88 | image_embeddings = torch.rand(8, 512) 89 | text_embeddings = torch.rand(8, 512) 90 | label = torch.LongTensor([1, 2, 2, 2, 2, 2, 2, 2]) 91 | 92 | triplet_loss = RankingLoss(0.3) 93 | print(triplet_loss(image_embeddings, text_embeddings, label)) 94 | """ 95 | 96 | -------------------------------------------------------------------------------- /loss/__pycache__/Id_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/loss/__pycache__/Id_loss.cpython-36.pyc -------------------------------------------------------------------------------- /loss/__pycache__/Id_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/loss/__pycache__/Id_loss.cpython-37.pyc -------------------------------------------------------------------------------- /loss/__pycache__/Id_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/loss/__pycache__/Id_loss.cpython-38.pyc -------------------------------------------------------------------------------- /loss/__pycache__/RankingLoss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/loss/__pycache__/RankingLoss.cpython-36.pyc -------------------------------------------------------------------------------- /loss/__pycache__/RankingLoss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/loss/__pycache__/RankingLoss.cpython-37.pyc -------------------------------------------------------------------------------- /loss/__pycache__/RankingLoss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/loss/__pycache__/RankingLoss.cpython-38.pyc -------------------------------------------------------------------------------- /loss/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/loss/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /loss_TransREID/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_loss import make_loss 2 | from .arcface import ArcFace -------------------------------------------------------------------------------- /loss_TransREID/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/loss_TransREID/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /loss_TransREID/__pycache__/arcface.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/loss_TransREID/__pycache__/arcface.cpython-38.pyc -------------------------------------------------------------------------------- /loss_TransREID/__pycache__/center_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/loss_TransREID/__pycache__/center_loss.cpython-38.pyc -------------------------------------------------------------------------------- /loss_TransREID/__pycache__/make_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/loss_TransREID/__pycache__/make_loss.cpython-38.pyc -------------------------------------------------------------------------------- /loss_TransREID/__pycache__/metric_learning.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/loss_TransREID/__pycache__/metric_learning.cpython-38.pyc -------------------------------------------------------------------------------- /loss_TransREID/__pycache__/softmax_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/loss_TransREID/__pycache__/softmax_loss.cpython-38.pyc -------------------------------------------------------------------------------- /loss_TransREID/__pycache__/triplet_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/loss_TransREID/__pycache__/triplet_loss.cpython-38.pyc -------------------------------------------------------------------------------- /loss_TransREID/arcface.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Parameter 5 | import math 6 | 7 | 8 | class ArcFace(nn.Module): 9 | def __init__(self, in_features, out_features, s=30.0, m=0.50, bias=False): 10 | super(ArcFace, self).__init__() 11 | self.in_features = in_features 12 | self.out_features = out_features 13 | self.s = s 14 | self.m = m 15 | self.cos_m = math.cos(m) 16 | self.sin_m = math.sin(m) 17 | 18 | self.th = math.cos(math.pi - m) 19 | self.mm = math.sin(math.pi - m) * m 20 | 21 | self.weight = Parameter(torch.Tensor(out_features, in_features)) 22 | if bias: 23 | self.bias = Parameter(torch.Tensor(out_features)) 24 | else: 25 | self.register_parameter('bias', None) 26 | self.reset_parameters() 27 | 28 | def reset_parameters(self): 29 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 30 | if self.bias is not None: 31 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 32 | bound = 1 / math.sqrt(fan_in) 33 | nn.init.uniform_(self.bias, -bound, bound) 34 | 35 | def forward(self, input, label): 36 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 37 | sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1)) 38 | phi = cosine * self.cos_m - sine * self.sin_m 39 | phi = torch.where(cosine > self.th, phi, cosine - self.mm) 40 | # --------------------------- convert label to one-hot --------------------------- 41 | # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda') 42 | one_hot = torch.zeros(cosine.size(), device='cuda') 43 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 44 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 45 | output = (one_hot * phi) + ( 46 | (1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 47 | output *= self.s 48 | # print(output) 49 | 50 | return output 51 | 52 | class CircleLoss(nn.Module): 53 | def __init__(self, in_features, num_classes, s=256, m=0.25): 54 | super(CircleLoss, self).__init__() 55 | self.weight = Parameter(torch.Tensor(num_classes, in_features)) 56 | self.s = s 57 | self.m = m 58 | self._num_classes = num_classes 59 | self.reset_parameters() 60 | 61 | 62 | def reset_parameters(self): 63 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 64 | 65 | def __call__(self, bn_feat, targets): 66 | 67 | sim_mat = F.linear(F.normalize(bn_feat), F.normalize(self.weight)) 68 | alpha_p = torch.clamp_min(-sim_mat.detach() + 1 + self.m, min=0.) 69 | alpha_n = torch.clamp_min(sim_mat.detach() + self.m, min=0.) 70 | delta_p = 1 - self.m 71 | delta_n = self.m 72 | 73 | s_p = self.s * alpha_p * (sim_mat - delta_p) 74 | s_n = self.s * alpha_n * (sim_mat - delta_n) 75 | 76 | targets = F.one_hot(targets, num_classes=self._num_classes) 77 | 78 | pred_class_logits = targets * s_p + (1.0 - targets) * s_n 79 | 80 | return pred_class_logits -------------------------------------------------------------------------------- /loss_TransREID/center_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class CenterLoss(nn.Module): 8 | """Center loss. 9 | 10 | Reference: 11 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 12 | 13 | Args: 14 | num_classes (int): number of classes. 15 | feat_dim (int): feature dimension. 16 | """ 17 | 18 | def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True): 19 | super(CenterLoss, self).__init__() 20 | self.num_classes = num_classes 21 | self.feat_dim = feat_dim 22 | self.use_gpu = use_gpu 23 | 24 | if self.use_gpu: 25 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 26 | else: 27 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 28 | 29 | def forward(self, x, labels): 30 | """ 31 | Args: 32 | x: feature matrix with shape (batch_size, feat_dim). 33 | labels: ground truth labels with shape (num_classes). 34 | """ 35 | assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)" 36 | 37 | batch_size = x.size(0) 38 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 39 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 40 | distmat.addmm_(1, -2, x, self.centers.t()) 41 | 42 | classes = torch.arange(self.num_classes).long() 43 | if self.use_gpu: classes = classes.cuda() 44 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 45 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 46 | 47 | dist = [] 48 | for i in range(batch_size): 49 | value = distmat[i][mask[i]] 50 | value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 51 | dist.append(value) 52 | dist = torch.cat(dist) 53 | loss = dist.mean() 54 | return loss 55 | 56 | 57 | if __name__ == '__main__': 58 | use_gpu = False 59 | center_loss = CenterLoss(use_gpu=use_gpu) 60 | features = torch.rand(16, 2048) 61 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).long() 62 | if use_gpu: 63 | features = torch.rand(16, 2048).cuda() 64 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).cuda() 65 | 66 | loss = center_loss(features, targets) 67 | print(loss) 68 | -------------------------------------------------------------------------------- /loss_TransREID/make_loss.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch.nn.functional as F 8 | from .softmax_loss import CrossEntropyLabelSmooth, LabelSmoothingCrossEntropy 9 | from .triplet_loss import TripletLoss 10 | from .center_loss import CenterLoss 11 | 12 | 13 | def make_loss(cfg , num_classes): # modified by gu 14 | sampler = cfg.DATALOADER.SAMPLER 15 | feat_dim = 2048 16 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss 17 | if 'triplet' in cfg.MODEL.METRIC_LOSS_TYPE: 18 | if cfg.MODEL.NO_MARGIN: 19 | triplet = TripletLoss() 20 | print("using soft triplet loss for training") 21 | else: 22 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss 23 | print("using triplet loss with margin:{}".format(cfg.SOLVER.MARGIN)) 24 | else: 25 | print('expected METRIC_LOSS_TYPE should be triplet' 26 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 27 | 28 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 29 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) 30 | print("label smooth on, numclasses:", num_classes) 31 | 32 | if sampler == 'softmax': 33 | def loss_func(score, feat, target): 34 | return F.cross_entropy(score, target) 35 | 36 | elif cfg.DATALOADER.SAMPLER == 'softmax_triplet': 37 | def loss_func(score, feat, target, carema): 38 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet': 39 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 40 | if isinstance(score, list): 41 | ID_LOSS = [xent(scor, target) for scor in score[1:]] 42 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS) 43 | ID_LOSS = 0.5 * ID_LOSS + 0.5 * xent(score[0], target) 44 | else: 45 | ID_LOSS = xent(score, target) 46 | 47 | if isinstance(feat, list): 48 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat[1:]] 49 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS) 50 | TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * triplet(feat[0], target)[0] 51 | else: 52 | TRI_LOSS = triplet(feat, target)[0] 53 | 54 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \ 55 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS 56 | else: 57 | if isinstance(score, list): 58 | ID_LOSS = [F.cross_entropy(scor, target) for scor in score[1:]] 59 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS) 60 | ID_LOSS = 0.5 * ID_LOSS + 0.5 * F.cross_entropy(score[0], target) 61 | else: 62 | ID_LOSS = F.cross_entropy(score, target) 63 | 64 | if isinstance(feat, list): 65 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat[1:]] 66 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS) 67 | TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * triplet(feat[0], target)[0] 68 | else: 69 | TRI_LOSS = triplet(feat, target)[0] 70 | 71 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \ 72 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS 73 | else: 74 | print('expected METRIC_LOSS_TYPE should be triplet' 75 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 76 | 77 | else: 78 | print('expected sampler should be softmax, triplet, softmax_triplet or softmax_triplet_center' 79 | 'but got {}'.format(cfg.DATALOADER.SAMPLER)) 80 | return loss_func, center_criterion 81 | 82 | 83 | -------------------------------------------------------------------------------- /loss_TransREID/metric_learning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.autograd 5 | from torch.nn import Parameter 6 | import math 7 | 8 | 9 | class ContrastiveLoss(nn.Module): 10 | def __init__(self, margin=0.3, **kwargs): 11 | super(ContrastiveLoss, self).__init__() 12 | self.margin = margin 13 | 14 | def forward(self, inputs, targets): 15 | n = inputs.size(0) 16 | # Compute similarity matrix 17 | sim_mat = torch.matmul(inputs, inputs.t()) 18 | targets = targets 19 | loss = list() 20 | c = 0 21 | 22 | for i in range(n): 23 | pos_pair_ = torch.masked_select(sim_mat[i], targets == targets[i]) 24 | 25 | # move itself 26 | pos_pair_ = torch.masked_select(pos_pair_, pos_pair_ < 1) 27 | neg_pair_ = torch.masked_select(sim_mat[i], targets != targets[i]) 28 | 29 | pos_pair_ = torch.sort(pos_pair_)[0] 30 | neg_pair_ = torch.sort(neg_pair_)[0] 31 | 32 | neg_pair = torch.masked_select(neg_pair_, neg_pair_ > self.margin) 33 | 34 | neg_loss = 0 35 | 36 | pos_loss = torch.sum(-pos_pair_ + 1) 37 | if len(neg_pair) > 0: 38 | neg_loss = torch.sum(neg_pair) 39 | loss.append(pos_loss + neg_loss) 40 | 41 | loss = sum(loss) / n 42 | return loss 43 | 44 | 45 | class CircleLoss(nn.Module): 46 | def __init__(self, in_features, num_classes, s=256, m=0.25): 47 | super(CircleLoss, self).__init__() 48 | self.weight = Parameter(torch.Tensor(num_classes, in_features)) 49 | self.s = s 50 | self.m = m 51 | self._num_classes = num_classes 52 | self.reset_parameters() 53 | 54 | 55 | def reset_parameters(self): 56 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 57 | 58 | def __call__(self, bn_feat, targets): 59 | 60 | sim_mat = F.linear(F.normalize(bn_feat), F.normalize(self.weight)) 61 | alpha_p = torch.clamp_min(-sim_mat.detach() + 1 + self.m, min=0.) 62 | alpha_n = torch.clamp_min(sim_mat.detach() + self.m, min=0.) 63 | delta_p = 1 - self.m 64 | delta_n = self.m 65 | 66 | s_p = self.s * alpha_p * (sim_mat - delta_p) 67 | s_n = self.s * alpha_n * (sim_mat - delta_n) 68 | 69 | targets = F.one_hot(targets, num_classes=self._num_classes) 70 | 71 | pred_class_logits = targets * s_p + (1.0 - targets) * s_n 72 | 73 | return pred_class_logits 74 | 75 | 76 | class Arcface(nn.Module): 77 | r"""Implement of large margin arc distance: : 78 | Args: 79 | in_features: size of each input sample 80 | out_features: size of each output sample 81 | s: norm of input feature 82 | m: margin 83 | cos(theta + m) 84 | """ 85 | def __init__(self, in_features, out_features, s=30.0, m=0.30, easy_margin=False, ls_eps=0.0): 86 | super(Arcface, self).__init__() 87 | self.in_features = in_features 88 | self.out_features = out_features 89 | self.s = s 90 | self.m = m 91 | self.ls_eps = ls_eps # label smoothing 92 | self.weight = Parameter(torch.FloatTensor(out_features, in_features)) 93 | nn.init.xavier_uniform_(self.weight) 94 | 95 | self.easy_margin = easy_margin 96 | self.cos_m = math.cos(m) 97 | self.sin_m = math.sin(m) 98 | self.th = math.cos(math.pi - m) 99 | self.mm = math.sin(math.pi - m) * m 100 | 101 | def forward(self, input, label): 102 | # --------------------------- cos(theta) & phi(theta) --------------------------- 103 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 104 | sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) 105 | phi = cosine * self.cos_m - sine * self.sin_m 106 | phi = phi.type_as(cosine) 107 | if self.easy_margin: 108 | phi = torch.where(cosine > 0, phi, cosine) 109 | else: 110 | phi = torch.where(cosine > self.th, phi, cosine - self.mm) 111 | # --------------------------- convert label to one-hot --------------------------- 112 | # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda') 113 | one_hot = torch.zeros(cosine.size(), device='cuda') 114 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 115 | if self.ls_eps > 0: 116 | one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features 117 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 118 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) 119 | output *= self.s 120 | 121 | return output 122 | 123 | 124 | class Cosface(nn.Module): 125 | r"""Implement of large margin cosine distance: : 126 | Args: 127 | in_features: size of each input sample 128 | out_features: size of each output sample 129 | s: norm of input feature 130 | m: margin 131 | cos(theta) - m 132 | """ 133 | 134 | def __init__(self, in_features, out_features, s=30.0, m=0.30): 135 | super(Cosface, self).__init__() 136 | self.in_features = in_features 137 | self.out_features = out_features 138 | self.s = s 139 | self.m = m 140 | self.weight = Parameter(torch.FloatTensor(out_features, in_features)) 141 | nn.init.xavier_uniform_(self.weight) 142 | 143 | def forward(self, input, label): 144 | # --------------------------- cos(theta) & phi(theta) --------------------------- 145 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 146 | phi = cosine - self.m 147 | # --------------------------- convert label to one-hot --------------------------- 148 | one_hot = torch.zeros(cosine.size(), device='cuda') 149 | # one_hot = one_hot.cuda() if cosine.is_cuda else one_hot 150 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 151 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 152 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 153 | output *= self.s 154 | # print(output) 155 | 156 | return output 157 | 158 | def __repr__(self): 159 | return self.__class__.__name__ + '(' \ 160 | + 'in_features=' + str(self.in_features) \ 161 | + ', out_features=' + str(self.out_features) \ 162 | + ', s=' + str(self.s) \ 163 | + ', m=' + str(self.m) + ')' 164 | 165 | 166 | class AMSoftmax(nn.Module): 167 | def __init__(self, in_features, out_features, s=30.0, m=0.30): 168 | super(AMSoftmax, self).__init__() 169 | self.m = m 170 | self.s = s 171 | self.in_feats = in_features 172 | self.W = torch.nn.Parameter(torch.randn(in_features, out_features), requires_grad=True) 173 | self.ce = nn.CrossEntropyLoss() 174 | nn.init.xavier_normal_(self.W, gain=1) 175 | 176 | def forward(self, x, lb): 177 | assert x.size()[0] == lb.size()[0] 178 | assert x.size()[1] == self.in_feats 179 | x_norm = torch.norm(x, p=2, dim=1, keepdim=True).clamp(min=1e-12) 180 | x_norm = torch.div(x, x_norm) 181 | w_norm = torch.norm(self.W, p=2, dim=0, keepdim=True).clamp(min=1e-12) 182 | w_norm = torch.div(self.W, w_norm) 183 | costh = torch.mm(x_norm, w_norm) 184 | # print(x_norm.shape, w_norm.shape, costh.shape) 185 | lb_view = lb.view(-1, 1) 186 | delt_costh = torch.zeros(costh.size(), device='cuda').scatter_(1, lb_view, self.m) 187 | costh_m = costh - delt_costh 188 | costh_m_s = self.s * costh_m 189 | return costh_m_s -------------------------------------------------------------------------------- /loss_TransREID/softmax_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | class CrossEntropyLabelSmooth(nn.Module): 5 | """Cross entropy loss with label smoothing regularizer. 6 | 7 | Reference: 8 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 9 | Equation: y = (1 - epsilon) * y + epsilon / K. 10 | 11 | Args: 12 | num_classes (int): number of classes. 13 | epsilon (float): weight. 14 | """ 15 | 16 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 17 | super(CrossEntropyLabelSmooth, self).__init__() 18 | self.num_classes = num_classes 19 | self.epsilon = epsilon 20 | self.use_gpu = use_gpu 21 | self.logsoftmax = nn.LogSoftmax(dim=1) 22 | 23 | def forward(self, inputs, targets): 24 | """ 25 | Args: 26 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 27 | targets: ground truth labels with shape (num_classes) 28 | """ 29 | log_probs = self.logsoftmax(inputs) 30 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 31 | if self.use_gpu: targets = targets.cuda() 32 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 33 | loss = (- targets * log_probs).mean(0).sum() 34 | return loss 35 | 36 | class LabelSmoothingCrossEntropy(nn.Module): 37 | """ 38 | NLL loss with label smoothing. 39 | """ 40 | def __init__(self, smoothing=0.1): 41 | """ 42 | Constructor for the LabelSmoothing module. 43 | :param smoothing: label smoothing factor 44 | """ 45 | super(LabelSmoothingCrossEntropy, self).__init__() 46 | assert smoothing < 1.0 47 | self.smoothing = smoothing 48 | self.confidence = 1. - smoothing 49 | 50 | def forward(self, x, target): 51 | logprobs = F.log_softmax(x, dim=-1) 52 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 53 | nll_loss = nll_loss.squeeze(1) 54 | smooth_loss = -logprobs.mean(dim=-1) 55 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 56 | return loss.mean() -------------------------------------------------------------------------------- /loss_TransREID/triplet_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def normalize(x, axis=-1): 6 | """Normalizing to unit length along the specified dimension. 7 | Args: 8 | x: pytorch Variable 9 | Returns: 10 | x: pytorch Variable, same shape as input 11 | """ 12 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 13 | return x 14 | 15 | 16 | def euclidean_dist(x, y): 17 | """ 18 | Args: 19 | x: pytorch Variable, with shape [m, d] 20 | y: pytorch Variable, with shape [n, d] 21 | Returns: 22 | dist: pytorch Variable, with shape [m, n] 23 | """ 24 | m, n = x.size(0), y.size(0) 25 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 26 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 27 | dist = xx + yy 28 | dist = dist - 2 * torch.matmul(x, y.t()) 29 | # dist.addmm_(1, -2, x, y.t()) 30 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 31 | return dist 32 | 33 | 34 | def cosine_dist(x, y): 35 | """ 36 | Args: 37 | x: pytorch Variable, with shape [m, d] 38 | y: pytorch Variable, with shape [n, d] 39 | Returns: 40 | dist: pytorch Variable, with shape [m, n] 41 | """ 42 | m, n = x.size(0), y.size(0) 43 | x_norm = torch.pow(x, 2).sum(1, keepdim=True).sqrt().expand(m, n) 44 | y_norm = torch.pow(y, 2).sum(1, keepdim=True).sqrt().expand(n, m).t() 45 | xy_intersection = torch.mm(x, y.t()) 46 | dist = xy_intersection/(x_norm * y_norm) 47 | dist = (1. - dist) / 2 48 | return dist 49 | 50 | 51 | def hard_example_mining(dist_mat, labels, return_inds=False): 52 | """For each anchor, find the hardest positive and negative sample. 53 | Args: 54 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 55 | labels: pytorch LongTensor, with shape [N] 56 | return_inds: whether to return the indices. Save time if `False`(?) 57 | Returns: 58 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 59 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 60 | p_inds: pytorch LongTensor, with shape [N]; 61 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 62 | n_inds: pytorch LongTensor, with shape [N]; 63 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 64 | NOTE: Only consider the case in which all labels have same num of samples, 65 | thus we can cope with all anchors in parallel. 66 | """ 67 | 68 | assert len(dist_mat.size()) == 2 69 | assert dist_mat.size(0) == dist_mat.size(1) 70 | N = dist_mat.size(0) 71 | 72 | # shape [N, N] 73 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 74 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) 75 | 76 | # `dist_ap` means distance(anchor, positive) 77 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 78 | dist_ap, relative_p_inds = torch.max( 79 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) 80 | # print(dist_mat[is_pos].shape) 81 | # `dist_an` means distance(anchor, negative) 82 | # both `dist_an` and `relative_n_inds` with shape [N, 1] 83 | dist_an, relative_n_inds = torch.min( 84 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) 85 | # shape [N] 86 | dist_ap = dist_ap.squeeze(1) 87 | dist_an = dist_an.squeeze(1) 88 | 89 | if return_inds: 90 | # shape [N, N] 91 | ind = (labels.new().resize_as_(labels) 92 | .copy_(torch.arange(0, N).long()) 93 | .unsqueeze(0).expand(N, N)) 94 | # shape [N, 1] 95 | p_inds = torch.gather( 96 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 97 | n_inds = torch.gather( 98 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 99 | # shape [N] 100 | p_inds = p_inds.squeeze(1) 101 | n_inds = n_inds.squeeze(1) 102 | return dist_ap, dist_an, p_inds, n_inds 103 | 104 | return dist_ap, dist_an 105 | 106 | 107 | class TripletLoss(object): 108 | """ 109 | Triplet loss using HARDER example mining, 110 | modified based on original triplet loss using hard example mining 111 | """ 112 | 113 | def __init__(self, margin=None, hard_factor=0.0): 114 | self.margin = margin 115 | self.hard_factor = hard_factor 116 | if margin is not None: 117 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 118 | else: 119 | self.ranking_loss = nn.SoftMarginLoss() 120 | 121 | def __call__(self, global_feat, labels, normalize_feature=False): 122 | if normalize_feature: 123 | global_feat = normalize(global_feat, axis=-1) 124 | dist_mat = euclidean_dist(global_feat, global_feat) 125 | dist_ap, dist_an = hard_example_mining(dist_mat, labels) 126 | 127 | dist_ap *= (1.0 + self.hard_factor) 128 | dist_an *= (1.0 - self.hard_factor) 129 | 130 | y = dist_an.new().resize_as_(dist_an).fill_(1) 131 | if self.margin is not None: 132 | loss = self.ranking_loss(dist_an, dist_ap, y) 133 | else: 134 | loss = self.ranking_loss(dist_an - dist_ap, y) 135 | return loss, dist_ap, dist_an 136 | 137 | 138 | -------------------------------------------------------------------------------- /model/DETR_model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from model.text_feature_extract import TextExtract, TextExtract_Bert_lstm 3 | from torchvision import models 4 | import torch 5 | from torch.nn import init 6 | from vit_pytorch import pixel_ViT, DECODER, PartQuery,mydecoder,mydecoder_DETR 7 | from einops.layers.torch import Rearrange 8 | from model.model import ft_net_TransREID_local, ft_net_TransREID_local_smallDeiT, ft_net_TransREID_local_smallVit 9 | from VD_project import SOHO_Pre_VD 10 | from einops import rearrange, repeat 11 | 12 | def weights_init_kaiming(m): 13 | classname = m.__class__.__name__ 14 | if classname.find('Conv2d') != -1: 15 | init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='relu') 16 | elif classname.find('Linear') != -1: 17 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 18 | # init.constant(m.bias.data, 0.0) 19 | elif classname.find('BatchNorm1d') != -1: 20 | init.normal(m.weight.data, 1.0, 0.02) 21 | init.constant(m.bias.data, 0.0) 22 | elif classname.find('BatchNorm2d') != -1: 23 | init.constant(m.weight.data, 1) 24 | init.constant(m.bias.data, 0) 25 | 26 | 27 | def weights_init_classifier(m): 28 | classname = m.__class__.__name__ 29 | if classname.find('Linear') != -1: 30 | init.normal_(m.weight.data, std=0.001) 31 | # init.constant(m.bias.data, 0.0) 32 | 33 | class conv(nn.Module): 34 | 35 | def __init__(self, input_dim, output_dim, relu=False, BN=False): 36 | super(conv, self).__init__() 37 | 38 | block = [] 39 | block += [nn.Conv2d(input_dim, output_dim, kernel_size=1, bias=False)] 40 | 41 | if BN: 42 | block += [nn.BatchNorm2d(output_dim)] 43 | if relu: 44 | block += [nn.ReLU(inplace=True)] 45 | 46 | self.block = nn.Sequential(*block) 47 | self.block.apply(weights_init_kaiming) 48 | 49 | def forward(self, x): 50 | x = self.block(x) 51 | x = x.squeeze(3).squeeze(2) 52 | return x 53 | 54 | 55 | class TextImgPersonReidNet(nn.Module): 56 | 57 | def __init__(self, opt): 58 | super(TextImgPersonReidNet, self).__init__() 59 | 60 | self.opt = opt 61 | resnet50 = models.resnet50(pretrained=True) 62 | 63 | self.ImageExtract = nn.Sequential(*(list(resnet50.children())[:-2])) 64 | self.TextExtract = TextExtract(opt) 65 | 66 | self.avg_global = nn.AdaptiveMaxPool2d((1, 1)) 67 | self.Decoder = DECODER(opt=opt, dim=opt.d_model, depth=2, heads=4, 68 | mlp_dim=512, pool='cls', patch_dim=2048, dim_head=512, 69 | dropout=0., emb_dropout=0.) 70 | # self._reset_parameters() 71 | # self.query_embed_image = nn.Embedding(opt.num_query,opt.d_model) 72 | # self.input_proj = nn.Conv2d(2048, opt.d_model, kernel_size=1) 73 | self.conv_1X1_2 = nn.ModuleList() 74 | for _ in range(opt.num_query): 75 | self.conv_1X1_2.append(conv(opt.d_model, opt.feature_length)) 76 | self.query_embed_image = nn.Parameter(torch.randn(1, 6, 2048)) 77 | 78 | """ 79 | self.part_query_net = nn.ModuleList() 80 | for _ in range(6): 81 | self.part_query_net.append(PartQuery(dim=2048, depth=2, heads=4, 82 | mlp_dim=512, dim_head=512, dropout=0.)) 83 | """ 84 | self.to_patch_embedding = nn.Sequential( 85 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=1, p2=1) 86 | ) 87 | 88 | def forward(self, image, caption_id, text_length): 89 | 90 | image_feature_part = self.img_embedding(image) 91 | text_feature_part = self.txt_embedding(caption_id, text_length) 92 | 93 | return image_feature_part, text_feature_part 94 | 95 | def image_DETR(self, image_feature): 96 | 97 | image_feature = self.to_patch_embedding(image_feature) 98 | query_embed_image = self.query_embed_image.repeat(image_feature.size(0), 1, 1) 99 | 100 | image_feature = self.Decoder(query_embed_image, image_feature) 101 | image_part = [] 102 | for i in range(self.opt.num_query): 103 | image_feature_i = self.conv_1X1_2[i](image_feature[:, i].unsqueeze(2).unsqueeze(2)) 104 | image_part.append(image_feature_i.unsqueeze(0)) 105 | image_part = torch.cat(image_part, dim=0) 106 | return image_part 107 | 108 | def img_embedding(self, image): 109 | image_feature = self.ImageExtract(image) 110 | image_feature_part = self.image_DETR(image_feature) 111 | 112 | return image_feature_part 113 | 114 | def txt_embedding(self, caption_id, text_length): 115 | text_feature = self.TextExtract(caption_id, text_length) 116 | 117 | ignore_mask = (caption_id == 0) 118 | ignore_mask = ignore_mask[:, :text_feature.size(1)] 119 | query_embed_image = self.query_embed_image.repeat(text_feature.size(0), 1, 1) 120 | text_feature = self.Decoder(query_embed_image, text_feature, mask=ignore_mask) 121 | text_feature_part = [] 122 | for i in range(self.opt.num_query): 123 | text_feature_i = self.conv_1X1_2[i](text_feature[:, i].unsqueeze(2).unsqueeze(2)) 124 | text_feature_part.append(text_feature_i.unsqueeze(0)) 125 | text_feature_part = torch.cat(text_feature_part, dim=0) 126 | return text_feature_part 127 | 128 | 129 | class TextImgPersonReidNet_mydecoder_pixelVit_transTXT_3_bert(nn.Module): 130 | 131 | def __init__(self, opt): 132 | super(TextImgPersonReidNet_mydecoder_pixelVit_transTXT_3_bert, self).__init__() 133 | 134 | self.opt = opt 135 | backbone = ft_net_TransREID_local_smallDeiT() 136 | # backbone = ft_net_TransREID_local_smallVit() 137 | self.ImageExtract = backbone 138 | # self.TextExtract = TextExtract_nomax(opt) 139 | self.TextExtract = TextExtract_Bert_lstm(opt) 140 | self.avg_global = nn.AdaptiveMaxPool2d((1, 1)) 141 | self.TXTDecoder = mydecoder(opt= opt,dim=384, depth=2, heads=6, 142 | mlp_dim=512, pool='cls', patch_dim=384, dim_head=512, 143 | dropout=0., emb_dropout=0.) 144 | self.TXTDecoder_2 = mydecoder(opt=opt, dim=384, depth=1, heads=6, 145 | mlp_dim=768, pool='cls', patch_dim=384, dim_head=64, 146 | dropout=0., emb_dropout=0.) 147 | self.pixel_to_patch = Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=1, p2=1) 148 | self.patch_to_pixel = Rearrange('b (h w) c -> b c h w', h=24, w=8) 149 | # self.conv_1X1_2 = conv(384, opt.feature_length) 150 | self.conv_1X1_2 = nn.ModuleList() 151 | for _ in range(opt.num_query): 152 | self.conv_1X1_2.append(conv(384, opt.feature_length)) 153 | self.pos_embed_image = nn.Parameter(torch.randn(1, 48, opt.d_model)) 154 | self.query_embed_image = nn.Parameter(torch.randn(1, self.opt.num_query, 384)) 155 | if opt.share_query == False: 156 | self.tgt_embed_image = nn.Parameter(torch.randn(1, self.opt.num_query, 384)) 157 | self.dict_feature = nn.Parameter(torch.randn(1, 400, 384)) 158 | self.mask = nn.Sequential(nn.Linear(384,1), 159 | nn.Sigmoid()) 160 | 161 | # self.vd = SOHO_Pre_VD(8000, 384, decay=0.1, max_decay=0.99) 162 | # self.linear_768 = nn.Linear(2048, 384, bias=False) 163 | 164 | def forward(self, image, caption_id, text_mask): 165 | image_feature = self.ImageExtract(image) 166 | text_feature = self.TextExtract(caption_id, text_mask) 167 | image_feature_fusion = self.image_fusion(image_feature , text_feature , caption_id) 168 | image_feature_part, image_feature_part_dict = self.image_DETR(image_feature_fusion,image_feature) 169 | text_feature_part, text_feature_part_dict = self.text_DETR(text_feature, caption_id) 170 | return image_feature_part, image_feature_part_dict,text_feature_part,text_feature_part_dict 171 | 172 | 173 | def text_DETR(self,text_featuremap,caption_id): 174 | 175 | # text_featuremap = self.linear_768(text_featuremap) 176 | B, L, C = text_featuremap.shape 177 | dict_feature = self.dict_feature.repeat(B, 1, 1) 178 | tgt = text_featuremap 179 | ignore_kv_mask = (caption_id == 0) 180 | ignore_kv_mask = ignore_kv_mask[:, :text_featuremap.size(1)] 181 | ignore_kv_mask = torch.logical_not(ignore_kv_mask) 182 | q_mask = torch.zeros(B,self.opt.num_query).to(self.opt.device) 183 | q_mask = (q_mask == 0) 184 | # memory = self.TXTEncoder(tgt,mask = ignore_kv_mask) 185 | memory = tgt 186 | memory_dict = self.TXTDecoder_2(memory,dict_feature) 187 | if self.opt.share_query: 188 | tgt_embed_image = self.query_embed_image.repeat(B, 1, 1) 189 | else: 190 | tgt_embed_image = self.tgt_embed_image.repeat(B, 1, 1) 191 | 192 | hs = self.TXTDecoder(tgt_embed_image,memory,ignore_kv_mask,q_mask) 193 | hs_dict = self.TXTDecoder(tgt_embed_image, memory_dict, ignore_kv_mask, q_mask) 194 | 195 | text_part = [] 196 | for i in range(self.opt.num_query): 197 | hs_i = self.conv_1X1_2[i](hs[:,i].unsqueeze(2).unsqueeze(2)) 198 | text_part.append(hs_i.unsqueeze(0)) 199 | text_part = torch.cat(text_part, dim=0) 200 | 201 | text_part_dict = [] 202 | for i in range(self.opt.num_query): 203 | hs_i = self.conv_1X1_2[i](hs_dict[:,i].unsqueeze(2).unsqueeze(2)) 204 | text_part_dict.append(hs_i.unsqueeze(0)) 205 | text_part_dict = torch.cat(text_part_dict, dim=0) 206 | return text_part , text_part_dict 207 | 208 | def image_fusion(self,image_feature,text_feature, caption_id): 209 | B, P, C = image_feature.shape 210 | _, L, _ = text_feature.shape 211 | 212 | ignore_kv_mask = (caption_id == 0) 213 | ignore_kv_mask = ignore_kv_mask[:, :L] 214 | ignore_kv_mask = torch.logical_not(ignore_kv_mask) 215 | q_mask = torch.zeros(B, P).to(self.opt.device) 216 | q_mask = (q_mask == 0) 217 | mask = self.mask(image_feature) 218 | memory_mask = image_feature 219 | memory_dict = self.TXTDecoder_2(memory_mask, text_feature , ignore_kv_mask, q_mask) 220 | memory_dict = memory_dict * mask 221 | return memory_dict 222 | 223 | 224 | def image_DETR(self,image_featuremap_fusion , image_featuremap): 225 | B , P , C = image_featuremap.shape 226 | dict_feature = self.dict_feature.repeat(B, 1, 1) 227 | 228 | memory = image_featuremap 229 | mask = self.mask(memory) 230 | memory_mask = memory 231 | memory_dict = self.TXTDecoder_2(memory_mask,dict_feature) 232 | memory_dict = memory_dict * mask 233 | query_embed_image = self.query_embed_image.repeat(B,1,1) 234 | if image_featuremap_fusion != None: 235 | 236 | hs = self.TXTDecoder(query_embed_image, image_featuremap_fusion) 237 | else: 238 | 239 | hs = self.TXTDecoder(query_embed_image, image_featuremap) 240 | 241 | hs_dict = self.TXTDecoder(query_embed_image, memory_dict) 242 | 243 | image_part = [] 244 | for i in range(self.opt.num_query): 245 | hs_i = self.conv_1X1_2[i](hs[:,i].unsqueeze(2).unsqueeze(2)) 246 | image_part.append(hs_i.unsqueeze(0)) 247 | image_part = torch.cat(image_part,dim=0) 248 | image_part_dict = [] 249 | for i in range(self.opt.num_query): 250 | hs_i = self.conv_1X1_2[i](hs_dict[:,i].unsqueeze(2).unsqueeze(2)) 251 | image_part_dict.append(hs_i.unsqueeze(0)) 252 | image_part_dict = torch.cat(image_part_dict, dim=0) 253 | return image_part , image_part_dict 254 | 255 | def img_embedding(self, image): 256 | # img_socre = torch.ones(image.size(0), 6).to(self.opt.device) 257 | image_feature = self.ImageExtract(image) 258 | image_feature_part , image_feature_part_dict= self.image_DETR(None, image_feature) 259 | return image_feature_part, image_feature_part_dict 260 | 261 | def txt_embedding(self, caption_id, text_mask): 262 | # text_socre = torch.ones(caption_id.size(0), 6).to(self.opt.device) 263 | text_feature = self.TextExtract(caption_id, text_mask) 264 | # text_feature = text_feature.squeeze(2) 265 | # text_feature = text_feature.permute(0,2,1) 266 | text_feature_part, text_feature_part_dict= self.text_DETR(text_feature, caption_id ) 267 | return text_feature_part, text_feature_part_dict -------------------------------------------------------------------------------- /model/__pycache__/DETR_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/model/__pycache__/DETR_model.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/DETR_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/model/__pycache__/DETR_model.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/model/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/model/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/model/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/text_feature_extract.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/model/__pycache__/text_feature_extract.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/text_feature_extract.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/model/__pycache__/text_feature_extract.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/text_feature_extract.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/model/__pycache__/text_feature_extract.cpython-38.pyc -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat., Aug. 1(st), 2019 at 9:05 4 | 5 | @author: zifyloo 6 | """ 7 | 8 | from torch import nn 9 | from model.text_feature_extract import TextExtract 10 | from torchvision import models 11 | import torch 12 | from torch.nn import init 13 | import torch.nn.functional as F 14 | import numpy as np 15 | from model_TransREID.backbones.vit_pytorch import vit_base_patch16_224_TransReID, vit_small_patch16_224_TransReID, deit_small_patch16_224_TransReID 16 | 17 | 18 | def weights_init_kaiming(m): 19 | classname = m.__class__.__name__ 20 | if classname.find('Conv2d') != -1: 21 | init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='relu') 22 | elif classname.find('Linear') != -1: 23 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 24 | # init.constant(m.bias.data, 0.0) 25 | elif classname.find('BatchNorm1d') != -1: 26 | init.normal(m.weight.data, 1.0, 0.02) 27 | init.constant(m.bias.data, 0.0) 28 | elif classname.find('BatchNorm2d') != -1: 29 | init.constant(m.weight.data, 1) 30 | init.constant(m.bias.data, 0) 31 | 32 | 33 | def weights_init_classifier(m): 34 | classname = m.__class__.__name__ 35 | if classname.find('Linear') != -1: 36 | init.normal_(m.weight.data, std=0.001) 37 | # init.constant(m.bias.data, 0.0) 38 | 39 | 40 | class conv(nn.Module): 41 | 42 | def __init__(self, input_dim, output_dim, relu=False, BN=False): 43 | super(conv, self).__init__() 44 | 45 | block = [] 46 | block += [nn.Conv2d(input_dim, output_dim, kernel_size=1, bias=False)] 47 | 48 | if BN: 49 | block += [nn.BatchNorm2d(output_dim)] 50 | if relu: 51 | block += [nn.ReLU(inplace=True)] 52 | 53 | self.block = nn.Sequential(*block) 54 | self.block.apply(weights_init_kaiming) 55 | 56 | def forward(self, x): 57 | x = self.block(x) 58 | x = x.squeeze(3).squeeze(2) 59 | return x 60 | 61 | 62 | class TextImgPersonReidNet(nn.Module): 63 | 64 | def __init__(self, opt): 65 | super(TextImgPersonReidNet, self).__init__() 66 | 67 | self.opt = opt 68 | resnet50 = models.resnet50(pretrained=True) 69 | 70 | self.ImageExtract = nn.Sequential(*(list(resnet50.children())[:-2])) 71 | self.TextExtract = TextExtract(opt) 72 | 73 | self.avg_global = nn.AdaptiveMaxPool2d((1, 1)) 74 | # self.avg_global = nn.AdaptiveAvgPool2d((1, 1)) 75 | 76 | self.conv_1X1 = conv(2048, opt.feature_length) 77 | 78 | def forward(self, image, caption_id, text_length): 79 | 80 | image_feature = self.img_embedding(image) 81 | text_feature = self.txt_embedding(caption_id, text_length) 82 | # print(text_feature.shape) 83 | 84 | return image_feature, text_feature 85 | 86 | def img_embedding(self, image): 87 | image_feature = self.ImageExtract(image) 88 | 89 | image_feature = self.avg_global(image_feature) 90 | image_feature = self.conv_1X1(image_feature) 91 | 92 | return image_feature 93 | 94 | def txt_embedding(self, caption_id, text_length): 95 | text_feature = self.TextExtract(caption_id, text_length) 96 | 97 | text_feature = self.conv_1X1(text_feature) 98 | 99 | return text_feature 100 | 101 | 102 | class ft_net_TransREID(nn.Module): 103 | 104 | def __init__(self, class_num=751, droprate=0.5, stride=2): 105 | super(ft_net_TransREID, self).__init__() 106 | model_ft = vit_base_patch16_224_TransReID(img_size=[256, 128], sie_xishu=3.0, 107 | camera=0, view=0, stride_size=[16, 16], drop_path_rate=0.1, 108 | drop_rate= 0.0, 109 | attn_drop_rate=0.0) 110 | self.in_planes = 768 111 | model_path = '/home/zhiying/my_test/text-image-reid/jx_vit_base_p16_224-80ecf9dd.pth' 112 | model_ft.load_param(model_path) 113 | print('Loading pretrained ImageNet model......from {}'.format(model_path)) 114 | # avg pooling to global pooling 115 | # self.bottleneck = nn.BatchNorm1d(self.in_planes) 116 | # self.bottleneck.bias.requires_grad_(False) 117 | # self.bottleneck.apply(weights_init_kaiming) 118 | # self.gap = nn.AdaptiveAvgPool2d(1) 119 | self.model = model_ft 120 | # self.classifier = nn.Linear(self.in_planes, class_num, bias=False) 121 | # self.classifier.apply(weights_init_classifier) 122 | 123 | def forward(self, x): 124 | x = self.model(x) 125 | return x 126 | 127 | class ft_net_TransREID_local(nn.Module): 128 | 129 | def __init__(self, class_num=751, droprate=0.5, stride=2): 130 | super(ft_net_TransREID_local, self).__init__() 131 | model_ft = vit_base_patch16_224_TransReID(img_size=[384, 128], sie_xishu=3.0,local_feature=True, 132 | camera=0, view=0, stride_size=[16, 16], drop_path_rate=0.1, 133 | drop_rate= 0.0, 134 | attn_drop_rate=0.0) 135 | self.in_planes = 768 136 | model_path = '/home/zhiying/my_test/text-image-reid/jx_vit_base_p16_224-80ecf9dd.pth' 137 | model_ft.load_param(model_path) 138 | print('Loading pretrained ImageNet model......from {}'.format(model_path)) 139 | # avg pooling to global pooling 140 | # self.bottleneck = nn.BatchNorm1d(self.in_planes) 141 | # self.bottleneck.bias.requires_grad_(False) 142 | # self.bottleneck.apply(weights_init_kaiming) 143 | # self.gap = nn.AdaptiveAvgPool2d(1) 144 | self.model = model_ft 145 | # self.classifier = nn.Linear(self.in_planes, class_num, bias=False) 146 | # self.classifier.apply(weights_init_classifier) 147 | 148 | def forward(self, x): 149 | x = self.model(x) 150 | x = x[:,1:] 151 | return x 152 | 153 | class ft_net_TransREID_local_smallDeiT(nn.Module): 154 | 155 | def __init__(self, class_num=751, droprate=0.5, stride=2): 156 | super(ft_net_TransREID_local_smallDeiT, self).__init__() 157 | model_ft = deit_small_patch16_224_TransReID(img_size=[384, 128], sie_xishu=3.0,local_feature=True, 158 | camera=0, view=0, stride_size=[16, 16], drop_path_rate=0.1, 159 | drop_rate= 0.0, 160 | attn_drop_rate=0.0) 161 | self.in_planes = 768 162 | model_path = '/home/zhiyin/deit_small_distilled_patch16_224-649709d9.pth' 163 | model_ft.load_param(model_path) 164 | print('Loading pretrained ImageNet model......from {}'.format(model_path)) 165 | # avg pooling to global pooling 166 | # self.bottleneck = nn.BatchNorm1d(self.in_planes) 167 | # self.bottleneck.bias.requires_grad_(False) 168 | # self.bottleneck.apply(weights_init_kaiming) 169 | # self.gap = nn.AdaptiveAvgPool2d(1) 170 | self.model = model_ft 171 | # self.classifier = nn.Linear(self.in_planes, class_num, bias=False) 172 | # self.classifier.apply(weights_init_classifier) 173 | 174 | def forward(self, x): 175 | x = self.model(x) 176 | x = x[:, 1:] 177 | return x 178 | 179 | class ft_net_TransREID_local_smallVit(nn.Module): 180 | 181 | def __init__(self, class_num=751, droprate=0.5, stride=2): 182 | super(ft_net_TransREID_local_smallVit, self).__init__() 183 | model_ft = vit_small_patch16_224_TransReID(img_size=[384, 128], sie_xishu=3.0,local_feature=True, 184 | camera=0, view=0, stride_size=[16, 16], drop_path_rate=0.1, 185 | drop_rate=0.0, 186 | attn_drop_rate=0.0) 187 | self.in_planes = 768 188 | model_path = '/home/zhiying/my_test/text-image-reid/vit_small_p16_224-15ec54c9.pth' 189 | model_ft.load_param(model_path) 190 | print('Loading pretrained ImageNet model......from {}'.format(model_path)) 191 | # avg pooling to global pooling 192 | # self.bottleneck = nn.BatchNorm1d(self.in_planes) 193 | # self.bottleneck.bias.requires_grad_(False) 194 | # self.bottleneck.apply(weights_init_kaiming) 195 | # self.gap = nn.AdaptiveAvgPool2d(1) 196 | self.model = model_ft 197 | # self.classifier = nn.Linear(self.in_planes, class_num, bias=False) 198 | # self.classifier.apply(weights_init_classifier) 199 | 200 | def forward(self, x): 201 | x = self.model(x) 202 | x = x[:, 1:] 203 | return x -------------------------------------------------------------------------------- /model/text_feature_extract.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import transformers as ppb 4 | 5 | class TextExtract(nn.Module): 6 | 7 | def __init__(self, opt): 8 | super(TextExtract, self).__init__() 9 | 10 | self.opt = opt 11 | self.last_lstm = opt.last_lstm 12 | self.embedding = nn.Embedding(opt.vocab_size, 512, padding_idx=0) 13 | self.dropout = nn.Dropout(0.3) 14 | self.lstm = nn.LSTM(512, 384, num_layers=1, bidirectional=True, bias=False) 15 | 16 | def forward(self, caption_id, text_length): 17 | 18 | text_embedding = self.embedding(caption_id) 19 | text_embedding = self.dropout(text_embedding) 20 | feature = self.calculate_different_length_lstm(text_embedding, text_length, self.lstm) 21 | 22 | # feature = feature.unsqueeze(2).unsqueeze(2) 23 | 24 | return feature 25 | 26 | def calculate_different_length_lstm(self, text_embedding, text_length, lstm): 27 | 28 | text_length = text_length.view(-1) 29 | _, sort_index = torch.sort(text_length, dim=0, descending=True) 30 | _, unsort_index = sort_index.sort() 31 | 32 | sortlength_text_embedding = text_embedding[sort_index, :] 33 | sort_text_length = text_length[sort_index] 34 | # print(sort_text_length) 35 | packed_text_embedding = nn.utils.rnn.pack_padded_sequence(sortlength_text_embedding, 36 | sort_text_length, 37 | batch_first=True) 38 | 39 | packed_feature, [hn, _] = lstm(packed_text_embedding) # [hn, cn] 40 | sort_feature = nn.utils.rnn.pad_packed_sequence(packed_feature, batch_first=True) # including[feature, length] 41 | # print(hn.size(), cn.size()) 42 | 43 | if self.last_lstm: 44 | hn = torch.cat([hn[0, :, :], hn[1, :, :]], dim=1)[unsort_index, :] 45 | return hn 46 | else: 47 | unsort_feature = sort_feature[0][unsort_index, :] 48 | unsort_feature = (unsort_feature[:, :, :int(unsort_feature.size(2) / 2)] 49 | + unsort_feature[:, :, int(unsort_feature.size(2) / 2):]) / 2 50 | # print(text_length[9]) 51 | # print(unsort_feature[9,text_length[9]]) 52 | # print(unsort_feature[9, text_length[9]-1]) 53 | # feature, _ = unsort_feature.max(dim=1) 54 | """ 55 | mean_feature = [] 56 | for i in range(len(text_length)): 57 | mean_feature.append(torch.mean(unsort_feature[i, :text_length[i], :], dim=0).unsqueeze(0)) 58 | mean_feature = torch.cat(mean_feature, dim=0) 59 | """ 60 | return unsort_feature 61 | 62 | class TextExtract_Bert_lstm(nn.Module): 63 | def __init__(self, args): 64 | super(TextExtract_Bert_lstm, self).__init__() 65 | 66 | # self.model_txt = Vit_text(768) 67 | self.last_lstm = args.last_lstm 68 | model_class, tokenizer_class, pretrained_weights = (ppb.BertModel, ppb.BertTokenizer, 'bert-base-uncased') 69 | self.text_embed = model_class.from_pretrained(pretrained_weights) 70 | self.text_embed.eval() 71 | for p in self.text_embed.parameters(): 72 | p.requires_grad = False 73 | self.dropout = nn.Dropout(0.3) 74 | self.lstm = nn.LSTM(768, 384, num_layers=1, bidirectional=True, bias=False) 75 | 76 | def forward(self, txt, mask): 77 | length = mask.sum(1) 78 | length = length.cpu() 79 | with torch.no_grad(): 80 | txt = self.text_embed(txt, attention_mask=mask)# 81 | txt = txt[0] ##64 * L * 768 82 | # txt = txt.unsqueeze(1) 83 | # txt = txt.permute(0, 3, 1, 2) ##64 * 768 * 1 * 64 84 | # txt = self.model_txt(txt , trans_mask) # txt4: batch x 2048 x 1 x 64 85 | 86 | txt = self.calculate_different_length_lstm(txt,length,self.lstm) 87 | return txt 88 | 89 | def calculate_different_length_lstm(self, text_embedding, text_length, lstm): 90 | 91 | text_length = text_length.view(-1) 92 | _, sort_index = torch.sort(text_length, dim=0, descending=True) 93 | _, unsort_index = sort_index.sort() 94 | 95 | sortlength_text_embedding = text_embedding[sort_index, :] 96 | sort_text_length = text_length[sort_index] 97 | # print(sort_text_length) 98 | packed_text_embedding = nn.utils.rnn.pack_padded_sequence(sortlength_text_embedding, 99 | sort_text_length, 100 | batch_first=True) 101 | 102 | packed_feature, [hn, _] = lstm(packed_text_embedding) # [hn, cn] 103 | sort_feature = nn.utils.rnn.pad_packed_sequence(packed_feature, batch_first=True) # including[feature, length] 104 | # print(hn.size(), cn.size()) 105 | 106 | if self.last_lstm: 107 | hn = torch.cat([hn[0, :, :], hn[1, :, :]], dim=1)[unsort_index, :] 108 | return hn 109 | else: 110 | unsort_feature = sort_feature[0][unsort_index, :] 111 | unsort_feature = (unsort_feature[:, :, :int(unsort_feature.size(2) / 2)] 112 | + unsort_feature[:, :, int(unsort_feature.size(2) / 2):]) / 2 113 | # print(text_length[9]) 114 | # print(unsort_feature[9,text_length[9]]) 115 | # print(unsort_feature[9, text_length[9]-1]) 116 | # feature, _ = unsort_feature.max(dim=1) 117 | """ 118 | mean_feature = [] 119 | for i in range(len(text_length)): 120 | mean_feature.append(torch.mean(unsort_feature[i, :text_length[i], :], dim=0).unsqueeze(0)) 121 | mean_feature = torch.cat(mean_feature, dim=0) 122 | """ 123 | return unsort_feature -------------------------------------------------------------------------------- /model_TransREID/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_model import make_model -------------------------------------------------------------------------------- /model_TransREID/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/model_TransREID/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model_TransREID/__pycache__/make_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/model_TransREID/__pycache__/make_model.cpython-38.pyc -------------------------------------------------------------------------------- /model_TransREID/backbones/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/model_TransREID/backbones/__init__.py -------------------------------------------------------------------------------- /model_TransREID/backbones/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/model_TransREID/backbones/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model_TransREID/backbones/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/model_TransREID/backbones/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /model_TransREID/backbones/__pycache__/vit_pytorch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/model_TransREID/backbones/__pycache__/vit_pytorch.cpython-38.pyc -------------------------------------------------------------------------------- /model_TransREID/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None): 17 | super(BasicBlock, self).__init__() 18 | self.conv1 = conv3x3(inplanes, planes, stride) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.conv2 = conv3x3(planes, planes) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.downsample = downsample 24 | self.stride = stride 25 | 26 | def forward(self, x): 27 | residual = x 28 | 29 | out = self.conv1(x) 30 | out = self.bn1(out) 31 | out = self.relu(out) 32 | 33 | out = self.conv2(out) 34 | out = self.bn2(out) 35 | 36 | if self.downsample is not None: 37 | residual = self.downsample(x) 38 | 39 | out += residual 40 | out = self.relu(out) 41 | 42 | return out 43 | 44 | 45 | class Bottleneck(nn.Module): 46 | expansion = 4 47 | 48 | def __init__(self, inplanes, planes, stride=1, downsample=None): 49 | super(Bottleneck, self).__init__() 50 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 51 | self.bn1 = nn.BatchNorm2d(planes) 52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 53 | padding=1, bias=False) 54 | self.bn2 = nn.BatchNorm2d(planes) 55 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 56 | self.bn3 = nn.BatchNorm2d(planes * 4) 57 | self.relu = nn.ReLU(inplace=True) 58 | self.downsample = downsample 59 | self.stride = stride 60 | 61 | def forward(self, x): 62 | residual = x 63 | 64 | out = self.conv1(x) 65 | out = self.bn1(out) 66 | out = self.relu(out) 67 | 68 | out = self.conv2(out) 69 | out = self.bn2(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv3(out) 73 | out = self.bn3(out) 74 | 75 | if self.downsample is not None: 76 | residual = self.downsample(x) 77 | 78 | out += residual 79 | out = self.relu(out) 80 | 81 | return out 82 | 83 | 84 | class ResNet(nn.Module): 85 | def __init__(self, last_stride=2, block=Bottleneck,layers=[3, 4, 6, 3]): 86 | self.inplanes = 64 87 | super().__init__() 88 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 89 | bias=False) 90 | self.bn1 = nn.BatchNorm2d(64) 91 | # self.relu = nn.ReLU(inplace=True) # add missed relu 92 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=None, padding=0) 93 | self.layer1 = self._make_layer(block, 64, layers[0]) 94 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 95 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 96 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride) 97 | 98 | def _make_layer(self, block, planes, blocks, stride=1): 99 | downsample = None 100 | if stride != 1 or self.inplanes != planes * block.expansion: 101 | downsample = nn.Sequential( 102 | nn.Conv2d(self.inplanes, planes * block.expansion, 103 | kernel_size=1, stride=stride, bias=False), 104 | nn.BatchNorm2d(planes * block.expansion), 105 | ) 106 | 107 | layers = [] 108 | layers.append(block(self.inplanes, planes, stride, downsample)) 109 | self.inplanes = planes * block.expansion 110 | for i in range(1, blocks): 111 | layers.append(block(self.inplanes, planes)) 112 | 113 | return nn.Sequential(*layers) 114 | 115 | def forward(self, x, cam_label=None): 116 | x = self.conv1(x) 117 | x = self.bn1(x) 118 | # x = self.relu(x) # add missed relu 119 | x = self.maxpool(x) 120 | x = self.layer1(x) 121 | x = self.layer2(x) 122 | x = self.layer3(x) 123 | x = self.layer4(x) 124 | 125 | return x 126 | 127 | def load_param(self, model_path): 128 | param_dict = torch.load(model_path) 129 | for i in param_dict: 130 | if 'fc' in i: 131 | continue 132 | self.state_dict()[i].copy_(param_dict[i]) 133 | 134 | def random_init(self): 135 | for m in self.modules(): 136 | if isinstance(m, nn.Conv2d): 137 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 138 | m.weight.data.normal_(0, math.sqrt(2. / n)) 139 | elif isinstance(m, nn.BatchNorm2d): 140 | m.weight.data.fill_(1) 141 | m.bias.data.zero_() -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # from .detr import build 3 | 4 | 5 | # def build_model(args): 6 | # return build(args) 7 | -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/models/__pycache__/transformer.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/models/__pycache__/transformer.cpython-38.pyc -------------------------------------------------------------------------------- /models/backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Backbone modules. 4 | """ 5 | from collections import OrderedDict 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torchvision 10 | from torch import nn 11 | from torchvision.models._utils import IntermediateLayerGetter 12 | from typing import Dict, List 13 | 14 | from util.misc import NestedTensor, is_main_process 15 | 16 | from .position_encoding import build_position_encoding 17 | 18 | 19 | class FrozenBatchNorm2d(torch.nn.Module): 20 | """ 21 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 22 | 23 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 24 | without which any other models than torchvision.models.resnet[18,34,50,101] 25 | produce nans. 26 | """ 27 | 28 | def __init__(self, n): 29 | super(FrozenBatchNorm2d, self).__init__() 30 | self.register_buffer("weight", torch.ones(n)) 31 | self.register_buffer("bias", torch.zeros(n)) 32 | self.register_buffer("running_mean", torch.zeros(n)) 33 | self.register_buffer("running_var", torch.ones(n)) 34 | 35 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 36 | missing_keys, unexpected_keys, error_msgs): 37 | num_batches_tracked_key = prefix + 'num_batches_tracked' 38 | if num_batches_tracked_key in state_dict: 39 | del state_dict[num_batches_tracked_key] 40 | 41 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 42 | state_dict, prefix, local_metadata, strict, 43 | missing_keys, unexpected_keys, error_msgs) 44 | 45 | def forward(self, x): 46 | # move reshapes to the beginning 47 | # to make it fuser-friendly 48 | w = self.weight.reshape(1, -1, 1, 1) 49 | b = self.bias.reshape(1, -1, 1, 1) 50 | rv = self.running_var.reshape(1, -1, 1, 1) 51 | rm = self.running_mean.reshape(1, -1, 1, 1) 52 | eps = 1e-5 53 | scale = w * (rv + eps).rsqrt() 54 | bias = b - rm * scale 55 | return x * scale + bias 56 | 57 | 58 | class BackboneBase(nn.Module): 59 | 60 | def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): 61 | super().__init__() 62 | for name, parameter in backbone.named_parameters(): 63 | if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: 64 | parameter.requires_grad_(False) 65 | if return_interm_layers: 66 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} 67 | else: 68 | return_layers = {'layer4': "0"} 69 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 70 | self.num_channels = num_channels 71 | 72 | def forward(self, tensor_list: NestedTensor): 73 | xs = self.body(tensor_list.tensors) 74 | out: Dict[str, NestedTensor] = {} 75 | for name, x in xs.items(): 76 | m = tensor_list.mask 77 | assert m is not None 78 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 79 | out[name] = NestedTensor(x, mask) 80 | return out 81 | 82 | 83 | class Backbone(BackboneBase): 84 | """ResNet backbone with frozen BatchNorm.""" 85 | def __init__(self, name: str, 86 | train_backbone: bool, 87 | return_interm_layers: bool, 88 | dilation: bool): 89 | backbone = getattr(torchvision.models, name)( 90 | replace_stride_with_dilation=[False, False, dilation], 91 | pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) 92 | num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 93 | super().__init__(backbone, train_backbone, num_channels, return_interm_layers) 94 | 95 | 96 | class Joiner(nn.Sequential): 97 | def __init__(self, backbone, position_embedding): 98 | super().__init__(backbone, position_embedding) 99 | 100 | def forward(self, tensor_list: NestedTensor): 101 | xs = self[0](tensor_list) 102 | out: List[NestedTensor] = [] 103 | pos = [] 104 | for name, x in xs.items(): 105 | out.append(x) 106 | # position encoding 107 | pos.append(self[1](x).to(x.tensors.dtype)) 108 | 109 | return out, pos 110 | 111 | 112 | def build_backbone(args): 113 | position_embedding = build_position_encoding(args) 114 | train_backbone = args.lr_backbone > 0 115 | return_interm_layers = args.masks 116 | backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) 117 | model = Joiner(backbone, position_embedding) 118 | model.num_channels = backbone.num_channels 119 | return model 120 | -------------------------------------------------------------------------------- /models/matcher.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Modules to compute the matching cost and solve the corresponding LSAP. 4 | """ 5 | import torch 6 | from scipy.optimize import linear_sum_assignment 7 | from torch import nn 8 | 9 | from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou 10 | 11 | 12 | class HungarianMatcher(nn.Module): 13 | """This class computes an assignment between the targets and the predictions of the network 14 | 15 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 16 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 17 | while the others are un-matched (and thus treated as non-objects). 18 | """ 19 | 20 | def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1): 21 | """Creates the matcher 22 | 23 | Params: 24 | cost_class: This is the relative weight of the classification error in the matching cost 25 | cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost 26 | cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost 27 | """ 28 | super().__init__() 29 | self.cost_class = cost_class 30 | self.cost_bbox = cost_bbox 31 | self.cost_giou = cost_giou 32 | assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" 33 | 34 | @torch.no_grad() 35 | def forward(self, outputs, targets): 36 | """ Performs the matching 37 | 38 | Params: 39 | outputs: This is a dict that contains at least these entries: 40 | "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits 41 | "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates 42 | 43 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: 44 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth 45 | objects in the target) containing the class labels 46 | "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates 47 | 48 | Returns: 49 | A list of size batch_size, containing tuples of (index_i, index_j) where: 50 | - index_i is the indices of the selected predictions (in order) 51 | - index_j is the indices of the corresponding selected targets (in order) 52 | For each batch element, it holds: 53 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 54 | """ 55 | bs, num_queries = outputs["pred_logits"].shape[:2] 56 | 57 | # We flatten to compute the cost matrices in a batch 58 | out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] 59 | out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] 60 | 61 | # Also concat the target labels and boxes 62 | tgt_ids = torch.cat([v["labels"] for v in targets]) 63 | tgt_bbox = torch.cat([v["boxes"] for v in targets]) 64 | 65 | # Compute the classification cost. Contrary to the loss, we don't use the NLL, 66 | # but approximate it in 1 - proba[target class]. 67 | # The 1 is a constant that doesn't change the matching, it can be ommitted. 68 | cost_class = -out_prob[:, tgt_ids] 69 | 70 | # Compute the L1 cost between boxes 71 | cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) 72 | 73 | # Compute the giou cost betwen boxes 74 | cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) 75 | 76 | # Final cost matrix 77 | C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou 78 | C = C.view(bs, num_queries, -1).cpu() 79 | 80 | sizes = [len(v["boxes"]) for v in targets] 81 | indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] 82 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] 83 | 84 | 85 | def build_matcher(args): 86 | return HungarianMatcher(cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou) 87 | -------------------------------------------------------------------------------- /models/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Various positional encodings for the transformer. 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | 9 | from util.misc import NestedTensor 10 | 11 | 12 | class PositionEmbeddingSine(nn.Module): 13 | """ 14 | This is a more standard version of the position embedding, very similar to the one 15 | used by the Attention is all you need paper, generalized to work on images. 16 | """ 17 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 18 | super().__init__() 19 | self.num_pos_feats = num_pos_feats 20 | self.temperature = temperature 21 | self.normalize = normalize 22 | if scale is not None and normalize is False: 23 | raise ValueError("normalize should be True if scale is passed") 24 | if scale is None: 25 | scale = 2 * math.pi 26 | self.scale = scale 27 | 28 | def forward(self, tensor_list: NestedTensor): 29 | x = tensor_list.tensors 30 | mask = tensor_list.mask 31 | assert mask is not None 32 | not_mask = ~mask 33 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 34 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 35 | if self.normalize: 36 | eps = 1e-6 37 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 38 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 39 | 40 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 41 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 42 | 43 | pos_x = x_embed[:, :, :, None] / dim_t 44 | pos_y = y_embed[:, :, :, None] / dim_t 45 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 46 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 47 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 48 | return pos 49 | 50 | 51 | class PositionEmbeddingLearned(nn.Module): 52 | """ 53 | Absolute pos embedding, learned. 54 | """ 55 | def __init__(self, num_pos_feats=256): 56 | super().__init__() 57 | self.row_embed = nn.Embedding(50, num_pos_feats) 58 | self.col_embed = nn.Embedding(50, num_pos_feats) 59 | self.reset_parameters() 60 | 61 | def reset_parameters(self): 62 | nn.init.uniform_(self.row_embed.weight) 63 | nn.init.uniform_(self.col_embed.weight) 64 | 65 | def forward(self, tensor_list: NestedTensor): 66 | x = tensor_list.tensors 67 | h, w = x.shape[-2:] 68 | i = torch.arange(w, device=x.device) 69 | j = torch.arange(h, device=x.device) 70 | x_emb = self.col_embed(i) 71 | y_emb = self.row_embed(j) 72 | pos = torch.cat([ 73 | x_emb.unsqueeze(0).repeat(h, 1, 1), 74 | y_emb.unsqueeze(1).repeat(1, w, 1), 75 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 76 | return pos 77 | 78 | 79 | def build_position_encoding(args): 80 | N_steps = args.hidden_dim // 2 81 | if args.position_embedding in ('v2', 'sine'): 82 | # TODO find a better way of exposing other arguments 83 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 84 | elif args.position_embedding in ('v3', 'learned'): 85 | position_embedding = PositionEmbeddingLearned(N_steps) 86 | else: 87 | raise ValueError(f"not supported {args.position_embedding}") 88 | 89 | return position_embedding 90 | -------------------------------------------------------------------------------- /option/__pycache__/options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/option/__pycache__/options.cpython-36.pyc -------------------------------------------------------------------------------- /option/__pycache__/options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/option/__pycache__/options.cpython-37.pyc -------------------------------------------------------------------------------- /option/__pycache__/options.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/option/__pycache__/options.cpython-38.pyc -------------------------------------------------------------------------------- /option/options.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thurs., Aug. 1(st), 2019 4 | 5 | Update on on Sun., Aug. 4(th), 2019 6 | 7 | @author: zifyloo 8 | """ 9 | 10 | import argparse 11 | import torch 12 | import logging 13 | import os 14 | from utils.read_write_data import makedir 15 | 16 | logger = logging.getLogger() 17 | logger.setLevel(logging.INFO) 18 | 19 | 20 | class options(): 21 | def __init__(self): 22 | self._par = argparse.ArgumentParser(description='options for Deep Cross Modal') 23 | 24 | self._par.add_argument('--mode', type=str, default='train', help='choose mode [train or test]') 25 | self._par.add_argument('--trained', type=bool, default=False, help='whether the network has pretrained model') 26 | 27 | # self._par.add_argument('--CMPM', type=bool, default=True, help='whether use the CMPM loss') 28 | # self._par.add_argument('--CMPC', type=bool, default=True, help='whether use the CMPC loss') 29 | self._par.add_argument('--bidirectional', type=bool, default=True, help='whether the lstm is bidirectional') 30 | self._par.add_argument('--using_pose', type=bool, default=False, help='whether using pose') 31 | self._par.add_argument('--last_lstm', type=bool, default=False, help='whether just using the last lstm') 32 | self._par.add_argument('--using_noun', type=bool, default=False, help='whether just using the noun') 33 | 34 | self._par.add_argument('--epoch', type=int, default=300, help='train epoch') 35 | self._par.add_argument('--start_epoch', type=int, default=0, help='the start epoch') 36 | self._par.add_argument('--epoch_decay', type=list, default=[], help='decay epoch') 37 | self._par.add_argument('--wd', type=float, default=0.00004, help='weight decay') 38 | self._par.add_argument('--batch_size', type=int, default=16, help='batch size') 39 | self._par.add_argument('--adam_alpha', type=float, default=0.9, help='momentum term of adam') 40 | self._par.add_argument('--adam_beta', type=float, default=0.999, help='momentum term of adam') 41 | self._par.add_argument('--lr', type=float, default=0.002, help='initial learning rate for adam') 42 | self._par.add_argument('--margin', type=float, default=0.2, help='ranking loss margin') 43 | 44 | self._par.add_argument('--vocab_size', type=int, default=5000, help='the size of vocab') 45 | self._par.add_argument('--feature_length', type=int, default=512, help='the length of feature') 46 | self._par.add_argument('--class_num', type=int, default=11003, 47 | help='num of class for StarGAN training on second dataset') 48 | self._par.add_argument('--part', type=int, default=6, help='the num of image part') 49 | self._par.add_argument('--caption_length_max', type=int, default=100, help='the max length of caption') 50 | self._par.add_argument('--random_erasing', type=float, default=0.0, help='the probability of random_erasing') 51 | 52 | self._par.add_argument('--Save_param_every', type=int, default=5, help='the frequency of save the param ') 53 | self._par.add_argument('--save_path', type=str, default='./checkpoints/test', 54 | help='save the result during training') 55 | self._par.add_argument('--GPU_id', type=list, default='2', help='choose GPU ID [0 1]') 56 | self._par.add_argument('--device', type=str, default='', help='cuda devie') 57 | self._par.add_argument('--dataset', type=str, default='CUHK-PEDES', help='choose the dataset ') 58 | self._par.add_argument('--dataroot', type=str, default='/data1/zhiying/text-image/CUHK-PEDES', 59 | help='data root of the Data') 60 | self._par.add_argument('--pkl_root', type=str, 61 | default='/home/zefeng/Exp/code/text-image/code by myself/data/processed_data/', 62 | help='data root of the pkl') 63 | 64 | self._par.add_argument('--test_image_num', type=int, default=200, help='the num of images in test mode') 65 | 66 | self.opt = self._par.parse_args() 67 | 68 | self.opt.device = torch.device('cuda:{}'.format(self.opt.GPU_id[0])) 69 | 70 | 71 | def config(opt): 72 | 73 | log_config(opt) 74 | model_root = os.path.join(opt.save_path, 'model') 75 | if os.path.exists(model_root) is False: 76 | makedir(model_root) 77 | 78 | 79 | def log_config(opt): 80 | logroot = os.path.join(opt.save_path, 'log') 81 | if os.path.exists(logroot) is False: 82 | makedir(logroot) 83 | filename = os.path.join(logroot, opt.mode + '.log') 84 | handler = logging.FileHandler(filename) 85 | handler.setLevel(logging.INFO) 86 | formatter = logging.Formatter('%(message)s') 87 | handler.setFormatter(formatter) 88 | logger.addHandler(logging.StreamHandler()) 89 | logger.addHandler(handler) 90 | if opt.mode != 'test': 91 | logger.info(opt) 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /processed_data_singledata_CUHK.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | processes the CUHK_PEDES/reid_raw.json, output the train_data, val_data, test_data, 4 | all data including[image_path, caption_id(be coded), label] 5 | 6 | Created on Thurs., Aug. 1(st), 2019 at 20:10 7 | 8 | @author: zifyloo 9 | """ 10 | 11 | from utils_VD.read_write_data import read_json, makedir, save_dict, write_txt 12 | import argparse 13 | from collections import namedtuple 14 | import os 15 | from random import shuffle 16 | import numpy as np 17 | import pickle 18 | import transformers as ppb 19 | 20 | 21 | ImageDecodeData = namedtuple('ImageDecodeData', ['id', 'image_path', 'captions_id', 'split']) 22 | 23 | 24 | class Word2Index(object): 25 | 26 | def __init__(self, vocab): 27 | self._vocab = {w: index + 1 for index, w in enumerate(vocab)} 28 | self.unk_id = len(vocab) + 1 29 | # print(self._vocab) 30 | 31 | def __call__(self, word): 32 | if word not in self._vocab: 33 | return self.unk_id 34 | return self._vocab[word] 35 | 36 | 37 | def parse_args(): 38 | parser = argparse.ArgumentParser(description='Command for data pre_processing') 39 | parser.add_argument('--img_root', default='/data1/zhiying/text-image/CUHK-PEDES/imgs', type=str) 40 | parser.add_argument('--json_root', default='//data1/zhiying/text-image/CUHK-PEDES/reid_raw.json', type=str) 41 | parser.add_argument('--out_root', default='./processed_data_singledata_CUHK', type=str) # processed_data_spa_img 42 | parser.add_argument('--min_word_count', default='2', type=int) 43 | parser.add_argument('--shuffle', default=False, type=bool) 44 | args = parser.parse_args() 45 | return args 46 | 47 | 48 | def split_json(args): 49 | """ 50 | has 40206 image in reid_raw_data 51 | has 13003 id 52 | every id has several images and every image has several caption 53 | data's structure in reid_raw_data is dict ['split', 'captions', 'file_path', 'processed_tokens', 'id'] 54 | """ 55 | reid_raw_data = read_json(args.json_root) 56 | 57 | train_json = [] 58 | test_json = [] 59 | 60 | for data in reid_raw_data: 61 | data_save = { 62 | 'img_path': 'imgs/'+data['file_path'], 63 | 'id': data['id'], 64 | 'tokens': data['processed_tokens'], 65 | 'captions': data['captions'] 66 | } 67 | split = data['split'].lower() 68 | if split == 'train': 69 | train_json.append(data_save) 70 | 71 | if split == 'test': 72 | test_json.append(data_save) 73 | return train_json, test_json 74 | 75 | 76 | def build_vocabulary(train_json, args): 77 | 78 | word_count = {} 79 | for data in train_json: 80 | for caption in data['tokens']: 81 | for word in caption: 82 | word_count[word.lower()] = word_count.get(word.lower(), 0) + 1 83 | 84 | word_count_list = [[v, k] for v, k in word_count.items()] 85 | word_count_list.sort(key=lambda x: x[1], reverse=True) # from high to low 86 | 87 | good_vocab = [v for v, k in word_count.items() if k >= args.min_word_count] 88 | 89 | print('top-10 highest frequency words:') 90 | for w, n in word_count_list[0:10]: 91 | print(w, n) 92 | 93 | good_count = len(good_vocab) 94 | all_count = len(word_count_list) 95 | good_word_rate = good_count * 100.0 / all_count 96 | st = 'good words: %d, total_words: %d, good_word_rate: %f%%' % (good_count, all_count, good_word_rate) 97 | write_txt(st, os.path.join(args.out_root, 'data_message')) 98 | print(st) 99 | word2Ind = Word2Index(good_vocab) 100 | 101 | save_dict(good_vocab, os.path.join(args.out_root, 'ind2word')) 102 | return word2Ind 103 | 104 | 105 | def generate_captionid(data_json, word2Ind, data_name, args, tokenizer): 106 | 107 | id_save = [] 108 | lstm_caption_id_save = [] 109 | bert_caption_id_save = [] 110 | img_path_save = [] 111 | caption_save = [] 112 | same_id_index_save = [] 113 | un_idx = word2Ind.unk_id 114 | # train_id_count = {} 115 | data_save_by_id = {} 116 | 117 | count_id = [] 118 | for data in data_json: 119 | 120 | if data['id'] in [1369, 4116, 6116]: # only one image 121 | print(111) 122 | continue 123 | if data['id'] not in count_id: 124 | count_id.append(data['id']) 125 | 126 | id_new = len(count_id) - 1 127 | 128 | data_save_i = { 129 | 'img_path': data['img_path'], 130 | 'id': id_new, 131 | 'tokens': data['tokens'], 132 | 'captions': data['captions'] 133 | } 134 | if id_new not in data_save_by_id.keys(): 135 | data_save_by_id[id_new] = [] 136 | 137 | data_save_by_id[id_new].append(data_save_i) 138 | 139 | data_order = 0 140 | for id_new, data_save_by_id_i in data_save_by_id.items(): 141 | 142 | caption_length = 0 143 | for data_save_by_id_i_i in data_save_by_id_i: 144 | caption_length += len(data_save_by_id_i_i['captions']) 145 | 146 | data_order_i = data_order + np.arange(caption_length) 147 | data_order_i_begin = 0 148 | 149 | for data_save_by_id_i_i in data_save_by_id_i: 150 | caption_length_i = len(data_save_by_id_i_i['captions']) 151 | data_order_i_end = data_order_i_begin + caption_length_i 152 | data_order_i_select = np.delete(data_order_i, np.arange(data_order_i_begin, data_order_i_end)) 153 | data_order_i_begin = data_order_i_end 154 | 155 | for j in range(len(data_save_by_id_i_i['tokens'])): 156 | tokens_j = data_save_by_id_i_i['tokens'][j] 157 | lstm_caption_id = [] 158 | for word in tokens_j: 159 | lstm_caption_id.append(word2Ind(word)) 160 | if un_idx in lstm_caption_id: 161 | lstm_caption_id = list(filter(lambda x: x != un_idx, lstm_caption_id)) 162 | 163 | caption_j = data_save_by_id_i_i['captions'][j] 164 | bert_caption_id = tokenizer.encode(caption_j, add_special_tokens=True) 165 | 166 | id_save.append(data_save_by_id_i_i['id']) 167 | img_path_save.append(data_save_by_id_i_i['img_path']) 168 | same_id_index_save.append(data_order_i_select) 169 | 170 | lstm_caption_id_save.append(lstm_caption_id) 171 | 172 | bert_caption_id_save.append(bert_caption_id) 173 | caption_save.append(caption_j) 174 | 175 | data_order = data_order + caption_length 176 | print(sorted(count_id)) 177 | data_save = { 178 | 'id': id_save, 179 | 'img_path': img_path_save, 180 | 'same_id_index': same_id_index_save, 181 | 182 | 'lstm_caption_id': lstm_caption_id_save, 183 | 'bert_caption_id': bert_caption_id_save, 184 | 'captions': caption_save, 185 | } 186 | 187 | img_num = len(set(img_path_save)) 188 | id_num = len(set(id_save)) 189 | # print(sorted(set(id_save))) 190 | caption_num = len(lstm_caption_id_save) 191 | """ 192 | for i in range(len(same_id_index_save)): 193 | for j in same_id_index_save[i]: 194 | if id_save[i] != id_save[j] or i in same_id_index_save[i]: 195 | print(111) 196 | """ 197 | st = '%s_img_num: %d, %s_id_num: %d, %s_caption_num: %d, ' %(data_name, img_num, data_name, id_num, data_name, caption_num) 198 | write_txt(st, os.path.join(args.out_root, 'data_message')) 199 | 200 | return data_save 201 | 202 | 203 | def generate_test_val_caption_id(data_json, word2Ind, data_name, args, tokenizer): 204 | id_save = [] 205 | lstm_caption_id_save = [] 206 | bert_caption_id_save = [] 207 | caption_save = [] 208 | img_path_save = [] 209 | img_caption_index_save = [] 210 | caption_matching_img_index_save = [] 211 | caption_label_save = [] 212 | 213 | un_idx = word2Ind.unk_id 214 | 215 | img_caption_index_i = 0 216 | caption_matching_img_index_i = 0 217 | for data in data_json: 218 | id_save.append(data['id']) 219 | img_path_save.append(data['img_path']) 220 | 221 | for j in range(len(data['tokens'])): 222 | 223 | tokens_j = data['tokens'][j] 224 | lstm_caption_id = [] 225 | for word in tokens_j: 226 | lstm_caption_id.append(word2Ind(word)) 227 | if un_idx in lstm_caption_id: 228 | lstm_caption_id = list(filter(lambda x: x != un_idx, lstm_caption_id)) 229 | 230 | caption_j = data['captions'][j] 231 | bert_caption_id = tokenizer.encode(caption_j, add_special_tokens=True) 232 | 233 | caption_matching_img_index_save.append(caption_matching_img_index_i) 234 | lstm_caption_id_save.append(lstm_caption_id) 235 | bert_caption_id_save.append(bert_caption_id) 236 | caption_save.append(caption_j) 237 | 238 | caption_label_save.append(data['id']) 239 | img_caption_index_save.append([img_caption_index_i, img_caption_index_i+len(data['captions'])-1]) 240 | img_caption_index_i += len(data['captions']) 241 | caption_matching_img_index_i += 1 242 | 243 | data_save = { 244 | 'id': id_save, 245 | 'img_path': img_path_save, 246 | 'img_caption_index': img_caption_index_save, 247 | 248 | 'caption_matching_img_index': caption_matching_img_index_save, 249 | 'caption_label': caption_label_save, 250 | 'lstm_caption_id': lstm_caption_id_save, 251 | 'bert_caption_id': bert_caption_id_save, 252 | 'captions': caption_save, 253 | } 254 | 255 | img_num = len(set(img_path_save)) 256 | id_num = len(set(id_save)) 257 | caption_num = len(lstm_caption_id_save) 258 | st = '%s_img_num: %d, %s_id_num: %d, %s_caption_num: %d, ' % ( 259 | data_name, img_num, data_name, id_num, data_name, caption_num) 260 | write_txt(st, os.path.join(args.out_root, 'data_message')) 261 | 262 | # print(sorted(set(id_save))) 263 | 264 | return data_save 265 | 266 | 267 | def main(args): 268 | train_json, test_json = split_json(args) 269 | 270 | word2Ind = build_vocabulary(train_json, args) 271 | 272 | # with open('./word2Ind' + '.pkl', 'wb') as f: 273 | # pickle.dump(word2Ind, f, pickle.HIGHEST_PROTOCOL) 274 | 275 | model_class, tokenizer_class, pretrained_weights = (ppb.BertModel, ppb.BertTokenizer, 'bert-base-uncased') 276 | tokenizer = tokenizer_class.from_pretrained(pretrained_weights) 277 | 278 | train_save = generate_captionid(train_json, word2Ind, 'train', args, tokenizer) 279 | test_save = generate_test_val_caption_id(test_json, word2Ind, 'test', args, tokenizer) 280 | 281 | 282 | save_dict(train_save, os.path.join(args.out_root, 'train_save')) 283 | save_dict(test_save, os.path.join(args.out_root, 'test_save')) 284 | 285 | 286 | 287 | if __name__ == '__main__': 288 | 289 | args = parse_args() 290 | if args.shuffle: 291 | args.out_root = args.out_root + '_shuffle' 292 | 293 | makedir(args.out_root) 294 | main(args) 295 | """ 296 | from utils.read_write_data import read_dict 297 | train_save_dic = read_dict(args.out_root + '/train_save.pkl') 298 | 299 | id_save = train_save_dic['id'] 300 | img_path_save = train_save_dic['img_path'] 301 | caption_id_save = train_save_dic['caption_id'] 302 | same_id_index_save = train_save_dic['same_id_index'] 303 | 304 | num = 1600 305 | print(id_save[num]) 306 | print(img_path_save[num]) 307 | print(caption_id_save[num]) 308 | 309 | print(same_id_index_save[num]) 310 | for x in same_id_index_save[num]: 311 | print(id_save[x]) 312 | print(img_path_save[x]) 313 | 314 | 315 | train_save_dic = read_dict('./processed_data/train_save.pkl') 316 | id_save = train_save_dic['id'] 317 | img_path_save = train_save_dic['img_path'] 318 | caption_id_save = train_save_dic['caption_id'] 319 | print(id_save[num]) 320 | print(img_path_save[num]) 321 | print(caption_id_save[num]) 322 | """ 323 | -------------------------------------------------------------------------------- /processed_data_singledata_ICFG.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | processes the CUHK_PEDES/reid_raw.json, output the train_data, val_data, test_data, 4 | all data including[image_path, caption_id(be coded), label] 5 | 6 | Created on Thurs., Aug. 1(st), 2019 at 20:10 7 | 8 | @author: zifyloo 9 | """ 10 | 11 | from utils.read_write_data import read_json, makedir, save_dict, write_txt 12 | import argparse 13 | from collections import namedtuple 14 | import os 15 | from random import shuffle 16 | import numpy as np 17 | import pickle 18 | import transformers as ppb 19 | 20 | 21 | ImageDecodeData = namedtuple('ImageDecodeData', ['id', 'image_path', 'captions_id', 'split']) 22 | 23 | 24 | class Word2Index(object): 25 | 26 | def __init__(self, vocab): 27 | self._vocab = {w: index + 1 for index, w in enumerate(vocab)} 28 | self.unk_id = len(vocab) + 1 29 | # print(self._vocab) 30 | 31 | def __call__(self, word): 32 | if word not in self._vocab: 33 | return self.unk_id 34 | return self._vocab[word] 35 | 36 | 37 | def parse_args(): 38 | parser = argparse.ArgumentParser(description='Command for data pre_processing') 39 | parser.add_argument('--img_root', default='/data1/zhiying/text-image/data/ICFG_PEDES', type=str) 40 | parser.add_argument('--json_root', default='/data1/zhiying/text-image/data/ICFG-PEDES.json', type=str) 41 | parser.add_argument('--out_root', default='./processed_data_singledata_ICFG', type=str) # processed_data_spa_img 42 | parser.add_argument('--min_word_count', default='2', type=int) 43 | parser.add_argument('--shuffle', default=False, type=bool) 44 | args = parser.parse_args() 45 | return args 46 | 47 | 48 | def split_json(args): 49 | """ 50 | has 40206 image in reid_raw_data 51 | has 13003 id 52 | every id has several images and every image has several caption 53 | data's structure in reid_raw_data is dict ['split', 'captions', 'file_path', 'processed_tokens', 'id'] 54 | """ 55 | reid_raw_data = read_json(args.json_root) 56 | 57 | train_json = [] 58 | test_json = [] 59 | 60 | for data in reid_raw_data: 61 | data_save = { 62 | 'img_path': data['file_path'], 63 | 'id': data['id'], 64 | 'tokens': data['processed_tokens'], 65 | 'captions': data['captions'] 66 | } 67 | split = data['split'].lower() 68 | if split == 'train': 69 | train_json.append(data_save) 70 | 71 | if split == 'test': 72 | data_save['tokens'] = data_save['tokens'][0] 73 | test_json.append(data_save) 74 | return train_json, test_json 75 | 76 | 77 | def build_vocabulary(train_json, args): 78 | 79 | word_count = {} 80 | for data in train_json: 81 | for caption in data['tokens']: 82 | for word in caption: 83 | word_count[word.lower()] = word_count.get(word.lower(), 0) + 1 84 | 85 | word_count_list = [[v, k] for v, k in word_count.items()] 86 | word_count_list.sort(key=lambda x: x[1], reverse=True) # from high to low 87 | 88 | good_vocab = [v for v, k in word_count.items() if k >= args.min_word_count] 89 | 90 | print('top-10 highest frequency words:') 91 | for w, n in word_count_list[0:10]: 92 | print(w, n) 93 | 94 | good_count = len(good_vocab) 95 | all_count = len(word_count_list) 96 | good_word_rate = good_count * 100.0 / all_count 97 | st = 'good words: %d, total_words: %d, good_word_rate: %f%%' % (good_count, all_count, good_word_rate) 98 | write_txt(st, os.path.join(args.out_root, 'data_message')) 99 | print(st) 100 | word2Ind = Word2Index(good_vocab) 101 | 102 | save_dict(good_vocab, os.path.join(args.out_root, 'ind2word')) 103 | return word2Ind 104 | 105 | 106 | def generate_captionid(data_json, word2Ind, data_name, args, tokenizer): 107 | 108 | id_save = [] 109 | lstm_caption_id_save = [] 110 | bert_caption_id_save = [] 111 | img_path_save = [] 112 | caption_save = [] 113 | same_id_index_save = [] 114 | un_idx = word2Ind.unk_id 115 | # train_id_count = {} 116 | data_save_by_id = {} 117 | 118 | count_id = [] 119 | for data in data_json: 120 | 121 | if data['id'] in [1369, 4116, 6116]: # only one image 122 | print(111) 123 | continue 124 | if data['id'] not in count_id: 125 | count_id.append(data['id']) 126 | 127 | id_new = len(count_id) - 1 128 | 129 | data_save_i = { 130 | 'img_path': data['img_path'], 131 | 'id': id_new, 132 | 'tokens': data['tokens'], 133 | 'captions': data['captions'] 134 | } 135 | if id_new not in data_save_by_id.keys(): 136 | data_save_by_id[id_new] = [] 137 | 138 | data_save_by_id[id_new].append(data_save_i) 139 | 140 | data_order = 0 141 | for id_new, data_save_by_id_i in data_save_by_id.items(): 142 | 143 | caption_length = 0 144 | for data_save_by_id_i_i in data_save_by_id_i: 145 | caption_length += len(data_save_by_id_i_i['captions']) 146 | 147 | data_order_i = data_order + np.arange(caption_length) 148 | data_order_i_begin = 0 149 | 150 | for data_save_by_id_i_i in data_save_by_id_i: 151 | caption_length_i = len(data_save_by_id_i_i['captions']) 152 | data_order_i_end = data_order_i_begin + caption_length_i 153 | data_order_i_select = np.delete(data_order_i, np.arange(data_order_i_begin, data_order_i_end)) 154 | data_order_i_begin = data_order_i_end 155 | 156 | for j in range(len(data_save_by_id_i_i['tokens'])): 157 | tokens_j = data_save_by_id_i_i['tokens'][j] 158 | lstm_caption_id = [] 159 | for word in tokens_j: 160 | lstm_caption_id.append(word2Ind(word)) 161 | if un_idx in lstm_caption_id: 162 | lstm_caption_id = list(filter(lambda x: x != un_idx, lstm_caption_id)) 163 | 164 | caption_j = data_save_by_id_i_i['captions'][j] 165 | bert_caption_id = tokenizer.encode(caption_j, add_special_tokens=True) 166 | 167 | id_save.append(data_save_by_id_i_i['id']) 168 | img_path_save.append(data_save_by_id_i_i['img_path']) 169 | same_id_index_save.append(data_order_i_select) 170 | 171 | lstm_caption_id_save.append(lstm_caption_id) 172 | 173 | bert_caption_id_save.append(bert_caption_id) 174 | caption_save.append(caption_j) 175 | 176 | data_order = data_order + caption_length 177 | print(sorted(count_id)) 178 | data_save = { 179 | 'id': id_save, 180 | 'img_path': img_path_save, 181 | 'same_id_index': same_id_index_save, 182 | 183 | 'lstm_caption_id': lstm_caption_id_save, 184 | 'bert_caption_id': bert_caption_id_save, 185 | 'captions': caption_save, 186 | } 187 | 188 | img_num = len(set(img_path_save)) 189 | id_num = len(set(id_save)) 190 | # print(sorted(set(id_save))) 191 | caption_num = len(lstm_caption_id_save) 192 | """ 193 | for i in range(len(same_id_index_save)): 194 | for j in same_id_index_save[i]: 195 | if id_save[i] != id_save[j] or i in same_id_index_save[i]: 196 | print(111) 197 | """ 198 | st = '%s_img_num: %d, %s_id_num: %d, %s_caption_num: %d, ' %(data_name, img_num, data_name, id_num, data_name, caption_num) 199 | write_txt(st, os.path.join(args.out_root, 'data_message')) 200 | 201 | return data_save 202 | 203 | 204 | def generate_test_val_caption_id(data_json, word2Ind, data_name, args, tokenizer): 205 | id_save = [] 206 | lstm_caption_id_save = [] 207 | bert_caption_id_save = [] 208 | caption_save = [] 209 | img_path_save = [] 210 | img_caption_index_save = [] 211 | caption_matching_img_index_save = [] 212 | caption_label_save = [] 213 | 214 | un_idx = word2Ind.unk_id 215 | 216 | img_caption_index_i = 0 217 | caption_matching_img_index_i = 0 218 | for data in data_json: 219 | id_save.append(data['id']) 220 | img_path_save.append(data['img_path']) 221 | 222 | for j in range(len(data['tokens'])): 223 | 224 | tokens_j = data['tokens'][j] 225 | lstm_caption_id = [] 226 | for word in tokens_j: 227 | lstm_caption_id.append(word2Ind(word)) 228 | if un_idx in lstm_caption_id: 229 | lstm_caption_id = list(filter(lambda x: x != un_idx, lstm_caption_id)) 230 | 231 | caption_j = data['captions'][j] 232 | bert_caption_id = tokenizer.encode(caption_j, add_special_tokens=True) 233 | 234 | caption_matching_img_index_save.append(caption_matching_img_index_i) 235 | lstm_caption_id_save.append(lstm_caption_id) 236 | bert_caption_id_save.append(bert_caption_id) 237 | caption_save.append(caption_j) 238 | 239 | caption_label_save.append(data['id']) 240 | img_caption_index_save.append([img_caption_index_i, img_caption_index_i+len(data['captions'])-1]) 241 | img_caption_index_i += len(data['captions']) 242 | caption_matching_img_index_i += 1 243 | 244 | data_save = { 245 | 'id': id_save, 246 | 'img_path': img_path_save, 247 | 'img_caption_index': img_caption_index_save, 248 | 249 | 'caption_matching_img_index': caption_matching_img_index_save, 250 | 'caption_label': caption_label_save, 251 | 'lstm_caption_id': lstm_caption_id_save, 252 | 'bert_caption_id': bert_caption_id_save, 253 | 'captions': caption_save, 254 | } 255 | 256 | img_num = len(set(img_path_save)) 257 | id_num = len(set(id_save)) 258 | caption_num = len(lstm_caption_id_save) 259 | st = '%s_img_num: %d, %s_id_num: %d, %s_caption_num: %d, ' % ( 260 | data_name, img_num, data_name, id_num, data_name, caption_num) 261 | write_txt(st, os.path.join(args.out_root, 'data_message')) 262 | 263 | # print(sorted(set(id_save))) 264 | 265 | return data_save 266 | 267 | 268 | def main(args): 269 | train_json, test_json = split_json(args) 270 | 271 | word2Ind = build_vocabulary(train_json, args) 272 | 273 | # with open('./word2Ind' + '.pkl', 'wb') as f: 274 | # pickle.dump(word2Ind, f, pickle.HIGHEST_PROTOCOL) 275 | 276 | model_class, tokenizer_class, pretrained_weights = (ppb.BertModel, ppb.BertTokenizer, 'bert-base-uncased') 277 | tokenizer = tokenizer_class.from_pretrained(pretrained_weights) 278 | 279 | train_save = generate_captionid(train_json, word2Ind, 'train', args, tokenizer) 280 | test_save = generate_test_val_caption_id(test_json, word2Ind, 'test', args, tokenizer) 281 | 282 | 283 | save_dict(train_save, os.path.join(args.out_root, 'train_save')) 284 | save_dict(test_save, os.path.join(args.out_root, 'test_save')) 285 | 286 | 287 | 288 | if __name__ == '__main__': 289 | 290 | args = parse_args() 291 | if args.shuffle: 292 | args.out_root = args.out_root + '_shuffle' 293 | 294 | makedir(args.out_root) 295 | main(args) 296 | """ 297 | from utils.read_write_data import read_dict 298 | train_save_dic = read_dict(args.out_root + '/train_save.pkl') 299 | 300 | id_save = train_save_dic['id'] 301 | img_path_save = train_save_dic['img_path'] 302 | caption_id_save = train_save_dic['caption_id'] 303 | same_id_index_save = train_save_dic['same_id_index'] 304 | 305 | num = 1600 306 | print(id_save[num]) 307 | print(img_path_save[num]) 308 | print(caption_id_save[num]) 309 | 310 | print(same_id_index_save[num]) 311 | for x in same_id_index_save[num]: 312 | print(id_save[x]) 313 | print(img_path_save[x]) 314 | 315 | 316 | train_save_dic = read_dict('./processed_data/train_save.pkl') 317 | id_save = train_save_dic['id'] 318 | img_path_save = train_save_dic['img_path'] 319 | caption_id_save = train_save_dic['caption_id'] 320 | print(id_save[num]) 321 | print(img_path_save[num]) 322 | print(caption_id_save[num]) 323 | """ 324 | -------------------------------------------------------------------------------- /random_erasing.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | 5 | from PIL import Image 6 | import random 7 | import math 8 | import numpy as np 9 | import torch 10 | 11 | class RandomErasing(object): 12 | """ Randomly selects a rectangle region in an image and erases its pixels. 13 | 'Random Erasing Data Augmentation' by Zhong et al. 14 | See https://arxiv.org/pdf/1708.04896.pdf 15 | Args: 16 | probability: The probability that the Random Erasing operation will be performed. 17 | sl: Minimum proportion of erased area against input image. 18 | sh: Maximum proportion of erased area against input image. 19 | r1: Minimum aspect ratio of erased area. 20 | mean: Erasing value. 21 | """ 22 | 23 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]): 24 | self.probability = probability 25 | self.mean = mean 26 | self.sl = sl 27 | self.sh = sh 28 | self.r1 = r1 29 | 30 | def __call__(self, img): 31 | 32 | if random.uniform(0, 1) > self.probability: 33 | return img 34 | 35 | for attempt in range(100): 36 | area = img.size()[1] * img.size()[2] 37 | 38 | target_area = random.uniform(self.sl, self.sh) * area 39 | aspect_ratio = random.uniform(self.r1, 1/self.r1) 40 | 41 | h = int(round(math.sqrt(target_area * aspect_ratio))) 42 | w = int(round(math.sqrt(target_area / aspect_ratio))) 43 | 44 | if w < img.size()[2] and h < img.size()[1]: 45 | x1 = random.randint(0, img.size()[1] - h) 46 | y1 = random.randint(0, img.size()[2] - w) 47 | if img.size()[0] == 3: 48 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 49 | img[1, x1:x1+h, y1:y1+w] = self.mean[1] 50 | img[2, x1:x1+h, y1:y1+w] = self.mean[2] 51 | else: 52 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 53 | return img 54 | 55 | return img 56 | -------------------------------------------------------------------------------- /read_json.py: -------------------------------------------------------------------------------- 1 | from utils.read_write_data import read_json, makedir, save_dict, write_txt 2 | import argparse 3 | from collections import namedtuple 4 | import os 5 | import nltk 6 | from nltk.tag import StanfordPOSTagger 7 | from random import shuffle 8 | import numpy as np 9 | import pickle 10 | import transformers as ppb 11 | import time 12 | import json 13 | 14 | reid_raw_data = read_json('./nouns_10_choose.json') 15 | 16 | print(len(reid_raw_data.keys())) -------------------------------------------------------------------------------- /reidtools.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | __all__ = ['visualize_ranked_results'] 5 | 6 | import numpy as np 7 | import os 8 | import os.path as osp 9 | import shutil 10 | # from os import listdir 11 | from PIL import Image 12 | from PIL import Image,ImageDraw,ImageFont 13 | 14 | # from data_process.utils import mkdir_if_missing 15 | 16 | def mkdir_if_missing(path): 17 | if not os.path.isdir(path): 18 | os.mkdir(path) 19 | def visualize_ranked_results(distmat, query, gallery, save_dir='', topk=10,wrong_indices=None): 20 | """Visualizes ranked results. 21 | 22 | Supports both image-reid and video-reid. 23 | 24 | Args: 25 | distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery). 26 | dataset (tuple): a 2-tuple containing (query, gallery), each of which contains 27 | tuples of (img_path(s), pid, camid). 28 | save_dir (str): directory to save output images. 29 | topk (int, optional): denoting top-k images in the rank list to be visualized. 30 | wrong_indices (ndarray): a 2-tuple containing wrong prediction q_pid and the predicted picture index in gallery 31 | """ 32 | num_q, num_g = distmat.shape 33 | 34 | print('Visualizing top-{} ranks'.format(topk)) 35 | print('# query: {}\n# gallery {}'.format(num_q, num_g)) 36 | print('Saving images to "{}"'.format(save_dir)) 37 | print(len(query.label)) 38 | assert num_q == len(query) 39 | assert num_g == len(gallery) 40 | 41 | indices = np.argsort(-distmat, axis=1) 42 | mkdir_if_missing(save_dir) 43 | 44 | def _cp_img_to(src, dst, rank, prefix): 45 | """ 46 | Args: 47 | src: image path or tuple (for vidreid) 48 | dst: target directory 49 | rank: int, denoting ranked position, starting from 1 50 | prefix: string 51 | """ 52 | if isinstance(src, tuple) or isinstance(src, list): 53 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3)) 54 | mkdir_if_missing(dst) 55 | for img_path in src: 56 | shutil.copy(img_path, dst) 57 | else: 58 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3) + '_name_' + osp.basename(src)) 59 | shutil.copy(src, dst) 60 | HEIGHT = 256 61 | WIDTH = 128 62 | 63 | for q_idx in range(num_q): 64 | # if q_idx not in wrong_indices[0]: 65 | # continue 66 | ims = [] 67 | qpid = query.label[q_idx] 68 | 69 | save_img_path = save_dir+'{}.jpg'.format(qpid) 70 | #'./img\\data\\market1501\\query\\0001_c1s1_001051_00.jpg' 71 | save_img_path.replace('\\','/') 72 | # 73 | # q_im = Image.open(qimg_path).resize((WIDTH, HEIGHT), Image.BILINEAR) 74 | # ims.append(q_im) 75 | 76 | # if isinstance(qimg_path, tuple) or isinstance(qimg_path, list): 77 | # qdir = osp.join(save_dir, osp.basename(qimg_path[0])) 78 | # else: 79 | # qdir = osp.join(save_dir, osp.basename(qimg_path)) 80 | # mkdir_if_missing(qdir) 81 | # _cp_img_to(qimg_path, qdir, rank=0, prefix='query') 82 | 83 | rank_idx = 1 84 | for g_idx in indices[q_idx,:]: 85 | print(indices[q_idx,:]) 86 | print(distmat[q_idx][indices[q_idx,:]]) 87 | gimg_path, gpid = gallery.img_path[g_idx], gallery.label[g_idx] 88 | # invalid = (qpid == gpid) & (qcamid == gcamid) 89 | # if not invalid: 90 | g_im = Image.open(gimg_path).resize((WIDTH, HEIGHT), Image.BILINEAR) 91 | draw = ImageDraw.Draw(g_im) 92 | if gpid==qpid: 93 | color = (0,255,0)#绿色 94 | else: 95 | color = (255,0,0)#红色 96 | draw.text((8, 8), str(gpid), fill=color)#在坐标(8,8)位置打印gpid,颜色为color,对的为绿色,错的为红色 97 | ims.append(g_im) 98 | rank_idx += 1 99 | if rank_idx > topk: 100 | break 101 | img_ = Image.new(ims[0].mode, (WIDTH*len(ims), HEIGHT))#制作新图片,由于图片是query+前rankk张图,所以WIDTH要x一个len 102 | for i, im in enumerate(ims): 103 | img_.paste(im, box=(i*WIDTH,0)) 104 | img_.save(save_img_path) 105 | 106 | print("Done") -------------------------------------------------------------------------------- /test_ICFG_my.py: -------------------------------------------------------------------------------- 1 | from option.options import options, config 2 | from data.dataloader import get_dataloader 3 | import torch 4 | import random 5 | from model.model import TextImgPersonReidNet 6 | from loss.Id_loss import Id_Loss 7 | from loss.RankingLoss import RankingLoss 8 | from torch import optim 9 | import logging 10 | import os 11 | from test_during_train import test , test_part 12 | from torch.autograd import Variable 13 | from model.DETR_model import TextImgPersonReidNet_mydecoder_pixelVit_transTXT_3, TextImgPersonReidNet_mydecoder_pixelVit_transTXT_3_vit 14 | import torch.nn as nn 15 | logger = logging.getLogger() 16 | logger.setLevel(logging.INFO) 17 | 18 | 19 | def save_checkpoint(state, opt): 20 | 21 | filename = os.path.join(opt.save_path, 'model/best.pth.tar') 22 | torch.save(state, filename) 23 | 24 | 25 | def load_checkpoint(opt): 26 | filename = os.path.join(opt.save_path, 'model/best.pth.tar') 27 | state = torch.load(filename) 28 | 29 | return state 30 | 31 | 32 | def calculate_similarity(image_embedding, text_embedding): 33 | image_embedding_norm = image_embedding / image_embedding.norm(dim=1, keepdim=True) 34 | text_embedding_norm = text_embedding / text_embedding.norm(dim=1, keepdim=True) 35 | 36 | similarity = torch.mm(image_embedding_norm, text_embedding_norm.t()) 37 | 38 | return similarity 39 | 40 | def calculate_similarity_part(numpart,image_embedding, text_embedding): 41 | image_embedding = torch.cat([image_embedding[i] for i in range(numpart)],dim=1) 42 | text_embedding = torch.cat([text_embedding[i] for i in range(numpart)], dim=1) 43 | image_embedding_norm = image_embedding / image_embedding.norm(dim=1, keepdim=True) 44 | text_embedding_norm = text_embedding / text_embedding.norm(dim=1, keepdim=True) 45 | 46 | similarity = torch.mm(image_embedding_norm, text_embedding_norm.t()) 47 | 48 | return similarity 49 | 50 | def calculate_similarity_score(opt,image_embedding, text_embedding , img_score ,txt_score): 51 | img_size = img_score.size(0) 52 | txt_size = txt_score.size(0) 53 | part_num = img_score.size(1) 54 | Final_matrix = torch.FloatTensor(img_size, txt_size).zero_().to(opt.device) 55 | Fq_matrix = torch.FloatTensor(img_size, txt_size).zero_().to(opt.device) 56 | for i in range(part_num): 57 | # print(i) 58 | # Compute pairwise distance, replace by the official when merged 59 | image_embedding_i = image_embedding[i] 60 | text_embedding_i = text_embedding[i] 61 | image_embedding_i = image_embedding_i / image_embedding_i.norm(dim=1, keepdim=True) 62 | text_embedding_i = text_embedding_i / text_embedding_i.norm(dim=1, keepdim=True) 63 | similarity = torch.mm(image_embedding_i, text_embedding_i.t()) 64 | img_score_i = img_score[:, i].unsqueeze(1) # .view(q_score.size(0), 1) 65 | txt_score_i = txt_score[:, i].unsqueeze(1) 66 | # print(img_score.shape) 67 | # print(img_score_i.shape) 68 | q_matrix = torch.mm(img_score_i, txt_score_i.t()) 69 | final_matrix = similarity.mul(q_matrix) 70 | Final_matrix = Final_matrix + final_matrix 71 | Fq_matrix = Fq_matrix + q_matrix 72 | Fq_matrix = Fq_matrix + 1e-12 73 | # print(Fq_matrix) 74 | dist_part = torch.div(Final_matrix, Fq_matrix) 75 | 76 | return dist_part 77 | 78 | if __name__ == '__main__': 79 | opt = options().opt 80 | opt.GPU_id = '1' 81 | opt.device = torch.device('cuda:{}'.format(opt.GPU_id)) 82 | opt.data_augment = False 83 | opt.lr = 0.001 84 | opt.margin = 0.2 85 | 86 | opt.feature_length = 512 87 | 88 | opt.train_dataset = 'CUHK-PEDES' 89 | opt.dataset = 'MSMT-PEDES' 90 | 91 | if opt.dataset == 'MSMT-PEDES': 92 | opt.pkl_root = '/data1/zhiying/text-image/MSMT-PEDES/3-1/' 93 | opt.class_num = 3102 94 | opt.vocab_size = 2500 95 | opt.dataroot = '/data1/zhiying/text-image/data/ICFG_PEDES' 96 | # opt.class_num = 2802 97 | # opt.vocab_size = 2300 98 | elif opt.dataset == 'CUHK-PEDES': 99 | opt.pkl_root = '/data1/zhiying/text-image/CUHK-PEDES_/' # same_id_new_ 100 | opt.class_num = 11003 101 | opt.vocab_size = 5000 102 | opt.dataroot = '/data1/zhiying/text-image/CUHK-PEDES' 103 | 104 | opt.d_model = 1024 105 | opt.nhead = 4 106 | opt.dim_feedforward = 2048 107 | opt.normalize_before = False 108 | opt.num_encoder_layers = 3 109 | opt.num_decoder_layers = 3 110 | opt.num_query = 6 111 | opt.detr_lr = 0.0001 112 | opt.txt_detr_lr = 0.0001 113 | opt.txt_lstm_lr = 0.001 114 | opt.res_y = False 115 | opt.noself = False 116 | opt.post_norm = False 117 | opt.n_heads = 4 118 | opt.n_layers = 2 119 | opt.share_query = True 120 | model_name = 'random_my_small_DeiT_2version_head6_384lstm' 121 | # model_name = 'test' 122 | opt.save_path = './checkpoints/dual_modal/{}/'.format(opt.train_dataset) + model_name 123 | 124 | opt.epoch = 60 125 | opt.epoch_decay = [20, 40, 50] 126 | 127 | opt.batch_size = 64 128 | opt.start_epoch = 0 129 | opt.trained = False 130 | 131 | config(opt) 132 | opt.epoch_decay = [i - opt.start_epoch for i in opt.epoch_decay] 133 | 134 | train_dataloader = get_dataloader(opt) 135 | opt.mode = 'test' 136 | test_img_dataloader, test_txt_dataloader = get_dataloader(opt) 137 | opt.mode = 'train' 138 | # train_dataloader = get_dataloader(opt) 139 | # opt.mode = 'test' 140 | # test_img_dataloader, test_txt_dataloader = get_dataloader(opt) 141 | # opt.mode = 'train' 142 | 143 | id_loss_fun = nn.ModuleList() 144 | for _ in range(opt.num_query): 145 | id_loss_fun.append(Id_Loss(opt).to(opt.device)) 146 | ranking_loss_fun = RankingLoss(opt) 147 | network = TextImgPersonReidNet_mydecoder_pixelVit_transTXT_3_vit(opt).to(opt.device) 148 | test_best = 0 149 | test_history = 0 150 | state = load_checkpoint(opt) 151 | network.load_state_dict(state['network']) 152 | test_best = state['test_best'] 153 | test_history = test_best 154 | print('load the {} epoch param successfully'.format(state['epoch'])) 155 | """ 156 | network.eval() 157 | test_best = test(opt, 0, 0, network, 158 | test_img_dataloader, test_txt_dataloader, test_best) 159 | network.train() 160 | exit(0) 161 | """ 162 | network.eval() 163 | test_best = test_part(opt, state['epoch'], 1, network, 164 | test_img_dataloader, test_txt_dataloader, test_best) 165 | logging.info('Training Done') 166 | 167 | 168 | 169 | 170 | 171 | -------------------------------------------------------------------------------- /train_mydecoder_pixelvit_txtimg_3_bert.py: -------------------------------------------------------------------------------- 1 | from option.options import options, config 2 | from data.dataloader import get_dataloader 3 | import torch 4 | import random 5 | from model.model import TextImgPersonReidNet 6 | from loss.Id_loss import Id_Loss 7 | from loss.RankingLoss import RankingLoss 8 | from torch import optim 9 | import logging 10 | import os 11 | from test_during_train import test , test_part 12 | from torch.autograd import Variable 13 | from model.DETR_model import TextImgPersonReidNet_mydecoder_pixelVit_transTXT_3_bert 14 | import torch.nn as nn 15 | 16 | seed_num = 233 17 | torch.manual_seed(seed_num) 18 | random.seed(seed_num) 19 | 20 | logger = logging.getLogger() 21 | logger.setLevel(logging.INFO) 22 | 23 | 24 | def save_checkpoint(state, opt): 25 | 26 | filename = os.path.join(opt.save_path, 'model/best.pth.tar') 27 | torch.save(state, filename) 28 | 29 | 30 | def load_checkpoint(opt): 31 | filename = os.path.join(opt.save_path, 'model/best.pth.tar') 32 | state = torch.load(filename) 33 | 34 | return state 35 | 36 | 37 | def calculate_similarity(image_embedding, text_embedding): 38 | image_embedding_norm = image_embedding / image_embedding.norm(dim=1, keepdim=True) 39 | text_embedding_norm = text_embedding / text_embedding.norm(dim=1, keepdim=True) 40 | 41 | similarity = torch.mm(image_embedding_norm, text_embedding_norm.t()) 42 | 43 | return similarity 44 | 45 | def calculate_similarity_part(numpart,image_embedding, text_embedding): 46 | image_embedding = torch.cat([image_embedding[i] for i in range(numpart)],dim=1) 47 | text_embedding = torch.cat([text_embedding[i] for i in range(numpart)], dim=1) 48 | image_embedding_norm = image_embedding / image_embedding.norm(dim=1, keepdim=True) 49 | text_embedding_norm = text_embedding / text_embedding.norm(dim=1, keepdim=True) 50 | 51 | similarity = torch.mm(image_embedding_norm, text_embedding_norm.t()) 52 | 53 | return similarity 54 | 55 | def calculate_part_id(id_loss_fun,num_query,image_embedding,text_embedding): 56 | id_loss_ = [] 57 | pred_i2t_ = [] 58 | pred_t2i_ = [] 59 | for i in range(num_query): 60 | id_loss, pred_i2t_local, pred_t2i_local = id_loss_fun[i](image_embedding[i], text_embedding[i], label) 61 | id_loss_.append(id_loss) 62 | pred_i2t_.append(pred_i2t_local) 63 | pred_t2i_.append(pred_t2i_local) 64 | id_loss_ = torch.stack(id_loss_) 65 | id_loss = torch.mean(id_loss_) 66 | pred_i2t_ = torch.stack(pred_i2t_) 67 | pred_i2t_local = torch.mean(pred_i2t_) 68 | pred_t2i_ = torch.stack(pred_t2i_) 69 | pred_t2i_local = torch.mean(pred_t2i_) 70 | 71 | return id_loss , pred_i2t_local, pred_t2i_local 72 | 73 | if __name__ == '__main__': 74 | opt = options().opt 75 | opt.GPU_id = '0' 76 | opt.device = torch.device('cuda:{}'.format(opt.GPU_id)) 77 | opt.data_augment = False 78 | opt.lr = 0.001 79 | opt.margin = 0.3 80 | 81 | opt.feature_length = 512 82 | 83 | opt.dataset = 'CUHK-PEDES' 84 | 85 | if opt.dataset == 'MSMT-PEDES': 86 | opt.pkl_root = '/home/zhiyin/tran_ACMMM/processed_data_singledata_ICFG/' 87 | opt.class_num = 3102 88 | opt.vocab_size = 2500 89 | opt.dataroot = '/home/zhiyin/ICFG_PEDES/ICFG_PEDES' 90 | # opt.class_num = 2802 91 | # opt.vocab_size = 2300 92 | elif opt.dataset == 'CUHK-PEDES': 93 | opt.pkl_root = '/home/zhiyin/tran_ACMMM/processed_data_singledata_CUHK/' # same_id_new_ 94 | opt.class_num = 11000 95 | opt.vocab_size = 5000 96 | opt.dataroot = '/home/zhiyin/CUHK-PEDES' 97 | 98 | opt.d_model = 1024 99 | opt.nhead = 4 100 | opt.dim_feedforward = 2048 101 | opt.normalize_before = False 102 | opt.num_encoder_layers = 3 103 | opt.num_decoder_layers = 3 104 | opt.num_query = 6 105 | opt.detr_lr = 0.0001 106 | opt.txt_detr_lr = 0.0001 107 | opt.txt_lstm_lr = 0.001 108 | opt.res_y = False 109 | opt.noself = False 110 | opt.post_norm = False 111 | opt.n_heads = 4 112 | opt.n_layers = 2 113 | opt.share_query = True 114 | opt.ViT_layer = 8 115 | opt.wordtype = 'bert' 116 | model_name = 'model_get' 117 | # model_name = 'test' 118 | opt.save_path = './checkpoints/dual_modal/{}/'.format(opt.dataset) + model_name 119 | 120 | opt.epoch = 60 121 | opt.epoch_decay = [20, 40, 50] 122 | 123 | opt.batch_size = 64 124 | opt.start_epoch = 0 125 | opt.trained = False 126 | 127 | config(opt) 128 | opt.epoch_decay = [i - opt.start_epoch for i in opt.epoch_decay] 129 | 130 | train_dataloader = get_dataloader(opt) 131 | opt.mode = 'test' 132 | test_img_dataloader, test_txt_dataloader = get_dataloader(opt) 133 | opt.mode = 'train' 134 | id_loss_fun = nn.ModuleList() 135 | for _ in range(opt.num_query): 136 | id_loss_fun.append(Id_Loss(opt).to(opt.device)) 137 | ranking_loss_fun = RankingLoss(opt) 138 | network = TextImgPersonReidNet_mydecoder_pixelVit_transTXT_3_bert(opt).to(opt.device) 139 | logging.info("Model_size: {:.5f}M".format(sum(p.numel() for p in network.parameters()) / 1000000.0)) 140 | ignored_params = (list(map(id, network.ImageExtract.parameters())) 141 | + list(map(id, network.TextExtract.parameters())) 142 | + list(map(id, network.conv_1X1_2.parameters())) 143 | # + list(map(id, network.conv_1X1.parameters())) 144 | # + list(map(id, network.TXTEncoder.parameters())) 145 | # + list(map(id, network.TXTDecoder.parameters())) 146 | ) 147 | DETR_params = filter(lambda p: id(p) not in ignored_params, network.parameters()) 148 | DETR_params = list(DETR_params) 149 | param_groups = [{'params': DETR_params, 'lr': opt.detr_lr}, 150 | # {'params': network.TXTEncoder.parameters(), 'lr': opt.txt_detr_lr}, 151 | # {'params': network.TXTDecoder.parameters(), 'lr': opt.txt_detr_lr}, 152 | {'params': network.ImageExtract.parameters(), 'lr': opt.lr * 0.1}, 153 | {'params': network.TextExtract.parameters(), 'lr': opt.lr}, 154 | {'params': network.conv_1X1_2.parameters(), 'lr': opt.lr}, 155 | # {'params': network.conv_1X1.parameters(), 'lr': opt.lr}, 156 | {'params': id_loss_fun.parameters(), 'lr': opt.lr} 157 | ] 158 | 159 | optimizer = optim.Adam(param_groups, betas=(opt.adam_alpha, opt.adam_beta)) 160 | 161 | test_best = 0 162 | test_history = 0 163 | if opt.trained: 164 | state = load_checkpoint(opt) 165 | network.load_state_dict(state['network']) 166 | test_best = state['test_best'] 167 | test_history = test_best 168 | id_loss_fun.load_state_dict(state['W']) 169 | print('load the {} epoch param successfully'.format(state['epoch'])) 170 | """ 171 | network.eval() 172 | test_best = test(opt, 0, 0, network, 173 | test_img_dataloader, test_txt_dataloader, test_best) 174 | network.train() 175 | exit(0) 176 | """ 177 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, opt.epoch_decay) 178 | 179 | for epoch in range(opt.start_epoch, opt.epoch): 180 | 181 | id_loss_sum = 0 182 | ranking_loss_sum = 0 183 | pred_i2t_local_sum = 0 184 | pred_t2i_local_sum = 0 185 | 186 | scheduler.step() 187 | for param in optimizer.param_groups: 188 | logging.info('lr:{}'.format(param['lr'])) 189 | 190 | for times, [image, label, caption_code, caption_length, caption_mask] in enumerate(train_dataloader): 191 | 192 | # network.eval() 193 | # test_best = test_part(opt, epoch + 1, times + 1, network, 194 | # test_img_dataloader, test_txt_dataloader, test_best) 195 | # network.train() 196 | image = Variable(image.to(opt.device)) 197 | label = Variable(label.to(opt.device)) 198 | caption_code = Variable(caption_code.to(opt.device).long()) 199 | caption_mask = Variable(caption_mask.to(opt.device)) 200 | 201 | 202 | image_embedding,image_embedding_dict, text_embedding ,text_embedding_dict= network(image, caption_code, caption_mask) 203 | 204 | id_loss , pred_i2t_local, pred_t2i_local = calculate_part_id(id_loss_fun,opt.num_query ,image_embedding, text_embedding) 205 | 206 | id_loss_dict, pred_i2t_local_dict, pred_t2i_local_dict = calculate_part_id(id_loss_fun,opt.num_query, image_embedding_dict, text_embedding_dict) 207 | 208 | similarity = calculate_similarity_part(opt.num_query,image_embedding, text_embedding) 209 | ranking_loss = ranking_loss_fun(similarity, label) 210 | similarity_dict = calculate_similarity_part(opt.num_query, image_embedding_dict, text_embedding_dict) 211 | ranking_loss_dict = ranking_loss_fun(similarity_dict, label) 212 | 213 | similarity_dict_text = calculate_similarity_part(opt.num_query, text_embedding, text_embedding_dict) 214 | ranking_loss_dict_text = ranking_loss_fun(similarity_dict_text, label) 215 | 216 | similarity_dict_image = calculate_similarity_part(opt.num_query, image_embedding, image_embedding_dict) 217 | ranking_loss_dict_image = ranking_loss_fun(similarity_dict_image, label) 218 | 219 | optimizer.zero_grad() 220 | loss = (id_loss + ranking_loss + id_loss_dict + ranking_loss_dict + ranking_loss_dict_text + ranking_loss_dict_image) 221 | loss.backward() 222 | # network.eval() 223 | # test_best = test_part(opt, epoch + 1, times + 1, network, 224 | # test_img_dataloader, test_txt_dataloader, test_best) 225 | # network.train() 226 | optimizer.step() 227 | # network.eval() 228 | # test_best = test_part(opt, epoch + 1, times + 1, network, 229 | # test_img_dataloader, test_txt_dataloader, test_best) 230 | # network.train() 231 | if (times + 1) % 50 == 0: 232 | logging.info("Epoch: %d/%d Setp: %d, ranking_loss: %.2f, id_loss: %.2f, ranking_loss_dict: %.2f, id_loss_dict: %.2f,ranking_loss_dict_text: %.2f, ranking_loss_dict_image: %.2f," 233 | "pred_i2t_local: %.3f pred_t2i_local %.3f" 234 | % (epoch+1, opt.epoch, times+1, ranking_loss, id_loss, ranking_loss_dict,id_loss_dict,ranking_loss_dict_text,ranking_loss_dict_image,pred_i2t_local, pred_t2i_local)) 235 | 236 | ranking_loss_sum += ranking_loss 237 | id_loss_sum += id_loss 238 | pred_i2t_local_sum += pred_i2t_local 239 | pred_t2i_local_sum += pred_t2i_local 240 | 241 | ranking_loss_avg = ranking_loss_sum / (times + 1) 242 | id_loss_avg = id_loss_sum / (times + 1) 243 | pred_i2t_local_avg = pred_i2t_local_sum / (times + 1) 244 | pred_t2i_local_avg = pred_t2i_local_sum / (times + 1) 245 | 246 | logging.info("Epoch: %d/%d , ranking_loss: %.2f, id_loss: %.2f," 247 | " pred_i2t_local: %.3f, pred_t2i_local %.3f " 248 | % (epoch+1, opt.epoch, ranking_loss_avg, id_loss_avg, pred_i2t_local_avg, pred_t2i_local_avg)) 249 | 250 | print(model_name) 251 | network.eval() 252 | test_best = test_part(opt, epoch + 1, times + 1, network, 253 | test_img_dataloader, test_txt_dataloader, test_best) 254 | network.train() 255 | if test_best > test_history: 256 | state = { 257 | 'test_best': test_best, 258 | 'network': network.cpu().state_dict(), 259 | 'optimizer': optimizer.state_dict(), 260 | 'W': id_loss_fun.cpu().state_dict(), 261 | 'epoch': epoch + 1} 262 | 263 | save_checkpoint(state, opt) 264 | network.to(opt.device) 265 | id_loss_fun.to(opt.device) 266 | 267 | test_history = test_best 268 | 269 | logging.info('Training Done') 270 | 271 | 272 | 273 | 274 | 275 | -------------------------------------------------------------------------------- /utils/__pycache__/random_erasing.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/utils/__pycache__/random_erasing.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/read_write_data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/utils/__pycache__/read_write_data.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/read_write_data.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/utils/__pycache__/read_write_data.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/read_write_data.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/utils/__pycache__/read_write_data.cpython-38.pyc -------------------------------------------------------------------------------- /utils/read_write_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | the tool to read or write the data 4 | 5 | Created on Thurs., Aug. 1(st), 2019 at 20:15 6 | 7 | @author: zifyloo 8 | """ 9 | 10 | import os 11 | import json 12 | import pickle 13 | 14 | 15 | def makedir(root): 16 | if not os.path.exists(root): 17 | os.makedirs(root) 18 | 19 | 20 | def write_json(data, root): 21 | with open(root, 'w') as f: 22 | json.dump(data, f) 23 | 24 | 25 | def read_json(root): 26 | with open(root, 'r') as f: 27 | data = json.load(f) 28 | 29 | return data 30 | 31 | 32 | def read_dict(root): 33 | with open(root, 'rb') as f: 34 | data = pickle.load(f) 35 | 36 | return data 37 | 38 | 39 | def save_dict(data, name): 40 | with open(name + '.pkl', 'wb') as f: 41 | pickle.dump(data, f, pickle.HIGHEST_PROTOCOL) 42 | 43 | 44 | def write_txt(data, name): 45 | with open(name, 'a') as f: 46 | f.write(data) 47 | f.write('\n') 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /utils_RVN/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | 4 | from .directory import write_json, makedir 5 | -------------------------------------------------------------------------------- /utils_RVN/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/utils_RVN/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils_RVN/__pycache__/directory.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/utils_RVN/__pycache__/directory.cpython-38.pyc -------------------------------------------------------------------------------- /utils_RVN/__pycache__/metric.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/utils_RVN/__pycache__/metric.cpython-38.pyc -------------------------------------------------------------------------------- /utils_RVN/directory.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | def makedir(root): 5 | if not os.path.exists(root): 6 | os.makedirs(root) 7 | 8 | 9 | def write_json(data, root): 10 | with open(dir, 'w') as f: 11 | json.dump(data, f) 12 | 13 | 14 | def check_exists(root): 15 | if os.path.exists(root): 16 | return True 17 | return False 18 | 19 | def check_file(root, keyword): 20 | if not os.path.isfile(root): 21 | raise RuntimeError('===> No {} in {}'.format(keyword, root)) 22 | -------------------------------------------------------------------------------- /utils_RVN/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plot 2 | import os 3 | import cv2 4 | 5 | # visualize loss & accuracy 6 | def visualize_curve(log_root): 7 | log_file = open(log_root, 'r') 8 | result_root = log_root[:log_root.rfind('/') + 1] + 'train.jpg' 9 | loss = [] 10 | 11 | top1_i2t = [] 12 | top10_i2t = [] 13 | top1_t2i = [] 14 | top10_t2i = [] 15 | for line in log_file.readlines(): 16 | line = line.strip().split() 17 | 18 | if 'top10_t2i' not in line[-2]: 19 | continue 20 | 21 | loss.append(line[1]) 22 | top1_i2t.append(line[3]) 23 | top10_i2t.append(line[5]) 24 | top1_t2i.append(line[7]) 25 | top10_t2i.append(line[9]) 26 | 27 | log_file.close() 28 | 29 | plt.figure('loss') 30 | plt.plot(loss) 31 | 32 | plt.figure('accuracy') 33 | plt.subplot(211) 34 | plt.plot(top1_i2t, label = 'top1') 35 | plt.plot(top10_i2t, label = 'top10') 36 | plt.legend(['image to text'], loc = 'upper right') 37 | plt.subplot(212) 38 | plt.plot(top1_t2i, label = 'top1') 39 | plt.plot(top10_i2t, label = 'top10') 40 | plt.legend(['text to image'], loc = 'upper right') 41 | plt.savefig(result_root) 42 | plt.show() 43 | 44 | -------------------------------------------------------------------------------- /vit_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from vit_pytorch.vit_pytorch import ViT,TransformerEncode,pixel_ViT,DECODER,ENCODER,PartQuery,mydecoder,mydecoder_DETR 2 | -------------------------------------------------------------------------------- /vit_pytorch/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/vit_pytorch/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /vit_pytorch/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/vit_pytorch/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /vit_pytorch/__pycache__/vit_pytorch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/vit_pytorch/__pycache__/vit_pytorch.cpython-37.pyc -------------------------------------------------------------------------------- /vit_pytorch/__pycache__/vit_pytorch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiyinShao-H/LGUR/c203cc7763abbc640d90b74aeb3986b73472410e/vit_pytorch/__pycache__/vit_pytorch.cpython-38.pyc -------------------------------------------------------------------------------- /vit_pytorch/distill.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from vit_pytorch.vit_pytorch import ViT 5 | from vit_pytorch.t2t import T2TViT 6 | from vit_pytorch.efficient import ViT as EfficientViT 7 | 8 | from einops import rearrange, repeat 9 | 10 | # helpers 11 | 12 | def exists(val): 13 | return val is not None 14 | 15 | # classes 16 | 17 | class DistillMixin: 18 | def forward(self, img, distill_token = None, mask = None): 19 | distilling = exists(distill_token) 20 | x = self.to_patch_embedding(img) 21 | b, n, _ = x.shape 22 | 23 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 24 | x = torch.cat((cls_tokens, x), dim = 1) 25 | x += self.pos_embedding[:, :(n + 1)] 26 | 27 | if distilling: 28 | distill_tokens = repeat(distill_token, '() n d -> b n d', b = b) 29 | x = torch.cat((x, distill_tokens), dim = 1) 30 | 31 | x = self._attend(x, mask) 32 | 33 | if distilling: 34 | x, distill_tokens = x[:, :-1], x[:, -1] 35 | 36 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 37 | 38 | x = self.to_latent(x) 39 | out = self.mlp_head(x) 40 | 41 | if distilling: 42 | return out, distill_tokens 43 | 44 | return out 45 | 46 | class DistillableViT(DistillMixin, ViT): 47 | def __init__(self, *args, **kwargs): 48 | super(DistillableViT, self).__init__(*args, **kwargs) 49 | self.args = args 50 | self.kwargs = kwargs 51 | self.dim = kwargs['dim'] 52 | self.num_classes = kwargs['num_classes'] 53 | 54 | def to_vit(self): 55 | v = ViT(*self.args, **self.kwargs) 56 | v.load_state_dict(self.state_dict()) 57 | return v 58 | 59 | def _attend(self, x, mask): 60 | x = self.dropout(x) 61 | x = self.transformer(x, mask) 62 | return x 63 | 64 | class DistillableT2TViT(DistillMixin, T2TViT): 65 | def __init__(self, *args, **kwargs): 66 | super(DistillableT2TViT, self).__init__(*args, **kwargs) 67 | self.args = args 68 | self.kwargs = kwargs 69 | self.dim = kwargs['dim'] 70 | self.num_classes = kwargs['num_classes'] 71 | 72 | def to_vit(self): 73 | v = T2TViT(*self.args, **self.kwargs) 74 | v.load_state_dict(self.state_dict()) 75 | return v 76 | 77 | def _attend(self, x, mask): 78 | x = self.dropout(x) 79 | x = self.transformer(x) 80 | return x 81 | 82 | class DistillableEfficientViT(DistillMixin, EfficientViT): 83 | def __init__(self, *args, **kwargs): 84 | super(DistillableEfficientViT, self).__init__(*args, **kwargs) 85 | self.args = args 86 | self.kwargs = kwargs 87 | self.dim = kwargs['dim'] 88 | self.num_classes = kwargs['num_classes'] 89 | 90 | def to_vit(self): 91 | v = EfficientViT(*self.args, **self.kwargs) 92 | v.load_state_dict(self.state_dict()) 93 | return v 94 | 95 | def _attend(self, x, mask): 96 | return self.transformer(x) 97 | 98 | # knowledge distillation wrapper 99 | 100 | class DistillWrapper(nn.Module): 101 | def __init__( 102 | self, 103 | *, 104 | teacher, 105 | student, 106 | temperature = 1., 107 | alpha = 0.5 108 | ): 109 | super().__init__() 110 | assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer' 111 | 112 | self.teacher = teacher 113 | self.student = student 114 | 115 | dim = student.dim 116 | num_classes = student.num_classes 117 | self.temperature = temperature 118 | self.alpha = alpha 119 | 120 | self.distillation_token = nn.Parameter(torch.randn(1, 1, dim)) 121 | 122 | self.distill_mlp = nn.Sequential( 123 | nn.LayerNorm(dim), 124 | nn.Linear(dim, num_classes) 125 | ) 126 | 127 | def forward(self, img, labels, temperature = None, alpha = None, **kwargs): 128 | b, *_ = img.shape 129 | alpha = alpha if exists(alpha) else self.alpha 130 | T = temperature if exists(temperature) else self.temperature 131 | 132 | with torch.no_grad(): 133 | teacher_logits = self.teacher(img) 134 | 135 | student_logits, distill_tokens = self.student(img, distill_token = self.distillation_token, **kwargs) 136 | distill_logits = self.distill_mlp(distill_tokens) 137 | 138 | loss = F.cross_entropy(student_logits, labels) 139 | 140 | distill_loss = F.kl_div( 141 | F.log_softmax(distill_logits / T, dim = -1), 142 | F.softmax(teacher_logits / T, dim = -1).detach(), 143 | reduction = 'batchmean') 144 | 145 | distill_loss *= T ** 2 146 | 147 | return loss * alpha + distill_loss * (1 - alpha) 148 | -------------------------------------------------------------------------------- /vit_pytorch/efficient.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from einops import rearrange, repeat 4 | from einops.layers.torch import Rearrange 5 | 6 | class ViT(nn.Module): 7 | def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, pool = 'cls', channels = 3): 8 | super().__init__() 9 | assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size' 10 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 11 | num_patches = (image_size // patch_size) ** 2 12 | patch_dim = channels * patch_size ** 2 13 | 14 | self.to_patch_embedding = nn.Sequential( 15 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), 16 | nn.Linear(patch_dim, dim), 17 | ) 18 | 19 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 20 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 21 | self.transformer = transformer 22 | 23 | self.pool = pool 24 | self.to_latent = nn.Identity() 25 | 26 | self.mlp_head = nn.Sequential( 27 | nn.LayerNorm(dim), 28 | nn.Linear(dim, num_classes) 29 | ) 30 | 31 | def forward(self, img): 32 | x = self.to_patch_embedding(img) 33 | b, n, _ = x.shape 34 | 35 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 36 | x = torch.cat((cls_tokens, x), dim=1) 37 | x += self.pos_embedding[:, :(n + 1)] 38 | x = self.transformer(x) 39 | 40 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 41 | 42 | x = self.to_latent(x) 43 | return self.mlp_head(x) 44 | -------------------------------------------------------------------------------- /vit_pytorch/mpp.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import reduce 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | from einops import rearrange, repeat 9 | 10 | # helpers 11 | 12 | 13 | def prob_mask_like(t, prob): 14 | batch, seq_length, _ = t.shape 15 | return torch.zeros((batch, seq_length)).float().uniform_(0, 1) < prob 16 | 17 | 18 | def get_mask_subset_with_prob(patched_input, prob): 19 | batch, seq_len, _, device = *patched_input.shape, patched_input.device 20 | max_masked = math.ceil(prob * seq_len) 21 | 22 | rand = torch.rand((batch, seq_len), device=device) 23 | _, sampled_indices = rand.topk(max_masked, dim=-1) 24 | 25 | new_mask = torch.zeros((batch, seq_len), device=device) 26 | new_mask.scatter_(1, sampled_indices, 1) 27 | return new_mask.bool() 28 | 29 | 30 | # mpp loss 31 | 32 | 33 | class MPPLoss(nn.Module): 34 | def __init__(self, patch_size, channels, output_channel_bits, 35 | max_pixel_val): 36 | super(MPPLoss, self).__init__() 37 | self.patch_size = patch_size 38 | self.channels = channels 39 | self.output_channel_bits = output_channel_bits 40 | self.max_pixel_val = max_pixel_val 41 | 42 | def forward(self, predicted_patches, target, mask): 43 | # reshape target to patches 44 | p = self.patch_size 45 | target = rearrange(target, 46 | "b c (h p1) (w p2) -> b (h w) c (p1 p2) ", 47 | p1=p, 48 | p2=p) 49 | 50 | avg_target = target.mean(dim=3) 51 | 52 | bin_size = self.max_pixel_val / self.output_channel_bits 53 | channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size) 54 | discretized_target = torch.bucketize(avg_target, channel_bins) 55 | discretized_target = F.one_hot(discretized_target, 56 | self.output_channel_bits) 57 | c, bi = self.channels, self.output_channel_bits 58 | discretized_target = rearrange(discretized_target, 59 | "b n c bi -> b n (c bi)", 60 | c=c, 61 | bi=bi) 62 | 63 | bin_mask = 2**torch.arange(c * bi - 1, -1, 64 | -1).to(discretized_target.device, 65 | discretized_target.dtype) 66 | target_label = torch.sum(bin_mask * discretized_target, -1) 67 | 68 | predicted_patches = predicted_patches[mask] 69 | target_label = target_label[mask] 70 | loss = F.cross_entropy(predicted_patches, target_label) 71 | return loss 72 | 73 | 74 | # main class 75 | 76 | 77 | class MPP(nn.Module): 78 | def __init__(self, 79 | transformer, 80 | patch_size, 81 | dim, 82 | output_channel_bits=3, 83 | channels=3, 84 | max_pixel_val=1.0, 85 | mask_prob=0.15, 86 | replace_prob=0.5, 87 | random_patch_prob=0.5): 88 | super().__init__() 89 | 90 | self.transformer = transformer 91 | self.loss = MPPLoss(patch_size, channels, output_channel_bits, 92 | max_pixel_val) 93 | 94 | # output transformation 95 | self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels)) 96 | 97 | # vit related dimensions 98 | self.patch_size = patch_size 99 | 100 | # mpp related probabilities 101 | self.mask_prob = mask_prob 102 | self.replace_prob = replace_prob 103 | self.random_patch_prob = random_patch_prob 104 | 105 | # token ids 106 | self.mask_token = nn.Parameter(torch.randn(1, 1, dim * channels)) 107 | 108 | def forward(self, input, **kwargs): 109 | transformer = self.transformer 110 | # clone original image for loss 111 | img = input.clone().detach() 112 | 113 | # reshape raw image to patches 114 | p = self.patch_size 115 | input = rearrange(input, 116 | 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', 117 | p1=p, 118 | p2=p) 119 | 120 | mask = get_mask_subset_with_prob(input, self.mask_prob) 121 | 122 | # mask input with mask patches with probability of `replace_prob` (keep patches the same with probability 1 - replace_prob) 123 | masked_input = input.clone().detach() 124 | 125 | # if random token probability > 0 for mpp 126 | if self.random_patch_prob > 0: 127 | random_patch_sampling_prob = self.random_patch_prob / ( 128 | 1 - self.replace_prob) 129 | random_patch_prob = prob_mask_like(input, 130 | random_patch_sampling_prob) 131 | bool_random_patch_prob = mask * random_patch_prob == True 132 | random_patches = torch.randint(0, 133 | input.shape[1], 134 | (input.shape[0], input.shape[1]), 135 | device=input.device) 136 | randomized_input = masked_input[ 137 | torch.arange(masked_input.shape[0]).unsqueeze(-1), 138 | random_patches] 139 | masked_input[bool_random_patch_prob] = randomized_input[ 140 | bool_random_patch_prob] 141 | 142 | # [mask] input 143 | replace_prob = prob_mask_like(input, self.replace_prob) 144 | bool_mask_replace = (mask * replace_prob) == True 145 | masked_input[bool_mask_replace] = self.mask_token 146 | 147 | # linear embedding of patches 148 | masked_input = transformer.to_patch_embedding[-1](masked_input) 149 | 150 | # add cls token to input sequence 151 | b, n, _ = masked_input.shape 152 | cls_tokens = repeat(transformer.cls_token, '() n d -> b n d', b=b) 153 | masked_input = torch.cat((cls_tokens, masked_input), dim=1) 154 | 155 | # add positional embeddings to input 156 | masked_input += transformer.pos_embedding[:, :(n + 1)] 157 | masked_input = transformer.dropout(masked_input) 158 | 159 | # get generator output and get mpp loss 160 | masked_input = transformer.transformer(masked_input, **kwargs) 161 | cls_logits = self.to_bits(masked_input) 162 | logits = cls_logits[:, 1:, :] 163 | 164 | mpp_loss = self.loss(logits, img, mask) 165 | 166 | return mpp_loss 167 | -------------------------------------------------------------------------------- /vit_pytorch/t2t.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | 5 | from vit_pytorch.vit_pytorch import Transformer 6 | 7 | from einops import rearrange, repeat 8 | from einops.layers.torch import Rearrange 9 | 10 | # helpers 11 | 12 | def exists(val): 13 | return val is not None 14 | 15 | def conv_output_size(image_size, kernel_size, stride, padding): 16 | return int(((image_size - kernel_size + (2 * padding)) / stride) + 1) 17 | 18 | # classes 19 | 20 | class RearrangeImage(nn.Module): 21 | def forward(self, x): 22 | return rearrange(x, 'b (h w) c -> b c h w', h = int(math.sqrt(x.shape[1]))) 23 | 24 | # main class 25 | 26 | class T2TViT(nn.Module): 27 | def __init__( 28 | self, *, image_size, num_classes, dim, depth = None, heads = None, mlp_dim = None, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., transformer = None, t2t_layers = ((7, 4), (3, 2), (3, 2))): 29 | super().__init__() 30 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 31 | 32 | layers = [] 33 | layer_dim = channels 34 | output_image_size = image_size 35 | 36 | for i, (kernel_size, stride) in enumerate(t2t_layers): 37 | layer_dim *= kernel_size ** 2 38 | is_first = i == 0 39 | output_image_size = conv_output_size(output_image_size, kernel_size, stride, stride // 2) 40 | 41 | layers.extend([ 42 | RearrangeImage() if not is_first else nn.Identity(), 43 | nn.Unfold(kernel_size = kernel_size, stride = stride, padding = stride // 2), 44 | Rearrange('b c n -> b n c'), 45 | Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout), 46 | ]) 47 | 48 | layers.append(nn.Linear(layer_dim, dim)) 49 | self.to_patch_embedding = nn.Sequential(*layers) 50 | 51 | self.pos_embedding = nn.Parameter(torch.randn(1, output_image_size ** 2 + 1, dim)) 52 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 53 | self.dropout = nn.Dropout(emb_dropout) 54 | 55 | if not exists(transformer): 56 | assert all([exists(depth), exists(heads), exists(mlp_dim)]), 'depth, heads, and mlp_dim must be supplied' 57 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 58 | else: 59 | self.transformer = transformer 60 | 61 | self.pool = pool 62 | self.to_latent = nn.Identity() 63 | 64 | self.mlp_head = nn.Sequential( 65 | nn.LayerNorm(dim), 66 | nn.Linear(dim, num_classes) 67 | ) 68 | 69 | def forward(self, img): 70 | x = self.to_patch_embedding(img) 71 | b, n, _ = x.shape 72 | 73 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 74 | x = torch.cat((cls_tokens, x), dim=1) 75 | x += self.pos_embedding 76 | x = self.dropout(x) 77 | 78 | x = self.transformer(x) 79 | 80 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 81 | 82 | x = self.to_latent(x) 83 | return self.mlp_head(x) 84 | -------------------------------------------------------------------------------- /vit_pytorch/train_2module_guide.py: -------------------------------------------------------------------------------- 1 | from option.options import options, config 2 | from data.dataloader import get_dataloader 3 | import torch 4 | import random 5 | from model.model import TextImgPersonReidNet, TextImgPersonReidNet_Res50_fusetrans_2moudel 6 | from loss.Id_loss import Id_Loss , Id_Loss_2, Id_Loss_3 7 | from loss.RankingLoss import RankingLoss 8 | from torch import optim 9 | import logging 10 | import os 11 | from test_during_train import test , test_2module , test_TOP50,test_TOP50_2 , test_TOP50_test 12 | from torch.autograd import Variable 13 | import time 14 | logger = logging.getLogger() 15 | logger.setLevel(logging.INFO) 16 | 17 | 18 | def save_checkpoint(state, opt): 19 | 20 | filename = os.path.join(opt.save_path, 'model/best.pth.tar') 21 | torch.save(state, filename) 22 | 23 | 24 | def load_checkpoint(opt): 25 | filename = os.path.join(opt.save_path, 'model/best.pth.tar') 26 | state = torch.load(filename) 27 | 28 | return state 29 | 30 | 31 | def calculate_similarity(image_embedding, text_embedding): 32 | image_embedding_norm = image_embedding / image_embedding.norm(dim=1, keepdim=True) 33 | text_embedding_norm = text_embedding / text_embedding.norm(dim=1, keepdim=True) 34 | 35 | similarity = torch.mm(image_embedding_norm, text_embedding_norm.t()) 36 | 37 | return similarity 38 | 39 | 40 | if __name__ == '__main__': 41 | opt = options().opt 42 | opt.GPU_id = '1' 43 | opt.device = torch.device('cuda:{}'.format(opt.GPU_id)) 44 | # opt.GPU_id1 = '1' 45 | # opt.device1 = torch.device('cuda:{}'.format(opt.GPU_id1)) 46 | opt.data_augment = False 47 | opt.lr = 0.001 48 | opt.margin = 0.2 49 | 50 | opt.feature_length = 1024 51 | 52 | opt.dataset = 'CUHK-PEDES' 53 | 54 | if opt.dataset == 'MSMT-PEDES': 55 | opt.pkl_root = '/data1/zhiying/text-image/MSMT-PEDES/3-1/' 56 | opt.class_num = 3102 57 | opt.vocab_size = 2500 58 | # opt.class_num = 2802 59 | # opt.vocab_size = 2300 60 | elif opt.dataset == 'CUHK-PEDES': 61 | opt.pkl_root = '/data1/zhiying/text-image/CUHK-PEDES_/' # same_id_new_ 62 | opt.class_num = 11000 63 | opt.vocab_size = 5000 64 | 65 | model_name = 'Decoder_theta1_batchsize32'.format(opt.lr) 66 | # model_name = 'test' 67 | opt.save_path = './checkpoints/{}/'.format(opt.dataset) + model_name 68 | opt.arf = 0.1 69 | opt.dim = 1024 70 | opt.depth = 4 71 | opt.heads = 8 72 | opt.mlp_dim = 1024 73 | opt.dim_head = 64 74 | opt.channels = 2048 75 | opt.epoch = 70 76 | opt.epoch_decay = [20, 40, 50] 77 | 78 | opt.batch_size = 32 79 | opt.start_epoch = 0 80 | opt.trained = False 81 | 82 | config(opt) 83 | opt.epoch_decay = [i - opt.start_epoch for i in opt.epoch_decay] 84 | 85 | train_dataloader = get_dataloader(opt) 86 | opt.mode = 'test' 87 | test_img_dataloader, test_txt_dataloader = get_dataloader(opt) 88 | opt.mode = 'train' 89 | 90 | id_loss_fusion = Id_Loss_3(opt).to(opt.device) 91 | id_loss_fun = Id_Loss_2(opt).to(opt.device) 92 | ranking_loss_fun = RankingLoss(opt) 93 | ranking_loss_fun1 = RankingLoss(opt) 94 | network = TextImgPersonReidNet_Res50_fusetrans_2moudel(opt).to(opt.device) 95 | 96 | cnn_params = list(map(id, network.ImageExtract.parameters())) 97 | trans_params = list(map(id, network.Decoder.parameters())) 98 | other_params = filter(lambda p: id(p) not in cnn_params + trans_params, network.parameters()) 99 | other_params = list(other_params) 100 | other_params.extend(list(id_loss_fun.parameters())) 101 | other_params.extend(list(id_loss_fusion.parameters())) 102 | param_groups = [{'params': other_params, 'lr': opt.lr}, 103 | {'params': network.Decoder.parameters(), 'lr': opt.lr * 0.1}, 104 | {'params': network.ImageExtract.parameters(), 'lr': opt.lr*0.1}] 105 | optimizer = optim.Adam(param_groups, betas=(opt.adam_alpha, opt.adam_beta)) 106 | 107 | test_best = 0 108 | test_front_best = 0 109 | test_history = 0 110 | if opt.trained: 111 | state = load_checkpoint(opt) 112 | network.load_state_dict(state['network']) 113 | test_best = state['test_best'] 114 | test_history = test_best 115 | id_loss_fun.load_state_dict(state['W']) 116 | print('load the {} epoch param successfully'.format(state['epoch'])) 117 | """ 118 | network.eval() 119 | test_best = test(opt, 0, 0, network, 120 | test_img_dataloader, test_txt_dataloader, test_best) 121 | network.train() 122 | exit(0) 123 | """ 124 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, opt.epoch_decay) 125 | 126 | for epoch in range(opt.start_epoch, opt.epoch): 127 | id_image_loss_sum = 0 128 | id_text_loss_sum = 0 129 | ranking_loss_sum = 0 130 | id_fusion_loss_sum = 0 131 | rank_fusion_sum = 0 132 | pred_i2t_local_sum = 0 133 | pred_t2i_local_sum = 0 134 | pred_fuse_sum = 0 135 | scheduler.step() 136 | for param in optimizer.param_groups: 137 | logging.info('lr:{}'.format(param['lr'])) 138 | 139 | for times, [image, label, caption_code, caption_length] in enumerate(train_dataloader): 140 | 141 | image = Variable(image.to(opt.device)) 142 | label = Variable(label.to(opt.device)) 143 | caption_code = Variable(caption_code.to(opt.device).long()) 144 | optimizer.zero_grad() 145 | # caption_length = caption_length.to(opt.device) 146 | feature_id , score_mat , image_embedding , text_embedding = network(image, caption_code, caption_length) 147 | # image_embedding, text_embedding = network(image, caption_code, caption_length) 148 | ##########compute 2module loss 149 | id_image_loss, id_text_loss, pred_i2t_local, pred_t2i_local = id_loss_fun(image_embedding, text_embedding, label) 150 | # id_loss, pred_i2t_local, pred_t2i_local = [0, 0, 0] 151 | similarity = calculate_similarity(image_embedding, text_embedding) 152 | ranking_loss = ranking_loss_fun(similarity, label) 153 | ##########compute fusion loss 154 | id_fusion_loss, pred_fuse ,_ = id_loss_fusion(feature_id, label) 155 | # id_loss, pred_i2t_local, pred_t2i_local = [0, 0, 0] 156 | similarity_fusion = score_mat.squeeze(2) 157 | ranking_loss_fusion = ranking_loss_fun1(similarity_fusion, label) 158 | 159 | # theta = 0 160 | if epoch < 20 : 161 | theta = 0.05*epoch 162 | theta = 1 163 | loss = (id_image_loss + id_text_loss + ranking_loss + theta * id_fusion_loss + theta * ranking_loss_fusion) 164 | else: 165 | loss = (id_image_loss + id_text_loss + ranking_loss + id_fusion_loss + ranking_loss_fusion) 166 | # loss = (id_image_loss+ id_text_loss + ranking_loss + theta * id_fusion_loss + theta * ranking_loss_fusion) 167 | # loss = (id_image_loss + id_text_loss + ranking_loss ) 168 | # loss = (id_fusion_loss + ranking_loss_fusion) 169 | loss.backward() 170 | # network.eval() 171 | # start_time = time.time() 172 | # # test_best = test_2module(opt, epoch + 1, times + 1, network, 173 | # # test_img_dataloader, test_txt_dataloader, test_best) 174 | # # test_best = test_TOP50_2(opt, epoch + 1, times + 1, network, 175 | # # test_img_dataloader, test_txt_dataloader, test_best, test_front_best) 176 | # test_best, test_best_front = test_TOP50_test(opt, epoch + 1, times + 1, network, 177 | # test_img_dataloader, test_txt_dataloader, test_best, 178 | # test_front_best) 179 | # end_time = time.time() 180 | # test_time = end_time - start_time 181 | # print('Test complete in {:.0f}m {:.0f}s'.format( 182 | # test_time // 60, test_time % 60)) 183 | # network.train() 184 | optimizer.step() 185 | 186 | if (times + 1) % 50 == 0: 187 | logging.info( 188 | "Epoch: %d/%d Setp: %d, id_image_loss: %.2f, id_text_loss: %.2f,ranking_loss: %.2f , ranking_loss_fusion: %.2f, id_fusion_loss: %.2f " 189 | "pred_fuse: %.3f " 190 | "pred_i2t: %.3f pred_t2i %.3f" 191 | % (epoch + 1, opt.epoch, times + 1, id_image_loss, id_text_loss, ranking_loss, ranking_loss_fusion, 192 | id_fusion_loss, pred_fuse, pred_i2t_local, pred_t2i_local)) 193 | 194 | ranking_loss_sum += ranking_loss 195 | id_fusion_loss_sum += id_fusion_loss 196 | id_image_loss_sum += id_image_loss 197 | id_text_loss_sum += id_text_loss 198 | rank_fusion_sum += ranking_loss_fusion 199 | pred_fuse_sum += pred_fuse 200 | pred_i2t_local_sum += pred_i2t_local 201 | pred_t2i_local_sum += pred_t2i_local 202 | 203 | ranking_loss_avg = ranking_loss_sum / (times + 1) 204 | id_fusion_loss_avg = id_fusion_loss_sum / (times + 1) 205 | pred_i2t_local_avg = pred_i2t_local_sum / (times + 1) 206 | pred_t2i_local_avg = pred_t2i_local_sum / (times + 1) 207 | id_image_loss_avg = id_image_loss_sum / (times + 1) 208 | id_text_loss_avg = id_text_loss_sum / (times + 1) 209 | rank_fusion_avg = rank_fusion_sum / (times + 1) 210 | pred_fuse_avg = pred_fuse_sum / (times + 1) 211 | 212 | logging.info("Epoch: %d/%d , id_image_loss: %.2f, id_text_loss: %.2f,ranking_loss: %.2f , ranking_loss_fusion: %.2f, id_fusion_loss: %.2f " 213 | "pred_fuse: %.3f " 214 | "pred_i2t: %.3f pred_t2i %.3f" 215 | % (epoch + 1, opt.epoch, id_image_loss_avg, id_text_loss_avg, ranking_loss_avg, rank_fusion_avg, 216 | id_fusion_loss_avg, pred_fuse_avg, pred_i2t_local_avg, pred_t2i_local_avg)) 217 | 218 | print(model_name) 219 | network.eval() 220 | start_time = time.time() 221 | # test_best = test_2module(opt, epoch + 1, times + 1, network, 222 | # test_img_dataloader, test_txt_dataloader, test_best) 223 | test_best , test_best_front= test_TOP50_2(opt, epoch + 1, times + 1, network, 224 | test_img_dataloader, test_txt_dataloader, test_best, test_front_best) 225 | # test_best, test_best_front = test_TOP50_test(opt, epoch + 1, times + 1, network, 226 | # test_img_dataloader, test_txt_dataloader, test_best, test_front_best) 227 | end_time = time.time() 228 | test_time = end_time - start_time 229 | print('Test complete in {:.0f}m {:.0f}s'.format( 230 | test_time // 60, test_time % 60)) 231 | network.train() 232 | if test_best > test_history: 233 | state = { 234 | 'test_best': test_best, 235 | 'network': network.cpu().state_dict(), 236 | 'optimizer': optimizer.state_dict(), 237 | 'W': id_loss_fun.cpu().state_dict(), 238 | 'W1': id_loss_fusion.cpu().state_dict(), 239 | 'epoch': epoch + 1} 240 | 241 | save_checkpoint(state, opt) 242 | network.to(opt.device) 243 | id_loss_fun.to(opt.device) 244 | id_loss_fusion.to(opt.device) 245 | test_history = test_best 246 | 247 | logging.info('Training Done') 248 | 249 | 250 | 251 | 252 | 253 | -------------------------------------------------------------------------------- /vit_pytorch_TransREID/__init__.py: -------------------------------------------------------------------------------- 1 | from vit_pytorch.vit_pytorch import ViT 2 | -------------------------------------------------------------------------------- /vit_pytorch_TransREID/distill.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from vit_pytorch.vit_pytorch import ViT 5 | from vit_pytorch.t2t import T2TViT 6 | from vit_pytorch.efficient import ViT as EfficientViT 7 | 8 | from einops import rearrange, repeat 9 | 10 | # helpers 11 | 12 | def exists(val): 13 | return val is not None 14 | 15 | # classes 16 | 17 | class DistillMixin: 18 | def forward(self, img, distill_token = None, mask = None): 19 | distilling = exists(distill_token) 20 | x = self.to_patch_embedding(img) 21 | b, n, _ = x.shape 22 | 23 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 24 | x = torch.cat((cls_tokens, x), dim = 1) 25 | x += self.pos_embedding[:, :(n + 1)] 26 | 27 | if distilling: 28 | distill_tokens = repeat(distill_token, '() n d -> b n d', b = b) 29 | x = torch.cat((x, distill_tokens), dim = 1) 30 | 31 | x = self._attend(x, mask) 32 | 33 | if distilling: 34 | x, distill_tokens = x[:, :-1], x[:, -1] 35 | 36 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 37 | 38 | x = self.to_latent(x) 39 | out = self.mlp_head(x) 40 | 41 | if distilling: 42 | return out, distill_tokens 43 | 44 | return out 45 | 46 | class DistillableViT(DistillMixin, ViT): 47 | def __init__(self, *args, **kwargs): 48 | super(DistillableViT, self).__init__(*args, **kwargs) 49 | self.args = args 50 | self.kwargs = kwargs 51 | self.dim = kwargs['dim'] 52 | self.num_classes = kwargs['num_classes'] 53 | 54 | def to_vit(self): 55 | v = ViT(*self.args, **self.kwargs) 56 | v.load_state_dict(self.state_dict()) 57 | return v 58 | 59 | def _attend(self, x, mask): 60 | x = self.dropout(x) 61 | x = self.transformer(x, mask) 62 | return x 63 | 64 | class DistillableT2TViT(DistillMixin, T2TViT): 65 | def __init__(self, *args, **kwargs): 66 | super(DistillableT2TViT, self).__init__(*args, **kwargs) 67 | self.args = args 68 | self.kwargs = kwargs 69 | self.dim = kwargs['dim'] 70 | self.num_classes = kwargs['num_classes'] 71 | 72 | def to_vit(self): 73 | v = T2TViT(*self.args, **self.kwargs) 74 | v.load_state_dict(self.state_dict()) 75 | return v 76 | 77 | def _attend(self, x, mask): 78 | x = self.dropout(x) 79 | x = self.transformer(x) 80 | return x 81 | 82 | class DistillableEfficientViT(DistillMixin, EfficientViT): 83 | def __init__(self, *args, **kwargs): 84 | super(DistillableEfficientViT, self).__init__(*args, **kwargs) 85 | self.args = args 86 | self.kwargs = kwargs 87 | self.dim = kwargs['dim'] 88 | self.num_classes = kwargs['num_classes'] 89 | 90 | def to_vit(self): 91 | v = EfficientViT(*self.args, **self.kwargs) 92 | v.load_state_dict(self.state_dict()) 93 | return v 94 | 95 | def _attend(self, x, mask): 96 | return self.transformer(x) 97 | 98 | # knowledge distillation wrapper 99 | 100 | class DistillWrapper(nn.Module): 101 | def __init__( 102 | self, 103 | *, 104 | teacher, 105 | student, 106 | temperature = 1., 107 | alpha = 0.5 108 | ): 109 | super().__init__() 110 | assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer' 111 | 112 | self.teacher = teacher 113 | self.student = student 114 | 115 | dim = student.dim 116 | num_classes = student.num_classes 117 | self.temperature = temperature 118 | self.alpha = alpha 119 | 120 | self.distillation_token = nn.Parameter(torch.randn(1, 1, dim)) 121 | 122 | self.distill_mlp = nn.Sequential( 123 | nn.LayerNorm(dim), 124 | nn.Linear(dim, num_classes) 125 | ) 126 | 127 | def forward(self, img, labels, temperature = None, alpha = None, **kwargs): 128 | b, *_ = img.shape 129 | alpha = alpha if exists(alpha) else self.alpha 130 | T = temperature if exists(temperature) else self.temperature 131 | 132 | with torch.no_grad(): 133 | teacher_logits = self.teacher(img) 134 | 135 | student_logits, distill_tokens = self.student(img, distill_token = self.distillation_token, **kwargs) 136 | distill_logits = self.distill_mlp(distill_tokens) 137 | 138 | loss = F.cross_entropy(student_logits, labels) 139 | 140 | distill_loss = F.kl_div( 141 | F.log_softmax(distill_logits / T, dim = -1), 142 | F.softmax(teacher_logits / T, dim = -1).detach(), 143 | reduction = 'batchmean') 144 | 145 | distill_loss *= T ** 2 146 | 147 | return loss * alpha + distill_loss * (1 - alpha) 148 | -------------------------------------------------------------------------------- /vit_pytorch_TransREID/efficient.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from einops import rearrange, repeat 4 | from einops.layers.torch import Rearrange 5 | 6 | class ViT(nn.Module): 7 | def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, pool = 'cls', channels = 3): 8 | super().__init__() 9 | assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size' 10 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 11 | num_patches = (image_size // patch_size) ** 2 12 | patch_dim = channels * patch_size ** 2 13 | 14 | self.to_patch_embedding = nn.Sequential( 15 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), 16 | nn.Linear(patch_dim, dim), 17 | ) 18 | 19 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 20 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 21 | self.transformer = transformer 22 | 23 | self.pool = pool 24 | self.to_latent = nn.Identity() 25 | 26 | self.mlp_head = nn.Sequential( 27 | nn.LayerNorm(dim), 28 | nn.Linear(dim, num_classes) 29 | ) 30 | 31 | def forward(self, img): 32 | x = self.to_patch_embedding(img) 33 | b, n, _ = x.shape 34 | 35 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 36 | x = torch.cat((cls_tokens, x), dim=1) 37 | x += self.pos_embedding[:, :(n + 1)] 38 | x = self.transformer(x) 39 | 40 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 41 | 42 | x = self.to_latent(x) 43 | return self.mlp_head(x) 44 | -------------------------------------------------------------------------------- /vit_pytorch_TransREID/mpp.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import reduce 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | from einops import rearrange, repeat 9 | 10 | # helpers 11 | 12 | 13 | def prob_mask_like(t, prob): 14 | batch, seq_length, _ = t.shape 15 | return torch.zeros((batch, seq_length)).float().uniform_(0, 1) < prob 16 | 17 | 18 | def get_mask_subset_with_prob(patched_input, prob): 19 | batch, seq_len, _, device = *patched_input.shape, patched_input.device 20 | max_masked = math.ceil(prob * seq_len) 21 | 22 | rand = torch.rand((batch, seq_len), device=device) 23 | _, sampled_indices = rand.topk(max_masked, dim=-1) 24 | 25 | new_mask = torch.zeros((batch, seq_len), device=device) 26 | new_mask.scatter_(1, sampled_indices, 1) 27 | return new_mask.bool() 28 | 29 | 30 | # mpp loss 31 | 32 | 33 | class MPPLoss(nn.Module): 34 | def __init__(self, patch_size, channels, output_channel_bits, 35 | max_pixel_val): 36 | super(MPPLoss, self).__init__() 37 | self.patch_size = patch_size 38 | self.channels = channels 39 | self.output_channel_bits = output_channel_bits 40 | self.max_pixel_val = max_pixel_val 41 | 42 | def forward(self, predicted_patches, target, mask): 43 | # reshape target to patches 44 | p = self.patch_size 45 | target = rearrange(target, 46 | "b c (h p1) (w p2) -> b (h w) c (p1 p2) ", 47 | p1=p, 48 | p2=p) 49 | 50 | avg_target = target.mean(dim=3) 51 | 52 | bin_size = self.max_pixel_val / self.output_channel_bits 53 | channel_bins = torch.arange(bin_size, self.max_pixel_val, bin_size) 54 | discretized_target = torch.bucketize(avg_target, channel_bins) 55 | discretized_target = F.one_hot(discretized_target, 56 | self.output_channel_bits) 57 | c, bi = self.channels, self.output_channel_bits 58 | discretized_target = rearrange(discretized_target, 59 | "b n c bi -> b n (c bi)", 60 | c=c, 61 | bi=bi) 62 | 63 | bin_mask = 2**torch.arange(c * bi - 1, -1, 64 | -1).to(discretized_target.device, 65 | discretized_target.dtype) 66 | target_label = torch.sum(bin_mask * discretized_target, -1) 67 | 68 | predicted_patches = predicted_patches[mask] 69 | target_label = target_label[mask] 70 | loss = F.cross_entropy(predicted_patches, target_label) 71 | return loss 72 | 73 | 74 | # main class 75 | 76 | 77 | class MPP(nn.Module): 78 | def __init__(self, 79 | transformer, 80 | patch_size, 81 | dim, 82 | output_channel_bits=3, 83 | channels=3, 84 | max_pixel_val=1.0, 85 | mask_prob=0.15, 86 | replace_prob=0.5, 87 | random_patch_prob=0.5): 88 | super().__init__() 89 | 90 | self.transformer = transformer 91 | self.loss = MPPLoss(patch_size, channels, output_channel_bits, 92 | max_pixel_val) 93 | 94 | # output transformation 95 | self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels)) 96 | 97 | # vit related dimensions 98 | self.patch_size = patch_size 99 | 100 | # mpp related probabilities 101 | self.mask_prob = mask_prob 102 | self.replace_prob = replace_prob 103 | self.random_patch_prob = random_patch_prob 104 | 105 | # token ids 106 | self.mask_token = nn.Parameter(torch.randn(1, 1, dim * channels)) 107 | 108 | def forward(self, input, **kwargs): 109 | transformer = self.transformer 110 | # clone original image for loss 111 | img = input.clone().detach() 112 | 113 | # reshape raw image to patches 114 | p = self.patch_size 115 | input = rearrange(input, 116 | 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', 117 | p1=p, 118 | p2=p) 119 | 120 | mask = get_mask_subset_with_prob(input, self.mask_prob) 121 | 122 | # mask input with mask patches with probability of `replace_prob` (keep patches the same with probability 1 - replace_prob) 123 | masked_input = input.clone().detach() 124 | 125 | # if random token probability > 0 for mpp 126 | if self.random_patch_prob > 0: 127 | random_patch_sampling_prob = self.random_patch_prob / ( 128 | 1 - self.replace_prob) 129 | random_patch_prob = prob_mask_like(input, 130 | random_patch_sampling_prob) 131 | bool_random_patch_prob = mask * random_patch_prob == True 132 | random_patches = torch.randint(0, 133 | input.shape[1], 134 | (input.shape[0], input.shape[1]), 135 | device=input.device) 136 | randomized_input = masked_input[ 137 | torch.arange(masked_input.shape[0]).unsqueeze(-1), 138 | random_patches] 139 | masked_input[bool_random_patch_prob] = randomized_input[ 140 | bool_random_patch_prob] 141 | 142 | # [mask] input 143 | replace_prob = prob_mask_like(input, self.replace_prob) 144 | bool_mask_replace = (mask * replace_prob) == True 145 | masked_input[bool_mask_replace] = self.mask_token 146 | 147 | # linear embedding of patches 148 | masked_input = transformer.to_patch_embedding[-1](masked_input) 149 | 150 | # add cls token to input sequence 151 | b, n, _ = masked_input.shape 152 | cls_tokens = repeat(transformer.cls_token, '() n d -> b n d', b=b) 153 | masked_input = torch.cat((cls_tokens, masked_input), dim=1) 154 | 155 | # add positional embeddings to input 156 | masked_input += transformer.pos_embedding[:, :(n + 1)] 157 | masked_input = transformer.dropout(masked_input) 158 | 159 | # get generator output and get mpp loss 160 | masked_input = transformer.transformer(masked_input, **kwargs) 161 | cls_logits = self.to_bits(masked_input) 162 | logits = cls_logits[:, 1:, :] 163 | 164 | mpp_loss = self.loss(logits, img, mask) 165 | 166 | return mpp_loss 167 | -------------------------------------------------------------------------------- /vit_pytorch_TransREID/t2t.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | 5 | from vit_pytorch.vit_pytorch import Transformer 6 | 7 | from einops import rearrange, repeat 8 | from einops.layers.torch import Rearrange 9 | 10 | # helpers 11 | 12 | def exists(val): 13 | return val is not None 14 | 15 | def conv_output_size(image_size, kernel_size, stride, padding): 16 | return int(((image_size - kernel_size + (2 * padding)) / stride) + 1) 17 | 18 | # classes 19 | 20 | class RearrangeImage(nn.Module): 21 | def forward(self, x): 22 | return rearrange(x, 'b (h w) c -> b c h w', h = int(math.sqrt(x.shape[1]))) 23 | 24 | # main class 25 | 26 | class T2TViT(nn.Module): 27 | def __init__( 28 | self, *, image_size, num_classes, dim, depth = None, heads = None, mlp_dim = None, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., transformer = None, t2t_layers = ((7, 4), (3, 2), (3, 2))): 29 | super().__init__() 30 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 31 | 32 | layers = [] 33 | layer_dim = channels 34 | output_image_size = image_size 35 | 36 | for i, (kernel_size, stride) in enumerate(t2t_layers): 37 | layer_dim *= kernel_size ** 2 38 | is_first = i == 0 39 | output_image_size = conv_output_size(output_image_size, kernel_size, stride, stride // 2) 40 | 41 | layers.extend([ 42 | RearrangeImage() if not is_first else nn.Identity(), 43 | nn.Unfold(kernel_size = kernel_size, stride = stride, padding = stride // 2), 44 | Rearrange('b c n -> b n c'), 45 | Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout), 46 | ]) 47 | 48 | layers.append(nn.Linear(layer_dim, dim)) 49 | self.to_patch_embedding = nn.Sequential(*layers) 50 | 51 | self.pos_embedding = nn.Parameter(torch.randn(1, output_image_size ** 2 + 1, dim)) 52 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 53 | self.dropout = nn.Dropout(emb_dropout) 54 | 55 | if not exists(transformer): 56 | assert all([exists(depth), exists(heads), exists(mlp_dim)]), 'depth, heads, and mlp_dim must be supplied' 57 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 58 | else: 59 | self.transformer = transformer 60 | 61 | self.pool = pool 62 | self.to_latent = nn.Identity() 63 | 64 | self.mlp_head = nn.Sequential( 65 | nn.LayerNorm(dim), 66 | nn.Linear(dim, num_classes) 67 | ) 68 | 69 | def forward(self, img): 70 | x = self.to_patch_embedding(img) 71 | b, n, _ = x.shape 72 | 73 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 74 | x = torch.cat((cls_tokens, x), dim=1) 75 | x += self.pos_embedding 76 | x = self.dropout(x) 77 | 78 | x = self.transformer(x) 79 | 80 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 81 | 82 | x = self.to_latent(x) 83 | return self.mlp_head(x) 84 | -------------------------------------------------------------------------------- /vit_pytorch_TransREID/vit_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | 8 | class Residual(nn.Module): 9 | def __init__(self, fn): 10 | super().__init__() 11 | self.fn = fn 12 | def forward(self, x, **kwargs): 13 | return self.fn(x, **kwargs) + x 14 | 15 | class PreNorm(nn.Module): 16 | def __init__(self, dim, fn): 17 | super().__init__() 18 | self.norm = nn.LayerNorm(dim) 19 | self.fn = fn 20 | def forward(self, x, **kwargs): 21 | return self.fn(self.norm(x), **kwargs) 22 | 23 | class FeedForward(nn.Module): 24 | def __init__(self, dim, hidden_dim, dropout = 0.): 25 | super().__init__() 26 | self.net = nn.Sequential( 27 | nn.Linear(dim, hidden_dim), 28 | # nn.GELU(), 29 | nn.ReLU(), 30 | nn.Dropout(dropout), 31 | nn.Linear(hidden_dim, dim), 32 | nn.Dropout(dropout) 33 | ) 34 | def forward(self, x): 35 | return self.net(x) 36 | 37 | class Attention(nn.Module): 38 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 39 | super().__init__() 40 | inner_dim = dim_head * heads 41 | project_out = not (heads == 1 and dim_head == dim) 42 | 43 | self.heads = heads 44 | self.scale = dim_head ** -0.5 45 | 46 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 47 | 48 | self.to_out = nn.Sequential( 49 | nn.Linear(inner_dim, dim), 50 | nn.Dropout(dropout) 51 | ) if project_out else nn.Identity() 52 | 53 | def forward(self, x, mask = None): 54 | b, n, _, h = *x.shape, self.heads 55 | qkv = self.to_qkv(x).chunk(3, dim = -1) 56 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 57 | 58 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 59 | mask_value = -torch.finfo(dots.dtype).max 60 | 61 | if mask is not None: 62 | mask = F.pad(mask.flatten(1), (1, 0), value = True) 63 | assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' 64 | mask = rearrange(mask, 'b i -> b () i ()') * rearrange(mask, 'b j -> b () () j') 65 | dots.masked_fill_(~mask, mask_value) 66 | del mask 67 | 68 | attn = dots.softmax(dim=-1) 69 | 70 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 71 | out = rearrange(out, 'b h n d -> b n (h d)') 72 | out = self.to_out(out) 73 | return out 74 | 75 | class Transformer(nn.Module): 76 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 77 | super().__init__() 78 | self.layers = nn.ModuleList([]) 79 | for _ in range(depth): 80 | self.layers.append(nn.ModuleList([ 81 | Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))), 82 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) 83 | ])) 84 | def forward(self, x, mask = None): 85 | for attn, ff in self.layers: 86 | x = attn(x, mask = mask) 87 | x = ff(x) 88 | return x 89 | 90 | class ViT(nn.Module): 91 | def __init__(self, *, image_size_h, image_size_w,patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): 92 | super().__init__() 93 | assert image_size_h % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 94 | assert image_size_w % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 95 | # num_patches = (image_size // patch_size) ** 2 96 | num_patches = (image_size_h // patch_size)*(image_size_w // patch_size) 97 | patch_dim = channels * patch_size ** 2 98 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 99 | 100 | self.to_patch_embedding = nn.Sequential( 101 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), 102 | nn.Linear(patch_dim, dim), 103 | ) 104 | 105 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 106 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 107 | self.dropout = nn.Dropout(emb_dropout) 108 | 109 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 110 | 111 | self.pool = pool 112 | self.to_latent = nn.Identity() 113 | 114 | self.mlp_head = nn.Sequential( 115 | nn.LayerNorm(dim), 116 | nn.Linear(dim, num_classes) 117 | ) 118 | self.conv = nn.Linear(dim, dim) 119 | def forward(self, img, mask = None): 120 | x = self.to_patch_embedding(img) 121 | b, n, _ = x.shape 122 | 123 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 124 | x = torch.cat((cls_tokens, x), dim=1) 125 | x += self.pos_embedding[:, :(n + 1)] 126 | x = self.dropout(x) 127 | 128 | x = self.transformer(x, mask) 129 | 130 | f = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 131 | # f = self.conv(x) 132 | x = self.to_latent(f) 133 | if self.training: 134 | return self.mlp_head(x) , f 135 | else: 136 | return f 137 | # return x --------------------------------------------------------------------------------