├── paper └── neurips_CMPS.pdf ├── reid ├── utils │ ├── data │ │ ├── __init__.py │ │ ├── preprocessor.py │ │ ├── sampler.py │ │ ├── transforms.py │ │ └── dataset.py │ ├── osutils.py │ ├── meters.py │ ├── __init__.py │ ├── logging.py │ └── serialization.py ├── evaluation_metrics │ ├── __init__.py │ ├── classification.py │ └── ranking.py ├── feature_extraction │ ├── __init__.py │ ├── database.py │ └── cnn.py ├── __init__.py ├── loss │ ├── __init__.py │ ├── oim.py │ ├── triplet.py │ ├── ms_loss.py │ ├── tripletAttack.py │ └── loss.py ├── models │ ├── __init__.py │ ├── resnet.py │ ├── DDAG.py │ ├── resnet2.py │ ├── AGW.py │ └── attention.py ├── datasets │ ├── __init__.py │ ├── sysu.py │ └── regdb.py └── evaluators.py ├── run.sh ├── deal_SYSU_testset_ID.py ├── testset_to_query.py ├── cross-modal_dataset_to_market_format.py ├── README.md ├── utils.py ├── test.py ├── MI_SGD.py └── CMPS_attack.py /paper/neurips_CMPS.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/finger-monkey/CMPS/HEAD/paper/neurips_CMPS.pdf -------------------------------------------------------------------------------- /reid/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .dataset import Dataset 4 | from .preprocessor import Preprocessor 5 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | 2 | ##########################DDAG 3 | 4 | 5 | CUDA_VISIBLE_DEVICES=1 python CMPS_attack.py --data /sda1/data -s regdb_v2 --batch_size 64 --resume /sda1/DDAG/save_model/regdb_G_P_3_drop_0.2_4_8_lr_0.1_seed_0_best.t --max-eps 8 6 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .classification import accuracy 4 | from .ranking import cmc, mean_ap 5 | 6 | __all__ = [ 7 | 'accuracy', 8 | 'cmc', 9 | 'mean_ap', 10 | ] 11 | -------------------------------------------------------------------------------- /reid/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /reid/feature_extraction/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .cnn import extract_cnn_feature, extract_pcb_feature 4 | from .database import FeatureDatabase 5 | 6 | __all__ = [ 7 | 'extract_cnn_feature', 'extract_pcb_feature', 8 | 'FeatureDatabase', 9 | ] 10 | -------------------------------------------------------------------------------- /reid/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import datasets 4 | from . import evaluation_metrics 5 | from . import feature_extraction 6 | from . import loss 7 | from . import models 8 | from . import utils 9 | from . import evaluators 10 | 11 | 12 | __version__ = '0.2.0' 13 | -------------------------------------------------------------------------------- /reid/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .oim import oim, OIM, OIMLoss 4 | from .tripletAttack import TripletLoss 5 | from .triplet import Triplet 6 | 7 | __all__ = [ 8 | 'oim', 9 | 'OIM', 10 | 'OIMLoss', 11 | 'loss','ms_loss' 12 | 'TripletLoss', 'Triplet' 13 | ] 14 | -------------------------------------------------------------------------------- /reid/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from ..utils import to_torch 4 | 5 | 6 | def accuracy(output, target, topk=(1,)): 7 | output, target = to_torch(output), to_torch(target) 8 | maxk = max(topk) 9 | batch_size = target.size(0) 10 | 11 | _, pred = output.topk(maxk, 1, True, True) 12 | pred = pred.t() 13 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 14 | 15 | ret = [] 16 | for k in topk: 17 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) 18 | ret.append(correct_k.mul_(1. / batch_size)) 19 | return ret 20 | -------------------------------------------------------------------------------- /reid/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from .resnet import * 3 | from .AGW import embed_net 4 | from .DDAG import embed_net2 5 | 6 | __factory = { 7 | 'resnet18': resnet18, 8 | 'resnet34': resnet34, 9 | 'resnet50': resnet50, 10 | 'resnet101': resnet101, 11 | 'resnet152': resnet152, 12 | 'AGW': embed_net, 13 | 'DDAG': embed_net2 14 | } 15 | 16 | def names(): 17 | return sorted(__factory.keys()) 18 | 19 | def create(name, *args, **kwargs): 20 | if name not in __factory: 21 | raise KeyError("Unknown model:", name) 22 | return __factory[name](*args, **kwargs) 23 | -------------------------------------------------------------------------------- /reid/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def to_numpy(tensor): 7 | if torch.is_tensor(tensor): 8 | return tensor.cpu().numpy() 9 | elif type(tensor).__module__ != 'numpy': 10 | raise ValueError("Cannot convert {} to numpy array" 11 | .format(type(tensor))) 12 | return tensor 13 | 14 | 15 | def to_torch(ndarray): 16 | if type(ndarray).__module__ == 'numpy': 17 | return torch.from_numpy(ndarray) 18 | elif not torch.is_tensor(ndarray): 19 | raise ValueError("Cannot convert {} to torch tensor" 20 | .format(type(ndarray))) 21 | return ndarray 22 | -------------------------------------------------------------------------------- /reid/utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /reid/utils/data/preprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os.path as osp 3 | 4 | from PIL import Image 5 | 6 | 7 | class Preprocessor(object): 8 | def __init__(self, dataset, root=None, transform=None): 9 | super(Preprocessor, self).__init__() 10 | self.dataset = dataset 11 | self.root = root 12 | self.transform = transform 13 | 14 | def __len__(self): 15 | return len(self.dataset) 16 | 17 | def __getitem__(self, indices): 18 | if isinstance(indices, (tuple, list)): 19 | return [self._get_single_item(index) for index in indices] 20 | return self._get_single_item(indices) 21 | 22 | def _get_single_item(self, index): 23 | fname, pid, camid = self.dataset[index] 24 | fpath = fname 25 | if self.root is not None: 26 | fpath = osp.join(self.root, fname) 27 | img = Image.open(fpath).convert('RGB') 28 | if self.transform is not None: 29 | img = self.transform(img) 30 | return img, fname, pid, camid 31 | -------------------------------------------------------------------------------- /reid/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data.sampler import ( 7 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 8 | WeightedRandomSampler) 9 | 10 | 11 | class RandomIdentitySampler(Sampler): 12 | def __init__(self, data_source, num_instances=1): 13 | self.data_source = data_source 14 | self.num_instances = num_instances 15 | self.index_dic = defaultdict(list) 16 | for index, (_, pid, _) in enumerate(data_source): 17 | self.index_dic[pid].append(index) 18 | self.pids = list(self.index_dic.keys()) 19 | self.num_samples = len(self.pids) 20 | 21 | def __len__(self): 22 | return self.num_samples * self.num_instances 23 | 24 | def __iter__(self): 25 | indices = torch.randperm(self.num_samples) 26 | ret = [] 27 | for i in indices: 28 | pid = self.pids[i] 29 | t = self.index_dic[pid] 30 | if len(t) >= self.num_instances: 31 | t = np.random.choice(t, size=self.num_instances, replace=False) 32 | else: 33 | t = np.random.choice(t, size=self.num_instances, replace=True) 34 | ret.extend(t) 35 | return iter(ret) 36 | -------------------------------------------------------------------------------- /reid/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import warnings 3 | from .regdb import Regdb 4 | from .sysu import Sysu 5 | 6 | warnings.filterwarnings("ignore") 7 | __factory = { 8 | 'regdb_v2': Regdb, 9 | 'sysu_v2': Sysu 10 | } 11 | 12 | def names(): 13 | return sorted(__factory.keys()) 14 | 15 | def create(name, root, *args, **kwargs): 16 | """ 17 | Create a dataset instance. 18 | 19 | Parameters 20 | ---------- 21 | name : str 22 | The dataset name. Can be one of 'viper', 'cuhk01', 'cuhk03', 23 | 'market1501', and 'dukemtmc'. 24 | root : str 25 | The path to the dataset directory. 26 | split_id : int, optional 27 | The index of data split. Default: 0 28 | num_val : int or float, optional 29 | When int, it means the number of validation identities. When float, 30 | it means the proportion of validation to all the trainval. Default: 100 31 | download : bool, optional 32 | If True, will download the dataset. Default: False 33 | """ 34 | if name not in __factory: 35 | raise KeyError("Unknown dataset:", name) 36 | return __factory[name](root, *args, **kwargs) 37 | 38 | def get_dataset(name, root, *args, **kwargs): 39 | warnings.warn("get_dataset is deprecated. Use create instead.") 40 | return create(name, root, *args, **kwargs) 41 | -------------------------------------------------------------------------------- /deal_SYSU_testset_ID.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import shutil 4 | 5 | # 定义文件夹A和文件夹B的路径 6 | # folder_A = 'D:/works/studio/data/SYSU-MM01/SYSU-MM01/visible/cam1/' # 替换为文件夹A的实际路径 7 | # folder_B = 'D:/works/studio/data/SYSU-MM01/SYSU-MM01/visible/test/cam1/' # 替换为文件夹B的实际路径 8 | 9 | folder_A = 'D:/works/studio/data/SYSU-MM01/SYSU-MM01/thermal/cam6/' # 替换为文件夹A的实际路径 10 | folder_B = 'D:/works/studio/data/SYSU-MM01/SYSU-MM01/thermal/test/cam6/' # 替换为文件夹B的实际路径 11 | 12 | # 确保输出根文件夹存在 13 | if not os.path.exists(folder_B): 14 | os.makedirs(folder_B) 15 | 16 | # 打开并读取test_id.txt文件,获取测试ID列表 17 | with open('D:/works/studio/data/SYSU-MM01/SYSU-MM01/test_id.txt', 'r') as file: 18 | test_ids = [int(id.strip()) for id in file.readline().split(',')] 19 | 20 | # 遍历文件夹A中的子文件夹 21 | for root, dirs, files in os.walk(folder_A): 22 | for folder_name in dirs: 23 | # 获取子文件夹的编号 24 | folder_number = int(folder_name) 25 | 26 | # 如果子文件夹的编号在测试ID列表中 27 | if folder_number in test_ids: 28 | # 构建源文件夹和目标文件夹的完整路径 29 | source_folder_path = os.path.join(root, folder_name) 30 | target_folder_path = os.path.join(folder_B, folder_name) 31 | 32 | # 移动子文件夹到文件夹B中 33 | shutil.move(source_folder_path, target_folder_path) 34 | print(f"移动子文件夹 {folder_name} 到 {target_folder_path}") 35 | 36 | print("任务完成!") 37 | -------------------------------------------------------------------------------- /testset_to_query.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | import random 5 | import shutil 6 | 7 | # 指定文件夹A和文件夹B的路径 8 | folder_A = 'D:/works/studio/data/RegDB/RegDB/split/bounding_box_test/Visible/' 9 | folder_B = 'D:/works/studio/data/RegDB/RegDB/split/bounding_box_test/Visible2/' 10 | 11 | # 获取文件夹A中的所有子文件夹 12 | subfolders_A = [f for f in os.listdir(folder_A) if os.path.isdir(os.path.join(folder_A, f))] 13 | 14 | # 遍历每个子文件夹 15 | for folder_i in subfolders_A: 16 | # 构建文件夹i的完整路径 17 | folder_A_i = os.path.join(folder_A, folder_i) 18 | 19 | # 获取文件夹i中的所有图片文件 20 | image_files = [f for f in os.listdir(folder_A_i) if f.endswith('.bmp')] #注意图片格式 21 | 22 | # 计算要剪切的图片数量(一半) 23 | num_images_to_cut = len(image_files) // 2 24 | 25 | # 随机选择要剪切的图片 26 | images_to_cut = random.sample(image_files, num_images_to_cut) 27 | 28 | # 构建文件夹i在文件夹B中的路径 29 | folder_B_i = os.path.join(folder_B, folder_i) 30 | 31 | # 确保文件夹B中的文件夹i存在 32 | if not os.path.exists(folder_B_i): 33 | os.makedirs(folder_B_i) 34 | 35 | # 遍历要剪切的图片并执行剪切操作 36 | for image in images_to_cut: 37 | source_path = os.path.join(folder_A_i, image) 38 | target_path = os.path.join(folder_B_i, image) 39 | 40 | # 执行剪切操作 41 | shutil.move(source_path, target_path) 42 | 43 | # 输出剪切文件的信息 44 | print(f"剪切文件: {source_path} 到 {target_path}") 45 | 46 | print("任务完成!") 47 | -------------------------------------------------------------------------------- /reid/feature_extraction/database.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import h5py 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | import warnings 7 | warnings.filterwarnings("ignore") 8 | 9 | class FeatureDatabase(Dataset): 10 | def __init__(self, *args, **kwargs): 11 | super(FeatureDatabase, self).__init__() 12 | self.fid = h5py.File(*args, **kwargs) 13 | 14 | def __enter__(self): 15 | return self 16 | 17 | def __exit__(self, exc_type, exc_val, exc_tb): 18 | self.close() 19 | 20 | def __getitem__(self, keys): 21 | if isinstance(keys, (tuple, list)): 22 | return [self._get_single_item(k) for k in keys] 23 | return self._get_single_item(keys) 24 | 25 | def _get_single_item(self, key): 26 | return np.asarray(self.fid[key]) 27 | 28 | def __setitem__(self, key, value): 29 | if key in self.fid: 30 | if self.fid[key].shape == value.shape and \ 31 | self.fid[key].dtype == value.dtype: 32 | self.fid[key][...] = value 33 | else: 34 | del self.fid[key] 35 | self.fid.create_dataset(key, data=value) 36 | else: 37 | self.fid.create_dataset(key, data=value) 38 | 39 | def __delitem__(self, key): 40 | del self.fid[key] 41 | 42 | def __len__(self): 43 | return len(self.fid) 44 | 45 | def __iter__(self): 46 | return iter(self.fid) 47 | 48 | def flush(self): 49 | self.fid.flush() 50 | 51 | def close(self): 52 | self.fid.close() 53 | -------------------------------------------------------------------------------- /reid/feature_extraction/cnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import OrderedDict 3 | 4 | import torch 5 | from torch.autograd import Variable 6 | 7 | from ..utils import to_torch 8 | import warnings 9 | warnings.filterwarnings("ignore") 10 | 11 | def extract_cnn_feature(model, inputs, modules=None): 12 | model.eval() 13 | inputs = to_torch(inputs) 14 | inputs = Variable(inputs, volatile=True).cuda() 15 | if modules is None: 16 | outputs = model(inputs)[0] 17 | outputs = outputs.data.cpu() 18 | return outputs 19 | 20 | # Register forward hook for each module 21 | outputs = OrderedDict() 22 | handles = [] 23 | for m in modules: 24 | outputs[id(m)] = None 25 | def func(m, i, o): outputs[id(m)] = o.data.cpu() 26 | handles.append(m.register_forward_hook(func)) 27 | model(inputs) 28 | for h in handles: 29 | h.remove() 30 | return list(outputs.values()) 31 | 32 | 33 | def extract_pcb_feature(model, inputs, modules=None): 34 | model.eval() 35 | inputs = to_torch(inputs) 36 | inputs = Variable(inputs, volatile=True).cuda() 37 | if modules is None: 38 | outputs = model(inputs) 39 | outputs = outputs.data.cpu() 40 | return outputs 41 | 42 | # Register forward hook for each module 43 | outputs = OrderedDict() 44 | handles = [] 45 | for m in modules: 46 | outputs[id(m)] = None 47 | def func(m, i, o): outputs[id(m)] = o.data.cpu() 48 | handles.append(m.register_forward_hook(func)) 49 | model(inputs) 50 | for h in handles: 51 | h.remove() 52 | return list(outputs.values()) 53 | -------------------------------------------------------------------------------- /reid/loss/oim.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, autograd 6 | 7 | 8 | class OIM(autograd.Function): 9 | def __init__(self, lut, momentum=0.5): 10 | super(OIM, self).__init__() 11 | self.lut = lut 12 | self.momentum = momentum 13 | 14 | def forward(self, inputs, targets): 15 | self.save_for_backward(inputs, targets) 16 | outputs = inputs.mm(self.lut.t()) 17 | return outputs 18 | 19 | def backward(self, grad_outputs): 20 | inputs, targets = self.saved_tensors 21 | grad_inputs = None 22 | if self.needs_input_grad[0]: 23 | grad_inputs = grad_outputs.mm(self.lut) 24 | for x, y in zip(inputs, targets): 25 | self.lut[y] = self.momentum * self.lut[y] + (1. - self.momentum) * x 26 | self.lut[y] /= self.lut[y].norm() 27 | return grad_inputs, None 28 | 29 | 30 | def oim(inputs, targets, lut, momentum=0.5): 31 | return OIM(lut, momentum=momentum)(inputs, targets) 32 | 33 | 34 | class OIMLoss(nn.Module): 35 | def __init__(self, num_features, num_classes, scalar=1.0, momentum=0.5, 36 | weight=None, size_average=True): 37 | super(OIMLoss, self).__init__() 38 | self.num_features = num_features 39 | self.num_classes = num_classes 40 | self.momentum = momentum 41 | self.scalar = scalar 42 | self.weight = weight 43 | self.size_average = size_average 44 | 45 | self.register_buffer('lut', torch.zeros(num_classes, num_features)) 46 | 47 | def forward(self, inputs, targets): 48 | inputs = oim(inputs, targets, self.lut, momentum=self.momentum) 49 | inputs *= self.scalar 50 | loss = F.cross_entropy(inputs, targets, weight=self.weight, 51 | size_average=self.size_average) 52 | return loss, inputs 53 | -------------------------------------------------------------------------------- /reid/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os.path as osp 4 | import shutil 5 | 6 | import torch 7 | from torch.nn import Parameter 8 | 9 | from .osutils import mkdir_if_missing 10 | 11 | 12 | def read_json(fpath): 13 | with open(fpath, 'r') as f: 14 | obj = json.load(f) 15 | return obj 16 | 17 | 18 | def write_json(obj, fpath): 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | with open(fpath, 'w') as f: 21 | json.dump(obj, f, indent=4, separators=(',', ': ')) 22 | 23 | 24 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 25 | mkdir_if_missing(osp.dirname(fpath)) 26 | torch.save(state, fpath) 27 | if is_best: 28 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 29 | 30 | 31 | def load_checkpoint(fpath): 32 | if osp.isfile(fpath): 33 | checkpoint = torch.load(fpath) 34 | print("=> Loaded checkpoint '{}'".format(fpath)) 35 | return checkpoint 36 | else: 37 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 38 | 39 | 40 | def copy_state_dict(state_dict, model, strip=None): 41 | tgt_state = model.state_dict() 42 | copied_names = set() 43 | for name, param in state_dict.items(): 44 | if strip is not None and name.startswith(strip): 45 | name = name[len(strip):] 46 | if name not in tgt_state: 47 | continue 48 | if isinstance(param, Parameter): 49 | param = param.data 50 | if param.size() != tgt_state[name].size(): 51 | print('mismatch:', name, param.size(), tgt_state[name].size()) 52 | continue 53 | tgt_state[name].copy_(param) 54 | copied_names.add(name) 55 | 56 | missing = set(tgt_state.keys()) - copied_names 57 | if len(missing) > 0: 58 | print("missing keys in state_dict:", missing) 59 | 60 | return model 61 | -------------------------------------------------------------------------------- /cross-modal_dataset_to_market_format.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | import shutil 5 | 6 | # 指定文件夹A和文件夹B的路径 7 | # folder_A = 'D:/works/studio/data/RegDB/RegDB/split/bounding_box_test/Visible/' 8 | # folder_B = 'D:/works/studio/data/RegDB/RegDB/split/deal/bounding_box_test/Visible/' 9 | 10 | # folder_A = 'D:/works/studio/data/RegDB/RegDB/split/combime/bounding_box_test/' 11 | # folder_B = 'D:/works/studio/data/RegDB/RegDB/split/combime/deal/bounding_box_test/' 12 | 13 | folder_A = 'D:/works/studio/data/SYSU-MM01/SYSU-MM01/thermal/test/cam3/' 14 | folder_B = 'D:/works/studio/data/SYSU-MM01/SYSU-MM01/deal/thermal/test/cam3/' 15 | 16 | # folder_A = 'D:/works/studio/data/SYSU-MM01/SYSU-MM01/visible/test/cam5/' 17 | # folder_B = 'D:/works/studio/data/SYSU-MM01/SYSU-MM01/deal/visible/test/cam5/' 18 | 19 | # 确保输出根文件夹存在 20 | if not os.path.exists(folder_B): 21 | os.makedirs(folder_B) 22 | 23 | # 用户指定的cam值 24 | # cam_value = input("请输入cam的值(例如c1_s1):") 25 | cam_value = "c3_s1" 26 | # 遍历文件夹A中的子文件夹 27 | for folder_K in os.listdir(folder_A): 28 | # 检查子文件夹是否是数字命名的 29 | if folder_K.isdigit(): 30 | # 计算文件夹K的pid,并格式化为四位宽度的字符串 31 | pid = folder_K.zfill(4) 32 | 33 | # 获取文件夹K中的所有图片文件 34 | # image_files = [f for f in os.listdir(os.path.join(folder_A, folder_K)) if f.endswith('.bmp')] ##注意图片格式!!!!!!!!!!!!!!!! 35 | image_files = [f for f in os.listdir(os.path.join(folder_A, folder_K)) if f.endswith('.jpg')]##注意图片格式!!!!!!!!!!!!!!!! 36 | 37 | # 遍历文件夹K中的图片文件 38 | for i, image_file in enumerate(image_files, start=1): 39 | # 构建新的文件名 40 | length = str(i).zfill(6) # 使用六位宽度表示图片序号 41 | # new_filename = f"{pid}_{cam_value}_{length}_01.bmp" 42 | new_filename = f"{pid}_{cam_value}_{length}_01.jpg" 43 | 44 | # 构建源文件和目标文件的完整路径 45 | source_path = os.path.join(folder_A, folder_K, image_file) 46 | target_path = os.path.join(folder_B, new_filename) 47 | 48 | # 复制文件并重命名 49 | shutil.copy2(source_path, target_path) 50 | 51 | # 输出复制文件的信息 52 | print(f"复制文件: {source_path} 到 {target_path}") 53 | 54 | print("任务完成!") 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Code for the NeurIPS 2024 paper ``Cross-Modality Perturbation Synergy Attack for Person Re-identification". 2 | 3 | 4 | 5 | ## [Paper](paper/CMPS.pdf) 6 | 7 | ## Requirements: 8 | * python 3.7 9 | * CUDA==11.2 10 | * faiss-gpu==1.6.0 11 | 12 | 13 | ## Preparing Data 14 | 15 | * There is a processed tar file in [BaiduYun](https://pan.baidu.com/s/160oRNcDSemBprqBUBX0PUQ?pwd=9pmk) (Password: 9pmk) with all needed files. 16 | 17 | 18 | ## Run our code 19 | 20 | See run.sh for more information. 21 | 22 | If you find this code useful in your research, please cite: 23 | 24 | ``` 25 | @article{gong2024cross, 26 | title={Cross-modality perturbation synergy attack for person re-identification}, 27 | author={Gong, Yunpeng and Zhong, Zhun and Qu, Yansong and Luo, Zhiming and Ji, Rongrong and Jiang, Min}, 28 | journal={Advances in Neural Information Processing Systems}, 29 | volume={37}, 30 | pages={23352--23377}, 31 | year={2024} 32 | } 33 | ``` 34 | 35 | ## Acknowledgments 36 | 37 | The code is based on [LTA](https://github.com/finger-monkey/LTA_and_joint-defence), and [Random Color Erasing](https://github.com/finger-monkey/Data-Augmentation). 38 | If you use the code, please cite their paper. 39 | ``` 40 | @inproceedings{colorAttack2022, 41 | title={Person re-identification method based on color attack and joint defence}, 42 | author={Gong, Yunpeng and Huang, Liqing and Chen, Lifei}, 43 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 44 | pages={4313--4322}, 45 | year={2022} 46 | } 47 | ``` 48 | ``` 49 | @article{gong2021eliminate, 50 | title={Eliminate deviation with deviation for data augmentation and a general multi-modal data learning method}, 51 | author={Gong, Yunpeng and Huang, Liqing and Chen, Lifei}, 52 | journal={arXiv preprint arXiv:2101.08533}, 53 | year={2021} 54 | } 55 | ``` 56 | 57 | 58 | 59 | 60 | 61 | 62 | ## Contact Me 63 | 64 | Email: fmonkey625@gmail.com 65 | 66 | 67 | 68 | 69 | 70 | Flag Counter 71 | -------------------------------------------------------------------------------- /reid/loss/triplet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | from torch.nn import functional as F 7 | from scipy.stats import norm 8 | 9 | import numpy as np 10 | 11 | 12 | class Triplet(nn.Module): 13 | def __init__(self, margin=0, num_instances=0, use_semi=True): 14 | super(Triplet, self).__init__() 15 | self.margin = margin 16 | self.ranking_loss = nn.MarginRankingLoss(margin=self.margin) 17 | self.K = num_instances 18 | self.use_semi = use_semi 19 | 20 | def forward(self, inputs, targets, epoch): 21 | n = inputs.size(0) 22 | P = n / self.K 23 | t0 = 20.0 24 | t1 = 40.0 25 | 26 | # Compute pairwise distance, replace by the official when merged 27 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 28 | dist = dist + dist.t() 29 | dist.addmm_(1, -2, inputs, inputs.t()) 30 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 31 | # For each anchor, find the hardest positive and negative 32 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 33 | dist_ap, dist_an = [], [] 34 | if self.use_semi: 35 | for i in range(P): 36 | for j in range(self.K): 37 | neg_examples = dist[i * self.K + j][mask[i * self.K + j] == 0] 38 | for pair in range(j + 1, self.K): 39 | ap = dist[i * self.K + j][i * self.K + pair].view(1) 40 | dist_ap.append(ap) 41 | dist_an.append(neg_examples.min().view(1)) 42 | else: 43 | for i in range(n): 44 | dist_ap.append(dist[i][mask[i]].max().view(1)) 45 | dist_an.append(dist[i][mask[i] == 0].min().view(1)) 46 | dist_ap = torch.cat(dist_ap) 47 | dist_an = torch.cat(dist_an) 48 | # Compute ranking hinge loss 49 | y = dist_an.data.new() 50 | y.resize_as_(dist_an.data) 51 | y.fill_(1) 52 | y = Variable(y) 53 | loss = self.ranking_loss(dist_an, dist_ap, y) 54 | prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0) 55 | return loss, prec -------------------------------------------------------------------------------- /reid/loss/ms_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import pdb 5 | 6 | class MultiSimilarityLoss(nn.Module): 7 | def __init__(self, scale_pos = 0.2, scale_neg =40, thresh = 0.5, margin = 0.3): 8 | super(MultiSimilarityLoss, self).__init__() 9 | self.thresh = thresh 10 | self.margin = margin 11 | 12 | self.scale_pos = scale_pos 13 | self.scale_neg = scale_neg 14 | 15 | self.l2norm = Normalize(2) 16 | def forward(self, feats, labels): 17 | assert feats.size(0) == labels.size(0), \ 18 | f"feats.size(0): {feats.size(0)} is not equal to labels.size(0): {labels.size(0)}" 19 | batch_size = feats.size(0) 20 | 21 | feats = self.l2norm(feats) 22 | sim_mat = torch.matmul(feats, torch.t(feats)) 23 | 24 | epsilon = 1e-5 25 | loss = list() 26 | ptr = 0 27 | for i in range(batch_size): 28 | pos_pair_ = sim_mat[i][labels == labels[i]] 29 | pos_pair_ = pos_pair_[pos_pair_ < 1 - epsilon] 30 | neg_pair_ = sim_mat[i][labels != labels[i]] 31 | 32 | neg_pair = neg_pair_[neg_pair_ + self.margin > min(pos_pair_)] 33 | pos_pair = pos_pair_[pos_pair_ < max(neg_pair_)] 34 | # pos_pair = pos_pair_[pos_pair_ - self.margin < max(neg_pair_)] 35 | 36 | if len(neg_pair) < 1 or len(pos_pair) < 1: 37 | ptr = ptr + 1 38 | continue 39 | # pdb.set_trace() 40 | # weighting step 41 | pos_loss = 1.0 / self.scale_pos * torch.log( 42 | 1 + torch.sum(torch.exp(-self.scale_pos * (pos_pair - self.thresh)))) 43 | neg_loss = 1.0 / self.scale_neg * torch.log( 44 | 1 + torch.sum(torch.exp(self.scale_neg * (neg_pair - self.thresh)))) 45 | loss.append(pos_loss + neg_loss) 46 | # pdb.set_trace() 47 | if len(loss) == 0: 48 | return torch.zeros([], requires_grad=True) 49 | 50 | loss = sum(loss) / batch_size 51 | 52 | return loss, ptr 53 | 54 | class Normalize(nn.Module): 55 | def __init__(self, power=2): 56 | super(Normalize, self).__init__() 57 | self.power = power 58 | 59 | def forward(self, x): 60 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 61 | out = x.div(norm) 62 | return out -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import argparse 4 | import sys 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from reid import models 8 | from torch.nn import functional as F 9 | from reid import datasets 10 | from MI_SGD import MI_SGD,keepGradUpdate 11 | from test import test 12 | from reid.utils.data import transforms as T 13 | from torchvision.transforms import Resize 14 | from reid.utils.data.preprocessor import Preprocessor 15 | from reid.evaluators import Evaluator 16 | from torch.optim.optimizer import Optimizer, required 17 | import random 18 | import numpy as np 19 | import math 20 | from reid.evaluators import extract_features 21 | from reid.utils.meters import AverageMeter 22 | import torchvision 23 | import faiss 24 | from torchvision import transforms 25 | import logging 26 | 27 | def get_data(sourceName, split_id, data_dir, height, width, 28 | batch_size, workers, combine): 29 | root = osp.join(data_dir, sourceName) 30 | 31 | sourceSet = datasets.create(sourceName, root, num_val=0.1, split_id=split_id) 32 | num_classes = sourceSet.num_trainval_ids if combine else sourceSet.num_train_ids 33 | tgtSet = sourceSet 34 | class_tgt = tgtSet.num_trainval_ids if combine else tgtSet.num_train_ids 35 | 36 | train_transformer = T.Compose([ 37 | Resize((height, width)), 38 | transforms.RandomGrayscale(p=0.2), 39 | T.ToTensor(), 40 | ]) 41 | 42 | train_transformer2 = T.Compose([ 43 | Resize((height, width)), 44 | T.ToTensor(), 45 | ]) 46 | 47 | train_step1 = DataLoader( 48 | Preprocessor(sourceSet.trainval, root=sourceSet.images_dir, transform=train_transformer), 49 | batch_size=batch_size, num_workers=workers, shuffle=False, pin_memory=True) 50 | train_step3 = DataLoader( 51 | Preprocessor(sourceSet.trainval, root=sourceSet.images_dir, transform=train_transformer2), 52 | batch_size=batch_size, num_workers=workers, shuffle=False, pin_memory=True) 53 | 54 | return sourceSet, sourceSet, num_classes, class_tgt, train_step1, train_step3 55 | 56 | 57 | 58 | 59 | def calDist(qFeat, gFeat): 60 | m, n = qFeat.size(0), gFeat.size(0) 61 | x = qFeat.view(m, -1) 62 | y = gFeat.view(n, -1) 63 | dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 64 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 65 | dist_m.addmm_(1, -2, x, y.t()) 66 | return dist_m 67 | 68 | -------------------------------------------------------------------------------- /reid/loss/tripletAttack.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import torch 3 | from torch import nn 4 | from torch.autograd import Variable 5 | from torch.nn import functional as F 6 | from scipy.stats import norm 7 | import numpy as np 8 | 9 | class TripletLoss(nn.Module): 10 | def __init__(self, margin=0, num_instances=0, use_semi=True): 11 | super(TripletLoss, self).__init__() 12 | self.margin = margin 13 | self.ranking_loss = nn.MarginRankingLoss(margin=self.margin) 14 | self.K = num_instances 15 | self.use_semi = use_semi 16 | 17 | def forward(self, inputs, purtub, targets, epoch): 18 | n = inputs.size(0) 19 | P = n / self.K 20 | 21 | # Compute pairwise distance, replace by the official when merged 22 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 23 | dist = dist + dist.t() 24 | dist.addmm_(1, -2, inputs, inputs.t()) 25 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 26 | # 27 | purtubeDist = torch.pow(purtub, 2).sum(dim=1, keepdim=True).expand(n, n) 28 | purtubeDist = purtubeDist + purtubeDist.t() 29 | purtubeDist.addmm_(1, -2, purtub, purtub.t()) 30 | purtubeDist = purtubeDist.clamp(min=1e-12).sqrt() 31 | # For each anchor, find the hardest positive and negative 32 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 33 | dist_ap, dist_an = [], [] 34 | if self.use_semi: 35 | for i in range(P): 36 | for j in range(self.K): 37 | neg_examples = dist[i * self.K + j][mask[i * self.K + j] == 0] 38 | for pair in range(j + 1, self.K): 39 | ap = dist[i * self.K + j][i * self.K + pair].view(1) 40 | dist_ap.append(ap) 41 | dist_an.append(neg_examples.min().view(1)) 42 | else: 43 | for i in range(n): 44 | maxLoc = dist[i][mask[i]].argmax().view(1) 45 | minLoc = dist[i][mask[i] == 0].argmin().view(1) 46 | dist_ap.append(purtubeDist[i][maxLoc].view(1)) 47 | dist_an.append(purtubeDist[i][minLoc].view(1)) 48 | dist_ap = torch.cat(dist_ap) 49 | dist_an = torch.cat(dist_an) 50 | # Compute ranking hinge loss 51 | y = dist_an.data.new() 52 | y.resize_as_(dist_an.data) 53 | y.fill_(1) 54 | y = Variable(y) 55 | # loss = self.ranking_loss(dist_an, dist_ap, y) 56 | loss = self.ranking_loss(dist_ap, dist_an, y) 57 | prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0) 58 | return loss, prec 59 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import argparse 4 | import sys 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from reid import models 8 | from torch.nn import functional as F 9 | from reid import datasets 10 | from MI_SGD import MI_SGD,keepGradUpdate 11 | from reid.utils.data import transforms as T 12 | from torchvision.transforms import Resize 13 | from reid.utils.data.preprocessor import Preprocessor 14 | from reid.evaluators import Evaluator 15 | from torch.optim.optimizer import Optimizer, required 16 | import random 17 | import numpy as np 18 | import math 19 | from reid.evaluators import extract_features 20 | from reid.utils.meters import AverageMeter 21 | import torchvision 22 | import faiss 23 | from torchvision import transforms 24 | 25 | 26 | MODE = "bilinear" 27 | 28 | def test(dataset, net, noise, args, evaluator, epoch): 29 | print(">> Evaluating network on test datasets...") 30 | 31 | net = net.cuda() 32 | net.eval() 33 | normalize = T.Normalize( 34 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 35 | ) 36 | 37 | def add_noise(img): 38 | n = noise.cpu() 39 | img = img.cpu() 40 | n = F.interpolate( 41 | n.unsqueeze(0), mode=MODE, size=tuple(img.shape[-2:]), align_corners=True 42 | ).squeeze() 43 | return torch.clamp(img + n, 0, 1) 44 | 45 | query_trans = T.Compose([ 46 | T.RectScale(args.height, args.width), 47 | T.ToTensor(), T.Lambda(lambda img: add_noise(img)), 48 | normalize 49 | ]) 50 | test_transformer = T.Compose([ 51 | T.RectScale(args.height, args.width), 52 | T.ToTensor(), normalize 53 | ]) 54 | query_loader = DataLoader( 55 | Preprocessor(dataset.query, root=dataset.images_dir, transform=query_trans), 56 | batch_size=args.batch_size, num_workers=0, shuffle=False, pin_memory=True 57 | ) 58 | gallery_loader = DataLoader( 59 | Preprocessor(dataset.gallery, root=dataset.images_dir, transform=test_transformer), 60 | batch_size=args.batch_size, num_workers=8, shuffle=False, pin_memory=True 61 | ) 62 | qFeats, gFeats, qnames, gnames = [], [], [], [] 63 | with torch.no_grad(): 64 | for (inputs, qname, _, _) in query_loader: 65 | inputs = inputs.cuda() 66 | qFeats.append(net(inputs)[0]) 67 | qnames.extend(qname) 68 | qFeats = torch.cat(qFeats, 0) 69 | for (inputs, gname, _, _) in gallery_loader: 70 | inputs = inputs.cuda() 71 | gFeats.append(net(inputs)[0]) 72 | gnames.extend(gname) 73 | gFeats = torch.cat(gFeats, 0) 74 | distMat = calDist(qFeats, gFeats) 75 | 76 | # evaluate on test datasets 77 | evaluator.evaMat(distMat, dataset.query, dataset.gallery) 78 | return 79 | 80 | def calDist(qFeat, gFeat): 81 | m, n = qFeat.size(0), gFeat.size(0) 82 | x = qFeat.view(m, -1) 83 | y = gFeat.view(n, -1) 84 | dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 85 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 86 | dist_m.addmm_(1, -2, x, y.t()) 87 | return dist_m 88 | -------------------------------------------------------------------------------- /reid/utils/data/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from PIL import Image, ImageOps, ImageEnhance 4 | import random 5 | import math 6 | from torchvision.transforms import * 7 | 8 | 9 | class RectScale(object): 10 | def __init__(self, height, width, interpolation=Image.BILINEAR): 11 | self.height = height 12 | self.width = width 13 | self.interpolation = interpolation 14 | 15 | def __call__(self, img): 16 | w, h = img.size 17 | if h == self.height and w == self.width: 18 | return img 19 | return img.resize((self.width, self.height), self.interpolation) 20 | 21 | 22 | class RandomSizedRectCrop(object): 23 | def __init__(self, height, width, interpolation=Image.BILINEAR): 24 | self.height = height 25 | self.width = width 26 | self.interpolation = interpolation 27 | 28 | def __call__(self, img): 29 | for attempt in range(10): 30 | area = img.size[0] * img.size[1] 31 | target_area = random.uniform(0.64, 1.0) * area 32 | aspect_ratio = random.uniform(2, 3) 33 | 34 | h = int(round(math.sqrt(target_area * aspect_ratio))) 35 | w = int(round(math.sqrt(target_area / aspect_ratio))) 36 | 37 | if w <= img.size[0] and h <= img.size[1]: 38 | x1 = random.randint(0, img.size[0] - w) 39 | y1 = random.randint(0, img.size[1] - h) 40 | 41 | img = img.crop((x1, y1, x1 + w, y1 + h)) 42 | assert (img.size == (w, h)) 43 | 44 | return img.resize((self.width, self.height), self.interpolation) 45 | 46 | # Fallback 47 | scale = RectScale(self.height, self.width, 48 | interpolation=self.interpolation) 49 | return scale(img) 50 | 51 | 52 | class RandomErasing(object): 53 | ''' 54 | Class that performs Random Erasing in Random Erasing Data Augmentation by Zhong et al. 55 | ------------------------------------------------------------------------------------- 56 | probability: The probability that the operation will be performed. 57 | sl: min erasing area 58 | sh: max erasing area 59 | r1: min aspect ratio 60 | mean: erasing value 61 | ------------------------------------------------------------------------------------- 62 | ''' 63 | 64 | def __init__(self, probability=0.5, sl=0.02, sh=0.2, r1=0.3, mean=[0.4914, 0.4822, 0.4465]): 65 | self.probability = probability 66 | self.mean = mean 67 | self.sl = sl 68 | self.sh = sh 69 | self.r1 = r1 70 | 71 | def __call__(self, img): 72 | 73 | if random.uniform(0, 1) > self.probability: 74 | return img 75 | 76 | for attempt in range(100): 77 | area = img.size()[1] * img.size()[2] 78 | 79 | target_area = random.uniform(self.sl, self.sh) * area 80 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 81 | 82 | h = int(round(math.sqrt(target_area * aspect_ratio))) 83 | w = int(round(math.sqrt(target_area / aspect_ratio))) 84 | 85 | if w < img.size()[2] and h < img.size()[1]: 86 | x1 = random.randint(0, img.size()[1] - h) 87 | y1 = random.randint(0, img.size()[2] - w) 88 | if img.size()[0] == 3: 89 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 90 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 91 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 92 | else: 93 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 94 | return img 95 | 96 | return img 97 | -------------------------------------------------------------------------------- /reid/models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import math 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.nn import init 8 | from torch.autograd import Variable 9 | import torchvision 10 | # from torch_deform_conv.layers import ConvOffset2D 11 | from reid.utils.serialization import load_checkpoint, save_checkpoint 12 | 13 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 14 | 'resnet152'] 15 | 16 | 17 | class ResNet(nn.Module): 18 | __factory = { 19 | 18: torchvision.models.resnet18, 20 | 34: torchvision.models.resnet34, 21 | 50: torchvision.models.resnet50, 22 | 101: torchvision.models.resnet101, 23 | 152: torchvision.models.resnet152, 24 | } 25 | 26 | def __init__(self, depth, checkpoint=None, pretrained=True, num_features=2048, 27 | dropout=0.1, num_classes=0): 28 | super(ResNet, self).__init__() 29 | 30 | self.depth = depth 31 | self.checkpoint = checkpoint 32 | self.pretrained = pretrained 33 | self.num_features = num_features 34 | self.dropout = dropout 35 | self.num_classes = num_classes 36 | 37 | if self.dropout > 0: 38 | self.drop = nn.Dropout(self.dropout) 39 | # Construct base (pretrained) resnet 40 | if depth not in ResNet.__factory: 41 | raise KeyError("Unsupported depth:", depth) 42 | self.base = ResNet.__factory[depth](pretrained=pretrained) 43 | out_planes = self.base.fc.in_features 44 | 45 | # resume from pre-iteration training 46 | if self.checkpoint: 47 | state_dict = load_checkpoint(checkpoint) 48 | self.load_state_dict(state_dict['state_dict'], strict=False) 49 | 50 | self.feat = nn.Linear(out_planes, self.num_features, bias=False) 51 | self.feat_bn = nn.BatchNorm1d(self.num_features) 52 | self.relu = nn.ReLU(inplace=True) 53 | init.normal(self.feat.weight, std=0.001) 54 | init.constant(self.feat_bn.weight, 1) 55 | init.constant(self.feat_bn.bias, 0) 56 | 57 | # x2 classifier 58 | self.classifier_x2 = nn.Linear(self.num_features, self.num_classes) 59 | init.normal(self.classifier_x2.weight, std=0.001) 60 | init.constant(self.classifier_x2.bias, 0) 61 | 62 | if not self.pretrained: 63 | self.reset_params() 64 | 65 | def forward(self, x): 66 | for name, module in self.base._modules.items(): 67 | if name == 'avgpool': 68 | break 69 | x = module(x) 70 | 71 | x1 = F.avg_pool2d(x, x.size()[2:]) 72 | x1 = x1.view(x1.size(0), -1) 73 | x2 = self.feat(x1) 74 | x2 = self.feat_bn(x2) 75 | x2 = self.relu(x2) 76 | x2 = self.drop(x2) 77 | x2 = self.classifier_x2(x2) 78 | return x1, x2 79 | 80 | def reset_params(self): 81 | for m in self.modules(): 82 | if isinstance(m, nn.Conv2d): 83 | init.normal(m.weight, std=0.001) 84 | if m.bias is not None: 85 | init.constant(m.bias, 0) 86 | elif isinstance(m, nn.BatchNorm2d): 87 | init.constant(m.weight, 1) 88 | init.constant(m.bias, 0) 89 | elif isinstance(m, nn.Linear): 90 | init.normal(m.weight, std=0.001) 91 | if m.bias is not None: 92 | init.constant(m.bias, 0) 93 | 94 | 95 | def resnet18(**kwargs): 96 | return ResNet(18, **kwargs) 97 | 98 | 99 | def resnet34(**kwargs): 100 | return ResNet(34, **kwargs) 101 | 102 | 103 | def resnet50(**kwargs): 104 | return ResNet(50, **kwargs) 105 | 106 | 107 | def resnet101(**kwargs): 108 | return ResNet(101, **kwargs) 109 | 110 | 111 | def resnet152(**kwargs): 112 | return ResNet(152, **kwargs) 113 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/ranking.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | from sklearn.metrics import average_precision_score 6 | 7 | from ..utils import to_numpy 8 | 9 | 10 | def _unique_sample(ids_dict, num): 11 | # mask = np.zeros(num, dtype=np.bool) 12 | mask = np.zeros(num, dtype=bool) 13 | for _, indices in ids_dict.items(): 14 | i = np.random.choice(indices) 15 | mask[i] = True 16 | return mask 17 | 18 | 19 | def cmc(distmat, query_ids=None, gallery_ids=None, 20 | query_cams=None, gallery_cams=None, topk=100, 21 | separate_camera_set=False, 22 | single_gallery_shot=False, 23 | first_match_break=False): 24 | distmat = to_numpy(distmat) 25 | m, n = distmat.shape 26 | # Fill up default values 27 | if query_ids is None: 28 | query_ids = np.arange(m) 29 | if gallery_ids is None: 30 | gallery_ids = np.arange(n) 31 | if query_cams is None: 32 | query_cams = np.zeros(m).astype(np.int32) 33 | if gallery_cams is None: 34 | gallery_cams = np.ones(n).astype(np.int32) 35 | # Ensure numpy array 36 | query_ids = np.asarray(query_ids) 37 | gallery_ids = np.asarray(gallery_ids) 38 | query_cams = np.asarray(query_cams) 39 | gallery_cams = np.asarray(gallery_cams) 40 | # Sort and find correct matches 41 | indices = np.argsort(distmat, axis=1) 42 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 43 | # Compute CMC for each query 44 | ret = np.zeros(topk) 45 | num_valid_queries = 0 46 | for i in range(m): 47 | # Filter out the same id and same camera 48 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 49 | (gallery_cams[indices[i]] != query_cams[i])) 50 | if separate_camera_set: 51 | # Filter out samples from same camera 52 | valid &= (gallery_cams[indices[i]] != query_cams[i]) 53 | if not np.any(matches[i, valid]): continue 54 | if single_gallery_shot: 55 | repeat = 10 56 | gids = gallery_ids[indices[i][valid]] 57 | inds = np.where(valid)[0] 58 | ids_dict = defaultdict(list) 59 | for j, x in zip(inds, gids): 60 | ids_dict[x].append(j) 61 | else: 62 | repeat = 1 63 | for _ in range(repeat): 64 | if single_gallery_shot: 65 | # Randomly choose one instance for each id 66 | sampled = (valid & _unique_sample(ids_dict, len(valid))) 67 | index = np.nonzero(matches[i, sampled])[0] 68 | else: 69 | index = np.nonzero(matches[i, valid])[0] 70 | delta = 1. / (len(index) * repeat) 71 | for j, k in enumerate(index): 72 | if k - j >= topk: break 73 | if first_match_break: 74 | ret[k - j] += 1 75 | break 76 | ret[k - j] += delta 77 | num_valid_queries += 1 78 | if num_valid_queries == 0: 79 | raise RuntimeError("No valid query") 80 | return ret.cumsum() / num_valid_queries 81 | 82 | 83 | def mean_ap(distmat, query_ids=None, gallery_ids=None, 84 | query_cams=None, gallery_cams=None): 85 | distmat = to_numpy(distmat) 86 | m, n = distmat.shape 87 | # Fill up default values 88 | if query_ids is None: 89 | query_ids = np.arange(m) 90 | if gallery_ids is None: 91 | gallery_ids = np.arange(n) 92 | if query_cams is None: 93 | query_cams = np.zeros(m).astype(np.int32) 94 | if gallery_cams is None: 95 | gallery_cams = np.ones(n).astype(np.int32) 96 | # Ensure numpy array 97 | query_ids = np.asarray(query_ids) 98 | gallery_ids = np.asarray(gallery_ids) 99 | query_cams = np.asarray(query_cams) 100 | gallery_cams = np.asarray(gallery_cams) 101 | # Sort and find correct matches 102 | indices = np.argsort(distmat, axis=1) 103 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 104 | # Compute AP for each query 105 | aps = [] 106 | for i in range(m): 107 | # Filter out the same id and same camera 108 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 109 | (gallery_cams[indices[i]] != query_cams[i])) 110 | y_true = matches[i, valid] 111 | y_score = -distmat[i][indices[i]][valid] 112 | if not np.any(y_true): continue 113 | aps.append(average_precision_score(y_true, y_score)) 114 | if len(aps) == 0: 115 | raise RuntimeError("No valid query") 116 | return np.mean(aps) 117 | -------------------------------------------------------------------------------- /reid/evaluators.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | from collections import OrderedDict 4 | from collections import namedtuple 5 | import torch 6 | from .evaluation_metrics import cmc, mean_ap 7 | from .feature_extraction import extract_cnn_feature, extract_pcb_feature 8 | from .utils.meters import AverageMeter 9 | 10 | def extract_features(model, data_loader, print_freq=1, metric=None): 11 | model.eval() 12 | batch_time = AverageMeter() 13 | data_time = AverageMeter() 14 | features = OrderedDict() 15 | labels = OrderedDict() 16 | end = time.time() 17 | for i, (imgs, fnames, pids, _) in enumerate(data_loader): 18 | data_time.update(time.time() - end) 19 | 20 | outputs = extract_cnn_feature(model, imgs) 21 | for fname, output, pid in zip(fnames, outputs, pids): 22 | features[fname] = output 23 | labels[fname] = pid 24 | batch_time.update(time.time() - end) 25 | end = time.time() 26 | if (i + 1) % print_freq == 0: 27 | print('Extract Features: [{}/{}]\t\t' 28 | 'Batch Loader Time {:.3f} ({:.3f})\t\t' 29 | 'Data Loader Time {:.3f} ({:.3f})\t\t' 30 | .format(i + 1, len(data_loader), 31 | batch_time.val, batch_time.avg, 32 | data_time.val, data_time.avg)) 33 | return features, labels 34 | def pairwise_distance(features, query=None, gallery=None, metric=None): 35 | if query is None and gallery is None: 36 | n = len(features) 37 | x = torch.cat(list(features.values())) 38 | x = x.view(n, -1) 39 | if metric is not None: 40 | x = metric.transform(x) 41 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True) * 2 42 | dist = dist.expand(n, n) - 2 * torch.mm(x, x.t()) 43 | return dist 44 | 45 | x = torch.cat([features[f].unsqueeze(0) for f, _, _ in query], 0) 46 | y = torch.cat([features[f].unsqueeze(0) for f, _, _ in gallery], 0) 47 | m, n = x.size(0), y.size(0) 48 | x = x.view(m, -1) 49 | y = y.view(n, -1) 50 | if metric is not None: 51 | x = metric.transform(x) 52 | y = metric.transform(y) 53 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 54 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 55 | dist.addmm_(1, -2, x, y.t()) 56 | return dist 57 | def evaluate_all(distmat, query=None, gallery=None, 58 | query_ids=None, gallery_ids=None, 59 | query_cams=None, gallery_cams=None, 60 | cmc_topk=(1, 10,20)): 61 | if query is not None and gallery is not None: 62 | query_ids = [pid for _, pid, _ in query] 63 | gallery_ids = [pid for _, pid, _ in gallery] 64 | query_cams = [cam for _, _, cam in query] 65 | gallery_cams = [cam for _, _, cam in gallery] 66 | else: 67 | assert (query_ids is not None and gallery_ids is not None 68 | and query_cams is not None and gallery_cams is not None) 69 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 70 | print('Mean AP: {:4.2%}'.format(mAP)) 71 | cmc_configs = { 72 | 'score': dict(separate_camera_set=False, 73 | single_gallery_shot=False, 74 | first_match_break=True)} 75 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 76 | query_cams, gallery_cams, **params) 77 | for name, params in cmc_configs.items()} 78 | print('CMC Scores{:>12}' 79 | .format('score')) 80 | 81 | rank_score = namedtuple( 82 | 'rank_score', 83 | ['map', 'score'], 84 | ) 85 | for k in cmc_topk: 86 | print(' top-{:<4}{:12.2%}' 87 | .format(k, 88 | cmc_scores['score'][k - 1])) 89 | score = rank_score( 90 | mAP, 91 | cmc_scores['score'], 92 | ) 93 | return score 94 | class Evaluator(object): 95 | def __init__(self, model, print_freq=1): 96 | super(Evaluator, self).__init__() 97 | self.model = model 98 | self.print_freq = print_freq 99 | 100 | def evaluate(self, data_loader, query, gallery, metric=None): 101 | features, _ = extract_features(self.model, data_loader, print_freq=self.print_freq) 102 | distmat = pairwise_distance(features, query, gallery, metric=metric) 103 | return evaluate_all(distmat, query=query, gallery=gallery) 104 | 105 | def evaMat(self, distMat, query, gallery, saveRank=False, root=None): 106 | return evaluate_all(distMat, query=query, gallery=gallery) 107 | -------------------------------------------------------------------------------- /reid/utils/data/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os.path as osp 3 | 4 | import numpy as np 5 | 6 | from ..serialization import read_json 7 | 8 | 9 | def _pluck(identities, indices, relabel=False): 10 | ret = [] 11 | for index, pid in enumerate(indices): 12 | pid_images = identities[pid] 13 | for camid, cam_images in enumerate(pid_images): 14 | for fname in cam_images: 15 | name = osp.splitext(fname)[0] 16 | x, y, _ = map(int, name.split('_')) 17 | assert pid == x and camid == y 18 | if relabel: 19 | ret.append((fname, index, camid)) 20 | else: 21 | ret.append((fname, pid, camid)) 22 | return ret 23 | 24 | def _pluck_gallery(identities, indices, relabel=False): 25 | ret = [] 26 | for index, pid in enumerate(indices): 27 | pid_images = identities[pid] 28 | for camid, cam_images in enumerate(pid_images): 29 | if len(cam_images[:-1])==0: 30 | for fname in cam_images: 31 | name = osp.splitext(fname)[0] 32 | x, y, _ = map(int, name.split('_')) 33 | assert pid == x and camid == y 34 | if relabel: 35 | ret.append((fname, index, camid)) 36 | else: 37 | ret.append((fname, pid, camid)) 38 | else: 39 | for fname in cam_images[:-1]: 40 | name = osp.splitext(fname)[0] 41 | x, y, _ = map(int, name.split('_')) 42 | assert pid == x and camid == y 43 | if relabel: 44 | ret.append((fname, index, camid)) 45 | else: 46 | ret.append((fname, pid, camid)) 47 | return ret 48 | 49 | def _pluck_query(identities, indices, relabel=False): 50 | ret = [] 51 | for index, pid in enumerate(indices): 52 | pid_images = identities[pid] 53 | for camid, cam_images in enumerate(pid_images): 54 | for fname in cam_images[-1:]: 55 | name = osp.splitext(fname)[0] 56 | x, y, _ = map(int, name.split('_')) 57 | assert pid == x and camid == y 58 | if relabel: 59 | ret.append((fname, index, camid)) 60 | else: 61 | ret.append((fname, pid, camid)) 62 | return ret 63 | 64 | 65 | class Dataset(object): 66 | def __init__(self, root, split_id=0): 67 | self.root = root 68 | self.split_id = split_id 69 | self.meta = None 70 | self.split = None 71 | self.train, self.val, self.trainval = [], [], [] 72 | self.query, self.gallery = [], [] 73 | self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0 74 | 75 | @property 76 | def images_dir(self): 77 | return osp.join(self.root, 'images') 78 | 79 | def load(self, num_val=0.3, verbose=True): 80 | splits = read_json(osp.join(self.root, 'splits.json')) 81 | if self.split_id >= len(splits): 82 | raise ValueError("split_id exceeds total splits {}" 83 | .format(len(splits))) 84 | self.split = splits[self.split_id] 85 | 86 | # Randomly split train / val 87 | trainval_pids = np.asarray(self.split['trainval']) 88 | np.random.shuffle(trainval_pids) 89 | num = len(trainval_pids) 90 | if isinstance(num_val, float): 91 | num_val = int(round(num * num_val)) 92 | if num_val >= num or num_val < 0: 93 | raise ValueError("num_val exceeds total identities {}" 94 | .format(num)) 95 | train_pids = sorted(trainval_pids[:-num_val]) 96 | val_pids = sorted(trainval_pids[-num_val:]) 97 | 98 | self.meta = read_json(osp.join(self.root, 'meta.json')) 99 | identities = self.meta['identities'] 100 | self.train = _pluck(identities, train_pids, relabel=True) 101 | self.val = _pluck(identities, val_pids, relabel=True) 102 | self.trainval = _pluck(identities, trainval_pids, relabel=True) 103 | self.query = _pluck_query(identities, self.split['query']) 104 | #self.gallery = _pluck(identities, self.split['gallery']) 105 | self.gallery = _pluck_gallery(identities, self.split['gallery']) 106 | self.num_train_ids = len(train_pids) 107 | self.num_val_ids = len(val_pids) 108 | self.num_trainval_ids = len(trainval_pids) 109 | 110 | if verbose: 111 | print(self.__class__.__name__, "dataset loaded") 112 | print(" subset | # ids | # images") 113 | print(" ---------------------------") 114 | print(" train | {:5d} | {:8d}" 115 | .format(self.num_train_ids, len(self.train))) 116 | print(" val | {:5d} | {:8d}" 117 | .format(self.num_val_ids, len(self.val))) 118 | print(" trainval | {:5d} | {:8d}" 119 | .format(self.num_trainval_ids, len(self.trainval))) 120 | print(" query | {:5d} | {:8d}" 121 | .format(len(self.split['query']), len(self.query))) 122 | print(" gallery | {:5d} | {:8d}" 123 | .format(len(self.split['gallery']), len(self.gallery))) 124 | 125 | def _check_integrity(self): 126 | return osp.isdir(osp.join(self.root, 'images')) and \ 127 | osp.isfile(osp.join(self.root, 'meta.json')) and \ 128 | osp.isfile(osp.join(self.root, 'splits.json')) 129 | -------------------------------------------------------------------------------- /MI_SGD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import argparse 4 | import sys 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from reid import models 8 | from torch.nn import functional as F 9 | from reid import datasets 10 | from reid.utils.data import transforms as T 11 | from torchvision.transforms import Resize 12 | from reid.utils.data.preprocessor import Preprocessor 13 | from reid.evaluators import Evaluator 14 | from torch.optim.optimizer import Optimizer, required 15 | import random 16 | import numpy as np 17 | import math 18 | from reid.evaluators import extract_features 19 | from reid.utils.meters import AverageMeter 20 | import torchvision 21 | import faiss 22 | from torchvision import transforms 23 | 24 | CHECK = 1e-5 25 | SAT_MIN = 0.5 26 | 27 | 28 | class MI_SGD(Optimizer): 29 | def __init__( 30 | self, params, lr=required, momentum=0, dampening=0, weight_decay=0, 31 | nesterov=False, max_eps=10 / 255 32 | ): 33 | if lr is not required and lr < 0.0: 34 | raise ValueError("Error learning rate: {}".format(lr)) 35 | if momentum < 0.0: 36 | raise ValueError("Error momentum: {}".format(momentum)) 37 | if weight_decay < 0.0: 38 | raise ValueError("Error weight_decay: {}".format(weight_decay)) 39 | 40 | defaults = dict( 41 | lr=lr, 42 | momentum=momentum, 43 | dampening=dampening, 44 | weight_decay=weight_decay, 45 | nesterov=nesterov, 46 | sign=False, 47 | ) 48 | if nesterov and (momentum <= 0 or dampening != 0): 49 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 50 | super(MI_SGD, self).__init__(params, defaults) 51 | self.sat = 0 52 | self.sat_prev = 0 53 | self.max_eps = max_eps 54 | 55 | def __setstate__(self, state): 56 | super(MI_SGD, self).__setstate__(state) 57 | for group in self.param_groups: 58 | group.setdefault("nesterov", False) 59 | 60 | 61 | def rescale(self, ): 62 | for group in self.param_groups: 63 | if not group["sign"]: 64 | continue 65 | for p in group["params"]: 66 | self.sat_prev = self.sat 67 | self.sat = (p.data.abs() >= self.max_eps).sum().item() / p.data.numel() 68 | sat_change = abs(self.sat - self.sat_prev) 69 | if rescale_check(CHECK, self.sat, sat_change, SAT_MIN): 70 | print('rescaled') 71 | p.data = p.data / 2 72 | 73 | def step(self, closure=None): 74 | loss = None 75 | if closure is not None: 76 | loss = closure() 77 | 78 | for group in self.param_groups: 79 | weight_decay = group["weight_decay"] 80 | momentum = group["momentum"] 81 | dampening = group["dampening"] 82 | nesterov = group["nesterov"] 83 | 84 | for p in group["params"]: 85 | if p.grad is None: 86 | continue 87 | d_p = p.grad.data 88 | if group["sign"]: 89 | d_p = d_p / (d_p.norm(1) + 1e-12) 90 | if weight_decay != 0: 91 | d_p.add_(weight_decay, p.data) 92 | if momentum != 0: 93 | param_state = self.state[p] 94 | if "momentum_buffer" not in param_state: 95 | buf = param_state["momentum_buffer"] = torch.zeros_like(p.data) 96 | buf.mul_(momentum).add_(d_p) 97 | else: 98 | buf = param_state["momentum_buffer"] 99 | buf.mul_(momentum).add_(1 - dampening, d_p) 100 | if nesterov: 101 | d_p = d_p.add(momentum, buf) 102 | else: 103 | d_p = buf 104 | 105 | if group["sign"]: 106 | p.data.add_(-group["lr"], d_p.sign()) 107 | p.data = torch.clamp(p.data, -self.max_eps, self.max_eps) 108 | else: 109 | p.data.add_(-group["lr"], d_p) 110 | 111 | return loss 112 | 113 | def rescale_check(check, sat, sat_change, sat_min): 114 | return sat_change < check and sat > sat_min 115 | 116 | def keepGradUpdate(noiseData, optimizer, gradInfo, max_eps): 117 | weight_decay = optimizer.param_groups[0]["weight_decay"] 118 | momentum = optimizer.param_groups[0]["momentum"] 119 | dampening = optimizer.param_groups[0]["dampening"] 120 | nesterov = optimizer.param_groups[0]["nesterov"] 121 | lr = optimizer.param_groups[0]["lr"] 122 | 123 | d_p = gradInfo 124 | if optimizer.param_groups[0]["sign"]: 125 | d_p = d_p / (d_p.norm(1) + 1e-12) 126 | if weight_decay != 0: 127 | d_p.add_(weight_decay, noiseData) 128 | if momentum != 0: 129 | param_state = optimizer.state[noiseData] 130 | if "momentum_buffer" not in param_state: 131 | buf = param_state["momentum_buffer"] = torch.zeros_like(noiseData.data) 132 | # buf.mul_(momentum).add_(d_p) 133 | buf = buf * momentum + d_p 134 | else: 135 | buf = param_state["momentum_buffer"] 136 | buf = buf * momentum + (1 - dampening) * d_p 137 | if nesterov: 138 | d_p = d_p + momentum * buf 139 | else: 140 | d_p = buf 141 | 142 | if optimizer.param_groups[0]["sign"]: 143 | noiseData = noiseData - lr * d_p.sign() 144 | noiseData = torch.clamp(noiseData, -max_eps, max_eps) 145 | else: 146 | noiseData = noiseData - lr * d_p.sign() 147 | 148 | 149 | return noiseData 150 | 151 | -------------------------------------------------------------------------------- /reid/models/DDAG.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from torchvision import models 5 | from torch.autograd import Variable 6 | from .resnet2 import resnet50, resnet18 7 | import torch.nn.functional as F 8 | import math 9 | from .attention import GraphAttentionLayer, IWPA 10 | 11 | class Normalize(nn.Module): 12 | def __init__(self, power=2): 13 | super(Normalize, self).__init__() 14 | self.power = power 15 | 16 | def forward(self, x): 17 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 18 | out = x.div(norm) 19 | return out 20 | 21 | def weights_init_kaiming(m): 22 | classname = m.__class__.__name__ 23 | if classname.find('Conv') != -1: 24 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 25 | elif classname.find('Linear') != -1: 26 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 27 | init.zeros_(m.bias.data) 28 | elif classname.find('BatchNorm1d') != -1: 29 | init.normal_(m.weight.data, 1.0, 0.01) 30 | init.zeros_(m.bias.data) 31 | 32 | 33 | def weights_init_classifier(m): 34 | classname = m.__class__.__name__ 35 | if classname.find('Linear') != -1: 36 | init.normal_(m.weight.data, 0, 0.001) 37 | if m.bias: 38 | init.zeros_(m.bias.data) 39 | 40 | # Defines the new fc layer and classification layer 41 | # |--Linear--|--bn--|--relu--|--Linear--| 42 | class FeatureBlock(nn.Module): 43 | def __init__(self, input_dim, low_dim, dropout=0.5, relu=True): 44 | super(FeatureBlock, self).__init__() 45 | feat_block = [] 46 | feat_block += [nn.Linear(input_dim, low_dim)] 47 | feat_block += [nn.BatchNorm1d(low_dim)] 48 | 49 | feat_block = nn.Sequential(*feat_block) 50 | feat_block.apply(weights_init_kaiming) 51 | self.feat_block = feat_block 52 | 53 | def forward(self, x): 54 | x = self.feat_block(x) 55 | return x 56 | 57 | 58 | class ClassBlock(nn.Module): 59 | def __init__(self, input_dim, class_num, dropout=0.5, relu=True): 60 | super(ClassBlock, self).__init__() 61 | classifier = [] 62 | if relu: 63 | classifier += [nn.LeakyReLU(0.1)] 64 | if dropout: 65 | classifier += [nn.Dropout(p=dropout)] 66 | 67 | classifier += [nn.Linear(input_dim, class_num)] 68 | classifier = nn.Sequential(*classifier) 69 | classifier.apply(weights_init_classifier) 70 | 71 | self.classifier = classifier 72 | 73 | def forward(self, x): 74 | x = self.classifier(x) 75 | return x 76 | 77 | class visible_module(nn.Module): 78 | def __init__(self, arch='resnet50'): 79 | super(visible_module, self).__init__() 80 | 81 | model_v = resnet50(pretrained=True, 82 | last_conv_stride=1, last_conv_dilation=1) 83 | # avg pooling to global pooling 84 | self.visible = model_v 85 | 86 | def forward(self, x): 87 | x = self.visible.conv1(x) 88 | x = self.visible.bn1(x) 89 | x = self.visible.relu(x) 90 | x = self.visible.maxpool(x) 91 | return x 92 | 93 | 94 | class thermal_module(nn.Module): 95 | def __init__(self, arch='resnet50'): 96 | super(thermal_module, self).__init__() 97 | 98 | model_t = resnet50(pretrained=True, 99 | last_conv_stride=1, last_conv_dilation=1) 100 | # avg pooling to global pooling 101 | self.thermal = model_t 102 | 103 | def forward(self, x): 104 | x = self.thermal.conv1(x) 105 | x = self.thermal.bn1(x) 106 | x = self.thermal.relu(x) 107 | x = self.thermal.maxpool(x) 108 | return x 109 | 110 | 111 | class base_resnet(nn.Module): 112 | def __init__(self, arch='resnet50'): 113 | super(base_resnet, self).__init__() 114 | 115 | model_base = resnet50(pretrained=True, 116 | last_conv_stride=1, last_conv_dilation=1) 117 | # avg pooling to global pooling 118 | model_base.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 119 | self.base = model_base 120 | 121 | def forward(self, x): 122 | x = self.base.layer1(x) 123 | x = self.base.layer2(x) 124 | x = self.base.layer3(x) 125 | x = self.base.layer4(x) 126 | return x 127 | 128 | class embed_net2(nn.Module): 129 | def __init__(self, low_dim=512, drop=0.2, part=3, alpha=0.2, nheads=4, arch='resnet50', wpa=False,pretrained=True,num_classes=0): 130 | super(embed_net2, self).__init__() 131 | 132 | self.thermal_module = thermal_module(arch=arch) 133 | self.visible_module = visible_module(arch=arch) 134 | self.base_resnet = base_resnet(arch=arch) 135 | pool_dim = 2048 136 | self.dropout = drop 137 | self.part = part 138 | self.lpa = wpa 139 | 140 | self.l2norm = Normalize(2) 141 | self.bottleneck = nn.BatchNorm1d(pool_dim) 142 | self.bottleneck.bias.requires_grad_(False) # no shift 143 | 144 | self.classifier = nn.Linear(pool_dim, num_classes, bias=False) 145 | 146 | self.classifier1 = nn.Linear(pool_dim, num_classes, bias=False) 147 | self.classifier2 = nn.Linear(pool_dim, num_classes, bias=False) 148 | 149 | self.bottleneck.apply(weights_init_kaiming) 150 | self.classifier.apply(weights_init_classifier) 151 | self.classifier1.apply(weights_init_classifier) 152 | self.classifier2.apply(weights_init_classifier) 153 | 154 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 155 | self.wpa = IWPA(pool_dim, part) 156 | 157 | 158 | self.attentions = [GraphAttentionLayer(pool_dim, low_dim, dropout=drop, alpha=alpha, concat=True) for _ in range(nheads)] 159 | for i, attention in enumerate(self.attentions): 160 | self.add_module('attention_{}'.format(i), attention) 161 | 162 | self.out_att = GraphAttentionLayer(low_dim * nheads, num_classes, dropout=drop, alpha=alpha, concat=False) 163 | 164 | def forward(self, x1, adj=0, modal=1, cpa=False): 165 | if modal == 1: 166 | x = self.visible_module(x1) 167 | elif modal == 2: 168 | x = self.thermal_module(x2) 169 | x = self.base_resnet(x) 170 | x_pool = self.avgpool(x) 171 | x_pool = x_pool.view(x_pool.size(0), x_pool.size(1)) 172 | feat = self.bottleneck(x_pool) 173 | 174 | feat_att = self.wpa(x, feat, 1, self.part) 175 | 176 | 177 | if self.training: 178 | x_g = F.dropout(x_pool, self.dropout, training=self.training) 179 | x_g = torch.cat([att(x_g, adj) for att in self.attentions], dim=1) 180 | x_g = F.dropout(x_g, self.dropout, training=self.training) 181 | x_g = F.elu(self.out_att(x_g, adj)) 182 | return x_pool, self.classifier(feat) 183 | else: 184 | return x_pool, self.classifier(feat) -------------------------------------------------------------------------------- /reid/datasets/sysu.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import numpy as np 4 | from ..utils.data import Dataset 5 | from ..utils.osutils import mkdir_if_missing 6 | from ..utils.serialization import read_json 7 | from ..utils.serialization import write_json 8 | 9 | 10 | def _pluck(identities, indices, relabel=False): 11 | ret = [] 12 | for index, pid in enumerate(indices): 13 | pid_images = identities[pid] 14 | for camid, cam_images in enumerate(pid_images): 15 | for fname in cam_images: 16 | name = osp.splitext(fname)[0] 17 | x, y, _ = map(int, name.split('_')) 18 | assert pid == x and camid == y 19 | if relabel: 20 | ret.append((fname, index, camid)) 21 | else: 22 | ret.append((fname, pid, camid)) 23 | return ret 24 | 25 | class Sysu(Dataset): 26 | url = 'https://drive.google.com/file/5fh-rUzbwVRfdgs453ytzWG9COHM/view' 27 | md5 = '34005ab7d134rgs4eeafe81gr433' 28 | 29 | def __init__(self, root, split_id=0, num_val=96, download=True): 30 | super(Sysu, self).__init__(root, split_id=split_id) 31 | 32 | if download: 33 | self.download() 34 | 35 | if not self._check_integrity(): 36 | raise RuntimeError("Dataset not found or corrupted. " + 37 | "You can use download=True to download it.") 38 | 39 | self.load(num_val) 40 | 41 | def download(self): 42 | if self._check_integrity(): 43 | return 44 | 45 | import re 46 | import hashlib 47 | import shutil 48 | from glob import glob 49 | from zipfile import ZipFile 50 | 51 | raw_dir = osp.join(self.root, 'raw') 52 | mkdir_if_missing(raw_dir) 53 | 54 | # Download the raw zip file 55 | fpath = osp.join(raw_dir, 'sysu_v2.zip') 56 | 57 | if osp.isfile(fpath): 58 | print("Using downloaded file: " + fpath) 59 | else: 60 | raise RuntimeError("Please download the dataset manually from {} " 61 | "to {}".format(self.url, fpath)) 62 | 63 | # Extract the file 64 | exdir = raw_dir 65 | if not osp.isdir(exdir): 66 | print("Extracting zip file") 67 | with ZipFile(fpath) as z: 68 | z.extractall(path=raw_dir) 69 | 70 | # Format 71 | images_dir = osp.join(self.root, 'images') 72 | mkdir_if_missing(images_dir) 73 | 74 | # XX identities (+1 for background) with X camera views each 75 | identities = [[[] for _ in range(6)] for _ in range(534)] 76 | # print("identities=",identities) 77 | 78 | def register(subdir, pattern=re.compile(r'([-\d]+)_c(\d)')): 79 | fnames = [] 80 | fpaths = sorted(glob(osp.join(exdir, subdir, '*.jpg'))) # 81 | pids = set() 82 | for fpath in fpaths: 83 | fname = osp.basename(fpath) 84 | pid, cam = map(int, pattern.search(fname).groups()) 85 | if pid == -1: continue # 86 | assert 0 <= pid <= 534 # 87 | assert 1 <= cam <= 6 88 | cam -= 1 89 | pids.add(pid) 90 | fname = ('{:08d}_{:02d}_{:04d}.jpg' 91 | .format(pid, cam, len(identities[pid][cam]))) 92 | identities[pid][cam].append(fname) 93 | shutil.copy(fpath, osp.join(images_dir, fname)) 94 | fnames.append(fname) 95 | return pids, fnames 96 | 97 | trainval_pids, _ = register('bounding_box_train') 98 | gallery_pids, gallery_fnames = register('bounding_box_test') 99 | query_pids, query_fnames = register('query') 100 | assert query_pids <= gallery_pids 101 | 102 | # Save meta information into a json file 103 | meta = {'name': 'SYSU', 'shot': 'multiple', 'num_cameras': 6, 104 | 'identities': identities, 105 | 'query_fnames': query_fnames, 106 | 'gallery_fnames': gallery_fnames} 107 | write_json(meta, osp.join(self.root, 'meta.json')) 108 | 109 | # Save the only training / test split 110 | splits = [{ 111 | 'trainval': sorted(list(trainval_pids)), 112 | 'query': sorted(list(query_pids)), 113 | 'gallery': sorted(list(gallery_pids))}] 114 | write_json(splits, osp.join(self.root, 'splits.json')) 115 | 116 | def load(self, num_val=0.3, verbose=True): 117 | splits = read_json(osp.join(self.root, 'splits.json')) 118 | if self.split_id >= len(splits): 119 | raise ValueError("split_id exceeds total splits {}" 120 | .format(len(splits))) 121 | self.split = splits[self.split_id] 122 | 123 | # Randomly split train / val 124 | trainval_pids = np.asarray(self.split['trainval']) 125 | np.random.shuffle(trainval_pids) 126 | num = len(trainval_pids) 127 | if isinstance(num_val, float): 128 | num_val = int(round(num * num_val)) 129 | if num_val >= num or num_val < 0: 130 | raise ValueError("num_val exceeds total identities {}" 131 | .format(num)) 132 | train_pids = sorted(trainval_pids[:-num_val]) 133 | val_pids = sorted(trainval_pids[-num_val:]) 134 | 135 | self.meta = read_json(osp.join(self.root, 'meta.json')) 136 | identities = self.meta['identities'] 137 | 138 | self.train = _pluck(identities, train_pids, relabel=True) 139 | self.val = _pluck(identities, val_pids, relabel=True) 140 | self.trainval = _pluck(identities, trainval_pids, relabel=True) 141 | self.num_train_ids = len(train_pids) 142 | self.num_val_ids = len(val_pids) 143 | self.num_trainval_ids = len(trainval_pids) 144 | 145 | query_fnames = self.meta['query_fnames'] 146 | gallery_fnames = self.meta['gallery_fnames'] 147 | self.query = [] 148 | for fname in query_fnames: 149 | name = osp.splitext(fname)[0] 150 | pid, cam, _ = map(int, name.split('_')) 151 | self.query.append((fname, pid, cam)) 152 | self.gallery = [] 153 | for fname in gallery_fnames: 154 | name = osp.splitext(fname)[0] 155 | pid, cam, _ = map(int, name.split('_')) 156 | self.gallery.append((fname, pid, cam)) 157 | 158 | if verbose: 159 | print(self.__class__.__name__, "dataset loaded") 160 | print(" Subset | # ids | # images") 161 | print(" ---------------------------") 162 | print(" Train | {:5d} | {:8d}" 163 | .format(self.num_train_ids, len(self.train))) 164 | print(" Val | {:5d} | {:8d}" 165 | .format(self.num_val_ids, len(self.val))) 166 | print(" Trainval | {:5d} | {:8d}" 167 | .format(self.num_trainval_ids, len(self.trainval))) 168 | print(" Query | {:5d} | {:8d}" 169 | .format(len(self.split['query']), len(self.query))) 170 | print(" Gallery | {:5d} | {:8d}" 171 | .format(len(self.split['gallery']), len(self.gallery))) 172 | 173 | 174 | -------------------------------------------------------------------------------- /reid/datasets/regdb.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import numpy as np 4 | from ..utils.data import Dataset 5 | from ..utils.osutils import mkdir_if_missing 6 | from ..utils.serialization import read_json 7 | from ..utils.serialization import write_json 8 | 9 | def _pluck(identities, indices, relabel=False): 10 | ret = [] 11 | for index, pid in enumerate(indices): 12 | pid_images = identities[pid] 13 | for camid, cam_images in enumerate(pid_images): 14 | for fname in cam_images: 15 | name = osp.splitext(fname)[0] 16 | x, y, _ = map(int, name.split('_')) 17 | assert pid == x and camid == y 18 | if relabel: 19 | ret.append((fname, index, camid)) 20 | else: 21 | ret.append((fname, pid, camid)) 22 | return ret 23 | 24 | 25 | class Regdb(Dataset): 26 | url = 'https://drive.google.com/file/5fh-rUzbwVRfdgs453ytzWG9COHM/view' 27 | md5 = '34005ab7d134rgs4eeafe81gr433' 28 | 29 | def __init__(self, root, split_id=0, num_val=30, download=True): 30 | super(Regdb, self).__init__(root, split_id=split_id) 31 | 32 | if download: 33 | self.download() 34 | 35 | if not self._check_integrity(): 36 | raise RuntimeError("Dataset not found or corrupted. " + 37 | "You can use download=True to download it.") 38 | 39 | self.load(num_val) 40 | 41 | def download(self): 42 | if self._check_integrity(): 43 | return 44 | 45 | import re 46 | import hashlib 47 | import shutil 48 | from glob import glob 49 | from zipfile import ZipFile 50 | 51 | raw_dir = osp.join(self.root, 'raw') 52 | mkdir_if_missing(raw_dir) 53 | 54 | # Download the raw zip file 55 | fpath = osp.join(raw_dir, 'regdb_v2.zip') 56 | 57 | if osp.isfile(fpath): 58 | print("Using downloaded file: " + fpath) 59 | else: 60 | raise RuntimeError("Please download the dataset manually from {} " 61 | "to {}".format(self.url, fpath)) 62 | 63 | exdir = raw_dir 64 | if not osp.isdir(exdir): 65 | print("Extracting zip file") 66 | with ZipFile(fpath) as z: 67 | z.extractall(path=raw_dir) 68 | 69 | # Format 70 | images_dir = osp.join(self.root, 'images') 71 | mkdir_if_missing(images_dir) 72 | 73 | # 412 identities (+1 for background) with 3 camera views each 74 | identities = [[[] for _ in range(3)] for _ in range(413)] 75 | # print("identities=",identities) 76 | 77 | def register(subdir, pattern=re.compile(r'([-\d]+)_c(\d)')): 78 | fnames = [] 79 | fpaths = sorted(glob(osp.join(exdir, subdir, '*.bmp'))) # 80 | pids = set() 81 | for fpath in fpaths: 82 | fname = osp.basename(fpath) 83 | pid, cam = map(int, pattern.search(fname).groups()) 84 | if pid == -1: continue 85 | assert 0 <= pid <= 413 86 | assert 1 <= cam <= 3 87 | cam -= 1 88 | pids.add(pid) 89 | fname = ('{:08d}_{:02d}_{:04d}.bmp' 90 | .format(pid, cam, len(identities[pid][cam]))) 91 | identities[pid][cam].append(fname) 92 | shutil.copy(fpath, osp.join(images_dir, fname)) 93 | fnames.append(fname) 94 | return pids, fnames 95 | 96 | trainval_pids, _ = register('bounding_box_train') 97 | gallery_pids, gallery_fnames = register('bounding_box_test') 98 | query_pids, query_fnames = register('query') 99 | assert query_pids <= gallery_pids 100 | assert trainval_pids.isdisjoint(gallery_pids) 101 | 102 | # Save meta information into a json file 103 | meta = {'name': 'Regdb', 'shot': 'multiple', 'num_cameras': 3, 104 | 'identities': identities, 105 | 'query_fnames': query_fnames, 106 | 'gallery_fnames': gallery_fnames} 107 | write_json(meta, osp.join(self.root, 'meta.json')) 108 | 109 | # Save the only training / test split 110 | splits = [{ 111 | 'trainval': sorted(list(trainval_pids)), 112 | 'query': sorted(list(query_pids)), 113 | 'gallery': sorted(list(gallery_pids))}] 114 | write_json(splits, osp.join(self.root, 'splits.json')) 115 | 116 | 117 | def load(self, num_val=0.3, verbose=True): 118 | splits = read_json(osp.join(self.root, 'splits.json')) 119 | if self.split_id >= len(splits): 120 | raise ValueError("split_id exceeds total splits {}" 121 | .format(len(splits))) 122 | self.split = splits[self.split_id] 123 | 124 | # Randomly split train / val 125 | trainval_pids = np.asarray(self.split['trainval']) 126 | np.random.shuffle(trainval_pids) 127 | num = len(trainval_pids) 128 | if isinstance(num_val, float): 129 | num_val = int(round(num * num_val)) 130 | if num_val >= num or num_val < 0: 131 | raise ValueError("num_val exceeds total identities {}" 132 | .format(num)) 133 | train_pids = sorted(trainval_pids[:-num_val]) 134 | val_pids = sorted(trainval_pids[-num_val:]) 135 | 136 | self.meta = read_json(osp.join(self.root, 'meta.json')) 137 | identities = self.meta['identities'] 138 | 139 | self.train = _pluck(identities, train_pids, relabel=True) 140 | self.val = _pluck(identities, val_pids, relabel=True) 141 | self.trainval = _pluck(identities, trainval_pids, relabel=True) 142 | self.num_train_ids = len(train_pids) 143 | self.num_val_ids = len(val_pids) 144 | self.num_trainval_ids = len(trainval_pids) 145 | 146 | 147 | query_fnames = self.meta['query_fnames'] 148 | gallery_fnames = self.meta['gallery_fnames'] 149 | self.query = [] 150 | for fname in query_fnames: 151 | name = osp.splitext(fname)[0] 152 | pid, cam, _ = map(int, name.split('_')) 153 | self.query.append((fname, pid, cam)) 154 | self.gallery = [] 155 | for fname in gallery_fnames: 156 | name = osp.splitext(fname)[0] 157 | pid, cam, _ = map(int, name.split('_')) 158 | self.gallery.append((fname, pid, cam)) 159 | 160 | 161 | if verbose: 162 | print(self.__class__.__name__, "dataset loaded") 163 | print(" Subset | # ids | # images") 164 | print(" ---------------------------") 165 | print(" Train | {:5d} | {:8d}" 166 | .format(self.num_train_ids, len(self.train))) 167 | print(" Val | {:5d} | {:8d}" 168 | .format(self.num_val_ids, len(self.val))) 169 | print(" Trainval | {:5d} | {:8d}" 170 | .format(self.num_trainval_ids, len(self.trainval))) 171 | print(" Query | {:5d} | {:8d}" 172 | .format(len(self.split['query']), len(self.query))) 173 | print(" Gallery | {:5d} | {:8d}" 174 | .format(len(self.split['gallery']), len(self.gallery))) 175 | 176 | -------------------------------------------------------------------------------- /reid/models/resnet2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152'] 7 | 8 | model_urls = { 9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1, dilation=1): 18 | """3x3 convolution with padding""" 19 | # original padding is 1; original dilation is 1 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=dilation, bias=False, dilation=dilation) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride, dilation) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | # original padding is 1; original dilation is 1 64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=dilation, bias=False, dilation=dilation) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 67 | self.bn3 = nn.BatchNorm2d(planes * 4) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class ResNet(nn.Module): 96 | 97 | def __init__(self, block, layers, last_conv_stride=2, last_conv_dilation=1): 98 | 99 | self.inplanes = 64 100 | super(ResNet, self).__init__() 101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 102 | bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | self.layer1 = self._make_layer(block, 64, layers[0]) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_conv_stride, dilation=last_conv_dilation) 110 | 111 | for m in self.modules(): 112 | if isinstance(m, nn.Conv2d): 113 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 114 | m.weight.data.normal_(0, math.sqrt(2. / n)) 115 | elif isinstance(m, nn.BatchNorm2d): 116 | m.weight.data.fill_(1) 117 | m.bias.data.zero_() 118 | 119 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 120 | downsample = None 121 | if stride != 1 or self.inplanes != planes * block.expansion: 122 | downsample = nn.Sequential( 123 | nn.Conv2d(self.inplanes, planes * block.expansion, 124 | kernel_size=1, stride=stride, bias=False), 125 | nn.BatchNorm2d(planes * block.expansion), 126 | ) 127 | 128 | layers = [] 129 | layers.append(block(self.inplanes, planes, stride, downsample, dilation)) 130 | self.inplanes = planes * block.expansion 131 | for i in range(1, blocks): 132 | layers.append(block(self.inplanes, planes)) 133 | 134 | return nn.Sequential(*layers) 135 | 136 | def forward(self, x): 137 | x = self.conv1(x) 138 | x = self.bn1(x) 139 | x = self.relu(x) 140 | x = self.maxpool(x) 141 | 142 | x = self.layer1(x) 143 | x = self.layer2(x) 144 | x = self.layer3(x) 145 | x = self.layer4(x) 146 | 147 | return x 148 | 149 | 150 | def remove_fc(state_dict): 151 | """Remove the fc layer parameters from state_dict.""" 152 | # for key, value in state_dict.items(): 153 | for key, value in list(state_dict.items()): 154 | if key.startswith('fc.'): 155 | del state_dict[key] 156 | return state_dict 157 | 158 | 159 | def resnet18(pretrained=False, **kwargs): 160 | """Constructs a ResNet-18 model. 161 | Args: 162 | pretrained (bool): If True, returns a model pre-trained on ImageNet 163 | """ 164 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 165 | if pretrained: 166 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet18']))) 167 | return model 168 | 169 | 170 | def resnet34(pretrained=False, **kwargs): 171 | """Constructs a ResNet-34 model. 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | """ 175 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 176 | if pretrained: 177 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet34']))) 178 | return model 179 | 180 | 181 | def resnet50(pretrained=False, **kwargs): 182 | """Constructs a ResNet-50 model. 183 | Args: 184 | pretrained (bool): If True, returns a model pre-trained on ImageNet 185 | """ 186 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 187 | if pretrained: 188 | # model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet50']))) 189 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet50']))) 190 | return model 191 | 192 | 193 | def resnet101(pretrained=False, **kwargs): 194 | """Constructs a ResNet-101 model. 195 | Args: 196 | pretrained (bool): If True, returns a model pre-trained on ImageNet 197 | """ 198 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 199 | if pretrained: 200 | model.load_state_dict( 201 | remove_fc(model_zoo.load_url(model_urls['resnet101']))) 202 | return model 203 | 204 | 205 | def resnet152(pretrained=False, **kwargs): 206 | """Constructs a ResNet-152 model. 207 | Args: 208 | pretrained (bool): If True, returns a model pre-trained on ImageNet 209 | """ 210 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 211 | if pretrained: 212 | model.load_state_dict( 213 | remove_fc(model_zoo.load_url(model_urls['resnet152']))) 214 | return model -------------------------------------------------------------------------------- /reid/models/AGW.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from .resnet2 import * 5 | 6 | class Normalize(nn.Module): 7 | def __init__(self, power=2): 8 | super(Normalize, self).__init__() 9 | self.power = power 10 | 11 | def forward(self, x): 12 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 13 | out = x.div(norm) 14 | return out 15 | 16 | class Non_local(nn.Module): 17 | def __init__(self, in_channels, reduc_ratio=2): 18 | super(Non_local, self).__init__() 19 | 20 | self.in_channels = in_channels 21 | self.inter_channels = reduc_ratio//reduc_ratio 22 | 23 | self.g = nn.Sequential( 24 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, 25 | padding=0), 26 | ) 27 | 28 | self.W = nn.Sequential( 29 | nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, 30 | kernel_size=1, stride=1, padding=0), 31 | nn.BatchNorm2d(self.in_channels), 32 | ) 33 | nn.init.constant_(self.W[1].weight, 0.0) 34 | nn.init.constant_(self.W[1].bias, 0.0) 35 | 36 | 37 | 38 | self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 39 | kernel_size=1, stride=1, padding=0) 40 | 41 | self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 42 | kernel_size=1, stride=1, padding=0) 43 | 44 | def forward(self, x): 45 | batch_size = x.size(0) 46 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 47 | g_x = g_x.permute(0, 2, 1) 48 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 49 | theta_x = theta_x.permute(0, 2, 1) 50 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 51 | f = torch.matmul(theta_x, phi_x) 52 | N = f.size(-1) 53 | f_div_C = f / N 54 | y = torch.matmul(f_div_C, g_x) 55 | y = y.permute(0, 2, 1).contiguous() 56 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 57 | W_y = self.W(y) 58 | z = W_y + x 59 | 60 | return z 61 | 62 | def weights_init_kaiming(m): 63 | classname = m.__class__.__name__ 64 | if classname.find('Conv') != -1: 65 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 66 | elif classname.find('Linear') != -1: 67 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 68 | init.zeros_(m.bias.data) 69 | elif classname.find('BatchNorm1d') != -1: 70 | init.normal_(m.weight.data, 1.0, 0.01) 71 | init.zeros_(m.bias.data) 72 | 73 | def weights_init_classifier(m): 74 | classname = m.__class__.__name__ 75 | if classname.find('Linear') != -1: 76 | init.normal_(m.weight.data, 0, 0.001) 77 | if m.bias: 78 | init.zeros_(m.bias.data) 79 | 80 | 81 | 82 | class visible_module(nn.Module): 83 | def __init__(self, arch='resnet50'): 84 | super(visible_module, self).__init__() 85 | 86 | model_v = resnet50(pretrained=True, 87 | last_conv_stride=1, last_conv_dilation=1) 88 | # avg pooling to global pooling 89 | self.visible = model_v 90 | 91 | def forward(self, x): 92 | x = self.visible.conv1(x) 93 | x = self.visible.bn1(x) 94 | x = self.visible.relu(x) 95 | x = self.visible.maxpool(x) 96 | return x 97 | 98 | 99 | class thermal_module(nn.Module): 100 | def __init__(self, arch='resnet50'): 101 | super(thermal_module, self).__init__() 102 | 103 | model_t = resnet50(pretrained=True, 104 | last_conv_stride=1, last_conv_dilation=1) 105 | # avg pooling to global pooling 106 | self.thermal = model_t 107 | 108 | def forward(self, x): 109 | x = self.thermal.conv1(x) 110 | x = self.thermal.bn1(x) 111 | x = self.thermal.relu(x) 112 | x = self.thermal.maxpool(x) 113 | return x 114 | 115 | 116 | class base_resnet(nn.Module): 117 | def __init__(self, arch='resnet50'): 118 | super(base_resnet, self).__init__() 119 | 120 | model_base = resnet50(pretrained=True, 121 | last_conv_stride=1, last_conv_dilation=1) 122 | # avg pooling to global pooling 123 | model_base.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 124 | self.base = model_base 125 | 126 | def forward(self, x): 127 | x = self.base.layer1(x) 128 | x = self.base.layer2(x) 129 | x = self.base.layer3(x) 130 | x = self.base.layer4(x) 131 | return x 132 | 133 | class embed_net(nn.Module): 134 | def __init__(self, no_local='on', gm_pool='on', arch='resnet50',pretrained=True,num_classes=0): 135 | super(embed_net, self).__init__() 136 | 137 | self.thermal_module = thermal_module(arch=arch) 138 | self.visible_module = visible_module(arch=arch) 139 | self.base_resnet = base_resnet(arch=arch) 140 | self.non_local = no_local 141 | if self.non_local =='on': 142 | layers=[3, 4, 6, 3] 143 | non_layers=[0,2,3,0] 144 | self.NL_1 = nn.ModuleList( 145 | [Non_local(256) for i in range(non_layers[0])]) 146 | self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])]) 147 | self.NL_2 = nn.ModuleList( 148 | [Non_local(512) for i in range(non_layers[1])]) 149 | self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])]) 150 | self.NL_3 = nn.ModuleList( 151 | [Non_local(1024) for i in range(non_layers[2])]) 152 | self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])]) 153 | self.NL_4 = nn.ModuleList( 154 | [Non_local(2048) for i in range(non_layers[3])]) 155 | self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])]) 156 | pool_dim = 2048 157 | self.l2norm = Normalize(2) 158 | self.bottleneck = nn.BatchNorm1d(pool_dim) 159 | self.bottleneck.bias.requires_grad_(False) # no shift 160 | 161 | self.classifier = nn.Linear(pool_dim, num_classes, bias=False) 162 | 163 | self.bottleneck.apply(weights_init_kaiming) 164 | self.classifier.apply(weights_init_classifier) 165 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 166 | self.gm_pool = gm_pool 167 | 168 | def forward(self, x1, modal=1): 169 | if modal == 1: 170 | x = self.visible_module(x1) 171 | elif modal == 2: 172 | x = self.thermal_module(x2) 173 | 174 | # shared block 175 | if self.non_local == 'on': 176 | NL1_counter = 0 177 | if len(self.NL_1_idx) == 0: self.NL_1_idx = [-1] 178 | for i in range(len(self.base_resnet.base.layer1)): 179 | x = self.base_resnet.base.layer1[i](x) 180 | if i == self.NL_1_idx[NL1_counter]: 181 | _, C, H, W = x.shape 182 | x = self.NL_1[NL1_counter](x) 183 | NL1_counter += 1 184 | # Layer 2 185 | NL2_counter = 0 186 | if len(self.NL_2_idx) == 0: self.NL_2_idx = [-1] 187 | for i in range(len(self.base_resnet.base.layer2)): 188 | x = self.base_resnet.base.layer2[i](x) 189 | if i == self.NL_2_idx[NL2_counter]: 190 | _, C, H, W = x.shape 191 | x = self.NL_2[NL2_counter](x) 192 | NL2_counter += 1 193 | # Layer 3 194 | NL3_counter = 0 195 | if len(self.NL_3_idx) == 0: self.NL_3_idx = [-1] 196 | for i in range(len(self.base_resnet.base.layer3)): 197 | x = self.base_resnet.base.layer3[i](x) 198 | if i == self.NL_3_idx[NL3_counter]: 199 | _, C, H, W = x.shape 200 | x = self.NL_3[NL3_counter](x) 201 | NL3_counter += 1 202 | # Layer 4 203 | NL4_counter = 0 204 | if len(self.NL_4_idx) == 0: self.NL_4_idx = [-1] 205 | for i in range(len(self.base_resnet.base.layer4)): 206 | x = self.base_resnet.base.layer4[i](x) 207 | if i == self.NL_4_idx[NL4_counter]: 208 | _, C, H, W = x.shape 209 | x = self.NL_4[NL4_counter](x) 210 | NL4_counter += 1 211 | else: 212 | x = self.base_resnet(x) 213 | if self.gm_pool == 'on': 214 | b, c, h, w = x.shape 215 | x = x.view(b, c, -1) 216 | p = 3.0 217 | x_pool = (torch.mean(x**p, dim=-1) + 1e-12)**(1/p) 218 | else: 219 | x_pool = self.avgpool(x) 220 | x_pool = x_pool.view(x_pool.size(0), x_pool.size(1)) 221 | 222 | feat = self.bottleneck(x_pool) 223 | 224 | if self.training: 225 | return x_pool, self.classifier(feat) 226 | else: 227 | return x_pool, self.classifier(feat) #要对齐 228 | # return self.l2norm(x_pool), self.l2norm(feat) -------------------------------------------------------------------------------- /reid/models/attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | """ 6 | PART of the code is from the following link 7 | https://github.com/Diego999/pyGAT/blob/master/layers.py 8 | """ 9 | 10 | 11 | class Normalize(nn.Module): 12 | def __init__(self, power=2): 13 | super(Normalize, self).__init__() 14 | self.power = power 15 | 16 | def forward(self, x): 17 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 18 | out = x.div(norm) 19 | return out 20 | 21 | class GraphAttentionLayer(nn.Module): 22 | """ 23 | Simple GAT layer, similar to https://arxiv.org/abs/1710.10903 24 | """ 25 | 26 | def __init__(self, in_features, out_features, dropout, alpha=0.2, concat=True): 27 | super(GraphAttentionLayer, self).__init__() 28 | self.dropout = dropout 29 | self.in_features = in_features 30 | self.out_features = out_features 31 | self.alpha = alpha 32 | self.concat = concat 33 | 34 | self.W = nn.Parameter(torch.zeros(size=(in_features, out_features))) 35 | nn.init.xavier_uniform_(self.W.data, gain=1.414) 36 | self.a = nn.Parameter(torch.zeros(size=(2 * out_features, 1))) 37 | nn.init.xavier_uniform_(self.a.data, gain=1.414) 38 | 39 | self.leakyrelu = nn.LeakyReLU(self.alpha) 40 | 41 | def forward(self, input, adj): 42 | h = torch.mm(input, self.W) 43 | N = h.size()[0] 44 | 45 | a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features) 46 | e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2)) 47 | 48 | zero_vec = -9e15 * torch.ones_like(e) 49 | attention = torch.where(adj > 0, e, zero_vec) 50 | attention = F.softmax(attention, dim=1) 51 | attention = F.dropout(attention, self.dropout, training=self.training) 52 | h_prime = torch.matmul(attention, h) 53 | 54 | if self.concat: 55 | return F.elu(h_prime) 56 | else: 57 | return h_prime 58 | 59 | def __repr__(self): 60 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 61 | 62 | 63 | class SpecialSpmmFunction(torch.autograd.Function): 64 | """Special function for only sparse region backpropataion layer.""" 65 | 66 | @staticmethod 67 | def forward(ctx, indices, values, shape, b): 68 | assert indices.requires_grad == False 69 | a = torch.sparse_coo_tensor(indices, values, shape) 70 | ctx.save_for_backward(a, b) 71 | ctx.N = shape[0] 72 | return torch.matmul(a, b) 73 | 74 | @staticmethod 75 | def backward(ctx, grad_output): 76 | a, b = ctx.saved_tensors 77 | grad_values = grad_b = None 78 | if ctx.needs_input_grad[1]: 79 | grad_a_dense = grad_output.matmul(b.t()) 80 | edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :] 81 | grad_values = grad_a_dense.view(-1)[edge_idx] 82 | if ctx.needs_input_grad[3]: 83 | grad_b = a.t().matmul(grad_output) 84 | return None, grad_values, None, grad_b 85 | 86 | 87 | class SpecialSpmm(nn.Module): 88 | def forward(self, indices, values, shape, b): 89 | return SpecialSpmmFunction.apply(indices, values, shape, b) 90 | 91 | 92 | class SpGraphAttentionLayer(nn.Module): 93 | """ 94 | Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903 95 | """ 96 | 97 | def __init__(self, in_features, out_features, dropout, alpha, concat=True): 98 | super(SpGraphAttentionLayer, self).__init__() 99 | self.in_features = in_features 100 | self.out_features = out_features 101 | self.alpha = alpha 102 | self.concat = concat 103 | 104 | self.W = nn.Parameter(torch.zeros(size=(in_features, out_features))) 105 | nn.init.xavier_normal_(self.W.data, gain=1.414) 106 | 107 | self.a = nn.Parameter(torch.zeros(size=(1, 2 * out_features))) 108 | nn.init.xavier_normal_(self.a.data, gain=1.414) 109 | 110 | self.dropout = nn.Dropout(dropout) 111 | self.leakyrelu = nn.LeakyReLU(self.alpha) 112 | self.special_spmm = SpecialSpmm() 113 | 114 | def forward(self, input, adj): 115 | dv = 'cuda' if input.is_cuda else 'cpu' 116 | 117 | N = input.size()[0] 118 | edge = adj.nonzero().t() 119 | 120 | h = torch.mm(input, self.W) 121 | # h: N x out 122 | assert not torch.isnan(h).any() 123 | 124 | # Self-attention on the nodes - Shared attention mechanism 125 | edge_h = torch.cat((h[edge[0, :], :], h[edge[1, :], :]), dim=1).t() 126 | # edge: 2*D x E 127 | 128 | edge_e = torch.exp(-self.leakyrelu(self.a.mm(edge_h).squeeze())) 129 | assert not torch.isnan(edge_e).any() 130 | # edge_e: E 131 | 132 | e_rowsum = self.special_spmm(edge, edge_e, torch.Size([N, N]), torch.ones(size=(N, 1), device=dv)) 133 | # e_rowsum: N x 1 134 | 135 | edge_e = self.dropout(edge_e) 136 | # edge_e: E 137 | 138 | h_prime = self.special_spmm(edge, edge_e, torch.Size([N, N]), h) 139 | assert not torch.isnan(h_prime).any() 140 | # h_prime: N x out 141 | 142 | h_prime = h_prime.div(e_rowsum) 143 | # h_prime: N x out 144 | assert not torch.isnan(h_prime).any() 145 | 146 | if self.concat: 147 | # if this layer is not last layer, 148 | return F.elu(h_prime) 149 | else: 150 | # if this layer is last layer, 151 | return h_prime 152 | 153 | def __repr__(self): 154 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 155 | 156 | 157 | class IWPA(nn.Module): 158 | """ 159 | Part attention layer, "Dynamic Dual-Attentive Aggregation Learning for Visible-Infrared Person Re-Identification" 160 | """ 161 | def __init__(self, in_channels, part = 3, inter_channels=None, out_channels=None): 162 | super(IWPA, self).__init__() 163 | 164 | self.in_channels = in_channels 165 | self.inter_channels = inter_channels 166 | self.out_channels = out_channels 167 | self.l2norm = Normalize(2) 168 | 169 | if self.inter_channels is None: 170 | self.inter_channels = in_channels 171 | 172 | if self.out_channels is None: 173 | self.out_channels = in_channels 174 | 175 | conv_nd = nn.Conv2d 176 | 177 | self.fc1 = nn.Sequential( 178 | conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, 179 | padding=0), 180 | ) 181 | 182 | self.fc2 = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 183 | kernel_size=1, stride=1, padding=0) 184 | 185 | self.fc3 = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 186 | kernel_size=1, stride=1, padding=0) 187 | 188 | self.W = nn.Sequential( 189 | conv_nd(in_channels=self.inter_channels, out_channels=self.out_channels, 190 | kernel_size=1, stride=1, padding=0), 191 | nn.BatchNorm2d(self.out_channels), 192 | ) 193 | nn.init.constant_(self.W[1].weight, 0.0) 194 | nn.init.constant_(self.W[1].bias, 0.0) 195 | 196 | 197 | self.bottleneck = nn.BatchNorm1d(in_channels) 198 | self.bottleneck.bias.requires_grad_(False) # no shift 199 | 200 | nn.init.normal_(self.bottleneck.weight.data, 1.0, 0.01) 201 | nn.init.zeros_(self.bottleneck.bias.data) 202 | 203 | # weighting vector of the part features 204 | self.gate = nn.Parameter(torch.FloatTensor(part)) 205 | nn.init.constant_(self.gate, 1/part) 206 | def forward(self, x, feat, t=None, part=0): 207 | bt, c, h, w = x.shape 208 | b = bt // t 209 | 210 | # get part features 211 | part_feat = F.adaptive_avg_pool2d(x, (part, 1)) 212 | part_feat = part_feat.view(b, t, c, part) 213 | part_feat = part_feat.permute(0, 2, 1, 3) # B, C, T, Part 214 | 215 | part_feat1 = self.fc1(part_feat).view(b, self.inter_channels, -1) # B, C//r, T*Part 216 | part_feat1 = part_feat1.permute(0, 2, 1) # B, T*Part, C//r 217 | 218 | part_feat2 = self.fc2(part_feat).view(b, self.inter_channels, -1) # B, C//r, T*Part 219 | 220 | part_feat3 = self.fc3(part_feat).view(b, self.inter_channels, -1) # B, C//r, T*Part 221 | part_feat3 = part_feat3.permute(0, 2, 1) # B, T*Part, C//r 222 | 223 | # get cross-part attention 224 | cpa_att = torch.matmul(part_feat1, part_feat2) # B, T*Part, T*Part 225 | cpa_att = F.softmax(cpa_att, dim=-1) 226 | 227 | # collect contextual information 228 | refined_part_feat = torch.matmul(cpa_att, part_feat3) # B, T*Part, C//r 229 | refined_part_feat = refined_part_feat.permute(0, 2, 1).contiguous() # B, C//r, T*Part 230 | refined_part_feat = refined_part_feat.view(b, self.inter_channels, part) # B, C//r, T, Part 231 | 232 | gate = F.softmax(self.gate, dim=-1) 233 | weight_part_feat = torch.matmul(refined_part_feat, gate) 234 | x = F.adaptive_avg_pool2d(x, (1, 1)) 235 | 236 | weight_part_feat = weight_part_feat + feat 237 | feat = self.bottleneck(weight_part_feat) 238 | 239 | return feat -------------------------------------------------------------------------------- /CMPS_attack.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import argparse 4 | import sys 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from reid import models 8 | from torch.nn import functional as F 9 | from reid import datasets 10 | from MI_SGD import MI_SGD,keepGradUpdate 11 | from test import test 12 | from utils import get_data,calDist 13 | from reid.utils.data import transforms as T 14 | from torchvision.transforms import Resize 15 | from reid.utils.data.preprocessor import Preprocessor 16 | from reid.evaluators import Evaluator 17 | from torch.optim.optimizer import Optimizer, required 18 | import random 19 | import numpy as np 20 | import math 21 | from reid.evaluators import extract_features 22 | from reid.utils.meters import AverageMeter 23 | import torchvision 24 | import faiss 25 | from torchvision import transforms 26 | import time 27 | 28 | def train_CMPS(train_step1_loader, train_step3_loader, net, noise, epoch, optimizer, 29 | centroids, metaCentroids, normalize): 30 | global args 31 | noise.requires_grad = True 32 | batch_time = AverageMeter() 33 | data_time = AverageMeter() 34 | losses = AverageMeter() 35 | mean = torch.Tensor(normalize.mean).view(1, 3, 1, 1).cuda() 36 | std = torch.Tensor(normalize.std).view(1, 3, 1, 1).cuda() 37 | net.eval() 38 | end = time.time() 39 | optimizer.zero_grad() 40 | optimizer.rescale() 41 | for i, ((input, _, pid, _), (iTest, _, _, _)) in enumerate(zip(train_step1_loader, train_step3_loader)): 42 | data_time.update(time.time() - end) 43 | model.zero_grad() 44 | input = input.cuda() 45 | iTest = iTest.cuda() 46 | with torch.no_grad(): 47 | normInput = (input - mean) / std 48 | feature, realPred = net(normInput) 49 | scores = centroids.mm(F.normalize(feature.t(), p=2, dim=0)) 50 | # scores = centroids.mm(feature.t()) 51 | realLab = scores.max(0, keepdim=True)[1] 52 | _, ranks = torch.sort(scores, dim=0, descending=True) 53 | pos_i = ranks[0, :] 54 | neg_i = ranks[-1, :] 55 | neg_feature = centroids[neg_i, :] 56 | pos_feature = centroids[pos_i, :] 57 | 58 | current_noise = noise 59 | current_noise = F.interpolate( 60 | current_noise.unsqueeze(0), 61 | mode=MODE, size=tuple(input.shape[-2:]), align_corners=True, 62 | ).squeeze() 63 | perturted_input = torch.clamp(input + current_noise, 0, 1) 64 | perturted_input_norm = (perturted_input - mean) / std 65 | perturbed_feature = net(perturted_input_norm)[0] 66 | optimizer.zero_grad() 67 | loss_step1 = 10 * F.triplet_margin_loss(perturbed_feature, neg_feature, pos_feature, 0.5) 68 | 69 | 70 | loss_step1 = loss_step1.view(1) 71 | loss = loss_step1 72 | 73 | grad = torch.autograd.grad(loss, noise, create_graph=True)[0] 74 | noiseOneStep = keepGradUpdate(noise, optimizer, grad, MAX_EPS) 75 | 76 | newNoise = F.interpolate( 77 | noiseOneStep.unsqueeze(0), mode=MODE, 78 | size=tuple(iTest.shape[-2:]), align_corners=True, 79 | ).squeeze() 80 | 81 | with torch.no_grad(): 82 | normMte = (iTest - mean) / std 83 | mteFeat = net(normMte)[0] 84 | scores = metaCentroids.mm(F.normalize(mteFeat.t(), p=2, dim=0)) 85 | metaLab = scores.max(0, keepdim=True)[1] 86 | _, ranks = torch.sort(scores, dim=0, descending=True) 87 | pos_i = ranks[0, :] 88 | neg_i = ranks[-1, :] 89 | neg_mte_feat = metaCentroids[neg_i, :] 90 | pos_mte_feat = metaCentroids[pos_i, :] 91 | 92 | perMteInput = torch.clamp(iTest + newNoise, 0, 1) 93 | normPerMteInput = (perMteInput - mean) / std 94 | normMteFeat = net(normPerMteInput)[0] 95 | 96 | loss_step3 = 10 * F.triplet_margin_loss( 97 | normMteFeat, neg_mte_feat, pos_mte_feat, 0.5 98 | ) 99 | 100 | finalLoss = loss_step3 + loss_step1 101 | finalLoss.backward() 102 | 103 | losses.update(loss_step1.item()) 104 | optimizer.step() 105 | 106 | batch_time.update(time.time() - end) 107 | end = time.time() 108 | 109 | if i % args.print_freq == 0: 110 | print( 111 | ">> Train: [{0}][{1}/{2}]\t" 112 | "Batch Loader Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t\t" 113 | "Data Loader Time {data_time.val:.3f} ({data_time.avg:.3f})\t\t" 114 | "Loss {loss.val:.4f} ({loss.avg:.4f})\t" 115 | "Noise l2: {noise:.4f}".format( 116 | epoch + 1, 117 | i, len(train_step1_loader), 118 | batch_time=batch_time, 119 | data_time=data_time, 120 | loss=losses, loss_step3=loss_step3.item(), 121 | noise=noise.norm(), 122 | ) 123 | ) 124 | noise.requires_grad = False 125 | print(f"Train {epoch}: Loss: {losses.avg}") 126 | return losses.avg, noise 127 | 128 | MODE = "bilinear" 129 | 130 | if __name__ == '__main__': 131 | parser = argparse.ArgumentParser( 132 | description='' 133 | ) 134 | 135 | parser.add_argument('--data', type=str, required=True, 136 | help='path to reid dataset') 137 | parser.add_argument('-s', '--source', type=str, default='sysu', 138 | choices=datasets.names()) 139 | parser.add_argument('-a', '--arch', type=str, default='resnet50', 140 | choices=models.names()) 141 | parser.add_argument('--batch_size', type=int, default=16, required=True, 142 | help='number of examples/minibatch') 143 | parser.add_argument('--num_batches', type=int, required=False, 144 | help='number of batches (default entire dataset)') 145 | parser.add_argument('--resume', type=str, default='', metavar='PATH') 146 | parser.add_argument('--split', type=int, default=0) 147 | parser.add_argument('--epoch', type=int, default=5) 148 | parser.add_argument('--height', type=int, default=256, 149 | help="input height, default: 256 for resnet*, " 150 | "144 for inception") 151 | parser.add_argument('--width', type=int, default=128, 152 | help="input width, default: 128 for resnet*, " 153 | "56 for inception") 154 | parser.add_argument('--num-instances', type=int, default=8, 155 | help="each minibatch consist of " 156 | "(batch_size // num_instances) identities, and " 157 | "each identity has num_instances instances, " 158 | "default: 4") 159 | parser.add_argument('--print_freq', type=int, default=10) 160 | parser.add_argument("--max-eps", default=8, type=int, help="max eps") 161 | parser.add_argument('--combine-trainval', action='store_true', 162 | help="train and val sets together for training, " 163 | "val set alone for validation") 164 | 165 | args = parser.parse_args() 166 | 167 | sourceSet, sourceSet, num_classes, class_tgt, train_step1, train_step3 = \ 168 | get_data(args.source, 169 | args.split, args.data, args.height, 170 | args.width, args.batch_size, 8, args.combine_trainval) 171 | 172 | model = models.create(args.arch, pretrained=True, num_classes=num_classes) 173 | modelTest = models.create(args.arch, pretrained=True, num_classes=class_tgt) 174 | if args.resume: 175 | checkpoint = torch.load(args.resume) 176 | if 'state_dict' in checkpoint.keys(): 177 | checkpoint = checkpoint['state_dict'] 178 | try: 179 | model.load_state_dict(checkpoint) 180 | except: 181 | allNames = list(checkpoint.keys()) 182 | for name in allNames: 183 | if name.count('classifier') != 0: 184 | del checkpoint[name] 185 | model.load_state_dict(checkpoint, strict=False) 186 | #for test 187 | checkTgt = torch.load(args.resume) 188 | if 'state_dict' in checkTgt.keys(): 189 | checkTgt = checkTgt['state_dict'] 190 | try: 191 | modelTest.load_state_dict(checkTgt) 192 | except: 193 | allNames = list(checkTgt.keys()) 194 | for name in allNames: 195 | if name.count('classifier') != 0: 196 | del checkTgt[name] 197 | modelTest.load_state_dict(checkTgt, strict=False) 198 | 199 | model.eval() 200 | modelTest.eval() 201 | if torch.cuda.is_available(): 202 | model = model.cuda() 203 | modelTest = modelTest.cuda() 204 | features, _ = extract_features(model, train_step1, print_freq=10) 205 | features = torch.stack([features[f] for f, _, _ in sourceSet.trainval]) 206 | metaFeats, _ = extract_features(model, train_step3, print_freq=10) 207 | metaFeats = torch.stack([metaFeats[f] for f, _, _ in sourceSet.trainval]) 208 | if args.source == "sysu": 209 | ncentroids = 395 210 | else: 211 | ncentroids = 206 212 | fDim = features.shape[1] 213 | cluster, metaClu = faiss.Kmeans(fDim, ncentroids, niter=20, gpu=True), \ 214 | faiss.Kmeans(fDim, ncentroids, niter=20, gpu=True) 215 | cluster.train(features.cpu().numpy()) 216 | metaClu.train(metaFeats.cpu().numpy()) 217 | centroids = torch.from_numpy(cluster.centroids).cuda().float() 218 | metaCentroids = torch.from_numpy(metaClu.centroids).cuda().float() 219 | del metaClu, cluster 220 | evaluator = Evaluator(modelTest, args.print_freq) 221 | evaSrc = Evaluator(model, args.print_freq) 222 | noise = torch.zeros((3, args.height, args.width)).cuda() 223 | normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 224 | noise.requires_grad = True 225 | MAX_EPS = args.max_eps / 255.0 226 | optimizer = MI_SGD( 227 | [{"params": [noise], "lr": MAX_EPS / 10, "momentum": 1, "sign": True}], 228 | max_eps=MAX_EPS, 229 | ) 230 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=math.exp(-0.01)) 231 | for epoch in range(args.epoch): 232 | scheduler.step() 233 | begin_time = time.time() 234 | loss, noise = train_CMPS( 235 | train_step1, train_step3, model, noise, epoch, optimizer, 236 | centroids, metaCentroids, normalize 237 | ) 238 | if epoch % 5 == 0: 239 | test(sourceSet, modelTest, noise, args, evaluator, epoch) 240 | 241 | 242 | -------------------------------------------------------------------------------- /reid/loss/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd.function import Function 6 | from torch.autograd import Variable 7 | import pdb 8 | 9 | 10 | class KLLoss(nn.Module): 11 | def __init__(self): 12 | super(KLLoss, self).__init__() 13 | def forward(self, pred, label): 14 | # pred: 2D matrix (batch_size, num_classes) 15 | # label: 1D vector indicating class number 16 | T=3 17 | 18 | predict = F.log_softmax(pred/T,dim=1) 19 | target_data = F.softmax(label/T,dim=1) 20 | target_data =target_data+10**(-7) 21 | target = Variable(target_data.data.cuda(),requires_grad=False) 22 | loss=T*T*((target*(target.log()-predict)).sum(1).sum()/target.size()[0]) 23 | return loss 24 | 25 | class OriTripletLoss(nn.Module): 26 | """Triplet loss with hard positive/negative mining. 27 | 28 | Reference: 29 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 30 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. 31 | 32 | Args: 33 | - margin (float): margin for triplet. 34 | """ 35 | 36 | def __init__(self, batch_size, margin=0.3): 37 | super(OriTripletLoss, self).__init__() 38 | self.margin = margin 39 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 40 | 41 | def forward(self, inputs, targets): 42 | """ 43 | Args: 44 | - inputs: feature matrix with shape (batch_size, feat_dim) 45 | - targets: ground truth labels with shape (num_classes) 46 | """ 47 | n = inputs.size(0) 48 | 49 | # Compute pairwise distance, replace by the official when merged 50 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 51 | dist = dist + dist.t() 52 | dist.addmm_(1, -2, inputs, inputs.t()) 53 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 54 | 55 | # For each anchor, find the hardest positive and negative 56 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 57 | dist_ap, dist_an = [], [] 58 | for i in range(n): 59 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) 60 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) 61 | dist_ap = torch.cat(dist_ap) 62 | dist_an = torch.cat(dist_an) 63 | 64 | # Compute ranking hinge loss 65 | y = torch.ones_like(dist_an) 66 | loss = self.ranking_loss(dist_an, dist_ap, y) 67 | 68 | # compute accuracy 69 | correct = torch.ge(dist_an, dist_ap).sum().item() 70 | return loss, correct 71 | 72 | 73 | class TDRLoss(nn.Module): 74 | """Tri-directional ranking loss. 75 | 76 | Args: 77 | - margin (float): margin for triplet. 78 | """ 79 | 80 | def __init__(self, margin=0.3): 81 | super(TDRLoss, self).__init__() 82 | self.margin = margin 83 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 84 | 85 | def forward(self, inputs, targets): 86 | """ 87 | Args: 88 | - inputs: feature matrix with shape (batch_size, feat_dim) 89 | - targets: ground truth labels with shape (num_classes) 90 | """ 91 | n = inputs.shape[0] // 3 92 | input1 = inputs.narrow(0, 0, n) 93 | input2 = inputs.narrow(0, n, n) 94 | input3 = inputs.narrow(0, 2 * n, n) 95 | 96 | dist1 = pdist_torch(input1, input2) 97 | dist2 = pdist_torch(input2, input3) 98 | dist3 = pdist_torch(input1, input3) 99 | 100 | # compute mask 101 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 102 | 103 | # P: 1 2 N: 3 104 | dist_ap1, dist_an1 = [], [] 105 | for i in range(n): 106 | dist_ap1.append(dist1[i][mask[i]].max().unsqueeze(0)) 107 | dist_an1.append(dist3[i][mask[i] == 0].min().unsqueeze(0)) 108 | dist_ap1 = torch.cat(dist_ap1) 109 | dist_an1 = torch.cat(dist_an1) 110 | 111 | # Compute ranking hinge loss 112 | y = torch.ones_like(dist_an1) 113 | loss1 = self.ranking_loss(dist_an1, dist_ap1, y) 114 | 115 | # P: 2 3 N: 1 116 | dist_ap2, dist_an2 = [], [] 117 | for i in range(n): 118 | dist_ap2.append(dist2[i][mask[i]].max().unsqueeze(0)) 119 | dist_an2.append(dist1[i][mask[i] == 0].min().unsqueeze(0)) 120 | dist_ap2= torch.cat(dist_ap2) 121 | dist_an2 = torch.cat(dist_an2) 122 | 123 | # Compute ranking hinge loss 124 | loss2 = self.ranking_loss(dist_an2, dist_ap2, y) 125 | 126 | # P: 3 1 N: 2 127 | dist_ap3, dist_an3 = [], [] 128 | for i in range(n): 129 | dist_ap3.append(dist3[i][mask[i]].max().unsqueeze(0)) 130 | dist_an3.append(dist2[i][mask[i] == 0].min().unsqueeze(0)) 131 | dist_ap3 = torch.cat(dist_ap3) 132 | dist_an3 = torch.cat(dist_an3) 133 | 134 | # Compute ranking hinge loss 135 | loss3 = self.ranking_loss(dist_an3, dist_ap3, y) 136 | 137 | # compute accuracy 138 | correct1 = torch.ge(dist_an1, dist_ap1).sum().item() 139 | correct2 = torch.ge(dist_an2, dist_ap2).sum().item() 140 | correct3 = torch.ge(dist_an3, dist_ap3).sum().item() 141 | 142 | 143 | # regularizer 144 | # pdb.set_trace() 145 | loss_reg = dist_ap1.mean() + dist_ap2.mean() + dist_ap3.mean() 146 | return loss1+loss2+loss3, loss_reg, correct1 + correct2+correct3 147 | 148 | class WTDRLoss(nn.Module): 149 | """Tri-directional ranking loss. 150 | 151 | Args: 152 | - margin (float): margin for triplet. 153 | """ 154 | 155 | def __init__(self, margin=0.3): 156 | super(WTDRLoss, self).__init__() 157 | self.margin = margin 158 | self.ranking_loss = nn.MarginRankingLoss(reduction='none', margin=margin) 159 | 160 | def forward(self, inputs, targets): 161 | """ 162 | Args: 163 | - inputs: feature matrix with shape (batch_size, feat_dim) 164 | - targets: ground truth labels with shape (num_classes) 165 | """ 166 | n = inputs.shape[0] // 3 167 | input1 = inputs.narrow(0, 0, n) 168 | input2 = inputs.narrow(0, n, n) 169 | input3 = inputs.narrow(0, 2 * n, n) 170 | 171 | dist1 = pdist_torch(input1, input2) 172 | dist2 = pdist_torch(input2, input3) 173 | dist3 = pdist_torch(input1, input3) 174 | 175 | # compute mask 176 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 177 | 178 | # P: 1 2 N: 3 179 | dist_ap1, dist_an1 = [], [] 180 | for i in range(n): 181 | dist_ap1.append(dist1[i][mask[i]].max().unsqueeze(0)) 182 | dist_an1.append(dist3[i][mask[i] == 0].min().unsqueeze(0)) 183 | dist_ap1 = torch.cat(dist_ap1) 184 | dist_an1 = torch.cat(dist_an1) 185 | 186 | # Compute ranking hinge loss 187 | y = torch.ones_like(dist_an1) 188 | loss1 = self.ranking_loss(dist_an1, dist_ap1, y) 189 | weights1 = loss1.data.exp() 190 | # weights1 = loss1.data.pow(2) 191 | 192 | # P: 2 3 N: 1 193 | dist_ap2, dist_an2 = [], [] 194 | for i in range(n): 195 | dist_ap2.append(dist2[i][mask[i]].max().unsqueeze(0)) 196 | dist_an2.append(dist1[i][mask[i] == 0].min().unsqueeze(0)) 197 | dist_ap2= torch.cat(dist_ap2) 198 | dist_an2 = torch.cat(dist_an2) 199 | 200 | # Compute ranking hinge loss 201 | loss2 = self.ranking_loss(dist_an2, dist_ap2, y) 202 | weights2 = loss2.data.exp() 203 | # weights2 = loss2.data.pow(2) 204 | 205 | 206 | # P: 3 1 N: 2 207 | dist_ap3, dist_an3 = [], [] 208 | for i in range(n): 209 | dist_ap3.append(dist3[i][mask[i]].max().unsqueeze(0)) 210 | dist_an3.append(dist2[i][mask[i] == 0].min().unsqueeze(0)) 211 | dist_ap3 = torch.cat(dist_ap3) 212 | dist_an3 = torch.cat(dist_an3) 213 | 214 | # Compute ranking hinge loss 215 | loss3 = self.ranking_loss(dist_an3, dist_ap3, y) 216 | weights3 = loss3.data.exp() 217 | # weights3 = loss3.data.pow(2) 218 | 219 | # compute accuracy 220 | correct1 = torch.ge(dist_an1, dist_ap1).sum().item() 221 | correct2 = torch.ge(dist_an2, dist_ap2).sum().item() 222 | correct3 = torch.ge(dist_an3, dist_ap3).sum().item() 223 | 224 | 225 | # weighted aggregation loss 226 | weights_sum = torch.cat((weights1, weights2, weights3),0) 227 | wloss1 = torch.mul(weights1.div_(weights_sum.sum()), loss1).sum() 228 | wloss2 = torch.mul(weights2.div_(weights_sum.sum()), loss2).sum() 229 | wloss3 = torch.mul(weights3.div_(weights_sum.sum()), loss3).sum() 230 | 231 | return 3*(wloss1+wloss2+wloss3), correct1 + correct2+correct3 232 | 233 | 234 | class BDTRLoss(nn.Module): 235 | """Triplet loss with hard positive/negative mining. 236 | 237 | Reference: 238 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 239 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. 240 | 241 | Args: 242 | - margin (float): margin for triplet.suffix 243 | """ 244 | def __init__(self, batch_size, margin=0.5): 245 | super(BDTRLoss, self).__init__() 246 | self.margin = margin 247 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 248 | self.batch_size = batch_size 249 | self.mask = torch.eye(batch_size) 250 | def forward(self, input, target): 251 | """ 252 | Args: 253 | - input: feature matrix with shape (batch_size, feat_dim) 254 | - target: ground truth labels with shape (num_classes) 255 | """ 256 | n = self.batch_size 257 | input1 = input.narrow(0,0,n) 258 | input2 = input.narrow(0,n,n) 259 | 260 | # modal 1 to modal 2 261 | # Compute modal 1 to modal 2 distance 262 | dist = pdist_torch(input1, input2) 263 | 264 | dist_ap1, dist_an1 = [], [] 265 | for i in range(n): 266 | dist_ap1.append(dist[i,i].unsqueeze(0)) 267 | dist_an1.append(dist[i][self.mask[i] == 0].min().unsqueeze(0)) 268 | dist_ap1 = torch.cat(dist_ap1) 269 | dist_an1 = torch.cat(dist_an1) 270 | 271 | # Compute ranking hinge loss for modal 1 to modal 2 272 | y = torch.ones_like(dist_an1) 273 | loss1 = self.ranking_loss(dist_an1, dist_ap1, y) 274 | 275 | # compute accuracy 276 | correct1 = torch.ge(dist_an1, dist_ap1).sum().item() 277 | 278 | # modal 2 to modal 1 279 | # Compute modal 1 to modal 2 distance 280 | dist2 = pdist_torch(input2, input1) 281 | 282 | # For each anchor, find the hardest positive and negative 283 | dist_ap2, dist_an2 = [], [] 284 | for i in range(n): 285 | dist_ap2.append(dist2[i,i].unsqueeze(0)) 286 | dist_an2.append(dist2[i][self.mask[i] == 0].min().unsqueeze(0)) 287 | dist_ap2 = torch.cat(dist_ap2) 288 | dist_an2 = torch.cat(dist_an2) 289 | 290 | # Compute ranking hinge loss for modal 2 to modal 1 291 | y2 = torch.ones_like(dist_an2) 292 | loss2 = self.ranking_loss(dist_an2, dist_ap2, y2) 293 | 294 | # compute accuracy 295 | correct2 = torch.ge(dist_an2, dist_ap2).sum().item() 296 | 297 | inter_loss = torch.add(loss1, loss2) 298 | 299 | # computer intra-modality loss 300 | 301 | 302 | return inter_loss, correct1 + correct2 303 | 304 | class CTriLoss: 305 | def __init__(self, rho): 306 | self.rho = rho 307 | 308 | def __call__(self, C_n_g, f_adv_RGB, C_p_ir, f_adv_RGB_ir, C_n_RGB, f_adv_ir, C_p_g, f_adv_ir_g, C_n_ir, f_adv_g, C_p_RGB, f_adv_g_RGB): 309 | loss1 = np.maximum(np.linalg.norm(C_n_g - f_adv_RGB) - np.linalg.norm(C_p_ir - f_adv_RGB_ir) + self.rho, 0) 310 | loss2 = np.maximum(np.linalg.norm(C_n_RGB - f_adv_ir) - np.linalg.norm(C_p_g - f_adv_ir_g) + self.rho, 0) 311 | loss3 = np.maximum(np.linalg.norm(C_n_ir - f_adv_g) - np.linalg.norm(C_p_RGB - f_adv_g_RGB) + self.rho, 0) 312 | 313 | total_loss = loss1 + loss2 + loss3 314 | return total_loss 315 | 316 | 317 | def pdist_torch(emb1, emb2): 318 | ''' 319 | compute the eucilidean distance matrix between embeddings1 and embeddings2 320 | using gpu 321 | ''' 322 | m, n = emb1.shape[0], emb2.shape[0] 323 | emb1_pow = torch.pow(emb1, 2).sum(dim = 1, keepdim = True).expand(m, n) 324 | emb2_pow = torch.pow(emb2, 2).sum(dim = 1, keepdim = True).expand(n, m).t() 325 | dist_mtx = emb1_pow + emb2_pow 326 | dist_mtx = dist_mtx.addmm_(1, -2, emb1, emb2.t()) 327 | # dist_mtx = dist_mtx.clamp(min = 1e-12) 328 | dist_mtx = dist_mtx.clamp(min = 1e-12).sqrt() 329 | return dist_mtx 330 | 331 | 332 | def pdist_np(emb1, emb2): 333 | ''' 334 | compute the eucilidean distance matrix between embeddings1 and embeddings2 335 | using cpu 336 | ''' 337 | m, n = emb1.shape[0], emb2.shape[0] 338 | emb1_pow = np.square(emb1).sum(axis = 1)[..., np.newaxis] 339 | emb2_pow = np.square(emb2).sum(axis = 1)[np.newaxis, ...] 340 | dist_mtx = -2 * np.matmul(emb1, emb2.T) + emb1_pow + emb2_pow 341 | # dist_mtx = np.sqrt(dist_mtx.clip(min = 1e-12)) 342 | return dist_mtx --------------------------------------------------------------------------------