├── reid ├── utils │ ├── data │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── dataset.cpython-311.pyc │ │ │ ├── dataset.cpython-39.pyc │ │ │ ├── __init__.cpython-311.pyc │ │ │ ├── transforms.cpython-39.pyc │ │ │ ├── preprocessor.cpython-311.pyc │ │ │ └── preprocessor.cpython-39.pyc │ │ ├── preprocessor.py │ │ ├── sampler.py │ │ ├── transforms.py │ │ └── dataset.py │ ├── osutils.py │ ├── meters.py │ ├── __init__.py │ ├── logging.py │ └── serialization.py ├── evaluation_metrics │ ├── __init__.py │ ├── classification.py │ └── ranking.py ├── feature_extraction │ ├── __init__.py │ ├── database.py │ └── cnn.py ├── loss │ ├── __init__.py │ ├── oim.py │ ├── triplet.py │ └── tripletAttack.py ├── __init__.py ├── datasets │ ├── __init__.py │ ├── CnMix.py │ ├── Sketch.py │ ├── sysu.py │ └── regdb.py ├── trainers.py ├── models │ ├── __init__.py │ ├── resnet.py │ ├── PCB.py │ ├── inception.py │ ├── DDAG.py │ ├── attention.py │ ├── baseline.py │ └── AGW.py ├── rerank.py └── evaluators.py ├── run.sh ├── requirements.txt ├── readme.md ├── CnMix_process.py ├── MOAA ├── Solutions.py ├── operators.py └── MOAA.py └── Multiform_attack.py /reid/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .dataset import Dataset 4 | from .preprocessor import Preprocessor 5 | -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/finger-monkey/Muitiform_Attack/HEAD/reid/utils/data/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/dataset.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/finger-monkey/Muitiform_Attack/HEAD/reid/utils/data/__pycache__/dataset.cpython-311.pyc -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/finger-monkey/Muitiform_Attack/HEAD/reid/utils/data/__pycache__/dataset.cpython-39.pyc -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/finger-monkey/Muitiform_Attack/HEAD/reid/utils/data/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/transforms.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/finger-monkey/Muitiform_Attack/HEAD/reid/utils/data/__pycache__/transforms.cpython-39.pyc -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/preprocessor.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/finger-monkey/Muitiform_Attack/HEAD/reid/utils/data/__pycache__/preprocessor.cpython-311.pyc -------------------------------------------------------------------------------- /reid/utils/data/__pycache__/preprocessor.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/finger-monkey/Muitiform_Attack/HEAD/reid/utils/data/__pycache__/preprocessor.cpython-39.pyc -------------------------------------------------------------------------------- /reid/evaluation_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .classification import accuracy 4 | from .ranking import cmc, mean_ap 5 | 6 | __all__ = [ 7 | 'accuracy', 8 | 'cmc', 9 | 'mean_ap', 10 | ] 11 | -------------------------------------------------------------------------------- /reid/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /reid/feature_extraction/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .cnn import extract_cnn_feature, extract_pcb_feature 4 | from .database import FeatureDatabase 5 | 6 | __all__ = [ 7 | 'extract_cnn_feature', 'extract_pcb_feature', 8 | 'FeatureDatabase', 9 | ] 10 | -------------------------------------------------------------------------------- /reid/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .oim import oim, OIM, OIMLoss 4 | from .tripletAttack import TripletLoss 5 | from .triplet import Triplet 6 | 7 | __all__ = [ 8 | 'oim', 9 | 'OIM', 10 | 'OIMLoss', 11 | 'TripletLoss', 'Triplet' 12 | ] 13 | -------------------------------------------------------------------------------- /reid/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import datasets 4 | from . import evaluation_metrics 5 | from . import feature_extraction 6 | from . import loss 7 | from . import models 8 | from . import utils 9 | from . import evaluators 10 | from . import trainers 11 | from . import rerank 12 | 13 | __version__ = '0.2.0' 14 | -------------------------------------------------------------------------------- /reid/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from ..utils import to_torch 4 | 5 | 6 | def accuracy(output, target, topk=(1,)): 7 | output, target = to_torch(output), to_torch(target) 8 | maxk = max(topk) 9 | batch_size = target.size(0) 10 | 11 | _, pred = output.topk(maxk, 1, True, True) 12 | pred = pred.t() 13 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 14 | 15 | ret = [] 16 | for k in topk: 17 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) 18 | ret.append(correct_k.mul_(1. / batch_size)) 19 | return ret 20 | -------------------------------------------------------------------------------- /reid/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def to_numpy(tensor): 7 | if torch.is_tensor(tensor): 8 | return tensor.cpu().numpy() 9 | elif type(tensor).__module__ != 'numpy': 10 | raise ValueError("Cannot convert {} to numpy array" 11 | .format(type(tensor))) 12 | return tensor 13 | 14 | 15 | def to_torch(ndarray): 16 | if type(ndarray).__module__ == 'numpy': 17 | return torch.from_numpy(ndarray) 18 | elif not torch.is_tensor(ndarray): 19 | raise ValueError("Cannot convert {} to torch tensor" 20 | .format(type(ndarray))) 21 | return ndarray 22 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=1 python -W ignore Multiform_attack.py -s sysu_v2 -m CnMix -m2 CnMix -t sysu_v2 --data /sda1/gyp/data --batch_size 128\ 3 | --resume /sda1/gyp/DDAG/save_model/sysu_G_P_3_drop_0.2_4_8_lr_0.1_seed_0_best.t\ 4 | --resumeSearchTgt /sda1/gyp/DDAG/save_model/CnMix.t\ 5 | --resumeSearchTgt2 /sda1/gyp/DDAG/save_model/CnMix.t\ 6 | --resumeTgt /sda1/gyp/DDAG/save_model/regdb_G_P_3_drop_0.2_4_8_lr_0.1_seed_0_trial_1_best.t 7 | 8 | CUDA_VISIBLE_DEVICES=1 python -W ignore Multiform_attack.py -s regdb_v2 -m CnMix -m2 CnMix -t regdb_v2 --data /sda1/gyp/data --batch_size 64\ 9 | --resume /sda1/gyp/DDAG/save_model/regdb_G_P_3_drop_0.2_4_8_lr_0.1_seed_0_trial_1_best.t\ 10 | --resumeSearchTgt /sda1/gyp/DDAG/save_model/CnMix.t\ 11 | --resumeSearchTgt2 /sda1/gyp/DDAG/save_model/CnMix.t\ 12 | --resumeTgt /sda1/gyp/DDAG/save_model/regdb_G_P_3_drop_0.2_4_8_lr_0.1_seed_0_trial_1_best.t 13 | 14 | -------------------------------------------------------------------------------- /reid/utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /reid/utils/data/preprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os.path as osp 3 | 4 | from PIL import Image 5 | 6 | 7 | class Preprocessor(object): 8 | def __init__(self, dataset, root=None, transform=None): 9 | super(Preprocessor, self).__init__() 10 | self.dataset = dataset 11 | self.root = root 12 | self.transform = transform 13 | 14 | def __len__(self): 15 | return len(self.dataset) 16 | 17 | def __getitem__(self, indices): 18 | if isinstance(indices, (tuple, list)): 19 | return [self._get_single_item(index) for index in indices] 20 | return self._get_single_item(indices) 21 | 22 | def _get_single_item(self, index): 23 | fname, pid, camid = self.dataset[index] 24 | fpath = fname 25 | if self.root is not None: 26 | fpath = osp.join(self.root, fname) 27 | img = Image.open(fpath).convert('RGB') 28 | if self.transform is not None: 29 | img = self.transform(img) 30 | return img, fname, pid, camid 31 | -------------------------------------------------------------------------------- /reid/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data.sampler import ( 7 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 8 | WeightedRandomSampler) 9 | 10 | 11 | class RandomIdentitySampler(Sampler): 12 | def __init__(self, data_source, num_instances=1): 13 | self.data_source = data_source 14 | self.num_instances = num_instances 15 | self.index_dic = defaultdict(list) 16 | for index, (_, pid, _) in enumerate(data_source): 17 | self.index_dic[pid].append(index) 18 | self.pids = list(self.index_dic.keys()) 19 | self.num_samples = len(self.pids) 20 | 21 | def __len__(self): 22 | return self.num_samples * self.num_instances 23 | 24 | def __iter__(self): 25 | indices = torch.randperm(self.num_samples) 26 | ret = [] 27 | for i in indices: 28 | pid = self.pids[i] 29 | t = self.index_dic[pid] 30 | if len(t) >= self.num_instances: 31 | t = np.random.choice(t, size=self.num_instances, replace=False) 32 | else: 33 | t = np.random.choice(t, size=self.num_instances, replace=True) 34 | ret.extend(t) 35 | return iter(ret) 36 | -------------------------------------------------------------------------------- /reid/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import warnings 3 | 4 | 5 | from .CnMix import CnMix 6 | from .Sketch import Sketch 7 | from .regdb import Regdb 8 | from .sysu import Sysu 9 | 10 | 11 | __factory = { 12 | 13 | 'CnMix': CnMix, 14 | 'regdb_v2': Regdb, 15 | 'Sketch': Sketch, 16 | 'sysu_v2': Sysu 17 | } 18 | 19 | 20 | def names(): 21 | return sorted(__factory.keys()) 22 | 23 | 24 | def create(name, root, *args, **kwargs): 25 | """ 26 | Create a dataset instance. 27 | 28 | Parameters 29 | ---------- 30 | name : str 31 | The dataset name. Can be one of 'viper', 'cuhk01', 'cuhk03', 32 | 'market1501', and 'dukemtmc'. 33 | root : str 34 | The path to the dataset directory. 35 | split_id : int, optional 36 | The index of data split. Default: 0 37 | num_val : int or float, optional 38 | When int, it means the number of validation identities. When float, 39 | it means the proportion of validation to all the trainval. Default: 100 40 | download : bool, optional 41 | If True, will download the dataset. Default: False 42 | """ 43 | if name not in __factory: 44 | raise KeyError("Unknown dataset:", name) 45 | return __factory[name](root, *args, **kwargs) 46 | 47 | 48 | def get_dataset(name, root, *args, **kwargs): 49 | warnings.warn("get_dataset is deprecated. Use create instead.") 50 | return create(name, root, *args, **kwargs) 51 | -------------------------------------------------------------------------------- /reid/feature_extraction/database.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import h5py 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class FeatureDatabase(Dataset): 9 | def __init__(self, *args, **kwargs): 10 | super(FeatureDatabase, self).__init__() 11 | self.fid = h5py.File(*args, **kwargs) 12 | 13 | def __enter__(self): 14 | return self 15 | 16 | def __exit__(self, exc_type, exc_val, exc_tb): 17 | self.close() 18 | 19 | def __getitem__(self, keys): 20 | if isinstance(keys, (tuple, list)): 21 | return [self._get_single_item(k) for k in keys] 22 | return self._get_single_item(keys) 23 | 24 | def _get_single_item(self, key): 25 | return np.asarray(self.fid[key]) 26 | 27 | def __setitem__(self, key, value): 28 | if key in self.fid: 29 | if self.fid[key].shape == value.shape and \ 30 | self.fid[key].dtype == value.dtype: 31 | self.fid[key][...] = value 32 | else: 33 | del self.fid[key] 34 | self.fid.create_dataset(key, data=value) 35 | else: 36 | self.fid.create_dataset(key, data=value) 37 | 38 | def __delitem__(self, key): 39 | del self.fid[key] 40 | 41 | def __len__(self): 42 | return len(self.fid) 43 | 44 | def __iter__(self): 45 | return iter(self.fid) 46 | 47 | def flush(self): 48 | self.fid.flush() 49 | 50 | def close(self): 51 | self.fid.close() 52 | -------------------------------------------------------------------------------- /reid/feature_extraction/cnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import OrderedDict 3 | 4 | import torch 5 | from torch.autograd import Variable 6 | 7 | from ..utils import to_torch 8 | 9 | 10 | def extract_cnn_feature(model, inputs, modules=None): 11 | model.eval() 12 | inputs = to_torch(inputs) 13 | inputs = Variable(inputs, volatile=True).cuda() 14 | if modules is None: 15 | outputs = model(inputs)[0] 16 | outputs = outputs.data.cpu() 17 | return outputs 18 | 19 | # Register forward hook for each module 20 | outputs = OrderedDict() 21 | handles = [] 22 | for m in modules: 23 | outputs[id(m)] = None 24 | def func(m, i, o): outputs[id(m)] = o.data.cpu() 25 | handles.append(m.register_forward_hook(func)) 26 | model(inputs) 27 | for h in handles: 28 | h.remove() 29 | return list(outputs.values()) 30 | 31 | 32 | def extract_pcb_feature(model, inputs, modules=None): 33 | model.eval() 34 | inputs = to_torch(inputs) 35 | inputs = Variable(inputs, volatile=True).cuda() 36 | if modules is None: 37 | outputs = model(inputs) 38 | outputs = outputs.data.cpu() 39 | return outputs 40 | 41 | # Register forward hook for each module 42 | outputs = OrderedDict() 43 | handles = [] 44 | for m in modules: 45 | outputs[id(m)] = None 46 | def func(m, i, o): outputs[id(m)] = o.data.cpu() 47 | handles.append(m.register_forward_hook(func)) 48 | model(inputs) 49 | for h in handles: 50 | h.remove() 51 | return list(outputs.values()) 52 | -------------------------------------------------------------------------------- /reid/loss/oim.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, autograd 6 | 7 | 8 | class OIM(autograd.Function): 9 | def __init__(self, lut, momentum=0.5): 10 | super(OIM, self).__init__() 11 | self.lut = lut 12 | self.momentum = momentum 13 | 14 | def forward(self, inputs, targets): 15 | self.save_for_backward(inputs, targets) 16 | outputs = inputs.mm(self.lut.t()) 17 | return outputs 18 | 19 | def backward(self, grad_outputs): 20 | inputs, targets = self.saved_tensors 21 | grad_inputs = None 22 | if self.needs_input_grad[0]: 23 | grad_inputs = grad_outputs.mm(self.lut) 24 | for x, y in zip(inputs, targets): 25 | self.lut[y] = self.momentum * self.lut[y] + (1. - self.momentum) * x 26 | self.lut[y] /= self.lut[y].norm() 27 | return grad_inputs, None 28 | 29 | 30 | def oim(inputs, targets, lut, momentum=0.5): 31 | return OIM(lut, momentum=momentum)(inputs, targets) 32 | 33 | 34 | class OIMLoss(nn.Module): 35 | def __init__(self, num_features, num_classes, scalar=1.0, momentum=0.5, 36 | weight=None, size_average=True): 37 | super(OIMLoss, self).__init__() 38 | self.num_features = num_features 39 | self.num_classes = num_classes 40 | self.momentum = momentum 41 | self.scalar = scalar 42 | self.weight = weight 43 | self.size_average = size_average 44 | 45 | self.register_buffer('lut', torch.zeros(num_classes, num_features)) 46 | 47 | def forward(self, inputs, targets): 48 | inputs = oim(inputs, targets, self.lut, momentum=self.momentum) 49 | inputs *= self.scalar 50 | loss = F.cross_entropy(inputs, targets, weight=self.weight, 51 | size_average=self.size_average) 52 | return loss, inputs 53 | -------------------------------------------------------------------------------- /reid/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os.path as osp 4 | import shutil 5 | 6 | import torch 7 | from torch.nn import Parameter 8 | 9 | from .osutils import mkdir_if_missing 10 | 11 | 12 | def read_json(fpath): 13 | with open(fpath, 'r') as f: 14 | obj = json.load(f) 15 | return obj 16 | 17 | 18 | def write_json(obj, fpath): 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | with open(fpath, 'w') as f: 21 | json.dump(obj, f, indent=4, separators=(',', ': ')) 22 | 23 | 24 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 25 | mkdir_if_missing(osp.dirname(fpath)) 26 | torch.save(state, fpath) 27 | if is_best: 28 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 29 | 30 | 31 | def load_checkpoint(fpath): 32 | if osp.isfile(fpath): 33 | checkpoint = torch.load(fpath) 34 | print("=> Loaded checkpoint '{}'".format(fpath)) 35 | return checkpoint 36 | else: 37 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 38 | 39 | 40 | def copy_state_dict(state_dict, model, strip=None): 41 | tgt_state = model.state_dict() 42 | copied_names = set() 43 | for name, param in state_dict.items(): 44 | if strip is not None and name.startswith(strip): 45 | name = name[len(strip):] 46 | if name not in tgt_state: 47 | continue 48 | if isinstance(param, Parameter): 49 | param = param.data 50 | if param.size() != tgt_state[name].size(): 51 | print('mismatch:', name, param.size(), tgt_state[name].size()) 52 | continue 53 | tgt_state[name].copy_(param) 54 | copied_names.add(name) 55 | 56 | missing = set(tgt_state.keys()) - copied_names 57 | if len(missing) > 0: 58 | print("missing keys in state_dict:", missing) 59 | 60 | return model 61 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Brotli @ file:///tmp/abs_ecyw11_7ze/croots/recipe/brotli-split_1659616059936/work 2 | certifi @ file:///croot/certifi_1700501669400/work/certifi 3 | cffi @ file:///croot/cffi_1700254295673/work 4 | charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work 5 | contourpy==1.2.0 6 | cryptography @ file:///croot/cryptography_1694444244250/work 7 | cycler==0.12.1 8 | debugpy @ file:///croot/debugpy_1690905042057/work 9 | faiss==1.7.3 10 | filelock==3.13.1 11 | fonttools==4.45.1 12 | fsspec==2023.12.2 13 | h5py==3.10.0 14 | huggingface-hub==0.20.3 15 | idna @ file:///croot/idna_1666125576474/work 16 | importlib-resources==6.1.1 17 | joblib==1.3.2 18 | kiwisolver==1.4.5 19 | matplotlib==3.8.2 20 | metric-learn==0.7.0 21 | mkl-fft @ file:///croot/mkl_fft_1695058164594/work 22 | mkl-random @ file:///croot/mkl_random_1695059800811/work 23 | mkl-service==2.4.0 24 | munch==4.0.0 25 | numpy @ file:///croot/numpy_and_numpy_base_1701295038894/work/dist/numpy-1.26.2-cp39-cp39-linux_x86_64.whl#sha256=ab2439928d6a64e481fc12c271cde207be02e933033d4fecde57d3e534b5025c 26 | opencv-python==4.8.1.78 27 | ops==2.9.0 28 | packaging==23.2 29 | Pillow @ file:///croot/pillow_1696580024257/work 30 | pretrainedmodels==0.7.4 31 | protobuf==4.25.1 32 | pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work 33 | pyOpenSSL @ file:///croot/pyopenssl_1690223430423/work 34 | pyparsing==3.1.1 35 | PySocks @ file:///tmp/build/80754af9/pysocks_1605305812635/work 36 | python-dateutil==2.8.2 37 | pytorch-metric-learning==2.4.1 38 | PyYAML==6.0.1 39 | requests @ file:///croot/requests_1690400202158/work 40 | safetensors==0.4.2 41 | scikit-learn==1.3.2 42 | scipy==1.11.4 43 | six==1.16.0 44 | tensorboardX==2.6.2.2 45 | threadpoolctl==3.2.0 46 | timm==0.9.12 47 | torch==1.12.0 48 | torchaudio==0.12.0 49 | torchvision==0.13.0 50 | tqdm==4.66.1 51 | typing_extensions @ file:///croot/typing_extensions_1690297465030/work 52 | urllib3 @ file:///croot/urllib3_1698257533958/work 53 | websocket-client==1.7.0 54 | zipp==3.17.0 55 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Cross-Modality Attack Boosted by Gradient-Evolutionary Multiform Optimization (XXXX 2024) 2 | 3 | Code for XXXX 2024 paper ``Cross-Modality Attack Boosted by Gradient-Evolutionary Multiform Optimization (XXXX 2024)". 4 | 5 | 6 | ## [Paper](paper/Multiform_attack.pdf) 7 | 8 | ## [Supplemental Material](paper/Supplementary_Materials.pdf) 9 | 10 | ## Requirements: 11 | * python 3.7 12 | * CUDA==10.1 13 | * Market1501 (will transform to CnMix), Sketch-ReID, SYSU and RegDB dataset. 14 | * faiss-gpu==1.6.0 15 | * Other necessary packages listed in [requirements.txt](requirements.txt) 16 | 17 | ## Preparing Data 18 | 19 | * Clone our repo 20 | 21 | Market-1501(namely CnMix) (SYSU and RegDB are the same): 22 | * Download "Market-1501-v15.09.15.zip". 23 | * Create a new directory, rename it as "data". 24 | * Create a directory called "raw" under "data" and put "Market-1501-v15.09.15.zip" under it. 25 | * The processed dataset is provided in the link below, please refer to it. 26 | * To adapt different dataset formats to this code, we have provided conversion scripts. Please refer to CnMix_process.py, cross-modal_dataset_to_market_format.py, deal_SYSU_testset_ID.py, and testset_to_query.py. 27 | 28 | 29 | * There is a processed tar file in [BaiduYun](https://pan.baidu.com/s/1dAMc0HEk_xEBQIJD1JWkPA?pwd=kwwu) (Password: kwwu) with all needed files. 30 | 31 | ## Preparing Models 32 | 33 | * Download re-ID models from [BaiduYun](https://pan.baidu.com/s/1lGoahWk--y-A008zl01VMQ?pwd=k4np) (Password: k4np) 34 | 35 | 36 | ## Run our code 37 | 38 | See run.sh for more information. 39 | 40 | If you find this code useful in your research, please consider citing: 41 | 42 | ``` 43 | @inproceedings{XXXXX, 44 | title={Cross-Modality Attack Boosted by Gradient-Evolutionary Multiform Optimization}, 45 | author={XXXXXXXXXx}, 46 | booktitle={XXX}, 47 | volume={35}, 48 | number={4}, 49 | pages={3128--3135}, 50 | year={2024} 51 | } 52 | ``` 53 | 54 | ## Contact Me 55 | 56 | Email: fmonkey625@gmail.com 57 | 58 | 59 | ### ↳ Visitors 60 | [![Visit tracker](https://clustrmaps.com/map_v2.png?cl=ffffff&w=896&t=tt&d=zLtXBhTnXw66l00fakOMI4K9BJmzjJ_0hpftLgebA_Y)](https://clustrmaps.com/site/1c4pf) 61 | -------------------------------------------------------------------------------- /reid/loss/triplet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | from torch.nn import functional as F 7 | from scipy.stats import norm 8 | 9 | import numpy as np 10 | 11 | 12 | class Triplet(nn.Module): 13 | def __init__(self, margin=0, num_instances=0, use_semi=True): 14 | super(Triplet, self).__init__() 15 | self.margin = margin 16 | self.ranking_loss = nn.MarginRankingLoss(margin=self.margin) 17 | self.K = num_instances 18 | self.use_semi = use_semi 19 | 20 | def forward(self, inputs, targets, epoch): 21 | n = inputs.size(0) 22 | P = n / self.K 23 | t0 = 20.0 24 | t1 = 40.0 25 | 26 | # Compute pairwise distance, replace by the official when merged 27 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 28 | dist = dist + dist.t() 29 | dist.addmm_(1, -2, inputs, inputs.t()) 30 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 31 | # For each anchor, find the hardest positive and negative 32 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 33 | dist_ap, dist_an = [], [] 34 | if self.use_semi: 35 | for i in range(P): 36 | for j in range(self.K): 37 | neg_examples = dist[i * self.K + j][mask[i * self.K + j] == 0] 38 | for pair in range(j + 1, self.K): 39 | ap = dist[i * self.K + j][i * self.K + pair].view(1) 40 | dist_ap.append(ap) 41 | dist_an.append(neg_examples.min().view(1)) 42 | else: 43 | for i in range(n): 44 | dist_ap.append(dist[i][mask[i]].max().view(1)) 45 | dist_an.append(dist[i][mask[i] == 0].min().view(1)) 46 | dist_ap = torch.cat(dist_ap) 47 | dist_an = torch.cat(dist_an) 48 | # Compute ranking hinge loss 49 | y = dist_an.data.new() 50 | y.resize_as_(dist_an.data) 51 | y.fill_(1) 52 | y = Variable(y) 53 | loss = self.ranking_loss(dist_an, dist_ap, y) 54 | prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0) 55 | return loss, prec -------------------------------------------------------------------------------- /CnMix_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import random 5 | from PIL import Image, ImageOps 6 | 7 | def to_sketch(img): 8 | # Convert img to numpy array 9 | img_np = np.array(img) 10 | # Invert image colors 11 | img_inv = 255 - img_np 12 | # Gaussian blur of img_inv 13 | img_blur = cv2.GaussianBlur(img_inv, (21, 21), sigmaX=0, sigmaY=0) 14 | # Blend original and blurred images 15 | img_blend = cv2.divide(img_np, 255 - img_blur, scale=256) 16 | return Image.fromarray(img_blend) 17 | 18 | def random_choose(r, g, b, gray_or_sketch): 19 | p = [r, g, b, gray_or_sketch, gray_or_sketch] 20 | idx = list(range(5)) 21 | random.shuffle(idx) 22 | return Image.merge('RGB', (p[idx[0]], p[idx[1]], p[idx[2]])) 23 | 24 | def fuse_rgb_gray_sketch(img, G, G_rgb, S_rgb): 25 | # Split img into RGB channels 26 | r, g, b = img.split() 27 | # Convert img to grayscale 28 | gray = ImageOps.grayscale(img) 29 | 30 | p = random.random() 31 | 32 | if p < G: 33 | return Image.merge('RGB', (gray, gray, gray)) 34 | elif p < G + G_rgb: 35 | return random_choose(r, g, b, gray) 36 | elif p < G + G_rgb + S_rgb: 37 | sketch = to_sketch(gray) 38 | return random_choose(r, g, b, sketch) 39 | else: 40 | return img 41 | 42 | def process_dataset(input_dir, output_dir, G, G_rgb, S_rgb): 43 | # Ensure the output directory exists 44 | if not os.path.exists(output_dir): 45 | os.makedirs(output_dir) 46 | 47 | # Iterate through all files in the input directory 48 | for root, _, files in os.walk(input_dir): 49 | for file in files: 50 | if file.endswith(('jpg', 'jpeg', 'png')): 51 | img_path = os.path.join(root, file) 52 | img = Image.open(img_path) 53 | 54 | # Transform the image 55 | output_img = fuse_rgb_gray_sketch(img, G, G_rgb, S_rgb) 56 | 57 | # Save the transformed image 58 | relative_path = os.path.relpath(img_path, input_dir) 59 | output_path = os.path.join(output_dir, relative_path) 60 | output_dirname = os.path.dirname(output_path) 61 | if not os.path.exists(output_dirname): 62 | os.makedirs(output_dirname) 63 | output_img.save(output_path) 64 | 65 | 66 | if __name__ == "__main__": 67 | input_dir = '/sda1/data/market1501/' 68 | output_dir = '/sda1/data/market1501_processed/' 69 | 70 | 71 | G = 0.3 72 | G_rgb = 0.4 73 | S_rgb = 0.2 74 | 75 | 76 | process_dataset(input_dir, output_dir, G, G_rgb, S_rgb) 77 | -------------------------------------------------------------------------------- /reid/loss/tripletAttack.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | from torch.nn import functional as F 7 | from scipy.stats import norm 8 | 9 | import numpy as np 10 | 11 | 12 | class TripletLoss(nn.Module): 13 | def __init__(self, margin=0, num_instances=0, use_semi=True): 14 | super(TripletLoss, self).__init__() 15 | self.margin = margin 16 | self.ranking_loss = nn.MarginRankingLoss(margin=self.margin) 17 | self.K = num_instances 18 | self.use_semi = use_semi 19 | 20 | def forward(self, inputs, purtub, targets, epoch): 21 | n = inputs.size(0) 22 | P = n / self.K 23 | 24 | # Compute pairwise distance, replace by the official when merged 25 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 26 | dist = dist + dist.t() 27 | dist.addmm_(1, -2, inputs, inputs.t()) 28 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 29 | # 30 | purtubeDist = torch.pow(purtub, 2).sum(dim=1, keepdim=True).expand(n, n) 31 | purtubeDist = purtubeDist + purtubeDist.t() 32 | purtubeDist.addmm_(1, -2, purtub, purtub.t()) 33 | purtubeDist = purtubeDist.clamp(min=1e-12).sqrt() 34 | # For each anchor, find the hardest positive and negative 35 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 36 | dist_ap, dist_an = [], [] 37 | if self.use_semi: 38 | for i in range(P): 39 | for j in range(self.K): 40 | neg_examples = dist[i * self.K + j][mask[i * self.K + j] == 0] 41 | for pair in range(j + 1, self.K): 42 | ap = dist[i * self.K + j][i * self.K + pair].view(1) 43 | dist_ap.append(ap) 44 | dist_an.append(neg_examples.min().view(1)) 45 | else: 46 | for i in range(n): 47 | maxLoc = dist[i][mask[i]].argmax().view(1) 48 | minLoc = dist[i][mask[i] == 0].argmin().view(1) 49 | dist_ap.append(purtubeDist[i][maxLoc].view(1)) 50 | dist_an.append(purtubeDist[i][minLoc].view(1)) 51 | dist_ap = torch.cat(dist_ap) 52 | dist_an = torch.cat(dist_an) 53 | # Compute ranking hinge loss 54 | y = dist_an.data.new() 55 | y.resize_as_(dist_an.data) 56 | y.fill_(1) 57 | y = Variable(y) 58 | # loss = self.ranking_loss(dist_an, dist_ap, y) 59 | loss = self.ranking_loss(dist_ap, dist_an, y) 60 | prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0) 61 | return loss, prec 62 | -------------------------------------------------------------------------------- /reid/trainers.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | 4 | import torch 5 | from torch.autograd import Variable 6 | 7 | from .evaluation_metrics import accuracy 8 | from .loss import OIMLoss, TripletLoss 9 | from .utils.meters import AverageMeter 10 | 11 | 12 | class BaseTrainer(object): 13 | def __init__(self, model, criterions, print_freq=1): 14 | super(BaseTrainer, self).__init__() 15 | self.model = model 16 | self.criterions = criterions 17 | self.print_freq = print_freq 18 | 19 | def train(self, epoch, data_loader, optimizer): 20 | self.model.train() 21 | 22 | batch_time = AverageMeter() 23 | data_time = AverageMeter() 24 | losses = AverageMeter() 25 | precisions = AverageMeter() 26 | 27 | end = time.time() 28 | for i, inputs in enumerate(data_loader): 29 | data_time.update(time.time() - end) 30 | 31 | inputs, targets = self._parse_data(inputs) 32 | loss, prec1 = self._forward(inputs, targets, epoch) 33 | losses.update(loss.item(), targets.size(0)) 34 | precisions.update(prec1, targets.size(0)) 35 | 36 | optimizer.zero_grad() 37 | loss.backward() 38 | # add gradient clip for lstm 39 | for param in self.model.parameters(): 40 | try: 41 | param.grad.data.clamp(-1., 1.) 42 | except: 43 | continue 44 | 45 | optimizer.step() 46 | 47 | batch_time.update(time.time() - end) 48 | end = time.time() 49 | 50 | if (i + 1) % self.print_freq == 0: 51 | print('Epoch: [{}][{}/{}]\t' 52 | 'Time {:.3f} ({:.3f})\t' 53 | 'Data {:.3f} ({:.3f})\t' 54 | 'Loss {:.3f} ({:.3f})\t' 55 | 'Prec {:.2%} ({:.2%})\t' 56 | .format(epoch, i + 1, len(data_loader), 57 | batch_time.val, batch_time.avg, 58 | data_time.val, data_time.avg, 59 | losses.val, losses.avg, 60 | precisions.val, precisions.avg)) 61 | 62 | def _parse_data(self, inputs): 63 | raise NotImplementedError 64 | 65 | def _forward(self, inputs, targets): 66 | raise NotImplementedError 67 | 68 | 69 | class Trainer(BaseTrainer): 70 | def _parse_data(self, inputs): 71 | imgs, _, pids, _ = inputs 72 | inputs = [Variable(imgs)] 73 | targets = Variable(pids.cuda()) 74 | return inputs, targets 75 | 76 | def _forward(self, inputs, targets, epoch): 77 | outputs = self.model(*inputs) # outputs=[x1,x2,x3] 78 | # new added by wc 79 | # x2 triplet loss 80 | loss_global, prec_global = self.criterions[0](outputs[1], targets, epoch) 81 | 82 | return loss_global, prec_global 83 | -------------------------------------------------------------------------------- /reid/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from .inception import * 3 | from .resnet import * 4 | from .PCB import PCB, PCBTrain 5 | from .baseline import ft_net 6 | 7 | from .AGW import embed_net 8 | from .DDAG import embed_net2 9 | # __factory = { 10 | # 'inception': inception, 11 | # 'resnet18': resnet18, 12 | # 'resnet34': resnet34, 13 | # 'resnet50': resnet50, 14 | # 'resnet101': resnet101, 15 | # 'resnet152': resnet152, 16 | # 'pcb': PCB, 17 | # 'pcbt': PCBTrain, 18 | # 'crossModal': embed_net, 19 | # 'DDAG': embed_net2 20 | # } 21 | 22 | # from .DDAG import embed_net2 23 | # __factory = { 24 | # 'inception': inception, 25 | # 'resnet18': resnet18, 26 | # 'resnet34': resnet34, 27 | # 'resnet50': resnet50, 28 | # 'resnet101': resnet101, 29 | # 'resnet152': resnet152, 30 | # 'pcb': PCB, 31 | # 'pcbt': PCBTrain, 32 | # 'DDAG': embed_net2 33 | # } 34 | 35 | __factory = { 36 | 'inception': inception, 37 | 'resnet18': resnet18, 38 | 'resnet34': resnet34, 39 | 'resnet50': resnet50, 40 | 'resnet101': resnet101, 41 | 'resnet152': resnet152, 42 | 'pcb': PCB, 43 | 'pcbt': PCBTrain, 44 | 'baseline':ft_net, 45 | 'AGW': embed_net, 46 | 'DDAG': embed_net2 47 | } 48 | 49 | # __factory = { 50 | # 'inception': inception, 51 | # 'resnet18': resnet18, 52 | # 'resnet34': resnet34, 53 | # 'resnet50': resnet50, 54 | # 'resnet101': resnet101, 55 | # 'resnet152': resnet152, 56 | # 'pcb': PCB, 57 | # 'pcbt': PCBTrain 58 | # } 59 | 60 | 61 | def names(): 62 | return sorted(__factory.keys()) 63 | 64 | 65 | def create(name, *args, **kwargs): 66 | """ 67 | Create a model instance. 68 | 69 | Parameters 70 | ---------- 71 | name : str 72 | Model name. Can be one of 'inception', 'resnet18', 'resnet34', 73 | 'resnet50', 'resnet101', and 'resnet152'. 74 | pretrained : bool, optional 75 | Only applied for 'resnet*' models. If True, will use ImageNet pretrained 76 | model. Default: True 77 | cut_at_pooling : bool, optional 78 | If True, will cut the model before the last global pooling layer and 79 | ignore the remaining kwargs. Default: False 80 | num_features : int, optional 81 | If positive, will append a Linear layer after the global pooling layer, 82 | with this number of output units, followed by a BatchNorm layer. 83 | Otherwise these layers will not be appended. Default: 256 for 84 | 'inception', 0 for 'resnet*' 85 | norm : bool, optional 86 | If True, will normalize the feature to be unit L2-norm for each sample. 87 | Otherwise will append a ReLU layer after the above Linear layer if 88 | num_features > 0. Default: False 89 | dropout : float, optional 90 | If positive, will append a Dropout layer with this dropout rate. 91 | Default: 0 92 | num_classes : int, optional 93 | If positive, will append a Linear layer at the end as the classifier 94 | with this number of output units. Default: 0 95 | """ 96 | if name not in __factory: 97 | raise KeyError("Unknown model:", name) 98 | return __factory[name](*args, **kwargs) 99 | -------------------------------------------------------------------------------- /reid/utils/data/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from PIL import Image, ImageOps, ImageEnhance 4 | import random 5 | import math 6 | from torchvision.transforms import * 7 | 8 | 9 | class RectScale(object): 10 | def __init__(self, height, width, interpolation=Image.BILINEAR): 11 | self.height = height 12 | self.width = width 13 | self.interpolation = interpolation 14 | 15 | def __call__(self, img): 16 | w, h = img.size 17 | if h == self.height and w == self.width: 18 | return img 19 | return img.resize((self.width, self.height), self.interpolation) 20 | 21 | 22 | class RandomSizedRectCrop(object): 23 | def __init__(self, height, width, interpolation=Image.BILINEAR): 24 | self.height = height 25 | self.width = width 26 | self.interpolation = interpolation 27 | 28 | def __call__(self, img): 29 | for attempt in range(10): 30 | area = img.size[0] * img.size[1] 31 | target_area = random.uniform(0.64, 1.0) * area 32 | aspect_ratio = random.uniform(2, 3) 33 | 34 | h = int(round(math.sqrt(target_area * aspect_ratio))) 35 | w = int(round(math.sqrt(target_area / aspect_ratio))) 36 | 37 | if w <= img.size[0] and h <= img.size[1]: 38 | x1 = random.randint(0, img.size[0] - w) 39 | y1 = random.randint(0, img.size[1] - h) 40 | 41 | img = img.crop((x1, y1, x1 + w, y1 + h)) 42 | assert (img.size == (w, h)) 43 | 44 | return img.resize((self.width, self.height), self.interpolation) 45 | 46 | # Fallback 47 | scale = RectScale(self.height, self.width, 48 | interpolation=self.interpolation) 49 | return scale(img) 50 | 51 | 52 | class RandomErasing(object): 53 | ''' 54 | Class that performs Random Erasing in Random Erasing Data Augmentation by Zhong et al. 55 | ------------------------------------------------------------------------------------- 56 | probability: The probability that the operation will be performed. 57 | sl: min erasing area 58 | sh: max erasing area 59 | r1: min aspect ratio 60 | mean: erasing value 61 | ------------------------------------------------------------------------------------- 62 | ''' 63 | 64 | def __init__(self, probability=0.5, sl=0.02, sh=0.2, r1=0.3, mean=[0.4914, 0.4822, 0.4465]): 65 | self.probability = probability 66 | self.mean = mean 67 | self.sl = sl 68 | self.sh = sh 69 | self.r1 = r1 70 | 71 | def __call__(self, img): 72 | 73 | if random.uniform(0, 1) > self.probability: 74 | return img 75 | 76 | for attempt in range(100): 77 | area = img.size()[1] * img.size()[2] 78 | 79 | target_area = random.uniform(self.sl, self.sh) * area 80 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 81 | 82 | h = int(round(math.sqrt(target_area * aspect_ratio))) 83 | w = int(round(math.sqrt(target_area / aspect_ratio))) 84 | 85 | if w < img.size()[2] and h < img.size()[1]: 86 | x1 = random.randint(0, img.size()[1] - h) 87 | y1 = random.randint(0, img.size()[2] - w) 88 | if img.size()[0] == 3: 89 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 90 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 91 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 92 | else: 93 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 94 | return img 95 | 96 | return img 97 | -------------------------------------------------------------------------------- /reid/models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import math 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.nn import init 8 | from torch.autograd import Variable 9 | import torchvision 10 | # from torch_deform_conv.layers import ConvOffset2D 11 | from reid.utils.serialization import load_checkpoint, save_checkpoint 12 | 13 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 14 | 'resnet152'] 15 | 16 | 17 | class ResNet(nn.Module): 18 | __factory = { 19 | 18: torchvision.models.resnet18, 20 | 34: torchvision.models.resnet34, 21 | 50: torchvision.models.resnet50, 22 | 101: torchvision.models.resnet101, 23 | 152: torchvision.models.resnet152, 24 | } 25 | 26 | def __init__(self, depth, checkpoint=None, pretrained=True, num_features=2048, 27 | dropout=0.1, num_classes=0): 28 | super(ResNet, self).__init__() 29 | 30 | self.depth = depth 31 | self.checkpoint = checkpoint 32 | self.pretrained = pretrained 33 | self.num_features = num_features 34 | self.dropout = dropout 35 | self.num_classes = num_classes 36 | 37 | if self.dropout > 0: 38 | self.drop = nn.Dropout(self.dropout) 39 | # Construct base (pretrained) resnet 40 | if depth not in ResNet.__factory: 41 | raise KeyError("Unsupported depth:", depth) 42 | self.base = ResNet.__factory[depth](pretrained=pretrained) 43 | out_planes = self.base.fc.in_features 44 | 45 | # resume from pre-iteration training 46 | if self.checkpoint: 47 | state_dict = load_checkpoint(checkpoint) 48 | self.load_state_dict(state_dict['state_dict'], strict=False) 49 | 50 | self.feat = nn.Linear(out_planes, self.num_features, bias=False) 51 | self.feat_bn = nn.BatchNorm1d(self.num_features) 52 | self.relu = nn.ReLU(inplace=True) 53 | init.normal(self.feat.weight, std=0.001) 54 | init.constant(self.feat_bn.weight, 1) 55 | init.constant(self.feat_bn.bias, 0) 56 | 57 | # x2 classifier 58 | self.classifier_x2 = nn.Linear(self.num_features, self.num_classes) 59 | init.normal(self.classifier_x2.weight, std=0.001) 60 | init.constant(self.classifier_x2.bias, 0) 61 | 62 | if not self.pretrained: 63 | self.reset_params() 64 | 65 | def forward(self, x): 66 | for name, module in self.base._modules.items(): 67 | if name == 'avgpool': 68 | break 69 | x = module(x) 70 | 71 | x1 = F.avg_pool2d(x, x.size()[2:]) 72 | x1 = x1.view(x1.size(0), -1) 73 | x2 = self.feat(x1) 74 | x2 = self.feat_bn(x2) 75 | x2 = self.relu(x2) 76 | x2 = self.drop(x2) 77 | x2 = self.classifier_x2(x2) 78 | return x1, x2 79 | 80 | def reset_params(self): 81 | for m in self.modules(): 82 | if isinstance(m, nn.Conv2d): 83 | init.normal(m.weight, std=0.001) 84 | if m.bias is not None: 85 | init.constant(m.bias, 0) 86 | elif isinstance(m, nn.BatchNorm2d): 87 | init.constant(m.weight, 1) 88 | init.constant(m.bias, 0) 89 | elif isinstance(m, nn.Linear): 90 | init.normal(m.weight, std=0.001) 91 | if m.bias is not None: 92 | init.constant(m.bias, 0) 93 | 94 | 95 | def resnet18(**kwargs): 96 | return ResNet(18, **kwargs) 97 | 98 | 99 | def resnet34(**kwargs): 100 | return ResNet(34, **kwargs) 101 | 102 | 103 | def resnet50(**kwargs): 104 | return ResNet(50, **kwargs) 105 | 106 | 107 | def resnet101(**kwargs): 108 | return ResNet(101, **kwargs) 109 | 110 | 111 | def resnet152(**kwargs): 112 | return ResNet(152, **kwargs) 113 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/ranking.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | from sklearn.metrics import average_precision_score 6 | 7 | from ..utils import to_numpy 8 | 9 | 10 | def _unique_sample(ids_dict, num): 11 | # mask = np.zeros(num, dtype=np.bool) 12 | mask = np.zeros(num, dtype=bool) 13 | for _, indices in ids_dict.items(): 14 | i = np.random.choice(indices) 15 | mask[i] = True 16 | return mask 17 | 18 | 19 | def cmc(distmat, query_ids=None, gallery_ids=None, 20 | query_cams=None, gallery_cams=None, topk=100, 21 | separate_camera_set=False, 22 | single_gallery_shot=False, 23 | first_match_break=False): 24 | distmat = to_numpy(distmat) 25 | m, n = distmat.shape 26 | # Fill up default values 27 | if query_ids is None: 28 | query_ids = np.arange(m) 29 | if gallery_ids is None: 30 | gallery_ids = np.arange(n) 31 | if query_cams is None: 32 | query_cams = np.zeros(m).astype(np.int32) 33 | if gallery_cams is None: 34 | gallery_cams = np.ones(n).astype(np.int32) 35 | # Ensure numpy array 36 | query_ids = np.asarray(query_ids) 37 | gallery_ids = np.asarray(gallery_ids) 38 | query_cams = np.asarray(query_cams) 39 | gallery_cams = np.asarray(gallery_cams) 40 | # Sort and find correct matches 41 | indices = np.argsort(distmat, axis=1) 42 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 43 | # Compute CMC for each query 44 | ret = np.zeros(topk) 45 | num_valid_queries = 0 46 | for i in range(m): 47 | # Filter out the same id and same camera 48 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 49 | (gallery_cams[indices[i]] != query_cams[i])) 50 | if separate_camera_set: 51 | # Filter out samples from same camera 52 | valid &= (gallery_cams[indices[i]] != query_cams[i]) 53 | if not np.any(matches[i, valid]): continue 54 | if single_gallery_shot: 55 | repeat = 10 56 | gids = gallery_ids[indices[i][valid]] 57 | inds = np.where(valid)[0] 58 | ids_dict = defaultdict(list) 59 | for j, x in zip(inds, gids): 60 | ids_dict[x].append(j) 61 | else: 62 | repeat = 1 63 | for _ in range(repeat): 64 | if single_gallery_shot: 65 | # Randomly choose one instance for each id 66 | sampled = (valid & _unique_sample(ids_dict, len(valid))) 67 | index = np.nonzero(matches[i, sampled])[0] 68 | else: 69 | index = np.nonzero(matches[i, valid])[0] 70 | delta = 1. / (len(index) * repeat) 71 | for j, k in enumerate(index): 72 | if k - j >= topk: break 73 | if first_match_break: 74 | ret[k - j] += 1 75 | break 76 | ret[k - j] += delta 77 | num_valid_queries += 1 78 | if num_valid_queries == 0: 79 | raise RuntimeError("No valid query") 80 | return ret.cumsum() / num_valid_queries 81 | 82 | 83 | def mean_ap(distmat, query_ids=None, gallery_ids=None, 84 | query_cams=None, gallery_cams=None): 85 | distmat = to_numpy(distmat) 86 | m, n = distmat.shape 87 | # Fill up default values 88 | if query_ids is None: 89 | query_ids = np.arange(m) 90 | if gallery_ids is None: 91 | gallery_ids = np.arange(n) 92 | if query_cams is None: 93 | query_cams = np.zeros(m).astype(np.int32) 94 | if gallery_cams is None: 95 | gallery_cams = np.ones(n).astype(np.int32) 96 | # Ensure numpy array 97 | query_ids = np.asarray(query_ids) 98 | gallery_ids = np.asarray(gallery_ids) 99 | query_cams = np.asarray(query_cams) 100 | gallery_cams = np.asarray(gallery_cams) 101 | # Sort and find correct matches 102 | indices = np.argsort(distmat, axis=1) 103 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 104 | # Compute AP for each query 105 | aps = [] 106 | for i in range(m): 107 | # Filter out the same id and same camera 108 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 109 | (gallery_cams[indices[i]] != query_cams[i])) 110 | y_true = matches[i, valid] 111 | y_score = -distmat[i][indices[i]][valid] 112 | if not np.any(y_true): continue 113 | aps.append(average_precision_score(y_true, y_score)) 114 | if len(aps) == 0: 115 | raise RuntimeError("No valid query") 116 | return np.mean(aps) 117 | -------------------------------------------------------------------------------- /reid/rerank.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Jun 26 14:46:56 2017 5 | @author: luohao 6 | """ 7 | 8 | """ 9 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 10 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 11 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 12 | """ 13 | 14 | """ 15 | Modified by L.Song and C.Wang 16 | """ 17 | 18 | import numpy as np 19 | from scipy.spatial.distance import cdist 20 | 21 | def re_ranking(input_feature_source, input_feature, k1=20, k2=6, lambda_value=0.1): 22 | 23 | all_num = input_feature.shape[0] 24 | feat = input_feature.astype(np.float16) 25 | 26 | if lambda_value != 0: 27 | print('Computing source distance...') 28 | all_num_source = input_feature_source.shape[0] 29 | sour_tar_dist = np.power( 30 | cdist(input_feature, input_feature_source), 2).astype(np.float16) 31 | sour_tar_dist = 1-np.exp(-sour_tar_dist) 32 | source_dist_vec = np.min(sour_tar_dist, axis = 1) 33 | source_dist_vec = source_dist_vec / np.max(source_dist_vec) 34 | source_dist = np.zeros([all_num, all_num]) 35 | for i in range(all_num): 36 | source_dist[i, :] = source_dist_vec + source_dist_vec[i] 37 | del sour_tar_dist 38 | del source_dist_vec 39 | 40 | print('Computing original distance...') 41 | original_dist = cdist(feat,feat).astype(np.float16) 42 | original_dist = np.power(original_dist,2).astype(np.float16) 43 | del feat 44 | euclidean_dist = original_dist 45 | gallery_num = original_dist.shape[0] #gallery_num=all_num 46 | original_dist = np.transpose(original_dist/np.max(original_dist,axis = 0)) 47 | V = np.zeros_like(original_dist).astype(np.float16) 48 | initial_rank = np.argsort(original_dist).astype(np.int32) ## default axis=-1. 49 | 50 | print('Starting re_ranking...') 51 | for i in range(all_num): 52 | # k-reciprocal neighbors 53 | forward_k_neigh_index = initial_rank[i,:k1+1] ## k1+1 because self always ranks first. forward_k_neigh_index.shape=[k1+1]. forward_k_neigh_index[0] == i. 54 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] ##backward.shape = [k1+1, k1+1]. For each ele in forward_k_neigh_index, find its rank k1 neighbors 55 | fi = np.where(backward_k_neigh_index==i)[0] 56 | k_reciprocal_index = forward_k_neigh_index[fi] ## get R(p,k) in the paper 57 | k_reciprocal_expansion_index = k_reciprocal_index 58 | for j in range(len(k_reciprocal_index)): 59 | candidate = k_reciprocal_index[j] 60 | candidate_forward_k_neigh_index = initial_rank[candidate,:int(np.around(k1/2))+1] 61 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,:int(np.around(k1/2))+1] 62 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 63 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 64 | if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2/3*len(candidate_k_reciprocal_index): 65 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 66 | 67 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) ## element-wise unique 68 | weight = np.exp(-original_dist[i,k_reciprocal_expansion_index]) 69 | V[i,k_reciprocal_expansion_index] = weight/np.sum(weight) 70 | #original_dist = original_dist[:query_num,] 71 | if k2 != 1: 72 | V_qe = np.zeros_like(V,dtype=np.float16) 73 | for i in range(all_num): 74 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0) 75 | V = V_qe 76 | del V_qe 77 | del initial_rank 78 | invIndex = [] 79 | for i in range(gallery_num): 80 | invIndex.append(np.where(V[:,i] != 0)[0]) #len(invIndex)=all_num 81 | 82 | jaccard_dist = np.zeros_like(original_dist,dtype = np.float16) 83 | 84 | 85 | for i in range(all_num): 86 | temp_min = np.zeros(shape=[1,gallery_num],dtype=np.float16) 87 | indNonZero = np.where(V[i,:] != 0)[0] 88 | indImages = [] 89 | indImages = [invIndex[ind] for ind in indNonZero] 90 | for j in range(len(indNonZero)): 91 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 92 | jaccard_dist[i] = 1-temp_min/(2-temp_min) 93 | 94 | pos_bool = (jaccard_dist < 0) 95 | jaccard_dist[pos_bool] = 0.0 96 | 97 | if lambda_value == 0: 98 | return jaccard_dist 99 | else: 100 | final_dist = jaccard_dist*(1-lambda_value) + source_dist*lambda_value 101 | return final_dist 102 | -------------------------------------------------------------------------------- /reid/models/PCB.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import init 6 | from torchvision import models 7 | 8 | __all__ = ['PCB'] 9 | 10 | 11 | def weights_init_kaiming(m): 12 | classname = m.__class__.__name__ 13 | if classname.find('Conv2d') != -1: 14 | init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='relu') 15 | elif classname.find('Linear') != -1: 16 | init.kaiming_normal(m.weight.data, a=0, mode='fan_out') 17 | init.constant(m.bias.data, 0.0) 18 | elif classname.find('BatchNorm1d') != -1: 19 | init.normal(m.weight.data, 1.0, 0.02) 20 | init.constant(m.bias.data, 0.0) 21 | elif classname.find('BatchNorm2d') != -1: 22 | init.constant(m.weight.data, 1) 23 | init.constant(m.bias.data, 0) 24 | 25 | 26 | def weights_init_classifier(m): 27 | classname = m.__class__.__name__ 28 | if classname.find('Linear') != -1: 29 | init.normal(m.weight.data, std=0.001) 30 | init.constant(m.bias.data, 0.0) 31 | 32 | 33 | # Defines the new fc layer and classification layer 34 | # |--Linear--|--bn--|--relu--|--Linear--| 35 | class ClassBlock(nn.Module): 36 | def __init__(self, input_dim, class_num, relu=True, num_bottleneck=512): 37 | super(ClassBlock, self).__init__() 38 | add_block = [] 39 | 40 | add_block += [nn.Conv2d(input_dim, num_bottleneck, kernel_size=1, bias=False)] 41 | add_block += [nn.BatchNorm2d(num_bottleneck)] 42 | if relu: 43 | add_block += [nn.ReLU(inplace=True)] 44 | add_block = nn.Sequential(*add_block) 45 | add_block.apply(weights_init_kaiming) 46 | 47 | classifier = [] 48 | classifier += [nn.Linear(num_bottleneck, class_num)] 49 | classifier = nn.Sequential(*classifier) 50 | classifier.apply(weights_init_classifier) 51 | 52 | self.add_block = add_block 53 | self.classifier = classifier 54 | 55 | def forward(self, x): 56 | x = self.add_block(x) 57 | x = torch.squeeze(x) 58 | x = self.classifier(x) 59 | return x 60 | 61 | 62 | # Part Model proposed in Yifan Sun etal. (2018) 63 | class PCB(nn.Module): 64 | def __init__(self, num_classes, pretrained=True): 65 | super(PCB, self).__init__() 66 | self.part = 6 67 | # resnet50 68 | resnet = models.resnet50(pretrained=pretrained) 69 | # remove the final downsample 70 | resnet.layer4[0].downsample[0].stride = (1, 1) 71 | resnet.layer4[0].conv2.stride = (1, 1) 72 | modules = list(resnet.children())[:-2] 73 | self.backbone = nn.Sequential(*modules) 74 | self.avgpool = nn.AdaptiveAvgPool2d((self.part, 1)) 75 | self.dropout = nn.Dropout(p=0.5) 76 | 77 | # define 6 classifiers 78 | self.classifiers = nn.ModuleList() 79 | for i in range(self.part): 80 | self.classifiers.append(ClassBlock(2048, num_classes, True, 256)) 81 | 82 | def forward(self, x): 83 | x = self.backbone(x) 84 | x = self.avgpool(x) 85 | x = self.dropout(x) 86 | part = {} 87 | predict = {} 88 | # get six part feature batchsize*2048*6 89 | for i in range(self.part): 90 | part[i] = x[:, :, i, :] 91 | part[i] = torch.unsqueeze(part[i], 3) 92 | predict[i] = self.classifiers[i].add_block(part[i]) # 6*256-dim 93 | 94 | scores, features = [], [] 95 | for i in range(self.part): 96 | scores.append(predict[i].view(predict[i].shape[0], -1)) # id-class or 1536 97 | return torch.cat(scores, 1) # 1536-dim 98 | 99 | 100 | class PCBTrain(nn.Module): 101 | def __init__(self, num_classes, pretrained=True): 102 | super(PCBTrain, self).__init__() 103 | self.part = 6 104 | # resnet50 105 | resnet = models.resnet50(pretrained=pretrained) 106 | # remove the final downsample 107 | resnet.layer4[0].downsample[0].stride = (1, 1) 108 | resnet.layer4[0].conv2.stride = (1, 1) 109 | modules = list(resnet.children())[:-2] 110 | self.backbone = nn.Sequential(*modules) 111 | self.avgpool = nn.AdaptiveAvgPool2d((self.part, 1)) 112 | self.dropout = nn.Dropout(p=0.5) 113 | 114 | # define 6 classifiers 115 | self.classifiers = nn.ModuleList() 116 | for i in range(self.part): 117 | self.classifiers.append(ClassBlock(2048, num_classes, True, 256)) 118 | 119 | def forward(self, x): 120 | x = self.backbone(x) 121 | x = self.avgpool(x) 122 | x = self.dropout(x) 123 | part = {} 124 | predict = {} 125 | # get six part feature batchsize*2048*6 126 | for i in range(self.part): 127 | part[i] = x[:, :, i, :] 128 | part[i] = torch.unsqueeze(part[i], 3) 129 | predict[i] = self.classifiers[i](part[i]) # 6*256-dim 130 | 131 | scores, features = [], [] 132 | for i in range(self.part): 133 | scores.append(predict[i].view(predict[i].shape[0], -1)) 134 | features.append(part[i].view(predict[i].shape[0], -1)) 135 | return features, scores # 1536-dim 136 | -------------------------------------------------------------------------------- /MOAA/Solutions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import deepcopy 3 | from operator import attrgetter 4 | 5 | 6 | class Solution: 7 | def __init__(self, pixels, values, x, p_size): 8 | self.pixels = pixels # list of Integers 9 | self.values = values # list of Binary tuples, i.e. [0, 1, 1] 10 | self.x = x # (w x w x 3) 11 | self.fitnesses = [] 12 | self.is_adversarial = None 13 | self.w = x.shape[0] 14 | self.delta = len(self.pixels) 15 | self.domination_count = None 16 | self.dominated_solutions = None 17 | self.rank = None 18 | self.crowding_distance = None 19 | 20 | self.loss = None 21 | self.p_size = p_size 22 | 23 | def copy(self): 24 | return deepcopy(self) 25 | 26 | def euc_distance(self, img): 27 | return np.sum((img - self.x.copy()) ** 2) 28 | 29 | def generate_image(self): 30 | x_adv = self.x.copy() 31 | for i in range(self.delta): 32 | x_adv[self.pixels[i] // self.w, self.pixels[i] % self.w] += (self.values[i] * self.p_size) 33 | 34 | return np.clip(x_adv, 0, 1) 35 | 36 | def evaluate(self, loss_function, include_dist): 37 | img_adv = self.generate_image() 38 | fs = loss_function(img_adv) 39 | self.is_adversarial = fs[0] # Assume first element is boolean always 40 | self.fitnesses = fs[1:] 41 | if include_dist: 42 | dist = self.euc_distance(img_adv) 43 | self.fitnesses.append(dist) 44 | else: 45 | self.fitnesses.append(0) 46 | 47 | self.fitnesses = np.array(self.fitnesses) 48 | self.loss = fs[1] 49 | 50 | def dominates(self, soln): 51 | if self.is_adversarial is True and soln.is_adversarial is False: 52 | return True 53 | 54 | if self.is_adversarial is False and soln.is_adversarial is True: 55 | return False 56 | 57 | if self.is_adversarial is True and soln.is_adversarial is True: 58 | return True if self.fitnesses[1] < soln.fitnesses[1] else False 59 | 60 | if self.is_adversarial is False and soln.is_adversarial is False: 61 | return True if self.fitnesses[0] < soln.fitnesses[0] else False 62 | 63 | 64 | def fast_nondominated_sort(population): 65 | fronts = [[]] 66 | for individual in population: 67 | individual.domination_count = 0 68 | individual.dominated_solutions = [] 69 | for other_individual in population: 70 | if individual.dominates(other_individual): 71 | individual.dominated_solutions.append(other_individual) 72 | elif other_individual.dominates(individual): 73 | individual.domination_count += 1 74 | if individual.domination_count == 0: 75 | individual.rank = 0 76 | fronts[0].append(individual) 77 | i = 0 78 | while len(fronts[i]) > 0: 79 | temp = [] 80 | for individual in fronts[i]: 81 | for other_individual in individual.dominated_solutions: 82 | other_individual.domination_count -= 1 83 | if other_individual.domination_count == 0: 84 | other_individual.rank = i + 1 85 | temp.append(other_individual) 86 | i = i + 1 87 | fronts.append(temp) 88 | 89 | return fronts 90 | 91 | 92 | def calculate_crowding_distance(front): 93 | if len(front) > 0: 94 | solutions_num = len(front) 95 | for individual in front: 96 | individual.crowding_distance = 0 97 | 98 | for m in range(len(front[0].fitnesses)): 99 | front.sort(key=lambda individual: individual.fitnesses[m]) 100 | front[0].crowding_distance = 10 ** 9 101 | front[solutions_num - 1].crowding_distance = 10 ** 9 102 | m_values = [individual.fitnesses[m] for individual in front] 103 | scale = max(m_values) - min(m_values) 104 | if scale == 0: scale = 1 105 | for i in range(1, solutions_num - 1): 106 | front[i].crowding_distance += (front[i + 1].fitnesses[m] - front[i - 1].fitnesses[m]) / scale 107 | 108 | 109 | def crowding_operator(individual, other_individual): 110 | if (individual.rank < other_individual.rank) or ((individual.rank == other_individual.rank) and ( 111 | individual.crowding_distance > other_individual.crowding_distance)): 112 | return 1 113 | else: 114 | return -1 115 | 116 | 117 | def __tournament(population, tournament_size): 118 | participants = np.random.choice(population, size=(tournament_size,), replace=False) 119 | best = None 120 | for participant in participants: 121 | if best is None or ( 122 | crowding_operator(participant, best) == 1): # and self.__choose_with_prob(self.tournament_prob)): 123 | best = participant 124 | 125 | return best 126 | 127 | 128 | def tournament_selection(population, tournament_size): 129 | parents = [] 130 | while len(parents) < len(population) // 2: 131 | parent1 = __tournament(population, tournament_size) 132 | parent2 = __tournament(population, tournament_size) 133 | 134 | parents.append([parent1, parent2]) 135 | return parents 136 | -------------------------------------------------------------------------------- /reid/models/inception.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.nn import init 7 | 8 | 9 | __all__ = ['InceptionNet', 'inception'] 10 | 11 | 12 | def _make_conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, 13 | bias=False): 14 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, 15 | stride=stride, padding=padding, bias=bias) 16 | bn = nn.BatchNorm2d(out_planes) 17 | relu = nn.ReLU(inplace=True) 18 | return nn.Sequential(conv, bn, relu) 19 | 20 | 21 | class Block(nn.Module): 22 | def __init__(self, in_planes, out_planes, pool_method, stride): 23 | super(Block, self).__init__() 24 | self.branches = nn.ModuleList([ 25 | nn.Sequential( 26 | _make_conv(in_planes, out_planes, kernel_size=1, padding=0), 27 | _make_conv(out_planes, out_planes, stride=stride) 28 | ), 29 | nn.Sequential( 30 | _make_conv(in_planes, out_planes, kernel_size=1, padding=0), 31 | _make_conv(out_planes, out_planes), 32 | _make_conv(out_planes, out_planes, stride=stride)) 33 | ]) 34 | 35 | if pool_method == 'Avg': 36 | assert stride == 1 37 | self.branches.append( 38 | _make_conv(in_planes, out_planes, kernel_size=1, padding=0)) 39 | self.branches.append(nn.Sequential( 40 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1), 41 | _make_conv(in_planes, out_planes, kernel_size=1, padding=0))) 42 | else: 43 | self.branches.append( 44 | nn.MaxPool2d(kernel_size=3, stride=stride, padding=1)) 45 | 46 | def forward(self, x): 47 | return torch.cat([b(x) for b in self.branches], 1) 48 | 49 | 50 | class InceptionNet(nn.Module): 51 | def __init__(self, cut_at_pooling=False, num_features=256, norm=False, 52 | dropout=0, num_classes=0): 53 | super(InceptionNet, self).__init__() 54 | self.cut_at_pooling = cut_at_pooling 55 | 56 | self.conv1 = _make_conv(3, 32) 57 | self.conv2 = _make_conv(32, 32) 58 | self.conv3 = _make_conv(32, 32) 59 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=1) 60 | self.in_planes = 32 61 | self.inception4a = self._make_inception(64, 'Avg', 1) 62 | self.inception4b = self._make_inception(64, 'Max', 2) 63 | self.inception5a = self._make_inception(128, 'Avg', 1) 64 | self.inception5b = self._make_inception(128, 'Max', 2) 65 | self.inception6a = self._make_inception(256, 'Avg', 1) 66 | self.inception6b = self._make_inception(256, 'Max', 2) 67 | 68 | if not self.cut_at_pooling: 69 | self.num_features = num_features 70 | self.norm = norm 71 | self.dropout = dropout 72 | self.has_embedding = num_features > 0 73 | self.num_classes = num_classes 74 | 75 | self.avgpool = nn.AdaptiveAvgPool2d(1) 76 | 77 | if self.has_embedding: 78 | self.feat = nn.Linear(self.in_planes, self.num_features) 79 | self.feat_bn = nn.BatchNorm1d(self.num_features) 80 | else: 81 | # Change the num_features to CNN output channels 82 | self.num_features = self.in_planes 83 | if self.dropout > 0: 84 | self.drop = nn.Dropout(self.dropout) 85 | if self.num_classes > 0: 86 | self.classifier = nn.Linear(self.num_features, self.num_classes) 87 | 88 | self.reset_params() 89 | 90 | def forward(self, x): 91 | x = self.conv1(x) 92 | x = self.conv2(x) 93 | x = self.conv3(x) 94 | x = self.pool3(x) 95 | x = self.inception4a(x) 96 | x = self.inception4b(x) 97 | x = self.inception5a(x) 98 | x = self.inception5b(x) 99 | x = self.inception6a(x) 100 | x = self.inception6b(x) 101 | 102 | if self.cut_at_pooling: 103 | return x 104 | 105 | x = self.avgpool(x) 106 | x = x.view(x.size(0), -1) 107 | 108 | if self.has_embedding: 109 | x = self.feat(x) 110 | x = self.feat_bn(x) 111 | if self.norm: 112 | x = F.normalize(x) 113 | elif self.has_embedding: 114 | x = F.relu(x) 115 | if self.dropout > 0: 116 | x = self.drop(x) 117 | if self.num_classes > 0: 118 | x = self.classifier(x) 119 | return x 120 | 121 | def _make_inception(self, out_planes, pool_method, stride): 122 | block = Block(self.in_planes, out_planes, pool_method, stride) 123 | self.in_planes = (out_planes * 4 if pool_method == 'Avg' else 124 | out_planes * 2 + self.in_planes) 125 | return block 126 | 127 | def reset_params(self): 128 | for m in self.modules(): 129 | if isinstance(m, nn.Conv2d): 130 | init.kaiming_normal(m.weight, mode='fan_out') 131 | if m.bias is not None: 132 | init.constant(m.bias, 0) 133 | elif isinstance(m, nn.BatchNorm2d): 134 | init.constant(m.weight, 1) 135 | init.constant(m.bias, 0) 136 | elif isinstance(m, nn.Linear): 137 | init.normal(m.weight, std=0.001) 138 | if m.bias is not None: 139 | init.constant(m.bias, 0) 140 | 141 | 142 | def inception(**kwargs): 143 | return InceptionNet(**kwargs) 144 | -------------------------------------------------------------------------------- /MOAA/operators.py: -------------------------------------------------------------------------------- 1 | from MOAA.Solutions import Solution 2 | import numpy as np 3 | import random 4 | 5 | 6 | def mutation(soln: Solution, pm: float, all_pixels: np.array, zero_prob: float): 7 | all_pixels = all_pixels.copy() 8 | pixels = soln.pixels.copy() 9 | rgbs = soln.values.copy() 10 | 11 | eps_it = max([int(len(soln.pixels) * pm), 1]) 12 | eps = len(soln.pixels) 13 | 14 | # select pixels to keep 15 | A_ = np.random.choice(eps, size=(eps - eps_it,), replace=False) 16 | new_pixels = pixels[A_] 17 | new_rgbs = rgbs[A_] 18 | 19 | # select new pixels to replace 20 | u_m = np.delete(all_pixels, pixels) 21 | B = np.random.choice(u_m, size=(eps_it,), replace=False) 22 | 23 | ones_prob = (1 - zero_prob) / 2 24 | rgbs_ = np.random.choice([-1, 1, 0], size=(eps_it, 3), p=(ones_prob, ones_prob, zero_prob)) 25 | pixels_ = all_pixels[B] 26 | 27 | new_pixels = np.concatenate([new_pixels, pixels_], axis=0) 28 | new_rgbs = np.concatenate([new_rgbs, rgbs_], axis=0) 29 | 30 | soln.pixels = new_pixels 31 | soln.values = new_rgbs 32 | 33 | 34 | def crossover(soln1: Solution, soln2: Solution, pc: float): 35 | l = max([int(len(soln1.pixels) * pc), 1]) 36 | k = len(soln1.pixels) 37 | # S1 crossover with S2 38 | # 1. Generate set of different pixels in S2 39 | delta = np.asarray([pi for pi in range(k) if soln2.pixels[pi] not in soln1.pixels]) 40 | 41 | offspring1 = soln1.copy() 42 | if len(delta)>0: 43 | l = l if l <= len(delta) else len(delta) 44 | switched_pixels = np.random.choice(delta, size=(l,)) 45 | offspring1.pixels[switched_pixels] = soln2.pixels[switched_pixels].copy() 46 | offspring1.values[switched_pixels] = soln2.values[switched_pixels].copy() 47 | 48 | # S2 crossover with S1 49 | # 1. Generate set of different pixels in S2 50 | delta = np.asarray([pi for pi in range(k) if soln1.pixels[pi] not in soln2.pixels]) 51 | offspring2 = soln1.copy() 52 | if len(delta)>0: 53 | l = l if l <= len(delta) else len(delta) 54 | switched_pixels = np.random.choice(delta, size=(l,)) 55 | offspring2.pixels[switched_pixels] = soln1.pixels[switched_pixels].copy() 56 | offspring2.values[switched_pixels] = soln1.values[switched_pixels].copy() 57 | 58 | return offspring1, offspring2 59 | 60 | 61 | def generate_offspring(parents, pc, pm, all_pixels, zero_prob): 62 | children = [] 63 | for pi in parents: 64 | offspring1, offspring2 = crossover(pi[0], pi[1], pc) 65 | mutation(offspring1, pm, all_pixels, zero_prob) 66 | mutation(offspring2, pm, all_pixels, zero_prob) 67 | 68 | assert len(np.unique(offspring1.pixels)) == len(offspring1.pixels) 69 | assert len(np.unique(offspring2.pixels)) == len(offspring2.pixels) 70 | children.extend([offspring1, offspring2]) 71 | 72 | return children 73 | 74 | def dominates(p, q, objectives): 75 | isBetter = False 76 | for i in range(len(objectives)): 77 | if p.objectives[i] > q.objectives[i]: 78 | return False 79 | elif p.objectives[i] < q.objectives[i]: 80 | isBetter = True 81 | return isBetter 82 | 83 | def fast_nondominated_sort(population, objectives): 84 | S = [[] for _ in range(len(population))] 85 | front = [[]] 86 | n = [0 for _ in range(len(population))] 87 | rank = [0 for _ in range(len(population))] 88 | 89 | for p in range(len(population)): 90 | S[p] = [] 91 | n[p] = 0 92 | for q in range(len(population)): 93 | if dominates(population[p], population[q], objectives): 94 | S[p].append(q) 95 | elif dominates(population[q], population[p], objectives): 96 | n[p] += 1 97 | if n[p] == 0: 98 | rank[p] = 0 99 | if front[0] == []: 100 | front[0] = [p] 101 | else: 102 | front[0].append(p) 103 | 104 | i = 0 105 | while len(front[i]) != 0: 106 | Q = [] 107 | for p in front[i]: 108 | for q in S[p]: 109 | n[q] -= 1 110 | if n[q] == 0: 111 | rank[q] = i + 1 112 | if Q == []: 113 | Q = [q] 114 | else: 115 | Q.append(q) 116 | i = i + 1 117 | front.append(Q) 118 | 119 | del front[len(front)-1] 120 | return front 121 | 122 | 123 | def calculate_crowding_distance(front): 124 | 125 | if len(front) == 0: 126 | return 127 | 128 | num_objectives = len(front[0].objectives) 129 | for individual in front: 130 | individual.crowding_distance = 0 131 | 132 | for i in range(num_objectives): 133 | front.sort(key=lambda x: x.objectives[i]) 134 | front[0].crowding_distance = float('inf') 135 | front[-1].crowding_distance = float('inf') 136 | 137 | if front[0].objectives[i] == front[-1].objectives[i]: 138 | continue 139 | 140 | for j in range(1, len(front) - 1): 141 | front[j].crowding_distance += (front[j + 1].objectives[i] - front[j - 1].objectives[i]) / \ 142 | (front[-1].objectives[i] - front[0].objectives[i]) 143 | 144 | def tournament_selection(population, tournament_size): 145 | 146 | selected_parents = [] 147 | for _ in range(len(population) // 2): 148 | tournament = random.sample(population, tournament_size) 149 | tournament.sort(key=lambda x: (x.rank, -x.crowding_distance)) 150 | selected_parents.append((tournament[0], tournament[1])) 151 | 152 | return selected_parents 153 | -------------------------------------------------------------------------------- /reid/utils/data/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os.path as osp 3 | 4 | import numpy as np 5 | 6 | from ..serialization import read_json 7 | 8 | 9 | def _pluck(identities, indices, relabel=False): 10 | ret = [] 11 | for index, pid in enumerate(indices): 12 | pid_images = identities[pid] 13 | for camid, cam_images in enumerate(pid_images): 14 | for fname in cam_images: 15 | name = osp.splitext(fname)[0] 16 | x, y, _ = map(int, name.split('_')) 17 | assert pid == x and camid == y 18 | if relabel: 19 | ret.append((fname, index, camid)) 20 | else: 21 | ret.append((fname, pid, camid)) 22 | return ret 23 | 24 | def _pluck_gallery(identities, indices, relabel=False): 25 | ret = [] 26 | for index, pid in enumerate(indices): 27 | pid_images = identities[pid] 28 | for camid, cam_images in enumerate(pid_images): 29 | if len(cam_images[:-1])==0: 30 | for fname in cam_images: 31 | name = osp.splitext(fname)[0] 32 | x, y, _ = map(int, name.split('_')) 33 | assert pid == x and camid == y 34 | if relabel: 35 | ret.append((fname, index, camid)) 36 | else: 37 | ret.append((fname, pid, camid)) 38 | else: 39 | for fname in cam_images[:-1]: 40 | name = osp.splitext(fname)[0] 41 | x, y, _ = map(int, name.split('_')) 42 | assert pid == x and camid == y 43 | if relabel: 44 | ret.append((fname, index, camid)) 45 | else: 46 | ret.append((fname, pid, camid)) 47 | return ret 48 | 49 | def _pluck_query(identities, indices, relabel=False): 50 | ret = [] 51 | for index, pid in enumerate(indices): 52 | pid_images = identities[pid] 53 | for camid, cam_images in enumerate(pid_images): 54 | for fname in cam_images[-1:]: 55 | name = osp.splitext(fname)[0] 56 | x, y, _ = map(int, name.split('_')) 57 | assert pid == x and camid == y 58 | if relabel: 59 | ret.append((fname, index, camid)) 60 | else: 61 | ret.append((fname, pid, camid)) 62 | return ret 63 | 64 | 65 | class Dataset(object): 66 | def __init__(self, root, split_id=0): 67 | self.root = root 68 | self.split_id = split_id 69 | self.meta = None 70 | self.split = None 71 | self.train, self.val, self.trainval = [], [], [] 72 | self.query, self.gallery = [], [] 73 | self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0 74 | 75 | @property 76 | def images_dir(self): 77 | return osp.join(self.root, 'images') 78 | 79 | def load(self, num_val=0.3, verbose=True): 80 | splits = read_json(osp.join(self.root, 'splits.json')) 81 | if self.split_id >= len(splits): 82 | raise ValueError("split_id exceeds total splits {}" 83 | .format(len(splits))) 84 | self.split = splits[self.split_id] 85 | 86 | # Randomly split train / val 87 | trainval_pids = np.asarray(self.split['trainval']) 88 | np.random.shuffle(trainval_pids) 89 | num = len(trainval_pids) 90 | if isinstance(num_val, float): 91 | num_val = int(round(num * num_val)) 92 | if num_val >= num or num_val < 0: 93 | raise ValueError("num_val exceeds total identities {}" 94 | .format(num)) 95 | train_pids = sorted(trainval_pids[:-num_val]) 96 | val_pids = sorted(trainval_pids[-num_val:]) 97 | 98 | self.meta = read_json(osp.join(self.root, 'meta.json')) 99 | identities = self.meta['identities'] 100 | self.train = _pluck(identities, train_pids, relabel=True) 101 | self.val = _pluck(identities, val_pids, relabel=True) 102 | self.trainval = _pluck(identities, trainval_pids, relabel=True) 103 | self.query = _pluck_query(identities, self.split['query']) 104 | #self.gallery = _pluck(identities, self.split['gallery']) 105 | self.gallery = _pluck_gallery(identities, self.split['gallery']) 106 | self.num_train_ids = len(train_pids) 107 | self.num_val_ids = len(val_pids) 108 | self.num_trainval_ids = len(trainval_pids) 109 | 110 | if verbose: 111 | print(self.__class__.__name__, "dataset loaded") 112 | print(" subset | # ids | # images") 113 | print(" ---------------------------") 114 | print(" train | {:5d} | {:8d}" 115 | .format(self.num_train_ids, len(self.train))) 116 | print(" val | {:5d} | {:8d}" 117 | .format(self.num_val_ids, len(self.val))) 118 | print(" trainval | {:5d} | {:8d}" 119 | .format(self.num_trainval_ids, len(self.trainval))) 120 | print(" query | {:5d} | {:8d}" 121 | .format(len(self.split['query']), len(self.query))) 122 | print(" gallery | {:5d} | {:8d}" 123 | .format(len(self.split['gallery']), len(self.gallery))) 124 | 125 | def _check_integrity(self): 126 | return osp.isdir(osp.join(self.root, 'images')) and \ 127 | osp.isfile(osp.join(self.root, 'meta.json')) and \ 128 | osp.isfile(osp.join(self.root, 'splits.json')) 129 | -------------------------------------------------------------------------------- /MOAA/MOAA.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import torch 4 | from torch.nn import functional as F 5 | from MOAA.operators import generate_offspring 6 | from MOAA.Solutions import Solution, fast_nondominated_sort, calculate_crowding_distance, tournament_selection 7 | from operator import attrgetter 8 | import faiss 9 | from reid.evaluators import extract_features 10 | from torch.utils.data import DataLoader 11 | 12 | 13 | 14 | class Population: 15 | def __init__(self, solutions: list, loss_function, include_dist): 16 | self.population = solutions 17 | self.fronts = None 18 | self.loss_function = loss_function 19 | self.include_dist = include_dist 20 | 21 | def evaluate(self, search_set, search_set2, model1, model2): 22 | for pi in self.population: 23 | pi.evaluate(self.loss_function, self.include_dist, search_set, search_set2, model1, model2) 24 | 25 | def find_adv_solns(self, max_dist): 26 | adv_solns = [] 27 | for pi in self.population: 28 | if pi.is_adversarial and pi.fitnesses[1] <= max_dist: 29 | adv_solns.append(pi) 30 | 31 | return adv_solns 32 | 33 | class Attack: 34 | def __init__(self, params,search_set, search_set2, modelTest, modelTest2): 35 | self.params = params 36 | self.fitness = [] 37 | self.data = [] 38 | self.search_set = search_set.dataset 39 | self.search_set2 = search_set2.dataset 40 | self.modelTest = modelTest 41 | self.modelTest2 = modelTest2 42 | 43 | 44 | def completion_procedure(self, population, loss_function, fe, success): 45 | adversarial_labels = [] 46 | for soln in population.fronts[0]: 47 | adversarial_labels.append(loss_function.get_label(soln.generate_image())) 48 | 49 | d = {"front0_imgs": [soln.generate_image() for soln in population.fronts[0]], 50 | "queries": fe, 51 | "true_label": loss_function.true, 52 | "adversarial_labels": adversarial_labels, 53 | "front0_fitness": [soln.fitnesses for soln in population.fronts[0]], 54 | "fitness_process": self.fitness, 55 | "success": success 56 | } 57 | 58 | np.save(self.params["save_directory"], d, allow_pickle=True) 59 | self.Snoise = population.fronts[0] 60 | 61 | def calculate_D(self, f_adv, centroids): 62 | C = centroids 63 | D_fadv = torch.matmul((f_adv - C).T, torch.inverse(self.S)) @ (f_adv - C) 64 | return D_fadv.sum().item() 65 | 66 | def calculate_S(self, f_adv, y_true, model): 67 | y_pred = model(f_adv).argmax(dim=1) 68 | S_fadv = (y_pred != y_true).float().mean().item() 69 | return S_fadv 70 | 71 | def attack(self,noise): 72 | self.noise = noise 73 | h, w, c = noise.size() 74 | pm = self.params["pm"] 75 | n_pixels = h * w 76 | all_pixels = np.arange(n_pixels) 77 | ones_prob = (1 - self.params["zero_probability"]) / 2 78 | try: 79 | init_solutions = [Solution(np.random.choice(all_pixels, size=(self.params["eps"]), replace=False), 80 | np.random.choice([-1, 1, 0], size=(self.params["eps"], 3), 81 | p=(ones_prob, ones_prob, self.params["zero_probability"])), 82 | noise.copy(), self.params["p_size"]) for _ in range(self.params["population_size"])] 83 | 84 | population = Population(init_solutions, self.calculate_fitness, self.params["include_dist"]) 85 | population.evaluate(self.search_set, self.search_set2, self.modelTest, self.modelTest2) 86 | fe = len(population.population) 87 | 88 | for it in range(1, self.params["iterations"]): 89 | pm = self.params["pm"] 90 | population.fronts = fast_nondominated_sort(population.population) 91 | 92 | adv_solns = population.find_adv_solns(self.params["max_dist"]) 93 | if len(adv_solns) > 0: 94 | self.fitness.append(min(population.population, key=attrgetter('loss')).fitnesses) 95 | self.completion_procedure(population, self.calculate_fitness, fe, True) 96 | return 97 | 98 | self.fitness.append(min(population.population, key=attrgetter('loss')).fitnesses) 99 | 100 | for front in population.fronts: 101 | calculate_crowding_distance(front) 102 | parents = tournament_selection(population.population, self.params["tournament_size"]) 103 | children = generate_offspring(parents, self.params["pc"], pm, all_pixels, self.params["zero_probability"]) 104 | 105 | offsprings = Population(children, self.calculate_fitness, self.params["include_dist"]) 106 | fe += len(offsprings.population) 107 | offsprings.evaluate(self.search_set, self.search_set2,self.modelTest, self.modelTest2) 108 | population.population.extend(offsprings.population) 109 | population.fronts = fast_nondominated_sort(population.population) 110 | front_num = 0 111 | new_solutions = [] 112 | while len(new_solutions) + len(population.fronts[front_num]) <= self.params["population_size"]: 113 | calculate_crowding_distance(population.fronts[front_num]) 114 | new_solutions.extend(population.fronts[front_num]) 115 | front_num += 1 116 | 117 | calculate_crowding_distance(population.fronts[front_num]) 118 | population.fronts[front_num].sort(key=attrgetter("crowding_distance"), reverse=True) 119 | new_solutions.extend(population.fronts[front_num][0:self.params["population_size"] - len(new_solutions)]) 120 | 121 | population = Population(new_solutions, self.calculate_fitness, self.params["include_dist"]) 122 | 123 | population.fronts = fast_nondominated_sort(population.population) 124 | self.fitness.append(min(population.population, key=attrgetter('loss')).fitnesses) 125 | self.completion_procedure(population, self.calculate_fitness, fe, False) 126 | except: 127 | 128 | perturbed_noise = torch.randn_like(noise) * 0.1 129 | perturbed_noise = noise + perturbed_noise 130 | perturbed_noise = torch.clamp(perturbed_noise, -self.params["epsilon"], self.params["epsilon"]) 131 | return perturbed_noise 132 | return population.fronts[0] 133 | 134 | def calculate_fitness(self, solution): 135 | f_adv = solution.data 136 | D_fadv1 = self.calculate_D(f_adv, self.sCentroids) 137 | D_fadv2 = self.calculate_D(f_adv, self.sCentroids2) 138 | S_fadv1 = self.calculate_S_on_dataset(f_adv, solution.true_label, self.search_set, self.model) 139 | S_fadv2 = self.calculate_S_on_dataset(f_adv, solution.true_label, self.search_set2, self.model2) 140 | fitness = np.exp(-(D_fadv1 + D_fadv2)) + (1 - (S_fadv1 + S_fadv2) / 2) 141 | solution.fitness = fitness 142 | return fitness 143 | 144 | def calculate_S_on_dataset(self, f_adv, y_true, dataset, model): 145 | model.eval() 146 | correct = 0 147 | total = 0 148 | with torch.no_grad(): 149 | for inputs, labels in DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4): 150 | inputs = inputs.cuda() 151 | labels = labels.cuda() 152 | outputs = model(inputs + f_adv) 153 | _, predicted = torch.max(outputs.data, 1) 154 | total += labels.size(0) 155 | correct += (predicted == labels).sum().item() 156 | S_fadv = (total - correct) / total 157 | return S_fadv -------------------------------------------------------------------------------- /reid/models/DDAG.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from torchvision import models 5 | from torch.autograd import Variable 6 | from .resnet2 import resnet50, resnet18 7 | import torch.nn.functional as F 8 | import math 9 | from .attention import GraphAttentionLayer, IWPA 10 | 11 | class Normalize(nn.Module): 12 | def __init__(self, power=2): 13 | super(Normalize, self).__init__() 14 | self.power = power 15 | 16 | def forward(self, x): 17 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 18 | out = x.div(norm) 19 | return out 20 | 21 | 22 | 23 | # ##################################################################### 24 | def weights_init_kaiming(m): 25 | classname = m.__class__.__name__ 26 | # print(classname) 27 | if classname.find('Conv') != -1: 28 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 29 | elif classname.find('Linear') != -1: 30 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 31 | init.zeros_(m.bias.data) 32 | elif classname.find('BatchNorm1d') != -1: 33 | init.normal_(m.weight.data, 1.0, 0.01) 34 | init.zeros_(m.bias.data) 35 | 36 | 37 | def weights_init_classifier(m): 38 | classname = m.__class__.__name__ 39 | if classname.find('Linear') != -1: 40 | init.normal_(m.weight.data, 0, 0.001) 41 | if m.bias: 42 | init.zeros_(m.bias.data) 43 | 44 | # Defines the new fc layer and classification layer 45 | # |--Linear--|--bn--|--relu--|--Linear--| 46 | class FeatureBlock(nn.Module): 47 | def __init__(self, input_dim, low_dim, dropout=0.5, relu=True): 48 | super(FeatureBlock, self).__init__() 49 | feat_block = [] 50 | feat_block += [nn.Linear(input_dim, low_dim)] 51 | feat_block += [nn.BatchNorm1d(low_dim)] 52 | 53 | feat_block = nn.Sequential(*feat_block) 54 | feat_block.apply(weights_init_kaiming) 55 | self.feat_block = feat_block 56 | 57 | def forward(self, x): 58 | x = self.feat_block(x) 59 | return x 60 | 61 | 62 | class ClassBlock(nn.Module): 63 | def __init__(self, input_dim, class_num, dropout=0.5, relu=True): 64 | super(ClassBlock, self).__init__() 65 | classifier = [] 66 | if relu: 67 | classifier += [nn.LeakyReLU(0.1)] 68 | if dropout: 69 | classifier += [nn.Dropout(p=dropout)] 70 | 71 | classifier += [nn.Linear(input_dim, class_num)] 72 | classifier = nn.Sequential(*classifier) 73 | classifier.apply(weights_init_classifier) 74 | 75 | self.classifier = classifier 76 | 77 | def forward(self, x): 78 | x = self.classifier(x) 79 | return x 80 | 81 | class visible_module(nn.Module): 82 | def __init__(self, arch='resnet50'): 83 | super(visible_module, self).__init__() 84 | 85 | model_v = resnet50(pretrained=True, 86 | last_conv_stride=1, last_conv_dilation=1) 87 | # avg pooling to global pooling 88 | self.visible = model_v 89 | 90 | def forward(self, x): 91 | x = self.visible.conv1(x) 92 | x = self.visible.bn1(x) 93 | x = self.visible.relu(x) 94 | x = self.visible.maxpool(x) 95 | return x 96 | 97 | 98 | class thermal_module(nn.Module): 99 | def __init__(self, arch='resnet50'): 100 | super(thermal_module, self).__init__() 101 | 102 | model_t = resnet50(pretrained=True, 103 | last_conv_stride=1, last_conv_dilation=1) 104 | # avg pooling to global pooling 105 | self.thermal = model_t 106 | 107 | def forward(self, x): 108 | x = self.thermal.conv1(x) 109 | x = self.thermal.bn1(x) 110 | x = self.thermal.relu(x) 111 | x = self.thermal.maxpool(x) 112 | return x 113 | 114 | 115 | class base_resnet(nn.Module): 116 | def __init__(self, arch='resnet50'): 117 | super(base_resnet, self).__init__() 118 | 119 | model_base = resnet50(pretrained=True, 120 | last_conv_stride=1, last_conv_dilation=1) 121 | # avg pooling to global pooling 122 | model_base.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 123 | self.base = model_base 124 | 125 | def forward(self, x): 126 | x = self.base.layer1(x) 127 | x = self.base.layer2(x) 128 | x = self.base.layer3(x) 129 | x = self.base.layer4(x) 130 | return x 131 | 132 | # parser.add_argument('--low-dim', default=512, type=int,metavar='D', help='feature dimension') 133 | # class embed_net2(nn.Module): 134 | # def __init__(self, low_dim, class_num, drop=0.2, part = 3, alpha=0.2, nheads=4, arch='resnet50', wpa = False): 135 | # super(embed_net2, self).__init__() 136 | 137 | class embed_net2(nn.Module): 138 | def __init__(self, low_dim=512, drop=0.2, part=3, alpha=0.2, nheads=4, arch='resnet50', wpa=False,pretrained=True,num_classes=0): 139 | super(embed_net2, self).__init__() 140 | 141 | self.thermal_module = thermal_module(arch=arch) 142 | self.visible_module = visible_module(arch=arch) 143 | self.base_resnet = base_resnet(arch=arch) 144 | pool_dim = 2048 145 | self.dropout = drop 146 | self.part = part 147 | self.lpa = wpa 148 | 149 | self.l2norm = Normalize(2) 150 | self.bottleneck = nn.BatchNorm1d(pool_dim) 151 | self.bottleneck.bias.requires_grad_(False) # no shift 152 | 153 | self.classifier = nn.Linear(pool_dim, num_classes, bias=False) 154 | 155 | self.classifier1 = nn.Linear(pool_dim, num_classes, bias=False) 156 | self.classifier2 = nn.Linear(pool_dim, num_classes, bias=False) 157 | 158 | self.bottleneck.apply(weights_init_kaiming) 159 | self.classifier.apply(weights_init_classifier) 160 | self.classifier1.apply(weights_init_classifier) 161 | self.classifier2.apply(weights_init_classifier) 162 | 163 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 164 | self.wpa = IWPA(pool_dim, part) 165 | 166 | 167 | self.attentions = [GraphAttentionLayer(pool_dim, low_dim, dropout=drop, alpha=alpha, concat=True) for _ in range(nheads)] 168 | for i, attention in enumerate(self.attentions): 169 | self.add_module('attention_{}'.format(i), attention) 170 | 171 | self.out_att = GraphAttentionLayer(low_dim * nheads, num_classes, dropout=drop, alpha=alpha, concat=False) 172 | 173 | # feat, feat_att = net(input, input, 0, test_mode[1]) 174 | # def forward(self, x1, x2, adj, modal=0, cpa = False): 175 | def forward(self, x1, adj=0, modal=1, cpa=False):#改成默认使用1模式,0会出问题 176 | x2 = x1 ################################### 177 | # domain specific block 178 | if modal == 0: 179 | x1 = self.visible_module(x1) 180 | x2 = self.thermal_module(x2) 181 | x = torch.cat((x1, x2), 0) 182 | elif modal == 1: 183 | x = self.visible_module(x1) 184 | elif modal == 2: 185 | x = self.thermal_module(x2) 186 | 187 | # shared four blocks 188 | x = self.base_resnet(x) 189 | x_pool = self.avgpool(x) 190 | x_pool = x_pool.view(x_pool.size(0), x_pool.size(1)) 191 | feat = self.bottleneck(x_pool) 192 | 193 | # feat_att = None #add 194 | # if self.lpa: 195 | # intra-modality weighted part attention 196 | feat_att = self.wpa(x, feat, 1, self.part) 197 | # print("feat_att=",feat_att) 198 | 199 | if self.training: 200 | # cross-modality graph attention 201 | x_g = F.dropout(x_pool, self.dropout, training=self.training) 202 | x_g = torch.cat([att(x_g, adj) for att in self.attentions], dim=1) 203 | x_g = F.dropout(x_g, self.dropout, training=self.training) 204 | x_g = F.elu(self.out_att(x_g, adj)) 205 | # return x_pool, self.classifier(feat), self.classifier(feat_att), F.log_softmax(x_g, dim=1) 206 | return x_pool, self.classifier(feat) 207 | else: 208 | # return self.l2norm(feat), self.l2norm(feat_att) 209 | return x_pool, self.classifier(feat) #要对齐 -------------------------------------------------------------------------------- /reid/datasets/CnMix.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import numpy as np 4 | 5 | from ..utils.data import Dataset 6 | from ..utils.osutils import mkdir_if_missing 7 | from ..utils.serialization import read_json 8 | from ..utils.serialization import write_json 9 | 10 | 11 | ######################## 12 | # Added 13 | def _pluck(identities, indices, relabel=False): 14 | """Extract im names of given pids. 15 | Args: 16 | identities: containing im names 17 | indices: pids 18 | relabel: whether to transform pids to classification labels 19 | """ 20 | ret = [] 21 | for index, pid in enumerate(indices): 22 | pid_images = identities[pid] 23 | for camid, cam_images in enumerate(pid_images): 24 | for fname in cam_images: 25 | name = osp.splitext(fname)[0] 26 | x, y, _ = map(int, name.split('_')) 27 | assert pid == x and camid == y 28 | if relabel: 29 | ret.append((fname, index, camid)) 30 | else: 31 | ret.append((fname, pid, camid)) 32 | return ret 33 | 34 | 35 | ######################## 36 | 37 | 38 | class CnMix(Dataset): 39 | url = 'https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view' 40 | md5 = '65005ab7d12ec1c44de4eeafe813e68a' 41 | 42 | def __init__(self, root, split_id=0, num_val=100, download=True): 43 | super(CnMix, self).__init__(root, split_id=split_id) 44 | 45 | if download: 46 | self.download() 47 | 48 | if not self._check_integrity(): 49 | raise RuntimeError("Dataset not found or corrupted. " + 50 | "You can use download=True to download it.") 51 | 52 | self.load(num_val) 53 | 54 | def download(self): 55 | if self._check_integrity(): 56 | print("Files already downloaded and verified") 57 | return 58 | 59 | import re 60 | import hashlib 61 | import shutil 62 | from glob import glob 63 | from zipfile import ZipFile 64 | 65 | raw_dir = osp.join(self.root, 'raw') 66 | mkdir_if_missing(raw_dir) 67 | 68 | # Download the raw zip file 69 | fpath = osp.join(raw_dir, 'Market-1501-v15.09.15.zip') 70 | if osp.isfile(fpath) and \ 71 | hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5: 72 | print("Using downloaded file: " + fpath) 73 | else: 74 | raise RuntimeError("Please download the dataset manually from {} " 75 | "to {}".format(self.url, fpath)) 76 | 77 | # Extract the file 78 | exdir = osp.join(raw_dir, 'Market-1501-v15.09.15') 79 | if not osp.isdir(exdir): 80 | print("Extracting zip file") 81 | with ZipFile(fpath) as z: 82 | z.extractall(path=raw_dir) 83 | 84 | # Format 85 | images_dir = osp.join(self.root, 'images') 86 | mkdir_if_missing(images_dir) 87 | 88 | # 1501 identities (+1 for background) with 6 camera views each 89 | identities = [[[] for _ in range(6)] for _ in range(1502)] 90 | 91 | def register(subdir, pattern=re.compile(r'([-\d]+)_c(\d)')): 92 | fnames = [] ######### Added. Names of images in new dir. 93 | fpaths = sorted(glob(osp.join(exdir, subdir, '*.jpg'))) 94 | pids = set() 95 | for fpath in fpaths: 96 | fname = osp.basename(fpath) 97 | pid, cam = map(int, pattern.search(fname).groups()) 98 | if pid == -1: continue # junk images are just ignored 99 | assert 0 <= pid <= 1501 # pid == 0 means background 100 | assert 1 <= cam <= 6 101 | cam -= 1 102 | pids.add(pid) 103 | fname = ('{:08d}_{:02d}_{:04d}.jpg' 104 | .format(pid, cam, len(identities[pid][cam]))) 105 | identities[pid][cam].append(fname) 106 | shutil.copy(fpath, osp.join(images_dir, fname)) 107 | fnames.append(fname) ######### Added 108 | return pids, fnames 109 | 110 | trainval_pids, _ = register('bounding_box_train') 111 | gallery_pids, gallery_fnames = register('bounding_box_test') 112 | query_pids, query_fnames = register('query') 113 | assert query_pids <= gallery_pids 114 | assert trainval_pids.isdisjoint(gallery_pids)#这句代码的作用是检查trainval_pids和gallery_pids两个集合是否互不相交(即没有共同的元素)。如果它们互不相交,那么代码将继续执行,没有任何问题;但如果它们有共同的元素,那么将引发AssertionError异常,表示存在问题需要解决。 115 | 116 | # Save meta information into a json file 117 | meta = {'name': 'Market1501', 'shot': 'multiple', 'num_cameras': 6, 118 | 'identities': identities, 119 | 'query_fnames': query_fnames, ######### Added 120 | 'gallery_fnames': gallery_fnames} ######### Added 121 | write_json(meta, osp.join(self.root, 'meta.json')) 122 | 123 | # Save the only training / test split 124 | splits = [{ 125 | 'trainval': sorted(list(trainval_pids)), 126 | 'query': sorted(list(query_pids)), 127 | 'gallery': sorted(list(gallery_pids))}] 128 | write_json(splits, osp.join(self.root, 'splits.json')) 129 | 130 | ######################## 131 | # Added 132 | def load(self, num_val=0.3, verbose=True): 133 | splits = read_json(osp.join(self.root, 'splits.json')) 134 | if self.split_id >= len(splits): 135 | raise ValueError("split_id exceeds total splits {}" 136 | .format(len(splits))) 137 | self.split = splits[self.split_id] 138 | 139 | # Randomly split train / val 140 | trainval_pids = np.asarray(self.split['trainval']) 141 | np.random.shuffle(trainval_pids) 142 | num = len(trainval_pids) 143 | if isinstance(num_val, float): 144 | num_val = int(round(num * num_val)) 145 | if num_val >= num or num_val < 0: 146 | raise ValueError("num_val exceeds total identities {}" 147 | .format(num)) 148 | train_pids = sorted(trainval_pids[:-num_val]) 149 | val_pids = sorted(trainval_pids[-num_val:]) 150 | 151 | self.meta = read_json(osp.join(self.root, 'meta.json')) 152 | identities = self.meta['identities'] 153 | 154 | self.train = _pluck(identities, train_pids, relabel=True) 155 | self.val = _pluck(identities, val_pids, relabel=True) 156 | self.trainval = _pluck(identities, trainval_pids, relabel=True) 157 | self.num_train_ids = len(train_pids) 158 | self.num_val_ids = len(val_pids) 159 | self.num_trainval_ids = len(trainval_pids) 160 | 161 | ########## 162 | # Added 163 | query_fnames = self.meta['query_fnames'] 164 | gallery_fnames = self.meta['gallery_fnames'] 165 | self.query = [] 166 | for fname in query_fnames: 167 | name = osp.splitext(fname)[0] 168 | pid, cam, _ = map(int, name.split('_')) 169 | self.query.append((fname, pid, cam)) 170 | self.gallery = [] 171 | for fname in gallery_fnames: 172 | name = osp.splitext(fname)[0] 173 | pid, cam, _ = map(int, name.split('_')) 174 | self.gallery.append((fname, pid, cam)) 175 | ########## 176 | 177 | if verbose: 178 | print(self.__class__.__name__, "dataset loaded") 179 | print(" subset | # ids | # images") 180 | print(" ---------------------------") 181 | print(" train | {:5d} | {:8d}" 182 | .format(self.num_train_ids, len(self.train))) 183 | print(" val | {:5d} | {:8d}" 184 | .format(self.num_val_ids, len(self.val))) 185 | print(" trainval | {:5d} | {:8d}" 186 | .format(self.num_trainval_ids, len(self.trainval))) 187 | print(" query | {:5d} | {:8d}" 188 | .format(len(self.split['query']), len(self.query))) 189 | print(" gallery | {:5d} | {:8d}" 190 | .format(len(self.split['gallery']), len(self.gallery))) 191 | ######################## 192 | -------------------------------------------------------------------------------- /reid/datasets/Sketch.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import numpy as np 4 | 5 | from ..utils.data import Dataset 6 | from ..utils.osutils import mkdir_if_missing 7 | from ..utils.serialization import read_json 8 | from ..utils.serialization import write_json 9 | 10 | 11 | ######################## 12 | # Added 13 | def _pluck(identities, indices, relabel=False): 14 | """Extract im names of given pids. 15 | Args: 16 | identities: containing im names 17 | indices: pids 18 | relabel: whether to transform pids to classification labels 19 | """ 20 | ret = [] 21 | for index, pid in enumerate(indices): 22 | pid_images = identities[pid] 23 | for camid, cam_images in enumerate(pid_images): 24 | for fname in cam_images: 25 | name = osp.splitext(fname)[0] 26 | x, y, _ = map(int, name.split('_')) 27 | assert pid == x and camid == y 28 | if relabel: 29 | ret.append((fname, index, camid)) 30 | else: 31 | ret.append((fname, pid, camid)) 32 | return ret 33 | 34 | 35 | ######################## 36 | 37 | 38 | class Sketch(Dataset): 39 | url = 'https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view' 40 | md5 = '65005ab7d12ec1c44de4eeafe813e68a' 41 | 42 | def __init__(self, root, split_id=0, num_val=100, download=True): 43 | super(Sketch, self).__init__(root, split_id=split_id) 44 | 45 | if download: 46 | self.download() 47 | 48 | if not self._check_integrity(): 49 | raise RuntimeError("Dataset not found or corrupted. " + 50 | "You can use download=True to download it.") 51 | 52 | self.load(num_val) 53 | 54 | def download(self): 55 | if self._check_integrity(): 56 | print("Files already downloaded and verified") 57 | return 58 | 59 | import re 60 | import hashlib 61 | import shutil 62 | from glob import glob 63 | from zipfile import ZipFile 64 | 65 | raw_dir = osp.join(self.root, 'raw') 66 | mkdir_if_missing(raw_dir) 67 | 68 | # Download the raw zip file 69 | fpath = osp.join(raw_dir, 'Market-1501-v15.09.15.zip') 70 | if osp.isfile(fpath) and \ 71 | hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5: 72 | print("Using downloaded file: " + fpath) 73 | else: 74 | raise RuntimeError("Please download the dataset manually from {} " 75 | "to {}".format(self.url, fpath)) 76 | 77 | # Extract the file 78 | exdir = osp.join(raw_dir, 'Market-1501-v15.09.15') 79 | if not osp.isdir(exdir): 80 | print("Extracting zip file") 81 | with ZipFile(fpath) as z: 82 | z.extractall(path=raw_dir) 83 | 84 | # Format 85 | images_dir = osp.join(self.root, 'images') 86 | mkdir_if_missing(images_dir) 87 | 88 | # 1501 identities (+1 for background) with 6 camera views each 89 | identities = [[[] for _ in range(6)] for _ in range(1502)] 90 | 91 | def register(subdir, pattern=re.compile(r'([-\d]+)_c(\d)')): 92 | fnames = [] ######### Added. Names of images in new dir. 93 | fpaths = sorted(glob(osp.join(exdir, subdir, '*.jpg'))) 94 | pids = set() 95 | for fpath in fpaths: 96 | fname = osp.basename(fpath) 97 | pid, cam = map(int, pattern.search(fname).groups()) 98 | if pid == -1: continue # junk images are just ignored 99 | assert 0 <= pid <= 1501 # pid == 0 means background 100 | assert 1 <= cam <= 6 101 | cam -= 1 102 | pids.add(pid) 103 | fname = ('{:08d}_{:02d}_{:04d}.jpg' 104 | .format(pid, cam, len(identities[pid][cam]))) 105 | identities[pid][cam].append(fname) 106 | shutil.copy(fpath, osp.join(images_dir, fname)) 107 | fnames.append(fname) ######### Added 108 | return pids, fnames 109 | 110 | trainval_pids, _ = register('bounding_box_train') 111 | gallery_pids, gallery_fnames = register('bounding_box_test') 112 | query_pids, query_fnames = register('query') 113 | assert query_pids <= gallery_pids 114 | assert trainval_pids.isdisjoint(gallery_pids)#这句代码的作用是检查trainval_pids和gallery_pids两个集合是否互不相交(即没有共同的元素)。如果它们互不相交,那么代码将继续执行,没有任何问题;但如果它们有共同的元素,那么将引发AssertionError异常,表示存在问题需要解决。 115 | 116 | # Save meta information into a json file 117 | meta = {'name': 'Market1501', 'shot': 'multiple', 'num_cameras': 6, 118 | 'identities': identities, 119 | 'query_fnames': query_fnames, ######### Added 120 | 'gallery_fnames': gallery_fnames} ######### Added 121 | write_json(meta, osp.join(self.root, 'meta.json')) 122 | 123 | # Save the only training / test split 124 | splits = [{ 125 | 'trainval': sorted(list(trainval_pids)), 126 | 'query': sorted(list(query_pids)), 127 | 'gallery': sorted(list(gallery_pids))}] 128 | write_json(splits, osp.join(self.root, 'splits.json')) 129 | 130 | ######################## 131 | # Added 132 | def load(self, num_val=0.3, verbose=True): 133 | splits = read_json(osp.join(self.root, 'splits.json')) 134 | if self.split_id >= len(splits): 135 | raise ValueError("split_id exceeds total splits {}" 136 | .format(len(splits))) 137 | self.split = splits[self.split_id] 138 | 139 | # Randomly split train / val 140 | trainval_pids = np.asarray(self.split['trainval']) 141 | np.random.shuffle(trainval_pids) 142 | num = len(trainval_pids) 143 | if isinstance(num_val, float): 144 | num_val = int(round(num * num_val)) 145 | if num_val >= num or num_val < 0: 146 | raise ValueError("num_val exceeds total identities {}" 147 | .format(num)) 148 | train_pids = sorted(trainval_pids[:-num_val]) 149 | val_pids = sorted(trainval_pids[-num_val:]) 150 | 151 | self.meta = read_json(osp.join(self.root, 'meta.json')) 152 | identities = self.meta['identities'] 153 | 154 | self.train = _pluck(identities, train_pids, relabel=True) 155 | self.val = _pluck(identities, val_pids, relabel=True) 156 | self.trainval = _pluck(identities, trainval_pids, relabel=True) 157 | self.num_train_ids = len(train_pids) 158 | self.num_val_ids = len(val_pids) 159 | self.num_trainval_ids = len(trainval_pids) 160 | 161 | ########## 162 | # Added 163 | query_fnames = self.meta['query_fnames'] 164 | gallery_fnames = self.meta['gallery_fnames'] 165 | self.query = [] 166 | for fname in query_fnames: 167 | name = osp.splitext(fname)[0] 168 | pid, cam, _ = map(int, name.split('_')) 169 | self.query.append((fname, pid, cam)) 170 | self.gallery = [] 171 | for fname in gallery_fnames: 172 | name = osp.splitext(fname)[0] 173 | pid, cam, _ = map(int, name.split('_')) 174 | self.gallery.append((fname, pid, cam)) 175 | ########## 176 | 177 | if verbose: 178 | print(self.__class__.__name__, "dataset loaded") 179 | print(" subset | # ids | # images") 180 | print(" ---------------------------") 181 | print(" train | {:5d} | {:8d}" 182 | .format(self.num_train_ids, len(self.train))) 183 | print(" val | {:5d} | {:8d}" 184 | .format(self.num_val_ids, len(self.val))) 185 | print(" trainval | {:5d} | {:8d}" 186 | .format(self.num_trainval_ids, len(self.trainval))) 187 | print(" query | {:5d} | {:8d}" 188 | .format(len(self.split['query']), len(self.query))) 189 | print(" gallery | {:5d} | {:8d}" 190 | .format(len(self.split['gallery']), len(self.gallery))) 191 | ######################## 192 | -------------------------------------------------------------------------------- /reid/datasets/sysu.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import numpy as np 4 | 5 | from ..utils.data import Dataset 6 | from ..utils.osutils import mkdir_if_missing 7 | from ..utils.serialization import read_json 8 | from ..utils.serialization import write_json 9 | 10 | 11 | ######################## 12 | # Added 13 | def _pluck(identities, indices, relabel=False): 14 | """Extract im names of given pids. 15 | Args: 16 | identities: containing im names 17 | indices: pids 18 | relabel: whether to transform pids to classification labels 19 | """ 20 | ret = [] 21 | for index, pid in enumerate(indices): 22 | pid_images = identities[pid] 23 | for camid, cam_images in enumerate(pid_images): 24 | for fname in cam_images: 25 | name = osp.splitext(fname)[0] 26 | x, y, _ = map(int, name.split('_')) 27 | assert pid == x and camid == y 28 | if relabel: 29 | ret.append((fname, index, camid)) 30 | else: 31 | ret.append((fname, pid, camid)) 32 | return ret 33 | 34 | 35 | ######################## 36 | 37 | 38 | class Sysu(Dataset): 39 | url = 'https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view' 40 | md5 = '65005ab7d12ec1c44de4eeafe813e68a' 41 | 42 | def __init__(self, root, split_id=0, num_val=96, download=True): 43 | super(Sysu, self).__init__(root, split_id=split_id) 44 | 45 | if download: 46 | self.download() 47 | 48 | if not self._check_integrity(): 49 | raise RuntimeError("Dataset not found or corrupted. " + 50 | "You can use download=True to download it.") 51 | 52 | self.load(num_val) 53 | 54 | def download(self): 55 | if self._check_integrity(): 56 | return 57 | 58 | import re 59 | import hashlib 60 | import shutil 61 | from glob import glob 62 | from zipfile import ZipFile 63 | 64 | raw_dir = osp.join(self.root, 'raw') 65 | mkdir_if_missing(raw_dir) 66 | 67 | # Download the raw zip file 68 | fpath = osp.join(raw_dir, 'sysu_v2.zip') 69 | # if osp.isfile(fpath) and \ 70 | # hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5: 71 | # print("Using downloaded file: " + fpath) 72 | # else: 73 | # raise RuntimeError("Please download the dataset manually from {} " 74 | # "to {}".format(self.url, fpath)) 75 | if osp.isfile(fpath): 76 | print("Using downloaded file: " + fpath) 77 | else: 78 | raise RuntimeError("Please download the dataset manually from {} " 79 | "to {}".format(self.url, fpath)) 80 | 81 | # Extract the file 82 | exdir = raw_dir ##这里的路径要小心会多一层外壳 83 | if not osp.isdir(exdir): 84 | print("Extracting zip file") 85 | with ZipFile(fpath) as z: 86 | z.extractall(path=raw_dir) 87 | 88 | # Format 89 | images_dir = osp.join(self.root, 'images') 90 | mkdir_if_missing(images_dir) 91 | 92 | # XX identities (+1 for background) with X camera views each 93 | identities = [[[] for _ in range(6)] for _ in range(534)]#id编号最大533,但是编号是从1开始的,所以是534!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 94 | # print("identities=",identities) 95 | 96 | def register(subdir, pattern=re.compile(r'([-\d]+)_c(\d)')): 97 | fnames = [] ######### Added. Names of images in new dir. 98 | fpaths = sorted(glob(osp.join(exdir, subdir, '*.jpg'))) # 99 | pids = set() 100 | for fpath in fpaths: 101 | fname = osp.basename(fpath) 102 | pid, cam = map(int, pattern.search(fname).groups()) 103 | if pid == -1: continue # junk images are just ignored 104 | assert 0 <= pid <= 534 # pid == 0 means background 105 | assert 1 <= cam <= 6 106 | cam -= 1 107 | pids.add(pid) 108 | fname = ('{:08d}_{:02d}_{:04d}.jpg' 109 | .format(pid, cam, len(identities[pid][cam]))) 110 | identities[pid][cam].append(fname) 111 | shutil.copy(fpath, osp.join(images_dir, fname)) 112 | fnames.append(fname) ######### Added 113 | return pids, fnames 114 | 115 | trainval_pids, _ = register('bounding_box_train') 116 | gallery_pids, gallery_fnames = register('bounding_box_test') 117 | query_pids, query_fnames = register('query') 118 | assert query_pids <= gallery_pids 119 | # assert trainval_pids.isdisjoint(gallery_pids)#这句代码的作用是检查trainval_pids和gallery_pids两个集合是否互不相交(即没有共同的元素)。如果它们互不相交,那么代码将继续执行,没有任何问题;但如果它们有共同的元素,那么将引发AssertionError异常,表示存在问题需要解决。 120 | 121 | # Save meta information into a json file 122 | meta = {'name': 'Regdb', 'shot': 'multiple', 'num_cameras': 6, 123 | 'identities': identities, 124 | 'query_fnames': query_fnames, ######### Added 125 | 'gallery_fnames': gallery_fnames} ######### Added 126 | write_json(meta, osp.join(self.root, 'meta.json')) 127 | 128 | # Save the only training / test split 129 | splits = [{ 130 | 'trainval': sorted(list(trainval_pids)), 131 | 'query': sorted(list(query_pids)), 132 | 'gallery': sorted(list(gallery_pids))}] 133 | write_json(splits, osp.join(self.root, 'splits.json')) 134 | 135 | ######################## 136 | # Added 137 | def load(self, num_val=0.3, verbose=True): 138 | splits = read_json(osp.join(self.root, 'splits.json')) 139 | if self.split_id >= len(splits): 140 | raise ValueError("split_id exceeds total splits {}" 141 | .format(len(splits))) 142 | self.split = splits[self.split_id] 143 | 144 | # Randomly split train / val 145 | trainval_pids = np.asarray(self.split['trainval']) 146 | np.random.shuffle(trainval_pids) 147 | num = len(trainval_pids) 148 | if isinstance(num_val, float): 149 | num_val = int(round(num * num_val)) 150 | if num_val >= num or num_val < 0: 151 | raise ValueError("num_val exceeds total identities {}" 152 | .format(num)) 153 | train_pids = sorted(trainval_pids[:-num_val]) 154 | val_pids = sorted(trainval_pids[-num_val:]) 155 | 156 | self.meta = read_json(osp.join(self.root, 'meta.json')) 157 | identities = self.meta['identities'] 158 | 159 | self.train = _pluck(identities, train_pids, relabel=True) 160 | self.val = _pluck(identities, val_pids, relabel=True) 161 | self.trainval = _pluck(identities, trainval_pids, relabel=True) 162 | self.num_train_ids = len(train_pids) 163 | self.num_val_ids = len(val_pids) 164 | self.num_trainval_ids = len(trainval_pids) 165 | 166 | ########## 167 | # Added 168 | query_fnames = self.meta['query_fnames'] 169 | gallery_fnames = self.meta['gallery_fnames'] 170 | self.query = [] 171 | for fname in query_fnames: 172 | name = osp.splitext(fname)[0] 173 | pid, cam, _ = map(int, name.split('_')) 174 | self.query.append((fname, pid, cam)) 175 | self.gallery = [] 176 | for fname in gallery_fnames: 177 | name = osp.splitext(fname)[0] 178 | pid, cam, _ = map(int, name.split('_')) 179 | self.gallery.append((fname, pid, cam)) 180 | ########## 181 | 182 | if verbose: 183 | print(self.__class__.__name__, "dataset loaded") 184 | print(" subset | # ids | # images") 185 | print(" ---------------------------") 186 | print(" train | {:5d} | {:8d}" 187 | .format(self.num_train_ids, len(self.train))) 188 | print(" val | {:5d} | {:8d}" 189 | .format(self.num_val_ids, len(self.val))) 190 | print(" trainval | {:5d} | {:8d}" 191 | .format(self.num_trainval_ids, len(self.trainval))) 192 | print(" query | {:5d} | {:8d}" 193 | .format(len(self.split['query']), len(self.query))) 194 | print(" gallery | {:5d} | {:8d}" 195 | .format(len(self.split['gallery']), len(self.gallery))) 196 | ######################## 197 | -------------------------------------------------------------------------------- /reid/datasets/regdb.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import numpy as np 4 | 5 | from ..utils.data import Dataset 6 | from ..utils.osutils import mkdir_if_missing 7 | from ..utils.serialization import read_json 8 | from ..utils.serialization import write_json 9 | 10 | 11 | ######################## 12 | # Added 13 | def _pluck(identities, indices, relabel=False): 14 | """Extract im names of given pids. 15 | Args: 16 | identities: containing im names 17 | indices: pids 18 | relabel: whether to transform pids to classification labels 19 | """ 20 | ret = [] 21 | for index, pid in enumerate(indices): 22 | pid_images = identities[pid] 23 | for camid, cam_images in enumerate(pid_images): 24 | for fname in cam_images: 25 | name = osp.splitext(fname)[0] 26 | x, y, _ = map(int, name.split('_')) 27 | assert pid == x and camid == y 28 | if relabel: 29 | ret.append((fname, index, camid)) 30 | else: 31 | ret.append((fname, pid, camid)) 32 | return ret 33 | 34 | 35 | ######################## 36 | 37 | 38 | class Regdb(Dataset): 39 | url = 'https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view' 40 | md5 = '65005ab7d12ec1c44de4eeafe813e68a' 41 | 42 | def __init__(self, root, split_id=0, num_val=30, download=True): 43 | super(Regdb, self).__init__(root, split_id=split_id) 44 | 45 | if download: 46 | self.download() 47 | 48 | if not self._check_integrity(): 49 | raise RuntimeError("Dataset not found or corrupted. " + 50 | "You can use download=True to download it.") 51 | 52 | self.load(num_val) 53 | 54 | def download(self): 55 | if self._check_integrity(): 56 | return 57 | 58 | import re 59 | import hashlib 60 | import shutil 61 | from glob import glob 62 | from zipfile import ZipFile 63 | 64 | raw_dir = osp.join(self.root, 'raw') 65 | mkdir_if_missing(raw_dir) 66 | 67 | # Download the raw zip file 68 | fpath = osp.join(raw_dir, 'regdb_v2.zip') 69 | # if osp.isfile(fpath) and \ 70 | # hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5: 71 | # print("Using downloaded file: " + fpath) 72 | # else: 73 | # raise RuntimeError("Please download the dataset manually from {} " 74 | # "to {}".format(self.url, fpath)) 75 | if osp.isfile(fpath): 76 | print("Using downloaded file: " + fpath) 77 | else: 78 | raise RuntimeError("Please download the dataset manually from {} " 79 | "to {}".format(self.url, fpath)) 80 | 81 | # Extract the file 82 | # exdir = osp.join(raw_dir, 'regdb_v2') 83 | exdir = raw_dir ##这里的路径要小心会多一层外壳 84 | if not osp.isdir(exdir): 85 | print("Extracting zip file") 86 | with ZipFile(fpath) as z: 87 | z.extractall(path=raw_dir) 88 | 89 | # Format 90 | images_dir = osp.join(self.root, 'images') 91 | mkdir_if_missing(images_dir) 92 | 93 | # 412 identities (+1 for background) with 3 camera views each 94 | identities = [[[] for _ in range(3)] for _ in range(413)]#共412个行人,但是编号是从1开始的,所以是413 95 | # print("identities=",identities) 96 | 97 | def register(subdir, pattern=re.compile(r'([-\d]+)_c(\d)')): 98 | fnames = [] ######### Added. Names of images in new dir. 99 | fpaths = sorted(glob(osp.join(exdir, subdir, '*.bmp'))) # 100 | pids = set() 101 | for fpath in fpaths: 102 | fname = osp.basename(fpath) 103 | pid, cam = map(int, pattern.search(fname).groups()) 104 | if pid == -1: continue # junk images are just ignored 105 | assert 0 <= pid <= 413 # pid == 0 means background 106 | assert 1 <= cam <= 3 107 | cam -= 1 108 | pids.add(pid) 109 | fname = ('{:08d}_{:02d}_{:04d}.bmp' 110 | .format(pid, cam, len(identities[pid][cam]))) 111 | identities[pid][cam].append(fname) 112 | shutil.copy(fpath, osp.join(images_dir, fname)) 113 | fnames.append(fname) ######### Added 114 | return pids, fnames 115 | 116 | trainval_pids, _ = register('bounding_box_train') 117 | gallery_pids, gallery_fnames = register('bounding_box_test') 118 | query_pids, query_fnames = register('query') 119 | assert query_pids <= gallery_pids 120 | assert trainval_pids.isdisjoint(gallery_pids)#这句代码的作用是检查trainval_pids和gallery_pids两个集合是否互不相交(即没有共同的元素)。如果它们互不相交,那么代码将继续执行,没有任何问题;但如果它们有共同的元素,那么将引发AssertionError异常,表示存在问题需要解决。 121 | 122 | # Save meta information into a json file 123 | meta = {'name': 'Regdb', 'shot': 'multiple', 'num_cameras': 3, 124 | 'identities': identities, 125 | 'query_fnames': query_fnames, ######### Added 126 | 'gallery_fnames': gallery_fnames} ######### Added 127 | write_json(meta, osp.join(self.root, 'meta.json')) 128 | 129 | # Save the only training / test split 130 | splits = [{ 131 | 'trainval': sorted(list(trainval_pids)), 132 | 'query': sorted(list(query_pids)), 133 | 'gallery': sorted(list(gallery_pids))}] 134 | write_json(splits, osp.join(self.root, 'splits.json')) 135 | 136 | ######################## 137 | # Added 138 | def load(self, num_val=0.3, verbose=True): 139 | splits = read_json(osp.join(self.root, 'splits.json')) 140 | if self.split_id >= len(splits): 141 | raise ValueError("split_id exceeds total splits {}" 142 | .format(len(splits))) 143 | self.split = splits[self.split_id] 144 | 145 | # Randomly split train / val 146 | trainval_pids = np.asarray(self.split['trainval']) 147 | np.random.shuffle(trainval_pids) 148 | num = len(trainval_pids) 149 | if isinstance(num_val, float): 150 | num_val = int(round(num * num_val)) 151 | if num_val >= num or num_val < 0: 152 | raise ValueError("num_val exceeds total identities {}" 153 | .format(num)) 154 | train_pids = sorted(trainval_pids[:-num_val]) 155 | val_pids = sorted(trainval_pids[-num_val:]) 156 | 157 | self.meta = read_json(osp.join(self.root, 'meta.json')) 158 | identities = self.meta['identities'] 159 | 160 | self.train = _pluck(identities, train_pids, relabel=True) 161 | self.val = _pluck(identities, val_pids, relabel=True) 162 | self.trainval = _pluck(identities, trainval_pids, relabel=True) 163 | self.num_train_ids = len(train_pids) 164 | self.num_val_ids = len(val_pids) 165 | self.num_trainval_ids = len(trainval_pids) 166 | 167 | ########## 168 | # Added 169 | query_fnames = self.meta['query_fnames'] 170 | gallery_fnames = self.meta['gallery_fnames'] 171 | self.query = [] 172 | for fname in query_fnames: 173 | name = osp.splitext(fname)[0] 174 | pid, cam, _ = map(int, name.split('_')) 175 | self.query.append((fname, pid, cam)) 176 | self.gallery = [] 177 | for fname in gallery_fnames: 178 | name = osp.splitext(fname)[0] 179 | pid, cam, _ = map(int, name.split('_')) 180 | self.gallery.append((fname, pid, cam)) 181 | ########## 182 | 183 | if verbose: 184 | print(self.__class__.__name__, "dataset loaded") 185 | print(" subset | # ids | # images") 186 | print(" ---------------------------") 187 | print(" train | {:5d} | {:8d}" 188 | .format(self.num_train_ids, len(self.train))) 189 | print(" val | {:5d} | {:8d}" 190 | .format(self.num_val_ids, len(self.val))) 191 | print(" trainval | {:5d} | {:8d}" 192 | .format(self.num_trainval_ids, len(self.trainval))) 193 | print(" query | {:5d} | {:8d}" 194 | .format(len(self.split['query']), len(self.query))) 195 | print(" gallery | {:5d} | {:8d}" 196 | .format(len(self.split['gallery']), len(self.gallery))) 197 | ######################## 198 | -------------------------------------------------------------------------------- /reid/models/attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | """ 6 | PART of the code is from the following link 7 | https://github.com/Diego999/pyGAT/blob/master/layers.py 8 | """ 9 | 10 | 11 | class Normalize(nn.Module): 12 | def __init__(self, power=2): 13 | super(Normalize, self).__init__() 14 | self.power = power 15 | 16 | def forward(self, x): 17 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 18 | out = x.div(norm) 19 | return out 20 | 21 | class GraphAttentionLayer(nn.Module): 22 | """ 23 | Simple GAT layer, similar to https://arxiv.org/abs/1710.10903 24 | """ 25 | 26 | def __init__(self, in_features, out_features, dropout, alpha=0.2, concat=True): 27 | super(GraphAttentionLayer, self).__init__() 28 | self.dropout = dropout 29 | self.in_features = in_features 30 | self.out_features = out_features 31 | self.alpha = alpha 32 | self.concat = concat 33 | 34 | self.W = nn.Parameter(torch.zeros(size=(in_features, out_features))) 35 | nn.init.xavier_uniform_(self.W.data, gain=1.414) 36 | self.a = nn.Parameter(torch.zeros(size=(2 * out_features, 1))) 37 | nn.init.xavier_uniform_(self.a.data, gain=1.414) 38 | 39 | self.leakyrelu = nn.LeakyReLU(self.alpha) 40 | 41 | def forward(self, input, adj): 42 | h = torch.mm(input, self.W) 43 | N = h.size()[0] 44 | 45 | a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features) 46 | e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2)) 47 | 48 | zero_vec = -9e15 * torch.ones_like(e) 49 | attention = torch.where(adj > 0, e, zero_vec) 50 | attention = F.softmax(attention, dim=1) 51 | attention = F.dropout(attention, self.dropout, training=self.training) 52 | h_prime = torch.matmul(attention, h) 53 | 54 | if self.concat: 55 | return F.elu(h_prime) 56 | else: 57 | return h_prime 58 | 59 | def __repr__(self): 60 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 61 | 62 | 63 | class SpecialSpmmFunction(torch.autograd.Function): 64 | """Special function for only sparse region backpropataion layer.""" 65 | 66 | @staticmethod 67 | def forward(ctx, indices, values, shape, b): 68 | assert indices.requires_grad == False 69 | a = torch.sparse_coo_tensor(indices, values, shape) 70 | ctx.save_for_backward(a, b) 71 | ctx.N = shape[0] 72 | return torch.matmul(a, b) 73 | 74 | @staticmethod 75 | def backward(ctx, grad_output): 76 | a, b = ctx.saved_tensors 77 | grad_values = grad_b = None 78 | if ctx.needs_input_grad[1]: 79 | grad_a_dense = grad_output.matmul(b.t()) 80 | edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :] 81 | grad_values = grad_a_dense.view(-1)[edge_idx] 82 | if ctx.needs_input_grad[3]: 83 | grad_b = a.t().matmul(grad_output) 84 | return None, grad_values, None, grad_b 85 | 86 | 87 | class SpecialSpmm(nn.Module): 88 | def forward(self, indices, values, shape, b): 89 | return SpecialSpmmFunction.apply(indices, values, shape, b) 90 | 91 | 92 | class SpGraphAttentionLayer(nn.Module): 93 | """ 94 | Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903 95 | """ 96 | 97 | def __init__(self, in_features, out_features, dropout, alpha, concat=True): 98 | super(SpGraphAttentionLayer, self).__init__() 99 | self.in_features = in_features 100 | self.out_features = out_features 101 | self.alpha = alpha 102 | self.concat = concat 103 | 104 | self.W = nn.Parameter(torch.zeros(size=(in_features, out_features))) 105 | nn.init.xavier_normal_(self.W.data, gain=1.414) 106 | 107 | self.a = nn.Parameter(torch.zeros(size=(1, 2 * out_features))) 108 | nn.init.xavier_normal_(self.a.data, gain=1.414) 109 | 110 | self.dropout = nn.Dropout(dropout) 111 | self.leakyrelu = nn.LeakyReLU(self.alpha) 112 | self.special_spmm = SpecialSpmm() 113 | 114 | def forward(self, input, adj): 115 | dv = 'cuda' if input.is_cuda else 'cpu' 116 | 117 | N = input.size()[0] 118 | edge = adj.nonzero().t() 119 | 120 | h = torch.mm(input, self.W) 121 | # h: N x out 122 | assert not torch.isnan(h).any() 123 | 124 | # Self-attention on the nodes - Shared attention mechanism 125 | edge_h = torch.cat((h[edge[0, :], :], h[edge[1, :], :]), dim=1).t() 126 | # edge: 2*D x E 127 | 128 | edge_e = torch.exp(-self.leakyrelu(self.a.mm(edge_h).squeeze())) 129 | assert not torch.isnan(edge_e).any() 130 | # edge_e: E 131 | 132 | e_rowsum = self.special_spmm(edge, edge_e, torch.Size([N, N]), torch.ones(size=(N, 1), device=dv)) 133 | # e_rowsum: N x 1 134 | 135 | edge_e = self.dropout(edge_e) 136 | # edge_e: E 137 | 138 | h_prime = self.special_spmm(edge, edge_e, torch.Size([N, N]), h) 139 | assert not torch.isnan(h_prime).any() 140 | # h_prime: N x out 141 | 142 | h_prime = h_prime.div(e_rowsum) 143 | # h_prime: N x out 144 | assert not torch.isnan(h_prime).any() 145 | 146 | if self.concat: 147 | # if this layer is not last layer, 148 | return F.elu(h_prime) 149 | else: 150 | # if this layer is last layer, 151 | return h_prime 152 | 153 | def __repr__(self): 154 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 155 | 156 | 157 | class IWPA(nn.Module): 158 | """ 159 | Part attention layer, "Dynamic Dual-Attentive Aggregation Learning for Visible-Infrared Person Re-Identification" 160 | """ 161 | def __init__(self, in_channels, part = 3, inter_channels=None, out_channels=None): 162 | super(IWPA, self).__init__() 163 | 164 | self.in_channels = in_channels 165 | self.inter_channels = inter_channels 166 | self.out_channels = out_channels 167 | self.l2norm = Normalize(2) 168 | 169 | if self.inter_channels is None: 170 | self.inter_channels = in_channels 171 | 172 | if self.out_channels is None: 173 | self.out_channels = in_channels 174 | 175 | conv_nd = nn.Conv2d 176 | 177 | self.fc1 = nn.Sequential( 178 | conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, 179 | padding=0), 180 | ) 181 | 182 | self.fc2 = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 183 | kernel_size=1, stride=1, padding=0) 184 | 185 | self.fc3 = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 186 | kernel_size=1, stride=1, padding=0) 187 | 188 | self.W = nn.Sequential( 189 | conv_nd(in_channels=self.inter_channels, out_channels=self.out_channels, 190 | kernel_size=1, stride=1, padding=0), 191 | nn.BatchNorm2d(self.out_channels), 192 | ) 193 | nn.init.constant_(self.W[1].weight, 0.0) 194 | nn.init.constant_(self.W[1].bias, 0.0) 195 | 196 | 197 | self.bottleneck = nn.BatchNorm1d(in_channels) 198 | self.bottleneck.bias.requires_grad_(False) # no shift 199 | 200 | nn.init.normal_(self.bottleneck.weight.data, 1.0, 0.01) 201 | nn.init.zeros_(self.bottleneck.bias.data) 202 | 203 | # weighting vector of the part features 204 | self.gate = nn.Parameter(torch.FloatTensor(part)) 205 | nn.init.constant_(self.gate, 1/part) 206 | def forward(self, x, feat, t=None, part=0): 207 | bt, c, h, w = x.shape 208 | b = bt // t 209 | 210 | # get part features 211 | part_feat = F.adaptive_avg_pool2d(x, (part, 1)) 212 | part_feat = part_feat.view(b, t, c, part) 213 | part_feat = part_feat.permute(0, 2, 1, 3) # B, C, T, Part 214 | 215 | part_feat1 = self.fc1(part_feat).view(b, self.inter_channels, -1) # B, C//r, T*Part 216 | part_feat1 = part_feat1.permute(0, 2, 1) # B, T*Part, C//r 217 | 218 | part_feat2 = self.fc2(part_feat).view(b, self.inter_channels, -1) # B, C//r, T*Part 219 | 220 | part_feat3 = self.fc3(part_feat).view(b, self.inter_channels, -1) # B, C//r, T*Part 221 | part_feat3 = part_feat3.permute(0, 2, 1) # B, T*Part, C//r 222 | 223 | # get cross-part attention 224 | cpa_att = torch.matmul(part_feat1, part_feat2) # B, T*Part, T*Part 225 | cpa_att = F.softmax(cpa_att, dim=-1) 226 | 227 | # collect contextual information 228 | refined_part_feat = torch.matmul(cpa_att, part_feat3) # B, T*Part, C//r 229 | refined_part_feat = refined_part_feat.permute(0, 2, 1).contiguous() # B, C//r, T*Part 230 | refined_part_feat = refined_part_feat.view(b, self.inter_channels, part) # B, C//r, T, Part 231 | 232 | gate = F.softmax(self.gate, dim=-1) 233 | weight_part_feat = torch.matmul(refined_part_feat, gate) 234 | x = F.adaptive_avg_pool2d(x, (1, 1)) 235 | # weight_part_feat = weight_part_feat + x.view(x.size(0), x.size(1)) 236 | 237 | weight_part_feat = weight_part_feat + feat 238 | feat = self.bottleneck(weight_part_feat) 239 | 240 | return feat -------------------------------------------------------------------------------- /reid/models/baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from torchvision import models 5 | from torch.autograd import Variable 6 | import pretrainedmodels 7 | 8 | ###################################################################### 9 | def weights_init_kaiming(m): 10 | classname = m.__class__.__name__ 11 | # print(classname) 12 | if classname.find('Conv') != -1: 13 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') # For old pytorch, you may use kaiming_normal. 14 | elif classname.find('Linear') != -1: 15 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 16 | init.constant_(m.bias.data, 0.0) 17 | elif classname.find('BatchNorm1d') != -1: 18 | init.normal_(m.weight.data, 1.0, 0.02) 19 | init.constant_(m.bias.data, 0.0) 20 | 21 | def weights_init_classifier(m): 22 | classname = m.__class__.__name__ 23 | if classname.find('Linear') != -1: 24 | init.normal_(m.weight.data, std=0.001) 25 | init.constant_(m.bias.data, 0.0) 26 | 27 | # Defines the new fc layer and classification layer 28 | # |--Linear--|--bn--|--relu--|--Linear--| 29 | class ClassBlock(nn.Module): 30 | def __init__(self, input_dim, class_num, droprate, relu=False, bnorm=True, num_bottleneck=512, linear=True, return_f = False): 31 | super(ClassBlock, self).__init__() 32 | self.return_f = return_f 33 | add_block = [] 34 | if linear: 35 | add_block += [nn.Linear(input_dim, num_bottleneck)] 36 | else: 37 | num_bottleneck = input_dim 38 | if bnorm: 39 | add_block += [nn.BatchNorm1d(num_bottleneck)] 40 | if relu: 41 | add_block += [nn.LeakyReLU(0.1)] 42 | if droprate>0: 43 | add_block += [nn.Dropout(p=droprate)] 44 | add_block = nn.Sequential(*add_block) 45 | add_block.apply(weights_init_kaiming) 46 | 47 | classifier = [] 48 | classifier += [nn.Linear(num_bottleneck, class_num)] 49 | classifier = nn.Sequential(*classifier) 50 | classifier.apply(weights_init_classifier) 51 | 52 | self.add_block = add_block 53 | self.classifier = classifier 54 | def forward(self, x): 55 | x = self.add_block(x) 56 | if self.return_f: 57 | f = x 58 | x = self.classifier(x) 59 | return x,f 60 | else: 61 | x = self.classifier(x) 62 | return x 63 | 64 | # Define the ResNet50-based Model 65 | class ft_net(nn.Module): 66 | 67 | # def __init__(self, class_num, droprate=0.5, stride=2): 68 | def __init__(self, pretrained=True,num_classes=0): 69 | droprate = 0.5 70 | stride = 2 71 | super(ft_net, self).__init__() 72 | model_ft = models.resnet50(pretrained=True) 73 | # avg pooling to global pooling 74 | if stride == 1: 75 | model_ft.layer4[0].downsample[0].stride = (1,1) 76 | model_ft.layer4[0].conv2.stride = (1,1) 77 | model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1)) 78 | self.model = model_ft 79 | self.classifier = ClassBlock(2048, num_classes, droprate) 80 | 81 | def forward(self, x): 82 | x = self.model.conv1(x) 83 | x = self.model.bn1(x) 84 | x = self.model.relu(x) 85 | x = self.model.maxpool(x) 86 | x = self.model.layer1(x) 87 | x = self.model.layer2(x) 88 | x = self.model.layer3(x) 89 | x = self.model.layer4(x) 90 | x = self.model.avgpool(x) 91 | x = x.view(x.size(0), x.size(1)) 92 | x1 = x 93 | x = self.classifier(x) 94 | x2 = x 95 | return x1,x2 96 | 97 | # Define the DenseNet121-based Model 98 | class ft_net_dense(nn.Module): 99 | 100 | def __init__(self, class_num, droprate=0.5): 101 | super().__init__() 102 | model_ft = models.densenet121(pretrained=True) 103 | model_ft.features.avgpool = nn.AdaptiveAvgPool2d((1,1)) 104 | model_ft.fc = nn.Sequential() 105 | self.model = model_ft 106 | # For DenseNet, the feature dim is 1024 107 | self.classifier = ClassBlock(1024, class_num, droprate) 108 | 109 | def forward(self, x): 110 | x = self.model.features(x) 111 | x = x.view(x.size(0), x.size(1)) 112 | x = self.classifier(x) 113 | return x 114 | 115 | # Define the NAS-based Model 116 | class ft_net_NAS(nn.Module): 117 | 118 | def __init__(self, class_num, droprate=0.5): 119 | super().__init__() 120 | model_name = 'nasnetalarge' 121 | # pip install pretrainedmodels 122 | model_ft = pretrainedmodels.__dict__[model_name](num_classes=1000, pretrained='imagenet') 123 | model_ft.avg_pool = nn.AdaptiveAvgPool2d((1,1)) 124 | model_ft.dropout = nn.Sequential() 125 | model_ft.last_linear = nn.Sequential() 126 | self.model = model_ft 127 | # For DenseNet, the feature dim is 4032 128 | self.classifier = ClassBlock(4032, class_num, droprate) 129 | 130 | def forward(self, x): 131 | x = self.model.features(x) 132 | x = self.model.avg_pool(x) 133 | x = x.view(x.size(0), x.size(1)) 134 | x = self.classifier(x) 135 | return x 136 | 137 | # Define the ResNet50-based Model (Middle-Concat) 138 | # In the spirit of "The Devil is in the Middle: Exploiting Mid-level Representations for Cross-Domain Instance Matching." Yu, Qian, et al. arXiv:1711.08106 (2017). 139 | class ft_net_middle(nn.Module): 140 | 141 | def __init__(self, class_num, droprate=0.5): 142 | super(ft_net_middle, self).__init__() 143 | model_ft = models.resnet50(pretrained=True) 144 | # avg pooling to global pooling 145 | model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1)) 146 | self.model = model_ft 147 | self.classifier = ClassBlock(2048+1024, class_num, droprate) 148 | 149 | def forward(self, x): 150 | x = self.model.conv1(x) 151 | x = self.model.bn1(x) 152 | x = self.model.relu(x) 153 | x = self.model.maxpool(x) 154 | x = self.model.layer1(x) 155 | x = self.model.layer2(x) 156 | x = self.model.layer3(x) 157 | # x0 n*1024*1*1 158 | x0 = self.model.avgpool(x) 159 | x = self.model.layer4(x) 160 | # x1 n*2048*1*1 161 | x1 = self.model.avgpool(x) 162 | x = torch.cat((x0,x1),1) 163 | x = x.view(x.size(0), x.size(1)) 164 | x = self.classifier(x) 165 | return x 166 | 167 | # Part Model proposed in Yifan Sun etal. (2018) 168 | class PCB(nn.Module): 169 | def __init__(self, class_num ): 170 | super(PCB, self).__init__() 171 | 172 | self.part = 6 # We cut the pool5 to 6 parts 173 | model_ft = models.resnet50(pretrained=True) 174 | self.model = model_ft 175 | self.avgpool = nn.AdaptiveAvgPool2d((self.part,1)) 176 | self.dropout = nn.Dropout(p=0.5) 177 | # remove the final downsample 178 | self.model.layer4[0].downsample[0].stride = (1,1) 179 | self.model.layer4[0].conv2.stride = (1,1) 180 | # define 6 classifiers 181 | for i in range(self.part): 182 | name = 'classifier'+str(i) 183 | setattr(self, name, ClassBlock(2048, class_num, droprate=0.5, relu=False, bnorm=True, num_bottleneck=256)) 184 | 185 | def forward(self, x): 186 | x = self.model.conv1(x) 187 | x = self.model.bn1(x) 188 | x = self.model.relu(x) 189 | x = self.model.maxpool(x) 190 | 191 | x = self.model.layer1(x) 192 | x = self.model.layer2(x) 193 | x = self.model.layer3(x) 194 | x = self.model.layer4(x) 195 | x = self.avgpool(x) 196 | x = self.dropout(x) 197 | part = {} 198 | predict = {} 199 | # get six part feature batchsize*2048*6 200 | for i in range(self.part): 201 | part[i] = torch.squeeze(x[:,:,i]) 202 | name = 'classifier'+str(i) 203 | c = getattr(self,name) 204 | predict[i] = c(part[i]) 205 | 206 | # sum prediction 207 | #y = predict[0] 208 | #for i in range(self.part-1): 209 | # y += predict[i+1] 210 | y = [] 211 | for i in range(self.part): 212 | y.append(predict[i]) 213 | return y 214 | 215 | class PCB_test(nn.Module): 216 | def __init__(self,model): 217 | super(PCB_test,self).__init__() 218 | self.part = 6 219 | self.model = model.model 220 | self.avgpool = nn.AdaptiveAvgPool2d((self.part,1)) 221 | # remove the final downsample 222 | self.model.layer4[0].downsample[0].stride = (1,1) 223 | self.model.layer4[0].conv2.stride = (1,1) 224 | 225 | def forward(self, x): 226 | x = self.model.conv1(x) 227 | x = self.model.bn1(x) 228 | x = self.model.relu(x) 229 | x = self.model.maxpool(x) 230 | 231 | x = self.model.layer1(x) 232 | x = self.model.layer2(x) 233 | x = self.model.layer3(x) 234 | x = self.model.layer4(x) 235 | x = self.avgpool(x) 236 | y = x.view(x.size(0),x.size(1),x.size(2)) 237 | return y 238 | ''' 239 | # debug model structure 240 | # Run this code with: 241 | python model.py 242 | ''' 243 | if __name__ == '__main__': 244 | # Here I left a simple forward function. 245 | # Test the model, before you train it. 246 | net = ft_net(751, stride=1) 247 | net.classifier = nn.Sequential() 248 | print(net) 249 | input = Variable(torch.FloatTensor(8, 3, 256, 128)) 250 | output = net(input) 251 | print('net output size:') 252 | print(output.shape) 253 | -------------------------------------------------------------------------------- /reid/evaluators.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | from collections import OrderedDict 4 | from collections import namedtuple 5 | 6 | import torch 7 | 8 | from .evaluation_metrics import cmc, mean_ap 9 | from .feature_extraction import extract_cnn_feature, extract_pcb_feature 10 | from .utils.meters import AverageMeter 11 | 12 | 13 | def extract_features(model, data_loader, print_freq=1, metric=None): 14 | model.eval() 15 | batch_time = AverageMeter() 16 | data_time = AverageMeter() 17 | 18 | features = OrderedDict() 19 | labels = OrderedDict() 20 | 21 | end = time.time() 22 | for i, (imgs, fnames, pids, _) in enumerate(data_loader): 23 | data_time.update(time.time() - end) 24 | 25 | outputs = extract_cnn_feature(model, imgs) 26 | for fname, output, pid in zip(fnames, outputs, pids): 27 | features[fname] = output 28 | labels[fname] = pid 29 | 30 | batch_time.update(time.time() - end) 31 | end = time.time() 32 | 33 | if (i + 1) % print_freq == 0: 34 | print('Extract Features: [{}/{}]\t' 35 | 'Time {:.3f} ({:.3f})\t' 36 | 'Data {:.3f} ({:.3f})\t' 37 | .format(i + 1, len(data_loader), 38 | batch_time.val, batch_time.avg, 39 | data_time.val, data_time.avg)) 40 | 41 | return features, labels 42 | 43 | 44 | def extract_pcb_features(model, data_loader, print_freq=1): 45 | model.eval() 46 | batch_time = AverageMeter() 47 | data_time = AverageMeter() 48 | 49 | features = OrderedDict() 50 | labels = OrderedDict() 51 | 52 | end = time.time() 53 | for i, (imgs, fnames, pids, _) in enumerate(data_loader): 54 | data_time.update(time.time() - end) 55 | outputs = extract_pcb_feature(model, imgs) 56 | for fname, output, pid in zip(fnames, outputs, pids): 57 | features[fname] = output 58 | labels[fname] = pid 59 | 60 | batch_time.update(time.time() - end) 61 | end = time.time() 62 | 63 | if (i + 1) % print_freq == 0: 64 | print('Extract Features: [{}/{}]\t' 65 | 'Time {:.3f} ({:.3f})\t' 66 | 'Data {:.3f} ({:.3f})\t' 67 | .format(i + 1, len(data_loader), 68 | batch_time.val, batch_time.avg, 69 | data_time.val, data_time.avg)) 70 | 71 | return features, labels 72 | 73 | 74 | def pairwise_distance(features, query=None, gallery=None, metric=None): 75 | if query is None and gallery is None: 76 | n = len(features) 77 | x = torch.cat(list(features.values())) 78 | x = x.view(n, -1) 79 | if metric is not None: 80 | x = metric.transform(x) 81 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True) * 2 82 | dist = dist.expand(n, n) - 2 * torch.mm(x, x.t()) 83 | return dist 84 | 85 | x = torch.cat([features[f].unsqueeze(0) for f, _, _ in query], 0) 86 | y = torch.cat([features[f].unsqueeze(0) for f, _, _ in gallery], 0) 87 | m, n = x.size(0), y.size(0) 88 | x = x.view(m, -1) 89 | y = y.view(n, -1) 90 | if metric is not None: 91 | x = metric.transform(x) 92 | y = metric.transform(y) 93 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 94 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 95 | dist.addmm_(1, -2, x, y.t()) 96 | return dist 97 | 98 | 99 | def evaluate_all(distmat, query=None, gallery=None, 100 | query_ids=None, gallery_ids=None, 101 | query_cams=None, gallery_cams=None, 102 | cmc_topk=(1, 10,20)):#################################### 103 | if query is not None and gallery is not None: 104 | query_ids = [pid for _, pid, _ in query] 105 | gallery_ids = [pid for _, pid, _ in gallery] 106 | query_cams = [cam for _, _, cam in query] 107 | gallery_cams = [cam for _, _, cam in gallery] 108 | else: 109 | assert (query_ids is not None and gallery_ids is not None 110 | and query_cams is not None and gallery_cams is not None) 111 | 112 | # Compute mean AP 113 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 114 | print('Mean AP: {:4.2%}'.format(mAP)) 115 | 116 | # # Compute all kinds of CMC scores 117 | # cmc_configs = { 118 | # 'allshots': dict(separate_camera_set=False, 119 | # single_gallery_shot=False, 120 | # first_match_break=False), 121 | # 'cuhk03': dict(separate_camera_set=True, 122 | # single_gallery_shot=True, 123 | # first_match_break=False), 124 | # 'hahaha': dict(separate_camera_set=False, 125 | # single_gallery_shot=False, 126 | # first_match_break=True)} 127 | # cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 128 | # query_cams, gallery_cams, **params) 129 | # for name, params in cmc_configs.items()} 130 | 131 | # Compute all kinds of CMC scores 132 | cmc_configs = { 133 | 'score': dict(separate_camera_set=False, 134 | single_gallery_shot=False, 135 | first_match_break=True)} 136 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 137 | query_cams, gallery_cams, **params) 138 | for name, params in cmc_configs.items()} 139 | 140 | # print('CMC Scores{:>12}{:>12}{:>12}' 141 | # .format('allshots', 'cuhk03', 'hahaha')) 142 | print('CMC Scores{:>12}' 143 | .format('score')) 144 | 145 | # rank_score = namedtuple( 146 | # 'rank_score', 147 | # ['map', 'allshots', 'cuhk03', 'hahaha'], 148 | # ) 149 | rank_score = namedtuple( 150 | 'rank_score', 151 | ['map', 'score'], 152 | ) 153 | # for k in cmc_topk: 154 | # print(' top-{:<4}{:12.1%}{:12.1%}{:12.1%}' 155 | # .format(k, cmc_scores['allshots'][k - 1], 156 | # cmc_scores['cuhk03'][k - 1], 157 | # cmc_scores['hahaha'][k - 1])) 158 | # # print(' top-{:<4}{:12.1%}' 159 | # # .format(k, cmc_scores['allshots'][k - 1])) 160 | # score = rank_score( 161 | # mAP, 162 | # cmc_scores['allshots'], cmc_scores['cuhk03'], 163 | # cmc_scores['hahaha'], 164 | # ) 165 | # # score = rank_score( 166 | # # mAP, 167 | # # cmc_scores['allshots'] 168 | # # ) 169 | # return score 170 | 171 | for k in cmc_topk: 172 | print(' top-{:<4}{:12.2%}' 173 | .format(k, 174 | cmc_scores['score'][k - 1])) 175 | # print(' top-{:<4}{:12.1%}' 176 | # .format(k, cmc_scores['allshots'][k - 1])) 177 | score = rank_score( 178 | mAP, 179 | cmc_scores['score'], 180 | ) 181 | # score = rank_score( 182 | # mAP, 183 | # cmc_scores['allshots'] 184 | # ) 185 | return score 186 | 187 | 188 | class Evaluator(object): 189 | def __init__(self, model, print_freq=1): 190 | super(Evaluator, self).__init__() 191 | self.model = model 192 | self.print_freq = print_freq 193 | 194 | def evaluate(self, data_loader, query, gallery, metric=None): 195 | features, _ = extract_features(self.model, data_loader, print_freq=self.print_freq) 196 | distmat = pairwise_distance(features, query, gallery, metric=metric) 197 | return evaluate_all(distmat, query=query, gallery=gallery) 198 | 199 | def evaMat(self, distMat, query, gallery, saveRank=False, root=None): 200 | if saveRank: 201 | assert root is not None 202 | import cv2 203 | import os.path as osp 204 | import shutil 205 | import os 206 | if osp.exists('correct'): 207 | shutil.rmtree('correct') 208 | os.makedirs('correct') 209 | # plot rakning list of 0001 210 | qnames, gnames = [val[0] for val in query], [val[0] for val in gallery] 211 | _, ind = torch.sort(distMat.cpu(), 1) 212 | ind = ind[0, :8] 213 | if root.count('msmt17') == 0: 214 | allNames = [osp.join(root, 'images', gnames[val.item()]) for val in ind] 215 | saveqNames = osp.join(root, 'images', qnames[0]) 216 | else: 217 | allNames = [osp.join(root, 'raw', gnames[val.item()]) for val in ind] 218 | saveqNames = osp.join(root, 'raw', qnames[0]) 219 | allNames = [saveqNames] + allNames 220 | isCorr = [ 221 | 1 if int(saveqNames.split('/')[-1].split('_')[0]) == int(allNames[ii].split('/')[-1].split('_')[0]) 222 | else 0 for ii in range(len(allNames))] 223 | # imshow 224 | ranklist = [] 225 | import numpy as np 226 | for ii, (name, mask) in enumerate(zip(allNames, isCorr)): 227 | img = cv2.resize(cv2.imread(name), (64, 128)) 228 | if ii != 0: 229 | img = cv2.rectangle(img, (0, 0), (64, 128), 230 | (0, 255, 0) if mask == 1 else (0, 0, 255), 2) 231 | ranklist.append(img) 232 | if ii == 0: 233 | ranklist.append(np.zeros((128, 20, 3))) 234 | ranklist = np.concatenate(ranklist, 1) 235 | cv2.imwrite(f'correct/{0}.jpg', ranklist) 236 | return evaluate_all(distMat, query=query, gallery=gallery) 237 | -------------------------------------------------------------------------------- /reid/models/AGW.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from .resnet2 import * 5 | 6 | class Normalize(nn.Module): 7 | def __init__(self, power=2): 8 | super(Normalize, self).__init__() 9 | self.power = power 10 | 11 | def forward(self, x): 12 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 13 | out = x.div(norm) 14 | return out 15 | 16 | class Non_local(nn.Module): 17 | def __init__(self, in_channels, reduc_ratio=2): 18 | super(Non_local, self).__init__() 19 | 20 | self.in_channels = in_channels 21 | self.inter_channels = reduc_ratio//reduc_ratio 22 | 23 | self.g = nn.Sequential( 24 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, 25 | padding=0), 26 | ) 27 | 28 | self.W = nn.Sequential( 29 | nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, 30 | kernel_size=1, stride=1, padding=0), 31 | nn.BatchNorm2d(self.in_channels), 32 | ) 33 | nn.init.constant_(self.W[1].weight, 0.0) 34 | nn.init.constant_(self.W[1].bias, 0.0) 35 | 36 | 37 | 38 | self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 39 | kernel_size=1, stride=1, padding=0) 40 | 41 | self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 42 | kernel_size=1, stride=1, padding=0) 43 | 44 | def forward(self, x): 45 | ''' 46 | :param x: (b, c, t, h, w) 47 | :return: 48 | ''' 49 | 50 | batch_size = x.size(0) 51 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 52 | g_x = g_x.permute(0, 2, 1) 53 | 54 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 55 | theta_x = theta_x.permute(0, 2, 1) 56 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 57 | f = torch.matmul(theta_x, phi_x) 58 | N = f.size(-1) 59 | # f_div_C = torch.nn.functional.softmax(f, dim=-1) 60 | f_div_C = f / N 61 | 62 | y = torch.matmul(f_div_C, g_x) 63 | y = y.permute(0, 2, 1).contiguous() 64 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 65 | W_y = self.W(y) 66 | z = W_y + x 67 | 68 | return z 69 | 70 | 71 | # ##################################################################### 72 | def weights_init_kaiming(m): 73 | classname = m.__class__.__name__ 74 | # print(classname) 75 | if classname.find('Conv') != -1: 76 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 77 | elif classname.find('Linear') != -1: 78 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 79 | init.zeros_(m.bias.data) 80 | elif classname.find('BatchNorm1d') != -1: 81 | init.normal_(m.weight.data, 1.0, 0.01) 82 | init.zeros_(m.bias.data) 83 | 84 | def weights_init_classifier(m): 85 | classname = m.__class__.__name__ 86 | if classname.find('Linear') != -1: 87 | init.normal_(m.weight.data, 0, 0.001) 88 | if m.bias: 89 | init.zeros_(m.bias.data) 90 | 91 | 92 | 93 | class visible_module(nn.Module): 94 | def __init__(self, arch='resnet50'): 95 | super(visible_module, self).__init__() 96 | 97 | model_v = resnet50(pretrained=True, 98 | last_conv_stride=1, last_conv_dilation=1) 99 | # avg pooling to global pooling 100 | self.visible = model_v 101 | 102 | def forward(self, x): 103 | x = self.visible.conv1(x) 104 | x = self.visible.bn1(x) 105 | x = self.visible.relu(x) 106 | x = self.visible.maxpool(x) 107 | return x 108 | 109 | 110 | class thermal_module(nn.Module): 111 | def __init__(self, arch='resnet50'): 112 | super(thermal_module, self).__init__() 113 | 114 | model_t = resnet50(pretrained=True, 115 | last_conv_stride=1, last_conv_dilation=1) 116 | # avg pooling to global pooling 117 | self.thermal = model_t 118 | 119 | def forward(self, x): 120 | x = self.thermal.conv1(x) 121 | x = self.thermal.bn1(x) 122 | x = self.thermal.relu(x) 123 | x = self.thermal.maxpool(x) 124 | return x 125 | 126 | 127 | class base_resnet(nn.Module): 128 | def __init__(self, arch='resnet50'): 129 | super(base_resnet, self).__init__() 130 | 131 | model_base = resnet50(pretrained=True, 132 | last_conv_stride=1, last_conv_dilation=1) 133 | # avg pooling to global pooling 134 | model_base.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 135 | self.base = model_base 136 | 137 | def forward(self, x): 138 | x = self.base.layer1(x) 139 | x = self.base.layer2(x) 140 | x = self.base.layer3(x) 141 | x = self.base.layer4(x) 142 | return x 143 | 144 | 145 | # class embed_net(nn.Module):#主网络 146 | # def __init__(self, class_num, no_local= 'on', gm_pool = 'on', arch='resnet50'): 147 | # super(embed_net, self).__init__() 148 | 149 | # model = models.create(args.arch, pretrained=True, num_classes=num_classes) 150 | 151 | # def __init__(self, depth, checkpoint=None, pretrained=True, num_features=2048, 152 | # dropout=0.1, num_classes=0): 153 | # super(ResNet, self).__init__() 154 | 155 | class embed_net(nn.Module): # 主网络 156 | def __init__(self, no_local='on', gm_pool='on', arch='resnet50',pretrained=True,num_classes=0): 157 | super(embed_net, self).__init__() 158 | 159 | self.thermal_module = thermal_module(arch=arch) 160 | self.visible_module = visible_module(arch=arch) 161 | self.base_resnet = base_resnet(arch=arch) 162 | self.non_local = no_local 163 | if self.non_local =='on': 164 | layers=[3, 4, 6, 3] 165 | non_layers=[0,2,3,0] 166 | self.NL_1 = nn.ModuleList( 167 | [Non_local(256) for i in range(non_layers[0])]) 168 | self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])]) 169 | self.NL_2 = nn.ModuleList( 170 | [Non_local(512) for i in range(non_layers[1])]) 171 | self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])]) 172 | self.NL_3 = nn.ModuleList( 173 | [Non_local(1024) for i in range(non_layers[2])]) 174 | self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])]) 175 | self.NL_4 = nn.ModuleList( 176 | [Non_local(2048) for i in range(non_layers[3])]) 177 | self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])]) 178 | 179 | 180 | pool_dim = 2048 181 | self.l2norm = Normalize(2) 182 | self.bottleneck = nn.BatchNorm1d(pool_dim) 183 | self.bottleneck.bias.requires_grad_(False) # no shift 184 | 185 | self.classifier = nn.Linear(pool_dim, num_classes, bias=False) 186 | 187 | self.bottleneck.apply(weights_init_kaiming) 188 | self.classifier.apply(weights_init_classifier) 189 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 190 | self.gm_pool = gm_pool 191 | 192 | # def forward(self, x1, x2, modal=0): 193 | def forward(self, x1, modal=1):#改成默认使用1模式,0会出问题 194 | x2=x1################################### 195 | if modal == 0: 196 | x1 = self.visible_module(x1) 197 | x2 = self.thermal_module(x2) 198 | x = torch.cat((x1, x2), 0) 199 | elif modal == 1: 200 | x = self.visible_module(x1) 201 | elif modal == 2: 202 | x = self.thermal_module(x2) 203 | 204 | # shared block 205 | if self.non_local == 'on': 206 | NL1_counter = 0 207 | if len(self.NL_1_idx) == 0: self.NL_1_idx = [-1] 208 | for i in range(len(self.base_resnet.base.layer1)): 209 | x = self.base_resnet.base.layer1[i](x) 210 | if i == self.NL_1_idx[NL1_counter]: 211 | _, C, H, W = x.shape 212 | x = self.NL_1[NL1_counter](x) 213 | NL1_counter += 1 214 | # Layer 2 215 | NL2_counter = 0 216 | if len(self.NL_2_idx) == 0: self.NL_2_idx = [-1] 217 | for i in range(len(self.base_resnet.base.layer2)): 218 | x = self.base_resnet.base.layer2[i](x) 219 | if i == self.NL_2_idx[NL2_counter]: 220 | _, C, H, W = x.shape 221 | x = self.NL_2[NL2_counter](x) 222 | NL2_counter += 1 223 | # Layer 3 224 | NL3_counter = 0 225 | if len(self.NL_3_idx) == 0: self.NL_3_idx = [-1] 226 | for i in range(len(self.base_resnet.base.layer3)): 227 | x = self.base_resnet.base.layer3[i](x) 228 | if i == self.NL_3_idx[NL3_counter]: 229 | _, C, H, W = x.shape 230 | x = self.NL_3[NL3_counter](x) 231 | NL3_counter += 1 232 | # Layer 4 233 | NL4_counter = 0 234 | if len(self.NL_4_idx) == 0: self.NL_4_idx = [-1] 235 | for i in range(len(self.base_resnet.base.layer4)): 236 | x = self.base_resnet.base.layer4[i](x) 237 | if i == self.NL_4_idx[NL4_counter]: 238 | _, C, H, W = x.shape 239 | x = self.NL_4[NL4_counter](x) 240 | NL4_counter += 1 241 | else: 242 | x = self.base_resnet(x) 243 | if self.gm_pool == 'on': 244 | b, c, h, w = x.shape 245 | x = x.view(b, c, -1) 246 | p = 3.0 247 | x_pool = (torch.mean(x**p, dim=-1) + 1e-12)**(1/p) 248 | else: 249 | x_pool = self.avgpool(x) 250 | x_pool = x_pool.view(x_pool.size(0), x_pool.size(1)) 251 | 252 | feat = self.bottleneck(x_pool) 253 | 254 | if self.training: 255 | return x_pool, self.classifier(feat) 256 | else: 257 | return x_pool, self.classifier(feat) #要对齐 258 | # return self.l2norm(x_pool), self.l2norm(feat) -------------------------------------------------------------------------------- /Multiform_attack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import sys 4 | import os 5 | from torch.utils.data import DataLoader 6 | 7 | from reid import models 8 | from torch.nn import functional as F 9 | import os.path as osp 10 | from reid import datasets 11 | 12 | from reid.utils.data import transforms as T 13 | from torchvision.transforms import Resize 14 | from reid.utils.data.preprocessor import Preprocessor 15 | from reid.evaluators import Evaluator 16 | from torch.optim.optimizer import Optimizer, required 17 | import random 18 | import numpy as np 19 | import math 20 | from reid.evaluators import extract_features 21 | from reid.utils.meters import AverageMeter 22 | import torchvision 23 | import faiss 24 | 25 | from torchvision import transforms 26 | 27 | from MOAA.MOAA import Attack 28 | import numpy as np 29 | import argparse 30 | import os 31 | 32 | CHECK = 1e-5 33 | SAT_MIN = 0.5 34 | MODE = "bilinear" 35 | 36 | 37 | 38 | 39 | def input(sourceName, mteName,mteName2, targetName, split_id, data_dir, height, width, 40 | batch_size, workers, combine): 41 | root = osp.join(data_dir, sourceName) 42 | rootMte = osp.join(data_dir, mteName) 43 | rootMte2 = osp.join(data_dir, mteName2) 44 | rootTgt = osp.join(data_dir, targetName) 45 | sourceSet = datasets.create(sourceName, root, num_val=0.1, split_id=split_id) 46 | mteSet = datasets.create(mteName, rootMte, num_val=0.1, split_id=split_id) 47 | mteSet2 = datasets.create(mteName2, rootMte2, num_val=0.1, split_id=split_id) 48 | tgtSet = datasets.create(targetName, rootTgt, num_val=0.1, split_id=split_id) 49 | num_classes = sourceSet.num_trainval_ids if combine else sourceSet.num_train_ids 50 | 51 | num_search = mteSet.num_trainval_ids if combine else mteSet.num_train_ids 52 | num_search2 = mteSet2.num_trainval_ids if combine else mteSet2.num_train_ids 53 | 54 | class_tgt = tgtSet.num_trainval_ids if combine else tgtSet.num_train_ids 55 | 56 | train_transformer = T.Compose([ 57 | Resize((height, width)), 58 | transforms.RandomGrayscale(p=0.2), 59 | T.ToTensor(), 60 | ]) 61 | 62 | gradient_based_train = DataLoader( 63 | Preprocessor(sourceSet.trainval, root=sourceSet.images_dir, transform=train_transformer), 64 | batch_size=batch_size, num_workers=workers, shuffle=False, pin_memory=True) 65 | 66 | search_set = DataLoader( 67 | Preprocessor(mteSet.trainval, root=mteSet.images_dir, transform=train_transformer), 68 | batch_size=batch_size, num_workers=workers, shuffle=False, pin_memory=True) 69 | 70 | 71 | search_set2 = DataLoader( 72 | Preprocessor(mteSet2.trainval, root=mteSet2.images_dir, transform=train_transformer), 73 | batch_size=batch_size, num_workers=workers, shuffle=False, pin_memory=True) 74 | 75 | 76 | # tgtSet is transferbility attack model's dataset, 77 | return sourceSet, tgtSet, mteSet,mteSet2, num_classes,num_search,num_search2, class_tgt, gradient_based_train, search_set,search_set2 78 | 79 | 80 | def rescale_check(check, sat, sat_change, sat_min): 81 | return sat_change < check and sat > sat_min 82 | 83 | 84 | class MI_SGD(Optimizer): 85 | def __init__( 86 | self, params, lr=required, momentum=0, dampening=0, weight_decay=0, 87 | nesterov=False, max_eps=10 / 255 88 | ): 89 | if lr is not required and lr < 0.0: 90 | raise ValueError("Invalid learning rate: {}".format(lr)) 91 | if momentum < 0.0: 92 | raise ValueError("Invalid momentum value: {}".format(momentum)) 93 | if weight_decay < 0.0: 94 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 95 | 96 | defaults = dict( 97 | lr=lr, 98 | momentum=momentum, 99 | dampening=dampening, 100 | weight_decay=weight_decay, 101 | nesterov=nesterov, 102 | sign=False, 103 | ) 104 | if nesterov and (momentum <= 0 or dampening != 0): 105 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 106 | super(MI_SGD, self).__init__(params, defaults) 107 | self.sat = 0 108 | self.sat_prev = 0 109 | self.max_eps = max_eps 110 | 111 | def __setstate__(self, state): 112 | super(MI_SGD, self).__setstate__(state) 113 | for group in self.param_groups: 114 | group.setdefault("nesterov", False) 115 | 116 | def rescale(self, ): 117 | for group in self.param_groups: 118 | if not group["sign"]: 119 | continue 120 | for p in group["params"]: 121 | self.sat_prev = self.sat 122 | self.sat = (p.data.abs() >= self.max_eps).sum().item() / p.data.numel() 123 | sat_change = abs(self.sat - self.sat_prev) 124 | if rescale_check(CHECK, self.sat, sat_change, SAT_MIN): 125 | print('rescaled') 126 | p.data = p.data / 2 127 | 128 | def step(self, closure=None): 129 | loss = None 130 | if closure is not None: 131 | loss = closure() 132 | 133 | for group in self.param_groups: 134 | weight_decay = group["weight_decay"] 135 | momentum = group["momentum"] 136 | dampening = group["dampening"] 137 | nesterov = group["nesterov"] 138 | 139 | for p in group["params"]: 140 | if p.grad is None: 141 | continue 142 | d_p = p.grad.data 143 | if group["sign"]: 144 | d_p = d_p / (d_p.norm(1) + 1e-12) 145 | if weight_decay != 0: 146 | d_p.add_(weight_decay, p.data) 147 | if momentum != 0: 148 | param_state = self.state[p] 149 | if "momentum_buffer" not in param_state: 150 | buf = param_state["momentum_buffer"] = torch.zeros_like(p.data) 151 | buf.mul_(momentum).add_(d_p) 152 | else: 153 | buf = param_state["momentum_buffer"] 154 | buf.mul_(momentum).add_(1 - dampening, d_p) 155 | if nesterov: 156 | d_p = d_p.add(momentum, buf) 157 | else: 158 | d_p = buf 159 | 160 | if group["sign"]: 161 | p.data.add_(-group["lr"], d_p.sign()) 162 | p.data = torch.clamp(p.data, -self.max_eps, self.max_eps) 163 | else: 164 | p.data.add_(-group["lr"], d_p) 165 | 166 | return loss 167 | 168 | 169 | def Update(noiseData, optimizer, gradInfo, max_eps): 170 | weight_decay = optimizer.param_groups[0]["weight_decay"] 171 | momentum = optimizer.param_groups[0]["momentum"] 172 | dampening = optimizer.param_groups[0]["dampening"] 173 | nesterov = optimizer.param_groups[0]["nesterov"] 174 | lr = optimizer.param_groups[0]["lr"] 175 | 176 | d_p = gradInfo 177 | if optimizer.param_groups[0]["sign"]: 178 | d_p = d_p / (d_p.norm(1) + 1e-12) 179 | if weight_decay != 0: 180 | d_p.add_(weight_decay, noiseData) 181 | if momentum != 0: 182 | param_state = optimizer.state[noiseData] 183 | if "momentum_buffer" not in param_state: 184 | buf = param_state["momentum_buffer"] = torch.zeros_like(noiseData.data) 185 | buf = buf * momentum + d_p 186 | else: 187 | buf = param_state["momentum_buffer"] 188 | buf = buf * momentum + (1 - dampening) * d_p 189 | if nesterov: 190 | d_p = d_p + momentum * buf 191 | else: 192 | d_p = buf 193 | 194 | if optimizer.param_groups[0]["sign"]: 195 | noiseData = noiseData - lr * d_p.sign() 196 | noiseData = torch.clamp(noiseData, -max_eps, max_eps) 197 | else: 198 | noiseData = noiseData - lr * d_p.sign() 199 | return noiseData 200 | 201 | 202 | def Multiform_attack(gradient_based_train_loader, search_set_loader, net, noise, epoch, optimizer, 203 | centroids, metaCentroids, normalize): 204 | global args 205 | noise.requires_grad = True 206 | batch_time = AverageMeter() 207 | data_time = AverageMeter() 208 | losses = AverageMeter() 209 | 210 | mean = torch.Tensor(normalize.mean).view(1, 3, 1, 1).cuda() 211 | std = torch.Tensor(normalize.std).view(1, 3, 1, 1).cuda() 212 | 213 | net.eval() 214 | 215 | end = time.time() 216 | optimizer.zero_grad() 217 | optimizer.rescale() 218 | for i, ((input, _, pid, _), (metaTest, _, _, _)) in enumerate(zip(gradient_based_train_loader, search_set_loader)): 219 | data_time.update(time.time() - end) 220 | model.zero_grad() 221 | input = input.cuda() 222 | metaTest = metaTest.cuda() 223 | 224 | 225 | with torch.no_grad(): 226 | normInput = (input - mean) / std 227 | feature, _ = net(normInput) 228 | scores = centroids.mm(F.normalize(feature.t(), p=2, dim=0)) 229 | 230 | _, ranks = torch.sort(scores, dim=0, descending=True) 231 | pos_i = ranks[0, :] 232 | neg_i = ranks[-1, :] 233 | neg_feature = centroids[neg_i, :] 234 | pos_feature = centroids[pos_i, :] 235 | 236 | current_noise = noise 237 | current_noise = F.interpolate( 238 | current_noise.unsqueeze(0), 239 | mode=MODE, size=tuple(input.shape[-2:]), align_corners=True, 240 | ).squeeze() 241 | perturted_input = torch.clamp(input + current_noise, 0, 1) 242 | perturted_input_norm = (perturted_input - mean) / std 243 | perturbed_feature = net(perturted_input_norm)[0] 244 | 245 | optimizer.zero_grad() 246 | 247 | pair_loss = 10 * F.triplet_margin_loss(perturbed_feature, neg_feature, pos_feature, 0.5) 248 | 249 | 250 | pair_loss = pair_loss.view(1) 251 | 252 | loss = pair_loss 253 | 254 | 255 | grad = torch.autograd.grad(loss, noise, create_graph=True)[0] 256 | noiseOneStep = Update(noise, optimizer, grad, MAX_EPS) 257 | 258 | 259 | newNoise = F.interpolate( 260 | noiseOneStep.unsqueeze(0), mode=MODE, 261 | size=tuple(metaTest.shape[-2:]), align_corners=True, 262 | ).squeeze() 263 | 264 | 265 | if epoch % 3 == 0: 266 | search_noise = evolutionary_search(search_set,search_set2, modelTest,modelTest2,noise) 267 | newNoise = newNoise + search_noise 268 | 269 | with torch.no_grad(): 270 | normMte = (metaTest - mean) / std 271 | mteFeat = net(normMte)[0] 272 | scores = metaCentroids.mm(F.normalize(mteFeat.t(), p=2, dim=0)) 273 | 274 | metaLab = scores.max(0, keepdim=True)[1] 275 | _, ranks = torch.sort(scores, dim=0, descending=True) 276 | pos_i = ranks[0, :] 277 | neg_i = ranks[-1, :] 278 | neg_mte_feat = metaCentroids[neg_i, :] 279 | pos_mte_feat = metaCentroids[pos_i, :] 280 | 281 | perMteInput = torch.clamp(metaTest + newNoise, 0, 1) 282 | normPerMteInput = (perMteInput - mean) / std 283 | normMteFeat = net(normPerMteInput)[0] 284 | 285 | lossTri = 10 * F.triplet_margin_loss( 286 | normMteFeat, neg_mte_feat, pos_mte_feat, 0.5 287 | ) 288 | 289 | oneHotRealMeta = torch.zeros(scores.t().shape).cuda() 290 | oneHotRealMeta.scatter_(1, metaLab.view(-1, 1), float(1)) 291 | 292 | 293 | finalLoss = lossTri + pair_loss 294 | 295 | finalLoss.backward() 296 | 297 | losses.update(pair_loss.item()) 298 | optimizer.step() 299 | 300 | # measure elapsed time 301 | batch_time.update(time.time() - end) 302 | end = time.time() 303 | 304 | if i % args.print_freq == 0: 305 | print( 306 | ">> Train: [{0}][{1}/{2}]\t" 307 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" 308 | "Data {data_time.val:.3f} ({data_time.avg:.3f})\t" 309 | "Loss {loss.val:.4f} ({loss.avg:.4f})\t" 310 | "Noise l2: {noise:.4f}".format( 311 | epoch + 1, 312 | i, len(gradient_based_train_loader), 313 | batch_time=batch_time, 314 | data_time=data_time, 315 | loss=losses, lossTri=lossTri.item(), 316 | noise=noise.norm(), 317 | ) 318 | ) 319 | 320 | noise.requires_grad = False 321 | print(f"Train {epoch}: Loss: {losses.avg}") 322 | return losses.avg, noise 323 | 324 | 325 | # Define the evolutionary search method, using the Attack class in MOAA 326 | def evolutionary_search(search_set,search_set2, modelTest,modelTest2,noise): 327 | 328 | params = { 329 | "population_size": 2, 330 | "num_generations": 150, 331 | "mutation_rate": 0.2, 332 | "crossover_rate": 0.3, 333 | "epsilon": 8 / 255.0, 334 | "p_size": noise.size(), 335 | "x": None, # input data 336 | "eps": 8 / 255.0, 337 | "zero_probability": 0.2, 338 | "pm": 0.1, 339 | "pop_size": 2, 340 | "iterations": 150, 341 | "pc": 0.3, 342 | "include_dist": True, 343 | "save_directory": "results", 344 | "tournament_size": 3, 345 | "max_dist": 1.0, 346 | } 347 | 348 | attack = Attack(params,search_set,search_set2,modelTest,modelTest2) 349 | 350 | search_noise = attack.attack(noise) 351 | return search_noise 352 | 353 | 354 | def calDist(qFeat, gFeat): 355 | m, n = qFeat.size(0), gFeat.size(0) 356 | x = qFeat.view(m, -1) 357 | y = gFeat.view(n, -1) 358 | dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 359 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 360 | dist_m.addmm_(1, -2, x, y.t()) 361 | return dist_m 362 | 363 | 364 | def test(dataset, net, noise, args, evaluator, epoch): 365 | print(">> Evaluating network on test datasets...") 366 | 367 | net = net.cuda() 368 | net.eval() 369 | normalize = T.Normalize( 370 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 371 | ) 372 | 373 | def add_noise(img): 374 | n = noise.cpu() 375 | img = img.cpu() 376 | n = F.interpolate( 377 | n.unsqueeze(0), mode=MODE, size=tuple(img.shape[-2:]), align_corners=True 378 | ).squeeze() 379 | return torch.clamp(img + n, 0, 1) 380 | 381 | query_trans = T.Compose([ 382 | T.RectScale(args.height, args.width), 383 | T.ToTensor(), T.Lambda(lambda img: add_noise(img)), 384 | # transforms.RandomGrayscale(p=0.5), 385 | normalize 386 | ]) 387 | test_transformer = T.Compose([ 388 | T.RectScale(args.height, args.width), 389 | transforms.RandomGrayscale(p=1), 390 | T.ToTensor(), normalize 391 | ]) 392 | query_loader = DataLoader( 393 | Preprocessor(dataset.query, root=dataset.images_dir, transform=query_trans), 394 | batch_size=args.batch_size, num_workers=0, shuffle=False, pin_memory=True 395 | ) 396 | gallery_loader = DataLoader( 397 | Preprocessor(dataset.gallery, root=dataset.images_dir, transform=test_transformer), 398 | batch_size=args.batch_size, num_workers=8, shuffle=False, pin_memory=True 399 | ) 400 | qFeats, gFeats, testQImage, qnames, gnames = [], [], [], [], [] 401 | with torch.no_grad(): 402 | for (inputs, qname, _, _) in query_loader: 403 | inputs = inputs.cuda() 404 | qFeats.append(net(inputs)[0]) 405 | qnames.extend(qname) 406 | qFeats = torch.cat(qFeats, 0) 407 | for (inputs, gname, _, _) in gallery_loader: 408 | inputs = inputs.cuda() 409 | gFeats.append(net(inputs)[0]) 410 | gnames.extend(gname) 411 | gFeats = torch.cat(gFeats, 0) 412 | distMat = calDist(qFeats, gFeats) 413 | 414 | 415 | # evaluate on test datasets 416 | evaluator.evaMat(distMat, dataset.query, dataset.gallery) 417 | return testQImage 418 | 419 | 420 | 421 | 422 | if __name__ == '__main__': 423 | parser = argparse.ArgumentParser() 424 | 425 | parser.add_argument('--data', type=str, required=True, 426 | help='path to reid dataset') 427 | parser.add_argument('-s', '--source', type=str, default='sysu_v2', 428 | choices=datasets.names()) 429 | parser.add_argument('-a', '--arch', type=str, default='resnet50', 430 | choices=models.names()) 431 | parser.add_argument('-t', '--target', type=str, default='sysu_v2', 432 | choices=datasets.names()) 433 | parser.add_argument('-m', '--mte', type=str, default='sysu_v2', 434 | choices=datasets.names()) 435 | parser.add_argument('-m2', '--mte2', type=str, default='sysu_v2', 436 | choices=datasets.names()) 437 | parser.add_argument('--batch_size', type=int, default=50, required=True, 438 | help='number of examples/minibatch') 439 | parser.add_argument('--num_batches', type=int, required=False, 440 | help='number of batches (default entire dataset)') 441 | parser.add_argument('--resume', type=str, default='', metavar='PATH') 442 | parser.add_argument('--resumeSearchTgt', type=str, default='', metavar='PATH') 443 | parser.add_argument('--resumeSearchTgt2', type=str, default='', metavar='PATH') 444 | parser.add_argument('--resumeTgt', type=str, default='', metavar='PATH') 445 | 446 | parser.add_argument('--split', type=int, default=0) 447 | parser.add_argument('--epoch', type=int, default=20) 448 | parser.add_argument('--height', type=int, default=256, 449 | help="input height, default: 256 for resnet*, " 450 | "144 for inception") 451 | parser.add_argument('--width', type=int, default=128, 452 | help="input width, default: 128 for resnet*, " 453 | "56 for inception") 454 | parser.add_argument('--num-instances', type=int, default=8, 455 | help="each minibatch consist of " 456 | "(batch_size // num_instances) identities, and " 457 | "each identity has num_instances instances, " 458 | "default: 4") 459 | parser.add_argument('--combine_trainval', action='store_true', 460 | help="train and val sets together for training, " 461 | "val set alone for validation") 462 | parser.add_argument('--print_freq', type=int, default=10) 463 | parser.add_argument("--max-eps", default=8, type=int, help="max eps") 464 | args = parser.parse_args() 465 | 466 | 467 | # torch.manual_seed(0) 468 | # torch.cuda.manual_seed_all(0) 469 | # np.random.seed(0) 470 | # random.seed(0) 471 | 472 | sourceSet, tgtSet, mteSet,mteSet2, num_classes,num_search,num_search2, class_tgt, gradient_based_train, search_set,search_set2 = \ 473 | input(args.source, args.mte, args.mte2, args.target, 474 | args.split, args.data, args.height, 475 | args.width, args.batch_size, 8, args.combine_trainval) 476 | 477 | 478 | model = models.create(args.arch, pretrained=True, num_classes=num_classes) 479 | modelTest = models.create(args.arch, pretrained=True, num_classes=num_search) 480 | modelTest2 = models.create(args.arch, pretrained=True, num_classes=num_search2) 481 | modelTarget = models.create(args.arch, pretrained=True, num_classes=class_tgt) 482 | if args.resume: 483 | checkpoint = torch.load(args.resume) 484 | if 'state_dict' in checkpoint.keys(): 485 | checkpoint = checkpoint['state_dict'] 486 | try: 487 | model.load_state_dict(checkpoint) 488 | except: 489 | allNames = list(checkpoint.keys()) 490 | for name in allNames: 491 | if name.count('classifier') != 0: 492 | del checkpoint[name] 493 | model.load_state_dict(checkpoint, strict=False) 494 | 495 | checkTest = torch.load(args.resumeSearchTgt) 496 | if 'state_dict' in checkTest.keys(): 497 | checkTgt = checkTest['state_dict'] 498 | try: 499 | modelTest.load_state_dict(checkTest) 500 | except: 501 | allNames = list(checkTest.keys()) 502 | for name in allNames: 503 | if name.count('classifier') != 0: 504 | del checkTest[name] 505 | modelTest.load_state_dict(checkTest, strict=False) 506 | 507 | checkTest2 = torch.load(args.resumeSearchTgt2) 508 | if 'state_dict' in checkTest2.keys(): 509 | checkTest2 = checkTest2['state_dict'] 510 | try: 511 | modelTest2.load_state_dict(checkTest2) 512 | except: 513 | allNames = list(checkTest2.keys()) 514 | for name in allNames: 515 | if name.count('classifier') != 0: 516 | del checkTest2[name] 517 | modelTest.load_state_dict(checkTest2, strict=False) 518 | 519 | checkTarget = torch.load(args.resumeTgt) 520 | if 'state_dict' in checkTarget.keys(): 521 | checkTarget = checkTarget['state_dict'] 522 | try: 523 | modelTarget.load_state_dict(checkTarget) 524 | except: 525 | allNames = list(checkTarget.keys()) 526 | for name in allNames: 527 | if name.count('classifier') != 0: 528 | del checkTarget[name] 529 | modelTarget.load_state_dict(checkTarget, strict=False) 530 | 531 | 532 | model.eval() 533 | modelTest.eval() 534 | modelTest2.eval() 535 | modelTarget.eval() 536 | if torch.cuda.is_available(): 537 | model = model.cuda() 538 | modelTest = modelTest.cuda() 539 | modelTest2 = modelTest2.cuda() 540 | modelTarget = modelTarget.cuda() 541 | 542 | features, _ = extract_features(model, gradient_based_train, print_freq=10) 543 | features = torch.stack([features[f] for f, _, _ in sourceSet.trainval]) 544 | metaFeats, _ = extract_features(model, search_set, print_freq=10) 545 | metaFeats = torch.stack([metaFeats[f] for f, _, _ in mteSet.trainval]) 546 | 547 | 548 | if args.source == "sysu": 549 | ncentroids = 395 550 | else: 551 | ncentroids = 206 552 | 553 | 554 | fDim = features.shape[1] 555 | cluster, metaClu = faiss.Kmeans(fDim, ncentroids, niter=20, gpu=True), \ 556 | faiss.Kmeans(fDim, ncentroids, niter=20, gpu=True) 557 | cluster.train(features.cpu().numpy()) 558 | metaClu.train(metaFeats.cpu().numpy()) 559 | 560 | centroids = torch.from_numpy(cluster.centroids).cuda().float() 561 | metaCentroids = torch.from_numpy(metaClu.centroids).cuda().float() 562 | del metaClu, cluster 563 | 564 | evaluator = Evaluator(modelTest, args.print_freq) 565 | evaSrc = Evaluator(model, args.print_freq) 566 | 567 | # universal noise 568 | noise = torch.zeros((3, args.height, args.width)).cuda() 569 | normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 570 | 571 | noise.requires_grad = True 572 | MAX_EPS = args.max_eps / 255.0 573 | 574 | optimizer = MI_SGD( 575 | [{"params": [noise], "lr": MAX_EPS / 10, "momentum": 1, "sign": True}], 576 | max_eps=MAX_EPS, 577 | ) 578 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=math.exp(-0.01)) 579 | 580 | 581 | import time 582 | 583 | for epoch in range(args.epoch): 584 | scheduler.step() 585 | begin_time = time.time() 586 | loss, noise = Multiform_attack( 587 | gradient_based_train, search_set, model, noise, epoch, optimizer, 588 | centroids, metaCentroids, normalize 589 | ) 590 | 591 | testQImage = test(tgtSet, modelTest, noise, args, evaluator, epoch) 592 | 593 | --------------------------------------------------------------------------------