├── AAAI-ID-1540-Yu.pdf ├── LICENSE ├── README.md ├── climb ├── __init__.py ├── __pycache__ │ ├── dataloader.cpython-310.pyc │ ├── dataloader.cpython-39.pyc │ ├── dataset.cpython-310.pyc │ ├── dataset.cpython-39.pyc │ ├── loss.cpython-310.pyc │ ├── loss.cpython-39.pyc │ ├── model.cpython-310.pyc │ ├── model.cpython-39.pyc │ ├── optimizer.cpython-310.pyc │ ├── optimizer.cpython-39.pyc │ ├── preprocessing.cpython-310.pyc │ ├── preprocessing.cpython-39.pyc │ ├── sampler.cpython-310.pyc │ ├── sampler.cpython-39.pyc │ ├── spmamba.cpython-310.pyc │ ├── utils.cpython-310.pyc │ ├── utils.cpython-39.pyc │ └── vivim.cpython-310.pyc ├── dataloader.py ├── dataset.py ├── loss.py ├── model.py ├── optimizer.py ├── preprocessing.py ├── processor_climb.py ├── sampler.py ├── spmamba.py ├── utils.py └── vivim.py ├── clip ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-39.pyc │ ├── clip.cpython-310.pyc │ ├── clip.cpython-39.pyc │ ├── model.cpython-310.pyc │ ├── model.cpython-39.pyc │ ├── simple_tokenizer.cpython-310.pyc │ └── simple_tokenizer.cpython-39.pyc ├── bpe_simple_vocab_16e6.txt.gz ├── clip.py ├── model.py └── simple_tokenizer.py ├── config ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-39.pyc │ ├── defaults.cpython-310.pyc │ └── defaults.cpython-39.pyc ├── climb-vit-market.yml ├── climb-vit-msmt.yml └── defaults.py ├── datasets ├── bases.py ├── dukemtmcreid.py ├── keypoint_test.txt ├── keypoint_train.txt ├── make_dataloader.py ├── make_dataloader_clipreid.py ├── make_dataloader_clipreid_ccpa.py ├── market1501.py ├── msmt17.py ├── msmt17_v2.py ├── occ_duke.py ├── preprocessing.py ├── sampler.py └── sampler_ddp.py ├── mamba ├── AUTHORS ├── assets │ └── selection.png ├── benchmarks │ └── benchmark_generation_mamba_simple.py └── csrc │ └── selective_scan │ ├── reverse_scan.cuh │ ├── selective_scan.cpp │ ├── selective_scan.h │ ├── selective_scan_bwd_bf16_complex.cu │ ├── selective_scan_bwd_bf16_real.cu │ ├── selective_scan_bwd_fp16_complex.cu │ ├── selective_scan_bwd_fp16_real.cu │ ├── selective_scan_bwd_fp32_complex.cu │ ├── selective_scan_bwd_fp32_real.cu │ ├── selective_scan_bwd_kernel.cuh │ └── selective_scan_common.h ├── solver ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-39.pyc │ ├── lr_scheduler.cpython-310.pyc │ └── lr_scheduler.cpython-39.pyc ├── cosine_lr.py ├── lr_scheduler.py ├── make_optimizer.py ├── scheduler.py └── scheduler_factory.py ├── train_climb.py └── utils ├── __init__.py ├── __pycache__ ├── logger.cpython-310.pyc ├── logger.cpython-39.pyc ├── meter.cpython-310.pyc └── meter.cpython-39.pyc ├── iotools.py ├── logger.py ├── meter.py ├── metrics.py └── reranking.py /AAAI-ID-1540-Yu.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/AAAI-ID-1540-Yu.pdf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Chenyang Yu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CLIMB-ReID 2 | CLIMB-ReID: A Hybrid CLIP-Mamba Framework for Person Re-Identification(AAAI2025) 3 | -------------------------------------------------------------------------------- /climb/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/climb/__init__.py -------------------------------------------------------------------------------- /climb/__pycache__/dataloader.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/climb/__pycache__/dataloader.cpython-310.pyc -------------------------------------------------------------------------------- /climb/__pycache__/dataloader.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/climb/__pycache__/dataloader.cpython-39.pyc -------------------------------------------------------------------------------- /climb/__pycache__/dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/climb/__pycache__/dataset.cpython-310.pyc -------------------------------------------------------------------------------- /climb/__pycache__/dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/climb/__pycache__/dataset.cpython-39.pyc -------------------------------------------------------------------------------- /climb/__pycache__/loss.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/climb/__pycache__/loss.cpython-310.pyc -------------------------------------------------------------------------------- /climb/__pycache__/loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/climb/__pycache__/loss.cpython-39.pyc -------------------------------------------------------------------------------- /climb/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/climb/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /climb/__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/climb/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /climb/__pycache__/optimizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/climb/__pycache__/optimizer.cpython-310.pyc -------------------------------------------------------------------------------- /climb/__pycache__/optimizer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/climb/__pycache__/optimizer.cpython-39.pyc -------------------------------------------------------------------------------- /climb/__pycache__/preprocessing.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/climb/__pycache__/preprocessing.cpython-310.pyc -------------------------------------------------------------------------------- /climb/__pycache__/preprocessing.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/climb/__pycache__/preprocessing.cpython-39.pyc -------------------------------------------------------------------------------- /climb/__pycache__/sampler.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/climb/__pycache__/sampler.cpython-310.pyc -------------------------------------------------------------------------------- /climb/__pycache__/sampler.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/climb/__pycache__/sampler.cpython-39.pyc -------------------------------------------------------------------------------- /climb/__pycache__/spmamba.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/climb/__pycache__/spmamba.cpython-310.pyc -------------------------------------------------------------------------------- /climb/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/climb/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /climb/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/climb/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /climb/__pycache__/vivim.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/climb/__pycache__/vivim.cpython-310.pyc -------------------------------------------------------------------------------- /climb/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as T 3 | from torch.utils.data import DataLoader 4 | from datasets.market1501 import Market1501 5 | # from datasets.msmt17_v2 import MSMT17_V2 6 | from datasets.msmt17 import MSMT17 7 | from .preprocessing import RandomErasing 8 | from .dataset import ImageDataset, IterLoader 9 | from .sampler import RandomIdentitySampler, RandomMultipleGallerySampler 10 | 11 | FACTORY = { 12 | 'market1501': Market1501, 13 | 'msmt17': MSMT17, 14 | } 15 | 16 | 17 | def make_CLIMB_dataloader(cfg, all_iters=False): 18 | """ 19 | PCL dataloader. It returns 3 dataloaders: training loader, cluster loader and validation loader. 20 | - For training loader, PK sampling is applied to select K instances from P classes. 21 | - For cluster loader, a plain loader is returned with validation augmentation but on 22 | training samples. 23 | - For validation loader, a validation loader is returned on test samples. 24 | 25 | Args: 26 | - dataset: dataset object. 27 | - all_iters: if `all_iters=True`, number training iteration is decided by `num_samples//batchsize` 28 | """ 29 | 30 | dataset = FACTORY[cfg.DATASETS.NAMES](root=cfg.DATASETS.ROOT_DIR) 31 | num_workers = cfg.DATALOADER.NUM_WORKERS 32 | num_classes = dataset.num_train_pids 33 | cam_num = dataset.num_train_cams 34 | view_num = dataset.num_train_vids 35 | 36 | # train loader 37 | train_transforms = T.Compose([ 38 | T.Resize(cfg.INPUT.SIZE_TRAIN, interpolation=3), 39 | T.RandomHorizontalFlip(p=cfg.INPUT.PROB), 40 | T.Pad(cfg.INPUT.PADDING), 41 | T.RandomCrop(cfg.INPUT.SIZE_TRAIN), 42 | T.ToTensor(), 43 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD), 44 | RandomErasing(probability=cfg.INPUT.RE_PROB, mean=cfg.INPUT.PIXEL_MEAN) 45 | ]) 46 | train_set = ImageDataset(dataset.train, train_transforms) 47 | # sampler = RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE) 48 | sampler = RandomMultipleGallerySampler(dataset.train, cfg.DATALOADER.NUM_INSTANCE) 49 | train_loader = DataLoader( 50 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, 51 | sampler=sampler, 52 | num_workers=num_workers 53 | ) 54 | train_loader = IterLoader(train_loader, cfg.SOLVER.ITERS if not all_iters else None) 55 | 56 | # val loader 57 | val_transforms = T.Compose([ 58 | T.Resize(cfg.INPUT.SIZE_TEST, interpolation=3), 59 | T.ToTensor(), 60 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD) 61 | ]) 62 | val_set = ImageDataset(dataset.query+dataset.gallery, val_transforms) 63 | num_queries = len(dataset.query) 64 | val_loader = DataLoader( 65 | val_set, batch_size=4096, shuffle=False, num_workers=num_workers 66 | ) 67 | 68 | # cluster loader 69 | cluster_set = ImageDataset(dataset.train, val_transforms) 70 | cluster_loader = DataLoader( 71 | cluster_set, batch_size=4096, shuffle=False, num_workers=num_workers 72 | ) 73 | 74 | return train_loader, val_loader, cluster_loader, num_queries, num_classes, cam_num, view_num 75 | -------------------------------------------------------------------------------- /climb/dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageFile 2 | 3 | from torch.utils.data import Dataset 4 | import os.path as osp 5 | import random 6 | import torch 7 | ImageFile.LOAD_TRUNCATED_IMAGES = True 8 | 9 | 10 | def read_image(img_path): 11 | """Keep reading image until succeed. 12 | This can avoid IOError incurred by heavy IO process.""" 13 | got_img = False 14 | if not osp.exists(img_path): 15 | raise IOError("{} does not exist".format(img_path)) 16 | while not got_img: 17 | try: 18 | img = Image.open(img_path).convert('RGB') 19 | got_img = True 20 | except IOError: 21 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 22 | pass 23 | return img 24 | 25 | 26 | class BaseDataset(object): 27 | """ 28 | Base class of reid dataset 29 | """ 30 | 31 | def get_imagedata_info(self, data): 32 | pids, cams, tracks = [], [], [] 33 | for _, pid, camid, trackid in data: 34 | pids += [pid] 35 | cams += [camid] 36 | tracks += [trackid] 37 | pids = set(pids) 38 | cams = set(cams) 39 | tracks = set(tracks) 40 | num_pids = len(pids) 41 | num_cams = len(cams) 42 | num_imgs = len(data) 43 | num_views = len(tracks) 44 | return num_pids, num_imgs, num_cams, num_views 45 | 46 | def print_dataset_statistics(self): 47 | raise NotImplementedError 48 | 49 | 50 | class BaseImageDataset(BaseDataset): 51 | """ 52 | Base class of image reid dataset 53 | """ 54 | 55 | def print_dataset_statistics(self, train, query, gallery): 56 | num_train_pids, num_train_imgs, num_train_cams, num_train_views = self.get_imagedata_info(train) 57 | num_query_pids, num_query_imgs, num_query_cams, num_train_views = self.get_imagedata_info(query) 58 | num_gallery_pids, num_gallery_imgs, num_gallery_cams, num_train_views = self.get_imagedata_info(gallery) 59 | 60 | print("Dataset statistics:") 61 | print(" ----------------------------------------") 62 | print(" subset | # ids | # images | # cameras") 63 | print(" ----------------------------------------") 64 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 65 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 66 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 67 | print(" ----------------------------------------") 68 | 69 | 70 | class ImageDataset(Dataset): 71 | def __init__(self, dataset, transform=None): 72 | self.dataset = dataset 73 | self.transform = transform 74 | 75 | def __len__(self): 76 | return len(self.dataset) 77 | 78 | def __getitem__(self, index): 79 | img_path, pid, camid, trackid = self.dataset[index] 80 | img = read_image(img_path) 81 | 82 | if self.transform is not None: 83 | img = self.transform(img) 84 | 85 | return img, pid, camid, trackid 86 | 87 | class PseudoLabelImageDataset(ImageDataset): 88 | def __init__(self, dataset, transform=None): 89 | super().__init__(dataset, transform) 90 | 91 | def __getitem__(self, index): 92 | # override to return pseudo ID 93 | img_path, pid, camid, trackid, pseudo_id = self.dataset[index] 94 | img = read_image(img_path) 95 | 96 | if self.transform is not None: 97 | img = self.transform(img) 98 | 99 | return img, pid, camid, trackid, pseudo_id 100 | 101 | class IterLoader: 102 | def __init__(self, loader, length=None): 103 | self.loader = loader 104 | self.length = length 105 | self.iter = None 106 | 107 | def __len__(self): 108 | if (self.length is not None): 109 | return self.length 110 | return len(self.loader) 111 | 112 | def new_epoch(self): 113 | self.iter = iter(self.loader) 114 | 115 | def next(self): 116 | try: 117 | return next(self.iter) 118 | except: 119 | self.iter = iter(self.loader) 120 | return next(self.iter) 121 | -------------------------------------------------------------------------------- /climb/loss.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from abc import ABC 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import autograd, nn 8 | from torch.cuda import amp 9 | 10 | 11 | class CM(autograd.Function): 12 | 13 | @staticmethod 14 | @amp.custom_fwd 15 | def forward(ctx, inputs, targets, features, momentum): 16 | ctx.features = features 17 | ctx.momentum = momentum 18 | ctx.save_for_backward(inputs, targets) 19 | outputs = inputs.mm(ctx.features.t()) 20 | 21 | return outputs 22 | 23 | @staticmethod 24 | @amp.custom_bwd 25 | def backward(ctx, grad_outputs): 26 | inputs, targets = ctx.saved_tensors 27 | grad_inputs = None 28 | if ctx.needs_input_grad[0]: 29 | grad_inputs = grad_outputs.mm(ctx.features) 30 | 31 | # momentum update 32 | for x, y in zip(inputs, targets): 33 | ctx.features[y] = ctx.momentum * ctx.features[y] + (1. - ctx.momentum) * x 34 | ctx.features[y] /= ctx.features[y].norm() 35 | 36 | return grad_inputs, None, None, None 37 | 38 | 39 | 40 | 41 | class CM_Hard(autograd.Function): 42 | 43 | @staticmethod 44 | @amp.custom_fwd 45 | def forward(ctx, inputs, targets, features, momentum): 46 | ctx.features = features 47 | ctx.momentum = momentum 48 | ctx.save_for_backward(inputs, targets) 49 | outputs = inputs.mm(ctx.features.t()) 50 | 51 | return outputs 52 | 53 | @staticmethod 54 | @amp.custom_bwd 55 | def backward(ctx, grad_outputs): 56 | inputs, targets = ctx.saved_tensors 57 | grad_inputs = None 58 | if ctx.needs_input_grad[0]: 59 | grad_inputs = grad_outputs.mm(ctx.features) 60 | 61 | batch_centers = collections.defaultdict(list) 62 | for instance_feature, index in zip(inputs, targets.tolist()): 63 | batch_centers[index].append(instance_feature) 64 | 65 | for index, features in batch_centers.items(): 66 | distances = [] 67 | for feature in features: 68 | distance = feature.unsqueeze(0).mm(ctx.features[index].unsqueeze(0).t())[0][0] 69 | distances.append(distance.cpu().numpy()) 70 | 71 | median = np.argmin(np.array(distances)) 72 | ctx.features[index] = ctx.features[index] * ctx.momentum + (1 - ctx.momentum) * features[median] 73 | ctx.features[index] /= ctx.features[index].norm() 74 | 75 | return grad_inputs, None, None, None 76 | 77 | def cm(inputs, indexes, features, momentum=0.5): 78 | return CM.apply(inputs, indexes, features, torch.Tensor([momentum]).to(inputs.device)) 79 | 80 | def cm_hard(inputs, indexes, features, momentum=0.5): 81 | return CM_Hard.apply(inputs, indexes, features, torch.Tensor([momentum]).to(inputs.device)) 82 | 83 | 84 | class ClusterMemoryAMP(nn.Module, ABC): 85 | def __init__(self, temp=0.05, momentum=0.2, use_hard=False): 86 | super(ClusterMemoryAMP, self).__init__() 87 | self.momentum = momentum 88 | self.temp = temp 89 | self.use_hard = use_hard 90 | self.features = None 91 | 92 | def forward(self, inputs, targets, cams=None, epoch=None): 93 | inputs = F.normalize(inputs, dim=1).cuda() 94 | if self.use_hard: 95 | outputs = cm_hard(inputs, targets, self.features, self.momentum) 96 | else: 97 | outputs = cm(inputs, targets, self.features, self.momentum) 98 | 99 | outputs /= self.temp 100 | loss = F.cross_entropy(outputs, targets) 101 | return loss 102 | 103 | 104 | class CrossEntropyLabelSmooth(nn.Module): 105 | """Cross entropy loss with label smoothing regularizer. 106 | 107 | Reference: 108 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 109 | Equation: y = (1 - epsilon) * y + epsilon / K. 110 | 111 | Args: 112 | num_classes (int): number of classes. 113 | epsilon (float): weight. 114 | """ 115 | 116 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 117 | super(CrossEntropyLabelSmooth, self).__init__() 118 | self.num_classes = num_classes 119 | self.epsilon = epsilon 120 | self.use_gpu = use_gpu 121 | self.logsoftmax = nn.LogSoftmax(dim=1) 122 | 123 | def forward(self, inputs, targets): 124 | """ 125 | Args: 126 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 127 | targets: ground truth labels with shape (num_classes) 128 | """ 129 | log_probs = self.logsoftmax(inputs) 130 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 131 | if self.use_gpu: targets = targets.cuda() 132 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 133 | loss = (- targets * log_probs).mean(0).sum() 134 | return loss 135 | 136 | 137 | from turtle import pd 138 | import torch 139 | from torch import nn 140 | 141 | def normalize(x, axis=-1): 142 | """Normalizing to unit length along the specified dimension. 143 | Args: 144 | x: pytorch Variable 145 | Returns: 146 | x: pytorch Variable, same shape as input 147 | """ 148 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 149 | return x 150 | 151 | 152 | def euclidean_dist(x, y): 153 | """ 154 | Args: 155 | x: pytorch Variable, with shape [m, d] 156 | y: pytorch Variable, with shape [n, d] 157 | Returns: 158 | dist: pytorch Variable, with shape [m, n] 159 | """ 160 | m, n = x.size(0), y.size(0) 161 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) #B, B 162 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 163 | dist = xx + yy 164 | dist = dist - 2 * torch.matmul(x, y.t()) 165 | # dist.addmm_(1, -2, x, y.t()) 166 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 167 | return dist 168 | 169 | 170 | def cosine_dist(x, y): 171 | """ 172 | Args: 173 | x: pytorch Variable, with shape [m, d] 174 | y: pytorch Variable, with shape [n, d] 175 | Returns: 176 | dist: pytorch Variable, with shape [m, n] 177 | """ 178 | m, n = x.size(0), y.size(0) 179 | x_norm = torch.pow(x, 2).sum(1, keepdim=True).sqrt().expand(m, n) 180 | y_norm = torch.pow(y, 2).sum(1, keepdim=True).sqrt().expand(n, m).t() 181 | xy_intersection = torch.mm(x, y.t()) 182 | dist = xy_intersection/(x_norm * y_norm) 183 | dist = (1. - dist) / 2 184 | return dist 185 | 186 | 187 | def hard_example_mining(dist_mat, labels, return_inds=False): 188 | """For each anchor, find the hardest positive and negative sample. 189 | Args: 190 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 191 | labels: pytorch LongTensor, with shape [N] 192 | return_inds: whether to return the indices. Save time if `False`(?) 193 | Returns: 194 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 195 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 196 | p_inds: pytorch LongTensor, with shape [N]; 197 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 198 | n_inds: pytorch LongTensor, with shape [N]; 199 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 200 | NOTE: Only consider the case in which all labels have same num of samples, 201 | thus we can cope with all anchors in parallel. 202 | """ 203 | assert len(dist_mat.size()) == 2 204 | assert dist_mat.size(0) == dist_mat.size(1) 205 | N = dist_mat.size(0) 206 | 207 | # shape [N, N] 208 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 209 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) 210 | 211 | # `dist_ap` means distance(anchor, positive) 212 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 213 | dist_ap, relative_p_inds = torch.max( 214 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) 215 | # print(dist_mat[is_pos].shape) 216 | # `dist_an` means distance(anchor, negative) 217 | # both `dist_an` and `relative_n_inds` with shape [N, 1] 218 | dist_an, relative_n_inds = torch.min( 219 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) 220 | # shape [N] 221 | dist_ap = dist_ap.squeeze(1) 222 | dist_an = dist_an.squeeze(1) 223 | 224 | if return_inds: 225 | # shape [N, N] 226 | ind = (labels.new().resize_as_(labels) 227 | .copy_(torch.arange(0, N).long()) 228 | .unsqueeze(0).expand(N, N)) 229 | # shape [N, 1] 230 | p_inds = torch.gather( 231 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 232 | n_inds = torch.gather( 233 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 234 | # shape [N] 235 | p_inds = p_inds.squeeze(1) 236 | n_inds = n_inds.squeeze(1) 237 | return dist_ap, dist_an, p_inds, n_inds 238 | 239 | return dist_ap, dist_an 240 | 241 | 242 | class TripletLoss(object): 243 | """ 244 | Triplet loss using HARDER example mining, 245 | modified based on original triplet loss using hard example mining 246 | """ 247 | 248 | def __init__(self, margin=None, hard_factor=0.0): 249 | self.margin = margin 250 | self.hard_factor = hard_factor 251 | if margin is not None: 252 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 253 | else: 254 | self.ranking_loss = nn.SoftMarginLoss() 255 | 256 | def __call__(self, global_feat, labels, normalize_feature=False): 257 | if normalize_feature: 258 | global_feat = normalize(global_feat, axis=-1) 259 | dist_mat = euclidean_dist(global_feat, global_feat) #B,B 260 | dist_ap, dist_an = hard_example_mining(dist_mat, labels) 261 | 262 | dist_ap *= (1.0 + self.hard_factor) 263 | dist_an *= (1.0 - self.hard_factor) 264 | 265 | y = dist_an.new().resize_as_(dist_an).fill_(1) 266 | if self.margin is not None: 267 | loss = self.ranking_loss(dist_an, dist_ap, y) 268 | else: 269 | loss = self.ranking_loss(dist_an - dist_ap, y) 270 | return loss -------------------------------------------------------------------------------- /climb/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from timm.models.layers import trunc_normal_ 4 | import os.path 5 | import torch.nn.functional as F 6 | 7 | 8 | from .vivim import MambaLayer 9 | from .spmamba import VSSBlock 10 | from mamba.mamba_ssm.modules.srmamba import SRMamba 11 | from mamba.mamba_ssm.modules.bimamba import BiMamba 12 | from mamba.mamba_ssm.modules.mamba_simple import Mamba 13 | def weights_init_kaiming(m): 14 | classname = m.__class__.__name__ 15 | if classname.find('Linear') != -1: 16 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 17 | nn.init.constant_(m.bias, 0.0) 18 | 19 | elif classname.find('Conv') != -1: 20 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 21 | if m.bias is not None: 22 | nn.init.constant_(m.bias, 0.0) 23 | elif classname.find('BatchNorm') != -1: 24 | if m.affine: 25 | nn.init.constant_(m.weight, 1.0) 26 | nn.init.constant_(m.bias, 0.0) 27 | 28 | def weights_init_classifier(m): 29 | classname = m.__class__.__name__ 30 | if classname.find('Linear') != -1: 31 | nn.init.normal_(m.weight, std=0.001) 32 | if m.bias: 33 | nn.init.constant_(m.bias, 0.0) 34 | 35 | import clip.clip as clip 36 | def load_clip_to_cpu(backbone_name, h_resolution, w_resolution, vision_stride_size): 37 | url = clip._MODELS[backbone_name] 38 | model_path1 = '/dataset_cc/Pretrain-models/ViT-B-16.pt' # 不用下载,用下载好的 39 | model_path2 = '/YCY/Pretrained_models/ViT-B-16.pt' # 不用下载,用下载好的 40 | if os.path.exists(model_path1): 41 | model_path = model_path1 42 | elif os.path.exists(model_path2): 43 | model_path = model_path2 44 | else: 45 | model_path = clip._download(url) 46 | try: 47 | # loading JIT archive 48 | model = torch.jit.load(model_path, map_location="cpu").eval() 49 | state_dict = None 50 | 51 | except RuntimeError: 52 | state_dict = torch.load(model_path, map_location="cpu") 53 | 54 | model = clip.build_model(state_dict or model.state_dict(), h_resolution, w_resolution, vision_stride_size) 55 | 56 | return model 57 | 58 | class CLIMB(nn.Module): 59 | def __init__(self, num_classes, camera_num, view_num, cfg): 60 | super(CLIMB, self).__init__() 61 | self.model_name = cfg.MODEL.NAME 62 | 63 | self.in_planes = 768 64 | self.in_planes_proj = 512 65 | self.camera_num = camera_num 66 | self.view_num = view_num 67 | self.sie_coe = cfg.MODEL.SIE_COE 68 | 69 | self.bottleneck = nn.BatchNorm1d(self.in_planes) 70 | self.bottleneck.bias.requires_grad_(False) 71 | self.bottleneck.apply(weights_init_kaiming) 72 | 73 | self.bottleneck_proj = nn.BatchNorm1d(self.in_planes_proj) 74 | self.bottleneck_proj.bias.requires_grad_(False) 75 | self.bottleneck_proj.apply(weights_init_kaiming) 76 | 77 | self.classifier = nn.Linear(self.in_planes_proj+self.in_planes, num_classes, bias=False) 78 | self.classifier.apply(weights_init_classifier) 79 | 80 | self.h_resolution = int((cfg.INPUT.SIZE_TRAIN[0]-16)//cfg.MODEL.STRIDE_SIZE[0] + 1) 81 | self.w_resolution = int((cfg.INPUT.SIZE_TRAIN[1]-16)//cfg.MODEL.STRIDE_SIZE[1] + 1) 82 | self.vision_stride_size = cfg.MODEL.STRIDE_SIZE[0] 83 | clip_model = load_clip_to_cpu(self.model_name, self.h_resolution, self.w_resolution, self.vision_stride_size) 84 | clip_model.to("cuda") 85 | 86 | self.image_encoder = clip_model.visual 87 | 88 | # Trick: freeze patch projection for improved stability 89 | # https://arxiv.org/pdf/2104.02057.pdf 90 | for _, v in self.image_encoder.conv1.named_parameters(): 91 | v.requires_grad_(False) 92 | print('Freeze patch projection layer with shape {}'.format(self.image_encoder.conv1.weight.shape)) 93 | 94 | if cfg.MODEL.SIE_CAMERA and cfg.MODEL.SIE_VIEW: 95 | self.cv_embed = nn.Parameter(torch.zeros(camera_num * view_num, self.in_planes)) 96 | trunc_normal_(self.cv_embed, std=.02) 97 | print('camera number is : {}'.format(camera_num)) 98 | elif cfg.MODEL.SIE_CAMERA: 99 | self.cv_embed = nn.Parameter(torch.zeros(camera_num, self.in_planes)) 100 | trunc_normal_(self.cv_embed, std=.02) 101 | print('camera number is : {}'.format(camera_num)) 102 | elif cfg.MODEL.SIE_VIEW: 103 | self.cv_embed = nn.Parameter(torch.zeros(view_num, self.in_planes)) 104 | trunc_normal_(self.cv_embed, std=.02) 105 | print('camera number is : {}'.format(view_num)) 106 | 107 | self.classifier2 = nn.Linear(self.in_planes, num_classes, bias=False) 108 | self.classifier2.apply(weights_init_classifier) 109 | self.bottleneck_proj_sp = nn.BatchNorm1d(self.in_planes) 110 | self.bottleneck_proj_sp.bias.requires_grad_(False) 111 | self.bottleneck_proj_sp.apply(weights_init_kaiming) 112 | self.sp_mamba_bi = nn.Sequential( 113 | nn.LayerNorm(768), 114 | BiMamba( 115 | d_model=768, 116 | d_state=16, 117 | d_conv=4, 118 | expand=2, 119 | ), 120 | ) 121 | self.sp_mamba_raw = nn.Sequential( 122 | nn.LayerNorm(768), 123 | Mamba( 124 | d_model=768, 125 | d_state=16, 126 | d_conv=4, 127 | expand=2, 128 | ), 129 | ) 130 | self.norm2_mamba = nn.LayerNorm(768) 131 | self.norm3_mamba = nn.LayerNorm(768) 132 | self.sp_attention = nn.Sequential( 133 | nn.Linear(768, 192), 134 | nn.Tanh(), 135 | nn.Linear(192, 1) 136 | ) 137 | 138 | def reorder(self, reference, raw): 139 | 140 | # attention_map = attention_map.mean(axis=1) # torch.Size([64, 50, 50]) 141 | reference_norm = F.normalize(reference, dim=-1).unsqueeze(1) # bt, 1, 768 142 | raw_norm = F.normalize(raw, dim=-1) # bt, 128, 768 143 | raw_norm = torch.transpose(raw_norm, 1, 2) # bt, 768, 128 144 | sim = torch.bmm(reference_norm, raw_norm).squeeze(1) # [bt, 1, 768] [bt, 768, 128]= [bt, 1, 128] 145 | 146 | sorted, indices = torch.sort(sim, descending=True) 147 | 148 | selected_patch_embedding = [] 149 | for i in range(indices.size(0)): #bs 150 | all_patch_embeddings_i = raw[i, :,:].squeeze() # torch.Size([128, 768]) 151 | top_k_embedding = torch.index_select(all_patch_embeddings_i, 0, indices[i]) # torch.Size([128, 768]) 152 | top_k_embedding = top_k_embedding.unsqueeze(0) # torch.Size([1, 128, 768]) 153 | selected_patch_embedding.append(top_k_embedding) 154 | selected_patch_embedding = torch.cat(selected_patch_embedding, 0) # torch.Size([64, 128, 768]) 155 | 156 | return selected_patch_embedding 157 | 158 | def forward(self, x, get_image = False, cam_label= None, view_label=None): 159 | if get_image == True: 160 | if cam_label != None and view_label!=None: 161 | cv_embed = self.sie_coe * self.cv_embed[cam_label * self.view_num + view_label] 162 | elif cam_label != None: 163 | cv_embed = self.sie_coe * self.cv_embed[cam_label] 164 | elif view_label!=None: 165 | cv_embed = self.sie_coe * self.cv_embed[view_label] 166 | else: 167 | cv_embed = None 168 | _, image_features, image_features_proj, = self.image_encoder(x, cv_embed) 169 | img_feature = image_features[:,0] 170 | img_feature_proj = image_features_proj[:,0] 171 | 172 | feat = self.bottleneck(img_feature) 173 | feat_proj = self.bottleneck_proj(img_feature_proj) 174 | 175 | out_feat = torch.cat([feat, feat_proj], dim=1) 176 | return out_feat 177 | 178 | if cam_label != None and view_label != None: 179 | cv_embed = self.sie_coe * self.cv_embed[cam_label * self.view_num + view_label] 180 | elif cam_label != None: 181 | cv_embed = self.sie_coe * self.cv_embed[cam_label] 182 | elif view_label != None: 183 | cv_embed = self.sie_coe * self.cv_embed[view_label] 184 | else: 185 | cv_embed = None 186 | _, image_features, image_features_proj, = self.image_encoder(x, cv_embed) 187 | img_feature = image_features[:, 0] 188 | img_feature_proj = image_features_proj[:, 0] 189 | 190 | feat = self.bottleneck(img_feature) 191 | feat_proj = self.bottleneck_proj(img_feature_proj) 192 | 193 | out_feat = torch.cat([feat, feat_proj], dim=1) 194 | 195 | feats_for_mamba = image_features.detach() # torch.Size([64, 129, 768]) 196 | # BT, hw, D => BT, D, hw => B, T, D, hw => B, D, T, hw => B, D, T, h, w 197 | # feats_for_mamba = feats_for_mamba.permute(0, 2, 1) # torch.Size([64, 768, 128]) 198 | feats_for_mamba_sp = feats_for_mamba[:, 1:, :].detach() 199 | feats_for_mamba_cls = feats_for_mamba[:, 0, :].detach() # torch.Size([64, 768]) 200 | #### reorder 201 | re_order_mamba_sp = self.reorder(feats_for_mamba_cls, feats_for_mamba_sp) 202 | 203 | B, num_token, D = re_order_mamba_sp.shape 204 | # re_order_mamba_sp = re_order_mamba_sp.reshape(BT, self.h_resolution, self.w_resolution, 205 | # D) # torch.Size([64, 16, 8, 768]) 206 | # mamba_sp_out = self.sp_mamba_raw(re_order_mamba_sp) # torch.Size([64, 128, 768]) 207 | mamba_sp_out = self.sp_mamba_bi(re_order_mamba_sp) # torch.Size([64, 16, 8, 768]) 208 | # mamba_sp_out = mamba_sp_out.reshape(B, self.h_resolution * self.w_resolution, D).contiguous() 209 | mamba_sp_out = torch.cat((feats_for_mamba_cls.unsqueeze(1), mamba_sp_out), dim=1) # torch.Size([64, 129, 768]) 210 | # mamba_sp_out = mamba_sp_out.mean(1) # torch.Size([64, 768]) 211 | mamba_sp_out2 = self.norm2_mamba(mamba_sp_out) # bt, 128, 768 212 | A = self.sp_attention(mamba_sp_out2) # [B, n, K] # torch.Size([8, 1024, 1]) 213 | A = torch.transpose(A, 1, 2) # torch.Size([8, 1, 1024]) 214 | A = F.softmax(A, dim=-1) # [B, K, n] # torch.Size([8, 1, 1024]) 215 | mamba_sp_out2 = torch.bmm(A, mamba_sp_out2) # [B, K, 512] torch.Size([8, 1, 512]) 216 | mamba_sp_out2 = mamba_sp_out2.squeeze(1) # torch.Size([64, 768]) 217 | feat_sp = self.bottleneck_proj_sp(mamba_sp_out2) 218 | 219 | if self.training: 220 | logit = self.classifier(out_feat) 221 | logitsp = self.classifier2(feat_sp) 222 | return out_feat, logit, feat_sp, logitsp 223 | else: 224 | feat_concat = torch.cat((out_feat, feat_sp), dim=1) 225 | return feat_concat, out_feat, feat_sp 226 | 227 | 228 | 229 | def load_param(self, trained_path): 230 | param_dict = torch.load(trained_path) 231 | for i in param_dict: 232 | if not self.training and 'classifier' in i: 233 | continue # ignore classifier weights in evaluation 234 | self.state_dict()[i.replace('module.', '')].copy_(param_dict[i]) 235 | print('Loading pretrained model from {}'.format(trained_path)) 236 | 237 | def load_param_finetune(self, model_path): 238 | param_dict = torch.load(model_path) 239 | for i in param_dict: 240 | self.state_dict()[i].copy_(param_dict[i]) 241 | print('Loading pretrained model for finetuning from {}'.format(model_path)) 242 | 243 | 244 | def make_model(cfg, num_classes, camera_num, view_num): 245 | model = CLIMB(num_classes, camera_num, view_num, cfg) 246 | return model -------------------------------------------------------------------------------- /climb/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def make_CLIMB_optimizer(cfg, model): 4 | params = [] 5 | 6 | for key, value in model.named_parameters(): 7 | if not value.requires_grad: 8 | continue 9 | lr = cfg.SOLVER.BASE_LR 10 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 11 | if "bias" in key: 12 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 13 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 14 | # if cfg.SOLVER.LARGE_FC_LR: 15 | if "classifier2" in key: 16 | lr = cfg.SOLVER.BASE_LR * 10 17 | print('Using 10 times learning rate for fc ') 18 | 19 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 20 | 21 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': 22 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) 23 | elif cfg.SOLVER.OPTIMIZER_NAME == 'AdamW': 24 | optimizer = torch.optim.AdamW(params, lr=cfg.SOLVER.BASE_LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 25 | else: 26 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) 27 | 28 | return optimizer -------------------------------------------------------------------------------- /climb/preprocessing.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | 4 | class RandomErasing(object): 5 | """ Randomly selects a rectangle region in an image and erases its pixels. 6 | 'Random Erasing Data Augmentation' by Zhong et al. 7 | See https://arxiv.org/pdf/1708.04896.pdf 8 | Args: 9 | probability: The probability that the Random Erasing operation will be performed. 10 | sl: Minimum proportion of erased area against input image. 11 | sh: Maximum proportion of erased area against input image. 12 | r1: Minimum aspect ratio of erased area. 13 | mean: Erasing value. 14 | """ 15 | 16 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 17 | self.probability = probability 18 | self.mean = mean 19 | self.sl = sl 20 | self.sh = sh 21 | self.r1 = r1 22 | 23 | def __call__(self, img): 24 | 25 | if random.uniform(0, 1) >= self.probability: 26 | return img 27 | 28 | for attempt in range(100): 29 | area = img.size()[1] * img.size()[2] 30 | 31 | target_area = random.uniform(self.sl, self.sh) * area 32 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 33 | 34 | h = int(round(math.sqrt(target_area * aspect_ratio))) 35 | w = int(round(math.sqrt(target_area / aspect_ratio))) 36 | 37 | if w < img.size()[2] and h < img.size()[1]: 38 | x1 = random.randint(0, img.size()[1] - h) 39 | y1 = random.randint(0, img.size()[2] - w) 40 | if img.size()[0] == 3: 41 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 42 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 43 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 44 | else: 45 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 46 | return img 47 | 48 | return img -------------------------------------------------------------------------------- /climb/processor_climb.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | from utils.meter import AverageMeter 7 | from utils.metrics import R1_mAP_eval 8 | from torch.cuda import amp 9 | from .utils import * 10 | from .loss import ClusterMemoryAMP, CrossEntropyLabelSmooth, TripletLoss 11 | 12 | 13 | def train_climb(cfg, 14 | model, 15 | train_loader, 16 | val_loader, 17 | cluster_loader, 18 | optimizer, 19 | scheduler, 20 | num_query, 21 | num_classes): 22 | 23 | log_period = cfg.SOLVER.LOG_PERIOD 24 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD 25 | eval_period = cfg.SOLVER.EVAL_PERIOD 26 | 27 | device = "cuda" 28 | epochs = cfg.SOLVER.MAX_EPOCHS 29 | 30 | logger = logging.getLogger("CLIMB") 31 | logger.info('start training') 32 | 33 | # model.to(device) 34 | if device: 35 | model.to(device) 36 | if torch.cuda.device_count() > 1 and cfg.MODEL.DIST_TRAIN: 37 | print('Using {} GPUs for training'.format(torch.cuda.device_count())) 38 | # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=True) 39 | else: 40 | model = nn.DataParallel(model).cuda() 41 | 42 | loss_meter = AverageMeter() 43 | loss_meter1 = AverageMeter() 44 | loss_meter2 = AverageMeter() 45 | loss_meter3 = AverageMeter() 46 | loss_meter4 = AverageMeter() 47 | acc_meter = AverageMeter() 48 | acc_meter1 = AverageMeter() 49 | xent = CrossEntropyLabelSmooth(num_classes) 50 | tri_loss = TripletLoss() 51 | logger.info(f'smoothed cross entropy loss on {num_classes} classes.') 52 | 53 | evaluator = R1_mAP_eval(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM) 54 | # scaler = amp.GradScaler() 55 | best_performance = 0 56 | best_epoch = 1 57 | # training epochs 58 | for epoch in range(1, epochs+1): 59 | loss_meter.reset() 60 | loss_meter1.reset() 61 | loss_meter2.reset() 62 | loss_meter3.reset() 63 | loss_meter4.reset() 64 | acc_meter.reset() 65 | acc_meter1.reset() 66 | 67 | evaluator.reset() 68 | 69 | # create memory bank 70 | image_features, gt_labels = extract_image_features(model, cluster_loader, use_amp=True) 71 | image_features = image_features.float() 72 | image_features = F.normalize(image_features, dim=1) 73 | 74 | num_classes = len(gt_labels.unique()) - 1 if -1 in gt_labels else len(gt_labels.unique()) 75 | logger.info(f'Memory has {num_classes} classes.') 76 | 77 | train_loader.new_epoch() 78 | 79 | # CAP memory 80 | memory = ClusterMemoryAMP(momentum=cfg.MODEL.MEMORY_MOMENTUM, use_hard=True).to(device) 81 | memory.features = compute_cluster_centroids(image_features, gt_labels).to(device) 82 | logger.info('Create memory bank with shape = {}'.format(memory.features.shape)) 83 | 84 | # train one iteration 85 | model.train() 86 | num_iters = len(train_loader) 87 | for n_iter in range(num_iters): 88 | img, target, target_cam, _ = train_loader.next() 89 | 90 | optimizer.zero_grad() 91 | 92 | img = img.to(device) 93 | target = target.to(device) 94 | target_cam = target_cam.to(device) 95 | 96 | if cfg.MODEL.SIE_CAMERA: 97 | target_cam = target_cam.to(device) 98 | else: 99 | target_cam = None 100 | if cfg.MODEL.SIE_VIEW: 101 | target_view = target_view.to(device) 102 | else: 103 | target_view = None 104 | 105 | # with amp.autocast(enabled=True): 106 | feat, logits, feat_sp, logits_sp = model(img, cam_label=target_cam, view_label=target_view) 107 | loss1 = memory(feat, target) * cfg.MODEL.PCL_LOSS_WEIGHT 108 | # if cfg.MODEL.ID_LOSS_WEIGHT > 0: 109 | loss_id = xent(logits, target) * cfg.MODEL.ID_LOSS_WEIGHT 110 | loss_id2 = xent(logits_sp, target) 111 | loss_tri = tri_loss(feat_sp, target) 112 | loss = loss1 + loss_id + loss_id2 + loss_tri 113 | 114 | loss.backward() 115 | optimizer.step() 116 | 117 | # scaler.step(optimizer) 118 | # scaler.update() 119 | acc = (logits.max(1)[1] == target).float().mean() 120 | acc2 = (logits_sp.max(1)[1] == target).float().mean() 121 | 122 | loss_meter.update(loss.item(), img.shape[0]) 123 | loss_meter1.update(loss1.item(), img.shape[0]) 124 | loss_meter2.update(loss_id.item(), img.shape[0]) 125 | loss_meter3.update(loss_id2.item(), img.shape[0]) 126 | loss_meter4.update(loss_tri.item(), img.shape[0]) 127 | acc_meter.update(acc, 1) 128 | acc_meter1.update(acc2, 1) 129 | 130 | torch.cuda.synchronize() 131 | 132 | if (n_iter + 1) % log_period == 0: 133 | logger.info("Epoch[{}] Iteration[{}/{}] " 134 | "Loss_total: {:.3f}, " 135 | "Loss1: {:.3f}, " 136 | "Loss2: {:.3f}, " 137 | "Loss3: {:.3f}, " 138 | "Loss4: {:.3f}, " 139 | "acc1: {:.3f}," 140 | "acc2: {:.3f}," 141 | "Lr: {:.2e}" 142 | .format(epoch, (n_iter + 1), len(train_loader), 143 | loss_meter.avg, 144 | loss_meter1.avg, 145 | loss_meter2.avg, 146 | loss_meter3.avg, 147 | loss_meter4.avg, 148 | acc_meter.avg, 149 | acc_meter1.avg, 150 | scheduler.get_lr()[0])) 151 | 152 | scheduler.step() 153 | logger.info("Epoch {} done.".format(epoch)) 154 | 155 | if epoch % eval_period == 0 and epoch >= 55: 156 | model.eval() 157 | for n_iter, (img, vid, camid, _) in enumerate(val_loader): 158 | with torch.no_grad(): 159 | img = img.to(device) 160 | if cfg.MODEL.SIE_CAMERA: 161 | camids = camid.to(device) 162 | else: 163 | camids = None 164 | if cfg.MODEL.SIE_VIEW: 165 | target_view = target_view.to(device) 166 | else: 167 | target_view = None 168 | feat, feat1, feat2 = model(img, cam_label=camids, view_label=target_view) 169 | evaluator.update((feat, feat1, feat2, vid, camid)) 170 | cmc, mAP, cmc01, mAP01, cmc02, mAP02, cmc03, mAP03 = evaluator.compute() 171 | logger.info("Validation Results - Epoch: {}".format(epoch)) 172 | logger.info("mAP: {:.1%}".format(mAP)) 173 | for r in [1, 5, 10, 20]: 174 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 175 | 176 | logger.info("mAP_1: {:.1%}".format(mAP01)) 177 | for r in [1, 5, 10, 20]: 178 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc01[r - 1])) 179 | logger.info("mAP_2: {:.1%}".format(mAP02)) 180 | for r in [1, 5, 10, 20]: 181 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc02[r - 1])) 182 | logger.info("mAP_3: {:.1%}".format(mAP03)) 183 | for r in [1, 5, 10, 20]: 184 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc03[r - 1])) 185 | torch.cuda.empty_cache() 186 | prec1 = cmc[0] + mAP 187 | is_best = prec1 > best_performance 188 | best_performance = max(prec1, best_performance) 189 | if is_best: 190 | best_epoch = epoch 191 | save_checkpoint(model.state_dict(), is_best, os.path.join(cfg.OUTPUT_DIR, 'checkpoint_ep.pth.tar')) 192 | 193 | torch.cuda.empty_cache() 194 | logger.info("==> Best Perform {:.1%}, achieved at epoch {}".format(best_performance, best_epoch)) 195 | logger.info('Training done.') 196 | print(cfg.OUTPUT_DIR) 197 | 198 | def do_inference(cfg, 199 | model, 200 | val_loader, 201 | num_query): 202 | device = "cuda" 203 | logger = logging.getLogger("CLIMB") 204 | logger.info("Enter inferencing") 205 | model.to(device) 206 | 207 | evaluator = R1_mAP_eval(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM) 208 | evaluator.reset() 209 | 210 | 211 | model.eval() 212 | for n_iter, (img, pid, camid, _) in enumerate(val_loader): 213 | with torch.no_grad(): 214 | img = img.to(device) 215 | if cfg.MODEL.SIE_CAMERA: 216 | camids = camids.to(device) 217 | else: 218 | camids = None 219 | if cfg.MODEL.SIE_VIEW: 220 | target_view = target_view.to(device) 221 | else: 222 | target_view = None 223 | feat = model(img, cam_label=camids, view_label=target_view) 224 | evaluator.update((feat, pid, camid)) 225 | 226 | 227 | cmc, mAP, _, _, _, _, _ = evaluator.compute() 228 | logger.info("Validation Results ") 229 | logger.info("mAP: {:.1%}".format(mAP)) 230 | for r in [1, 5, 10]: 231 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 232 | return cmc[0], cmc[4] -------------------------------------------------------------------------------- /climb/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.sampler import Sampler 3 | from collections import defaultdict 4 | import copy 5 | import random 6 | import numpy as np 7 | 8 | def No_index(a, b): 9 | assert isinstance(a, list) 10 | return [i for i, j in enumerate(a) if j != b] 11 | 12 | class RandomMultipleGallerySampler(Sampler): 13 | def __init__(self, data_source, num_instances=4): 14 | super().__init__(data_source) 15 | self.data_source = data_source 16 | self.index_pid = defaultdict(int) 17 | self.pid_cam = defaultdict(list) 18 | self.pid_index = defaultdict(list) 19 | self.num_instances = num_instances 20 | 21 | for index, (_, pid, cam, _) in enumerate(data_source): 22 | if pid < 0: 23 | continue 24 | self.index_pid[index] = pid 25 | self.pid_cam[pid].append(cam) 26 | self.pid_index[pid].append(index) 27 | 28 | self.pids = list(self.pid_index.keys()) 29 | self.num_samples = len(self.pids) 30 | 31 | def __len__(self): 32 | return self.num_samples * self.num_instances 33 | 34 | def __iter__(self): 35 | indices = torch.randperm(len(self.pids)).tolist() # 打乱顺序 36 | ret = [] 37 | 38 | for kid in indices: 39 | i = random.choice(self.pid_index[self.pids[kid]]) # 随机选一个顺序索引 40 | 41 | _, i_pid, i_cam, _ = self.data_source[i] # 索引对应的ID和cam 42 | 43 | ret.append(i) 44 | 45 | pid_i = self.index_pid[i] # 58 真实ID 46 | cams = self.pid_cam[pid_i] # 33333 camID 47 | index = self.pid_index[pid_i] # 所有同ID的索引 48 | select_cams = No_index(cams, i_cam) # 想选择不同cam的样本 49 | 50 | if select_cams: 51 | if len(select_cams) >= self.num_instances: 52 | cam_indexes = np.random.choice(select_cams, size=self.num_instances-1, replace=False) 53 | else: 54 | cam_indexes = np.random.choice(select_cams, size=self.num_instances-1, replace=True) 55 | 56 | for kk in cam_indexes: 57 | ret.append(index[kk]) 58 | 59 | else: 60 | select_indexes = No_index(index, i) # 实在找不到不同cam 的样本,只能找同cam下的其他序列了 61 | if not select_indexes: 62 | continue 63 | if len(select_indexes) >= self.num_instances: # 16 64 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=False) 65 | else: 66 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=True) 67 | 68 | for kk in ind_indexes: 69 | ret.append(index[kk]) 70 | 71 | return iter(ret) 72 | 73 | 74 | class PseudoLabelSampler(Sampler): 75 | """ 76 | Random identity sampler with PK sampling on pseudo labels. 77 | Invalid labels (-1) will not be sampled. 78 | """ 79 | 80 | def __init__(self, data_source, batch_size, num_instances): 81 | self.data_source = data_source # containing pseudo labels 82 | self.batch_size = batch_size 83 | self.num_instances = num_instances 84 | self.num_pids_per_batch = self.batch_size // self.num_instances 85 | self.index_dic = defaultdict(list) #dict with list value 86 | #{783: [0, 5, 116, 876, 1554, 2041],...,} 87 | for index, (_, _, _, _, pid) in enumerate(self.data_source): 88 | # pseudo label at last 89 | # ignore noisy label -1 90 | if pid != -1: 91 | self.index_dic[pid].append(index) 92 | self.pids = list(self.index_dic.keys()) 93 | 94 | # estimate number of examples in an epoch 95 | self.length = 0 96 | for pid in self.pids: 97 | idxs = self.index_dic[pid] 98 | num = len(idxs) 99 | if num < self.num_instances: 100 | num = self.num_instances 101 | self.length += num - num % self.num_instances 102 | 103 | def __iter__(self): 104 | batch_idxs_dict = defaultdict(list) 105 | 106 | for pid in self.pids: 107 | idxs = copy.deepcopy(self.index_dic[pid]) 108 | if len(idxs) < self.num_instances: 109 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 110 | random.shuffle(idxs) 111 | batch_idxs = [] 112 | for idx in idxs: 113 | batch_idxs.append(idx) 114 | if len(batch_idxs) == self.num_instances: 115 | batch_idxs_dict[pid].append(batch_idxs) 116 | batch_idxs = [] 117 | 118 | avai_pids = copy.deepcopy(self.pids) 119 | final_idxs = [] 120 | 121 | while len(avai_pids) >= self.num_pids_per_batch: 122 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 123 | for pid in selected_pids: 124 | batch_idxs = batch_idxs_dict[pid].pop(0) 125 | final_idxs.extend(batch_idxs) 126 | if len(batch_idxs_dict[pid]) == 0: 127 | avai_pids.remove(pid) 128 | 129 | return iter(final_idxs) 130 | 131 | def __len__(self): 132 | return self.length 133 | 134 | class RandomIdentitySampler(Sampler): 135 | """ 136 | Randomly sample N identities, then for each identity, 137 | randomly sample K instances, therefore batch size is N*K. 138 | Args: 139 | - data_source (list): list of (img_path, pid, camid). 140 | - num_instances (int): number of instances per identity in a batch. 141 | - batch_size (int): number of examples in a batch. 142 | """ 143 | 144 | def __init__(self, data_source, batch_size, num_instances): 145 | self.data_source = data_source 146 | self.batch_size = batch_size 147 | self.num_instances = num_instances 148 | self.num_pids_per_batch = self.batch_size // self.num_instances 149 | self.index_dic = defaultdict(list) #dict with list value 150 | #{783: [0, 5, 116, 876, 1554, 2041],...,} 151 | for index, (_, pid, _, _) in enumerate(self.data_source): 152 | self.index_dic[pid].append(index) 153 | self.pids = list(self.index_dic.keys()) 154 | 155 | # estimate number of examples in an epoch 156 | self.length = 0 157 | for pid in self.pids: 158 | idxs = self.index_dic[pid] 159 | num = len(idxs) 160 | if num < self.num_instances: 161 | num = self.num_instances 162 | self.length += num - num % self.num_instances 163 | 164 | def __iter__(self): 165 | batch_idxs_dict = defaultdict(list) 166 | 167 | for pid in self.pids: 168 | idxs = copy.deepcopy(self.index_dic[pid]) 169 | if len(idxs) < self.num_instances: 170 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 171 | random.shuffle(idxs) 172 | batch_idxs = [] 173 | for idx in idxs: 174 | batch_idxs.append(idx) 175 | if len(batch_idxs) == self.num_instances: 176 | batch_idxs_dict[pid].append(batch_idxs) 177 | batch_idxs = [] 178 | 179 | avai_pids = copy.deepcopy(self.pids) 180 | final_idxs = [] 181 | 182 | while len(avai_pids) >= self.num_pids_per_batch: 183 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 184 | for pid in selected_pids: 185 | batch_idxs = batch_idxs_dict[pid].pop(0) 186 | final_idxs.extend(batch_idxs) 187 | if len(batch_idxs_dict[pid]) == 0: 188 | avai_pids.remove(pid) 189 | 190 | return iter(final_idxs) 191 | 192 | def __len__(self): 193 | return self.length 194 | 195 | -------------------------------------------------------------------------------- /climb/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.cuda import amp 4 | import tqdm 5 | import random 6 | import copy 7 | import numpy as np 8 | from collections import defaultdict 9 | import torch 10 | import shutil 11 | import os.path as osp 12 | import errno 13 | import os 14 | 15 | 16 | def pk_sampling(batchsize, k, pseudo_labels, samples): 17 | pseudo_labels = pseudo_labels.cpu() 18 | samples = samples.cpu() 19 | 20 | batch_idxs_dict = defaultdict(list) 21 | pids = torch.unique(pseudo_labels).cpu().tolist() 22 | 23 | for pid in pids: 24 | idxs = samples[pseudo_labels == pid].tolist() 25 | if len(idxs) < k: 26 | idxs = np.random.choice(idxs, size=k, replace=True) 27 | random.shuffle(idxs) 28 | batch_idxs = [] 29 | for idx in idxs: 30 | batch_idxs.append(idx) 31 | if len(batch_idxs) == k: 32 | batch_idxs_dict[pid].append(batch_idxs) 33 | batch_idxs = [] 34 | 35 | avai_pids = copy.deepcopy(pids) 36 | final_idxs = [] 37 | 38 | while len(avai_pids) >= (batchsize // k): 39 | selected_pids = random.sample(avai_pids, batchsize//k) 40 | for pid in selected_pids: 41 | batch_idxs = batch_idxs_dict[pid].pop(0) 42 | final_idxs.extend(batch_idxs) 43 | if len(batch_idxs_dict[pid]) == 0: 44 | avai_pids.remove(pid) 45 | 46 | final_idxs = torch.split(torch.tensor(final_idxs), batchsize) 47 | 48 | return iter(final_idxs) 49 | 50 | 51 | def extract_image_features(model, cluster_loader, use_amp=False): 52 | image_features = [] 53 | labels = [] 54 | 55 | model.eval() 56 | with torch.no_grad(): 57 | for _, (img, pid, camid, _) in enumerate(tqdm.tqdm(cluster_loader, desc='Extract image features')): 58 | img = img.cuda() 59 | target = pid.cuda() 60 | camid = camid.cuda() 61 | # with amp.autocast(enabled=use_amp): 62 | image_feature = model(img, get_image = True,cam_label= camid) 63 | for i, img_feat in zip(target, image_feature): 64 | labels.append(i) 65 | image_features.append(img_feat.cpu()) 66 | labels_list = torch.stack(labels, dim=0).cuda() 67 | image_features_list = torch.stack(image_features, dim=0).cuda() # NC 68 | return image_features_list, labels_list 69 | 70 | 71 | def cam_label_split(cluster_labels, all_img_cams): 72 | """ 73 | Split proxies using camera labels. 74 | """ 75 | proxy_labels = -1 * torch.ones(cluster_labels.shape).type_as(cluster_labels) 76 | cnt = 0 77 | for i in range(0, int(cluster_labels.max() + 1)): 78 | inds = torch.where(cluster_labels == i)[0] 79 | local_cams = all_img_cams[inds] 80 | for cc in torch.unique(local_cams): 81 | pc_inds = torch.where(local_cams == cc)[0] 82 | proxy_labels[inds[pc_inds]] = cnt 83 | cnt += 1 84 | return proxy_labels 85 | 86 | def compute_cluster_centroids(features, labels): 87 | """ 88 | Compute L2-normed cluster centroid for each class. 89 | """ 90 | num_classes = len(labels.unique()) - 1 if -1 in labels else len(labels.unique()) 91 | centers = torch.zeros((num_classes, features.shape[1]), dtype=torch.float32) 92 | for i in range(num_classes): 93 | idx = torch.where(labels == i)[0] 94 | temp = features[idx,:] 95 | if len(temp.shape) == 1: 96 | temp = temp.reshape(1, -1) 97 | centers[i,:] = temp.mean(0) 98 | return F.normalize(centers, dim=1) 99 | 100 | def mkdir_if_missing(directory): 101 | if not osp.exists(directory): 102 | try: 103 | os.makedirs(directory) 104 | except OSError as e: 105 | if e.errno != errno.EEXIST: 106 | raise 107 | 108 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 109 | mkdir_if_missing(osp.dirname(fpath)) 110 | torch.save(state, fpath) 111 | if is_best: 112 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar')) -------------------------------------------------------------------------------- /climb/vivim.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Sequence 4 | 5 | import torch.nn as nn 6 | import torch 7 | from functools import partial 8 | from mamba.mamba_ssm.modules.mamba_simple import Mamba 9 | # from mamba.mamba_ssm import Mamba 10 | import torch.nn.functional as F 11 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 12 | import math 13 | 14 | 15 | 16 | class Mlp(nn.Module): 17 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 18 | super().__init__() 19 | out_features = out_features or in_features 20 | hidden_features = hidden_features or in_features 21 | self.fc1 = nn.Linear(in_features, hidden_features) 22 | # self.dwconv = DWConv(hidden_features) 23 | self.act = act_layer() 24 | self.fc2 = nn.Linear(hidden_features, out_features) 25 | self.drop = nn.Dropout(drop) 26 | 27 | self.apply(self._init_weights) 28 | 29 | def _init_weights(self, m): 30 | if isinstance(m, nn.Linear): 31 | trunc_normal_(m.weight, std=.02) 32 | if isinstance(m, nn.Linear) and m.bias is not None: 33 | nn.init.constant_(m.bias, 0) 34 | elif isinstance(m, nn.LayerNorm): 35 | nn.init.constant_(m.bias, 0) 36 | nn.init.constant_(m.weight, 1.0) 37 | elif isinstance(m, nn.Conv2d): 38 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 39 | fan_out //= m.groups 40 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 41 | if m.bias is not None: 42 | m.bias.data.zero_() 43 | 44 | def forward(self, x): # x : torch.Size([8, 1024, 768]) 45 | x = self.fc1(x) 46 | # x = self.dwconv(x, nf, H, W) # torch.Size([8, 3072, 8, 16, 8]) 47 | x = self.act(x) # torch.Size([8, 1024, 3072]) 48 | x = self.drop(x) # torch.Size([8, 1024, 3072]) 49 | x = self.fc2(x) # torch.Size([8, 1024, 768]) 50 | x = self.drop(x) 51 | return x 52 | 53 | class MambaLayer(nn.Module): 54 | def __init__(self, dim, d_state=16, d_conv=4, expand=2, mlp_ratio=4, drop=0., drop_path=0., act_layer=nn.GELU): 55 | super().__init__() 56 | self.dim = dim 57 | self.norm1 = nn.LayerNorm(dim) 58 | self.mamba = Mamba( 59 | d_model=dim, # Model dimension d_model 60 | d_state=d_state, # SSM state expansion factor 61 | d_conv=d_conv, # Local convolution width 62 | expand=expand, # Block expansion factor 63 | # bimamba_type="v3", 64 | bimamba_type="none", 65 | nframes=8 66 | # use_fast_path=False, 67 | ) 68 | 69 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 70 | self.norm2 = nn.LayerNorm(dim) 71 | mlp_hidden_dim = int(dim * mlp_ratio) 72 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 73 | self.ln_2 = nn.LayerNorm(dim) 74 | self.apply(self._init_weights) 75 | 76 | def _init_weights(self, m): 77 | if isinstance(m, nn.Linear): 78 | trunc_normal_(m.weight, std=.02) 79 | if isinstance(m, nn.Linear) and m.bias is not None: 80 | nn.init.constant_(m.bias, 0) 81 | elif isinstance(m, nn.LayerNorm): 82 | nn.init.constant_(m.bias, 0) 83 | nn.init.constant_(m.weight, 1.0) 84 | elif isinstance(m, nn.Conv2d): 85 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 86 | fan_out //= m.groups 87 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 88 | if m.bias is not None: 89 | m.bias.data.zero_() 90 | 91 | def forward(self, x): 92 | # B, C, nf, H, W = x.shape # torch.Size([8, 768, 8, 16, 8]) 93 | # 94 | # assert C == self.dim 95 | # n_tokens = x.shape[2:].numel() # 8 * 16*8 = 1024 96 | # img_dims = x.shape[2:] # torch.Size([8, 16, 8]) 97 | # x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2) # b, num_token, D = b, thw, 768 torch.Size([8, 1024, 768]) 98 | 99 | x_mamba = x + self.drop_path(self.mamba(self.norm1(x))) # torch.Size([8, 1024, 768]) 100 | x_mamba = x_mamba + self.drop_path(self.mlp(self.norm2(x_mamba))) # torch.Size([8, 1024, 768]) 101 | # out = x_mamba.transpose(-1, -2).reshape(B, C, *img_dims) # 102 | x_mamba = self.ln_2(x_mamba) # torch.Size([30, 1024, 768]) 103 | return x_mamba 104 | -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /clip/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/clip/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /clip/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/clip/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /clip/__pycache__/clip.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/clip/__pycache__/clip.cpython-310.pyc -------------------------------------------------------------------------------- /clip/__pycache__/clip.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/clip/__pycache__/clip.cpython-39.pyc -------------------------------------------------------------------------------- /clip/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/clip/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /clip/__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/clip/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /clip/__pycache__/simple_tokenizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/clip/__pycache__/simple_tokenizer.cpython-310.pyc -------------------------------------------------------------------------------- /clip/__pycache__/simple_tokenizer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/clip/__pycache__/simple_tokenizer.cpython-39.pyc -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | try: 16 | from torchvision.transforms import InterpolationMode 17 | BICUBIC = InterpolationMode.BICUBIC 18 | except ImportError: 19 | BICUBIC = Image.BICUBIC 20 | 21 | 22 | if torch.__version__.split(".") < ["1", "7", "1"]: 23 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 24 | 25 | 26 | __all__ = ["available_models", "load", "tokenize"] 27 | _tokenizer = _Tokenizer() 28 | 29 | _MODELS = { 30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 34 | "ViT-B-32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 35 | "ViT-B-16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 36 | } 37 | 38 | 39 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 40 | os.makedirs(root, exist_ok=True) 41 | filename = os.path.basename(url) 42 | 43 | expected_sha256 = url.split("/")[-2] 44 | download_target = os.path.join(root, filename) 45 | 46 | if os.path.exists(download_target) and not os.path.isfile(download_target): 47 | raise RuntimeError(f"{download_target} exists and is not a regular file") 48 | 49 | if os.path.isfile(download_target): 50 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 51 | return download_target 52 | else: 53 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 54 | 55 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 56 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 57 | while True: 58 | buffer = source.read(8192) 59 | if not buffer: 60 | break 61 | 62 | output.write(buffer) 63 | loop.update(len(buffer)) 64 | 65 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 66 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 67 | 68 | return download_target 69 | 70 | 71 | def _transform(n_px): 72 | return Compose([ 73 | Resize(n_px, interpolation=BICUBIC), 74 | CenterCrop(n_px), 75 | lambda image: image.convert("RGB"), 76 | ToTensor(), 77 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 78 | ]) 79 | 80 | 81 | def available_models() -> List[str]: 82 | """Returns the names of available CLIP models""" 83 | return list(_MODELS.keys()) 84 | 85 | 86 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False): 87 | """Load a CLIP model 88 | 89 | Parameters 90 | ---------- 91 | name : str 92 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 93 | 94 | device : Union[str, torch.device] 95 | The device to put the loaded model 96 | 97 | jit : bool 98 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 99 | 100 | Returns 101 | ------- 102 | model : torch.nn.Module 103 | The CLIP model 104 | 105 | preprocess : Callable[[PIL.Image], torch.Tensor] 106 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 107 | """ 108 | if name in _MODELS: 109 | model_path = _download(_MODELS[name]) 110 | elif os.path.isfile(name): 111 | model_path = name 112 | else: 113 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 114 | 115 | try: 116 | # loading JIT archive 117 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 118 | state_dict = None 119 | except RuntimeError: 120 | # loading saved state dict 121 | if jit: 122 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 123 | jit = False 124 | state_dict = torch.load(model_path, map_location="cpu") 125 | 126 | if not jit: 127 | model = build_model(state_dict or model.state_dict()).to(device) 128 | if str(device) == "cpu": 129 | model.float() 130 | return model, _transform(model.visual.input_resolution) 131 | 132 | # patch the device names 133 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 134 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 135 | 136 | def patch_device(module): 137 | try: 138 | graphs = [module.graph] if hasattr(module, "graph") else [] 139 | except RuntimeError: 140 | graphs = [] 141 | 142 | if hasattr(module, "forward1"): 143 | graphs.append(module.forward1.graph) 144 | 145 | for graph in graphs: 146 | for node in graph.findAllNodes("prim::Constant"): 147 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 148 | node.copyAttributes(device_node) 149 | 150 | model.apply(patch_device) 151 | patch_device(model.encode_image) 152 | patch_device(model.encode_text) 153 | 154 | # patch dtype to float32 on CPU 155 | if str(device) == "cpu": 156 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 157 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 158 | float_node = float_input.node() 159 | 160 | def patch_float(module): 161 | try: 162 | graphs = [module.graph] if hasattr(module, "graph") else [] 163 | except RuntimeError: 164 | graphs = [] 165 | 166 | if hasattr(module, "forward1"): 167 | graphs.append(module.forward1.graph) 168 | 169 | for graph in graphs: 170 | for node in graph.findAllNodes("aten::to"): 171 | inputs = list(node.inputs()) 172 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 173 | if inputs[i].node()["value"] == 5: 174 | inputs[i].node().copyAttributes(float_node) 175 | 176 | model.apply(patch_float) 177 | patch_float(model.encode_image) 178 | patch_float(model.encode_text) 179 | 180 | model.float() 181 | 182 | return model, _transform(model.input_resolution.item()) 183 | 184 | 185 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 186 | """ 187 | Returns the tokenized representation of given input string(s) 188 | 189 | Parameters 190 | ---------- 191 | texts : Union[str, List[str]] 192 | An input string or a list of input strings to tokenize 193 | 194 | context_length : int 195 | The context length to use; all CLIP models use 77 as the context length 196 | 197 | truncate: bool 198 | Whether to truncate the text in case its encoding is longer than the context length 199 | 200 | Returns 201 | ------- 202 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 203 | """ 204 | # import pdb 205 | # pdb.set_trace() 206 | if isinstance(texts, str): 207 | texts = [texts] #['a photo of a face.'] 208 | 209 | sot_token = _tokenizer.encoder["<|startoftext|>"] #49406 210 | eot_token = _tokenizer.encoder["<|endoftext|>"] #49407 211 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 212 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) #1,77 213 | 214 | for i, tokens in enumerate(all_tokens): 215 | if len(tokens) > context_length: #context_length 77 216 | if truncate: 217 | tokens = tokens[:context_length] 218 | tokens[-1] = eot_token 219 | else: 220 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 221 | result[i, :len(tokens)] = torch.tensor(tokens) 222 | 223 | return result 224 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .defaults import _C as cfg -------------------------------------------------------------------------------- /config/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/config/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /config/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/config/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /config/__pycache__/defaults.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/config/__pycache__/defaults.cpython-310.pyc -------------------------------------------------------------------------------- /config/__pycache__/defaults.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/config/__pycache__/defaults.cpython-39.pyc -------------------------------------------------------------------------------- /config/climb-vit-market.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: 'ViT-B-16' 3 | STRIDE_SIZE: [16, 16] 4 | MEMORY_MOMENTUM: 0.2 5 | ID_LOSS_WEIGHT: 0.25 6 | PCL_LOSS_WEIGHT: 1.0 7 | # SIE_CAMERA: True 8 | # SIE_COE : 1.0 9 | 10 | INPUT: 11 | SIZE_TRAIN: [256, 128] 12 | SIZE_TEST: [256, 128] 13 | PROB: 0.5 # random horizontal flip 14 | RE_PROB: 0.5 # random erasing 15 | PADDING: 10 16 | PIXEL_MEAN: [0.5, 0.5, 0.5] 17 | PIXEL_STD: [0.5, 0.5, 0.5] 18 | 19 | DATALOADER: 20 | NUM_INSTANCE: 8 21 | NUM_WORKERS: 0 22 | 23 | SOLVER: 24 | IMS_PER_BATCH: 64 25 | OPTIMIZER_NAME: "SGD" 26 | BASE_LR: 3.5e-4 27 | WARMUP_METHOD: 'linear' 28 | WARMUP_ITERS: 10 29 | WARMUP_FACTOR: 0.1 30 | WEIGHT_DECAY: 5.0e-4 31 | MAX_EPOCHS: 60 32 | CHECKPOINT_PERIOD: 60 33 | LOG_PERIOD: 50 34 | EVAL_PERIOD: 60 35 | ITERS: 200 36 | 37 | STEPS: [30, 50] 38 | GAMMA: 0.1 39 | 40 | TEST: 41 | EVAL: True 42 | IMS_PER_BATCH: 256 43 | RE_RANKING: False 44 | 45 | DATASETS: 46 | NAMES: ('market1501') 47 | ROOT_DIR: ('/home/ok/data/ycy_data') 48 | OUTPUT_DIR: './logs' -------------------------------------------------------------------------------- /config/climb-vit-msmt.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: 'ViT-B-16' 3 | STRIDE_SIZE: [16, 16] 4 | MEMORY_MOMENTUM: 0.2 5 | ID_LOSS_WEIGHT: 0.25 6 | PCL_LOSS_WEIGHT: 1.0 7 | SIE_CAMERA: True 8 | SIE_COE : 1.0 9 | 10 | INPUT: 11 | SIZE_TRAIN: [256, 128] 12 | SIZE_TEST: [256, 128] 13 | PROB: 0.5 # random horizontal flip 14 | RE_PROB: 0.5 # random erasing 15 | PADDING: 10 16 | PIXEL_MEAN: [0.5, 0.5, 0.5] 17 | PIXEL_STD: [0.5, 0.5, 0.5] 18 | 19 | DATALOADER: 20 | NUM_INSTANCE: 8 21 | NUM_WORKERS: 8 22 | 23 | SOLVER: 24 | IMS_PER_BATCH: 256 25 | OPTIMIZER_NAME: "SGD" 26 | BASE_LR: 3.5e-4 27 | WARMUP_METHOD: 'linear' 28 | WARMUP_ITERS: 10 29 | WARMUP_FACTOR: 0.1 30 | WEIGHT_DECAY: 5.0e-4 31 | MAX_EPOCHS: 80 32 | CHECKPOINT_PERIOD: 60 33 | LOG_PERIOD: 50 34 | EVAL_PERIOD: 1 35 | ITERS: 200 36 | 37 | STEPS: [30, 50, 70] 38 | GAMMA: 0.1 39 | 40 | TEST: 41 | EVAL: True 42 | IMS_PER_BATCH: 256 43 | RE_RANKING: False 44 | 45 | DATASETS: 46 | NAMES: ('msmt17') 47 | ROOT_DIR: ('/home/ok/data/ycy_data') 48 | OUTPUT_DIR: './logs200' -------------------------------------------------------------------------------- /config/defaults.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | # ----------------------------------------------------------------------------- 4 | # Convention about Training / Test specific parameters 5 | # ----------------------------------------------------------------------------- 6 | # Whenever an argument can be either used for training or for testing, the 7 | # corresponding name will be post-fixed by a _TRAIN for a training parameter, 8 | 9 | # ----------------------------------------------------------------------------- 10 | # Config definition 11 | # ----------------------------------------------------------------------------- 12 | 13 | _C = CN() 14 | # ----------------------------------------------------------------------------- 15 | # MODEL 16 | # ----------------------------------------------------------------------------- 17 | _C.MODEL = CN() 18 | # Using cuda or cpu for training 19 | _C.MODEL.DEVICE = "cuda" 20 | # ID number of GPU 21 | _C.MODEL.DEVICE_ID = '0' 22 | # Name of backbone 23 | _C.MODEL.NAME = 'resnet50' 24 | # Last stride of backbone 25 | _C.MODEL.LAST_STRIDE = 1 26 | # Path to pretrained model of backbone 27 | _C.MODEL.PRETRAIN_PATH = '' 28 | 29 | # Use ImageNet pretrained model to initialize backbone or use self trained model to initialize the whole model 30 | # Options: 'imagenet' , 'self' , 'finetune' 31 | _C.MODEL.PRETRAIN_CHOICE = 'imagenet' 32 | 33 | # If train with BNNeck, options: 'bnneck' or 'no' 34 | _C.MODEL.NECK = 'bnneck' 35 | # If train loss include center loss, options: 'yes' or 'no'. Loss with center loss has different optimizer configuration 36 | _C.MODEL.IF_WITH_CENTER = 'no' 37 | 38 | _C.MODEL.ID_LOSS_TYPE = 'softmax' 39 | _C.MODEL.ID_LOSS_WEIGHT = 1.0 40 | _C.MODEL.TRIPLET_LOSS_WEIGHT = 1.0 41 | _C.MODEL.I2T_LOSS_WEIGHT = 1.0 42 | 43 | _C.MODEL.METRIC_LOSS_TYPE = 'triplet' 44 | # If train with multi-gpu ddp mode, options: 'True', 'False' 45 | _C.MODEL.DIST_TRAIN = False 46 | # If train with soft triplet loss, options: 'True', 'False' 47 | _C.MODEL.NO_MARGIN = False 48 | # If train with label smooth, options: 'on', 'off' 49 | _C.MODEL.IF_LABELSMOOTH = 'on' 50 | # If train with arcface loss, options: 'True', 'False' 51 | _C.MODEL.COS_LAYER = False 52 | 53 | # Transformer setting 54 | _C.MODEL.DROP_PATH = 0.1 55 | _C.MODEL.DROP_OUT = 0.0 56 | _C.MODEL.ATT_DROP_RATE = 0.0 57 | _C.MODEL.TRANSFORMER_TYPE = 'None' 58 | _C.MODEL.STRIDE_SIZE = [16, 16] 59 | 60 | # SIE Parameter 61 | _C.MODEL.SIE_COE = 3.0 62 | _C.MODEL.SIE_CAMERA = False 63 | _C.MODEL.SIE_VIEW = False 64 | _C.MODEL.PCL_LOSS_WEIGHT = 1.0 65 | 66 | # memory bank 67 | _C.MODEL.MEMORY_MOMENTUM = 0.2 68 | 69 | # ----------------------------------------------------------------------------- 70 | # INPUT 71 | # ----------------------------------------------------------------------------- 72 | _C.INPUT = CN() 73 | # Size of the image during training 74 | _C.INPUT.SIZE_TRAIN = [384, 128] 75 | # Size of the image during test 76 | _C.INPUT.SIZE_TEST = [384, 128] 77 | # Random probability for image horizontal flip 78 | _C.INPUT.PROB = 0.5 79 | # Random probability for random erasing 80 | _C.INPUT.RE_PROB = 0.5 81 | # Values to be used for image normalization 82 | _C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406] 83 | # Values to be used for image normalization 84 | _C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225] 85 | # Value of padding size 86 | _C.INPUT.PADDING = 10 87 | 88 | # ----------------------------------------------------------------------------- 89 | # Dataset 90 | # ----------------------------------------------------------------------------- 91 | _C.DATASETS = CN() 92 | # List of the dataset names for training, as present in paths_catalog.py 93 | _C.DATASETS.NAMES = ('market1501') 94 | # Root directory where datasets should be used (and downloaded if not found) 95 | _C.DATASETS.ROOT_DIR = ('../data') 96 | 97 | _C.DATASETS.ATTR_PATH = '' 98 | _C.DATASETS.ATTR_NUM = 30 99 | 100 | # ----------------------------------------------------------------------------- 101 | # DataLoader 102 | # ----------------------------------------------------------------------------- 103 | _C.DATALOADER = CN() 104 | # Number of data loading threads 105 | _C.DATALOADER.NUM_WORKERS = 8 106 | # Sampler for data loading 107 | _C.DATALOADER.SAMPLER = 'softmax' 108 | # Number of instance for one batch 109 | _C.DATALOADER.NUM_INSTANCE = 16 110 | 111 | # ---------------------------------------------------------------------------- # 112 | # Solver 113 | _C.SOLVER = CN() 114 | _C.SOLVER.SEED = 1234 115 | _C.SOLVER.MARGIN = 0.3 116 | 117 | # Number of images per batch 118 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 128, each GPU will 119 | # contain 16 images per batch 120 | _C.SOLVER.IMS_PER_BATCH = 64 121 | # Name of optimizer 122 | _C.SOLVER.OPTIMIZER_NAME = "Adam" 123 | # Number of max epoches 124 | _C.SOLVER.MAX_EPOCHS = 100 125 | # Base learning rate 126 | _C.SOLVER.BASE_LR = 3e-4 127 | # Whether using larger learning rate for fc layer 128 | _C.SOLVER.LARGE_FC_LR = False 129 | # Factor of learning bias 130 | _C.SOLVER.BIAS_LR_FACTOR = 1 131 | # Momentum 132 | _C.SOLVER.MOMENTUM = 0.9 133 | # Margin of triplet loss 134 | # Learning rate of SGD to learn the centers of center loss 135 | _C.SOLVER.CENTER_LR = 0.5 136 | # Balanced weight of center loss 137 | _C.SOLVER.CENTER_LOSS_WEIGHT = 0.0005 138 | 139 | # Settings of weight decay 140 | _C.SOLVER.WEIGHT_DECAY = 0.0005 141 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0.0005 142 | 143 | # decay rate of learning rate 144 | _C.SOLVER.GAMMA = 0.1 145 | # decay step of learning rate 146 | _C.SOLVER.STEPS = (40, 70) 147 | # warm up factor 148 | _C.SOLVER.WARMUP_FACTOR = 0.01 149 | # warm up epochs 150 | _C.SOLVER.WARMUP_EPOCHS = 5 151 | _C.SOLVER.WARMUP_LR_INIT = 0.01 152 | _C.SOLVER.LR_MIN = 0.000016 153 | 154 | 155 | _C.SOLVER.WARMUP_ITERS = 500 156 | # method of warm up, option: 'constant','linear' 157 | _C.SOLVER.WARMUP_METHOD = "linear" 158 | 159 | _C.SOLVER.COSINE_MARGIN = 0.5 160 | _C.SOLVER.COSINE_SCALE = 30 161 | 162 | # epoch number of saving checkpoints 163 | _C.SOLVER.CHECKPOINT_PERIOD = 10 164 | # iteration of display training log 165 | _C.SOLVER.LOG_PERIOD = 100 166 | # epoch number of validation 167 | _C.SOLVER.EVAL_PERIOD = 10 168 | 169 | # iters per epoch 170 | _C.SOLVER.ITERS = 200 171 | 172 | # ---------------------------------------------------------------------------- # 173 | # TEST 174 | # ---------------------------------------------------------------------------- # 175 | 176 | _C.TEST = CN() 177 | # Number of images per batch during test 178 | _C.TEST.IMS_PER_BATCH = 128 179 | # If test with re-ranking, options: 'True','False' 180 | _C.TEST.RE_RANKING = False 181 | # Path to trained model 182 | _C.TEST.WEIGHT = "" 183 | # Which feature of BNNeck to be used for test, before or after BNNneck, options: 'before' or 'after' 184 | _C.TEST.NECK_FEAT = 'after' 185 | # Whether feature is nomalized before test, if yes, it is equivalent to cosine distance 186 | _C.TEST.FEAT_NORM = 'yes' 187 | 188 | # Name for saving the distmat after testing. 189 | _C.TEST.DIST_MAT = "dist_mat.npy" 190 | # Whether calculate the eval score option: 'True', 'False' 191 | _C.TEST.EVAL = False 192 | 193 | # ---------------------------------------------------------------------------- # 194 | # Misc options 195 | # ---------------------------------------------------------------------------- # 196 | # Path to checkpoint and saved log of trained model 197 | _C.OUTPUT_DIR = "" 198 | -------------------------------------------------------------------------------- /datasets/bases.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageFile 2 | 3 | from torch.utils.data import Dataset 4 | import os.path as osp 5 | import random 6 | import torch 7 | import pickle 8 | ImageFile.LOAD_TRUNCATED_IMAGES = True 9 | 10 | 11 | def read_image(img_path): 12 | """Keep reading image until succeed. 13 | This can avoid IOError incurred by heavy IO process.""" 14 | got_img = False 15 | if not osp.exists(img_path): 16 | raise IOError("{} does not exist".format(img_path)) 17 | while not got_img: 18 | try: 19 | img = Image.open(img_path).convert('RGB') 20 | got_img = True 21 | except IOError: 22 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 23 | pass 24 | return img 25 | 26 | 27 | class BaseDataset(object): 28 | """ 29 | Base class of reid dataset 30 | """ 31 | 32 | def get_imagedata_info(self, data): 33 | pids, cams, tracks = [], [], [] 34 | for _, pid, camid, trackid in data: 35 | pids += [pid] 36 | cams += [camid] 37 | tracks += [trackid] 38 | pids = set(pids) 39 | cams = set(cams) 40 | tracks = set(tracks) 41 | num_pids = len(pids) 42 | num_cams = len(cams) 43 | num_imgs = len(data) 44 | num_views = len(tracks) 45 | return num_pids, num_imgs, num_cams, num_views 46 | 47 | def print_dataset_statistics(self): 48 | raise NotImplementedError 49 | 50 | 51 | class BaseImageDataset(BaseDataset): 52 | """ 53 | Base class of image reid dataset 54 | """ 55 | 56 | def print_dataset_statistics(self, train, query, gallery): 57 | num_train_pids, num_train_imgs, num_train_cams, num_train_views = self.get_imagedata_info(train) 58 | num_query_pids, num_query_imgs, num_query_cams, num_train_views = self.get_imagedata_info(query) 59 | num_gallery_pids, num_gallery_imgs, num_gallery_cams, num_train_views = self.get_imagedata_info(gallery) 60 | 61 | print("Dataset statistics:") 62 | print(" ----------------------------------------") 63 | print(" subset | # ids | # images | # cameras") 64 | print(" ----------------------------------------") 65 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 66 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 67 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 68 | print(" ----------------------------------------") 69 | 70 | 71 | class ImageDataset(Dataset): 72 | def __init__(self, dataset, transform=None): 73 | self.dataset = dataset 74 | self.transform = transform 75 | 76 | def __len__(self): 77 | return len(self.dataset) 78 | 79 | def __getitem__(self, index): 80 | img_path, pid, camid, trackid = self.dataset[index] 81 | img = read_image(img_path) 82 | 83 | if self.transform is not None: 84 | img = self.transform(img) 85 | 86 | return img, pid, camid, trackid, img_path.split('/')[-1] 87 | 88 | class AttributeImageDataset(Dataset): 89 | def __init__(self, dataset, attribute_path, pid2label, transform=None): 90 | self.dataset = dataset 91 | self.transform = transform 92 | self.label2pid = {label: pid for pid, label in pid2label.items()} 93 | self.plabel2attr = self.load_attributes(attribute_path) 94 | print(f'Load attribute annotation file {attribute_path}') 95 | 96 | def load_attributes(self, attr_path): 97 | with open(attr_path, 'rb') as f: 98 | attr_dict = pickle.load(f) 99 | train_set_attr = attr_dict['train'] 100 | new_attr_dict = {int(k): v for k, v in train_set_attr.items()} 101 | 102 | all_keys = sorted([ 103 | 'backpack', 'bag', 'handbag', 'downblack', 'downblue', 'downbrown', 'downgray', 'downgreen', 'downpink', 'downpurple', 'downwhite', 'downyellow', 104 | 'upblack', 'upblue', 'upgreen', 'upgray', 'uppurple', 'upred', 'upwhite', 'upyellow', 'clothes', 'down', 'up', 'hair', 'hat', 'gender' 105 | ]) # w/o age 106 | plabel2attr = {} 107 | for _, plabel, _, _ in self.dataset: 108 | if plabel in plabel2attr.keys(): 109 | continue 110 | pid = self.label2pid[plabel] 111 | attr = new_attr_dict[pid] 112 | attr_vector = torch.zeros(len(all_keys)) # w/o age 113 | for i, k in enumerate(all_keys): 114 | attr_vector[i] = attr[k] - 1 # 1 or 2 -> 0 or 1 115 | age_vector = torch.zeros(4) 116 | age_vector[attr['age']-1] = 1 117 | attr_vector = torch.cat([attr_vector, age_vector], dim=0) 118 | plabel2attr[plabel] = attr_vector 119 | return plabel2attr 120 | 121 | 122 | def __len__(self): 123 | return len(self.dataset) 124 | 125 | def __getitem__(self, index): 126 | img_path, pid, camid, trackid = self.dataset[index] 127 | img = read_image(img_path) 128 | 129 | attr = self.plabel2attr[pid] 130 | 131 | if self.transform is not None: 132 | img = self.transform(img) 133 | 134 | return img, pid, camid, trackid, img_path.split('/')[-1], attr -------------------------------------------------------------------------------- /datasets/dukemtmcreid.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import glob 8 | import re 9 | import urllib 10 | import zipfile 11 | 12 | import os.path as osp 13 | 14 | from utils.iotools import mkdir_if_missing 15 | from .bases import BaseImageDataset 16 | 17 | 18 | class DukeMTMCreID(BaseImageDataset): 19 | """ 20 | DukeMTMC-reID 21 | Reference: 22 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 23 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 24 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation 25 | 26 | Dataset statistics: 27 | # identities: 1404 (train + query) 28 | # images:16522 (train) + 2228 (query) + 17661 (gallery) 29 | # cameras: 8 30 | """ 31 | dataset_dir = 'dukemtmcreid' 32 | 33 | def __init__(self, root='', verbose=True, pid_begin=0, **kwargs): 34 | super(DukeMTMCreID, self).__init__() 35 | self.dataset_dir = osp.join(root, self.dataset_dir) 36 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 37 | self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_train') 38 | self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query') 39 | self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test') 40 | self.pid_begin = pid_begin 41 | self._download_data() 42 | self._check_before_run() 43 | 44 | train = self._process_dir(self.train_dir, relabel=True) 45 | query = self._process_dir(self.query_dir, relabel=False) 46 | gallery = self._process_dir(self.gallery_dir, relabel=False) 47 | 48 | if verbose: 49 | print("=> DukeMTMC-reID loaded") 50 | self.print_dataset_statistics(train, query, gallery) 51 | 52 | self.train = train 53 | self.query = query 54 | self.gallery = gallery 55 | 56 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train) 57 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query) 58 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery) 59 | 60 | def _download_data(self): 61 | if osp.exists(self.dataset_dir): 62 | print("This dataset has been downloaded.") 63 | return 64 | 65 | print("Creating directory {}".format(self.dataset_dir)) 66 | mkdir_if_missing(self.dataset_dir) 67 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 68 | 69 | print("Downloading DukeMTMC-reID dataset") 70 | urllib.request.urlretrieve(self.dataset_url, fpath) 71 | 72 | print("Extracting files") 73 | zip_ref = zipfile.ZipFile(fpath, 'r') 74 | zip_ref.extractall(self.dataset_dir) 75 | zip_ref.close() 76 | 77 | def _check_before_run(self): 78 | """Check if all files are available before going deeper""" 79 | if not osp.exists(self.dataset_dir): 80 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 81 | if not osp.exists(self.train_dir): 82 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 83 | if not osp.exists(self.query_dir): 84 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 85 | if not osp.exists(self.gallery_dir): 86 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 87 | 88 | def _process_dir(self, dir_path, relabel=False): 89 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 90 | pattern = re.compile(r'([-\d]+)_c(\d)') 91 | 92 | pid_container = set() 93 | for img_path in img_paths: 94 | pid, _ = map(int, pattern.search(img_path).groups()) 95 | pid_container.add(pid) 96 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 97 | 98 | dataset = [] 99 | cam_container = set() 100 | for img_path in img_paths: 101 | pid, camid = map(int, pattern.search(img_path).groups()) 102 | assert 1 <= camid <= 8 103 | camid -= 1 # index starts from 0 104 | if relabel: pid = pid2label[pid] 105 | dataset.append((img_path, self.pid_begin + pid, camid, 0)) 106 | cam_container.add(camid) 107 | print(cam_container, 'cam_container') 108 | return dataset 109 | -------------------------------------------------------------------------------- /datasets/make_dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as T 3 | from torch.utils.data import DataLoader 4 | 5 | from .bases import ImageDataset 6 | from timm.data.random_erasing import RandomErasing 7 | from .sampler import RandomIdentitySampler 8 | from .dukemtmcreid import DukeMTMCreID 9 | from .market1501 import Market1501 10 | from .msmt17 import MSMT17 11 | from .sampler_ddp import RandomIdentitySampler_DDP 12 | import torch.distributed as dist 13 | from .occ_duke import OCC_DukeMTMCreID 14 | from .vehicleid import VehicleID 15 | from .veri import VeRi 16 | 17 | __factory = { 18 | 'market1501': Market1501, 19 | 'dukemtmc': DukeMTMCreID, 20 | 'msmt17': MSMT17, 21 | 'occ_duke': OCC_DukeMTMCreID, 22 | 'veri': VeRi, 23 | 'VehicleID': VehicleID 24 | } 25 | 26 | def train_collate_fn(batch): 27 | """ 28 | # collate_fn这个函数的输入就是一个list,list的长度是一个batch size,list中的每个元素都是__getitem__得到的结果 29 | """ 30 | imgs, pids, camids, viewids , _ = zip(*batch) 31 | pids = torch.tensor(pids, dtype=torch.int64) 32 | viewids = torch.tensor(viewids, dtype=torch.int64) 33 | camids = torch.tensor(camids, dtype=torch.int64) 34 | return torch.stack(imgs, dim=0), pids, camids, viewids, 35 | 36 | def val_collate_fn(batch): 37 | imgs, pids, camids, viewids, img_paths = zip(*batch) 38 | viewids = torch.tensor(viewids, dtype=torch.int64) 39 | camids_batch = torch.tensor(camids, dtype=torch.int64) 40 | return torch.stack(imgs, dim=0), pids, camids, camids_batch, viewids, img_paths 41 | 42 | def make_dataloader(cfg): 43 | train_transforms = T.Compose([ 44 | T.Resize(cfg.INPUT.SIZE_TRAIN, interpolation=3), 45 | T.RandomHorizontalFlip(p=cfg.INPUT.PROB), 46 | T.Pad(cfg.INPUT.PADDING), 47 | T.RandomCrop(cfg.INPUT.SIZE_TRAIN), 48 | T.ToTensor(), 49 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD), 50 | RandomErasing(probability=cfg.INPUT.RE_PROB, mode='pixel', max_count=1, device='cpu'), 51 | # RandomErasing(probability=cfg.INPUT.RE_PROB, mean=cfg.INPUT.PIXEL_MEAN) 52 | ]) 53 | 54 | val_transforms = T.Compose([ 55 | T.Resize(cfg.INPUT.SIZE_TEST), 56 | T.ToTensor(), 57 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD) 58 | ]) 59 | 60 | num_workers = cfg.DATALOADER.NUM_WORKERS 61 | 62 | dataset = __factory[cfg.DATASETS.NAMES](root=cfg.DATASETS.ROOT_DIR) 63 | 64 | train_set = ImageDataset(dataset.train, train_transforms) 65 | train_set_normal = ImageDataset(dataset.train, val_transforms) 66 | num_classes = dataset.num_train_pids 67 | cam_num = dataset.num_train_cams 68 | view_num = dataset.num_train_vids 69 | 70 | if 'triplet' in cfg.DATALOADER.SAMPLER: 71 | if cfg.MODEL.DIST_TRAIN: 72 | print('DIST_TRAIN START') 73 | mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // dist.get_world_size() 74 | data_sampler = RandomIdentitySampler_DDP(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE) 75 | batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size, True) 76 | train_loader = torch.utils.data.DataLoader( 77 | train_set, 78 | num_workers=num_workers, 79 | batch_sampler=batch_sampler, 80 | collate_fn=train_collate_fn, 81 | pin_memory=True, 82 | ) 83 | else: 84 | train_loader = DataLoader( 85 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, 86 | sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE), 87 | num_workers=num_workers, collate_fn=train_collate_fn 88 | ) 89 | elif cfg.DATALOADER.SAMPLER == 'softmax': 90 | print('using softmax sampler') 91 | train_loader = DataLoader( 92 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers, 93 | collate_fn=train_collate_fn 94 | ) 95 | else: 96 | print('unsupported sampler! expected softmax or triplet but got {}'.format(cfg.SAMPLER)) 97 | 98 | val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms) 99 | 100 | val_loader = DataLoader( 101 | val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers, 102 | collate_fn=val_collate_fn 103 | ) 104 | train_loader_normal = DataLoader( 105 | train_set_normal, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers, 106 | collate_fn=val_collate_fn 107 | ) 108 | return train_loader, train_loader_normal, val_loader, len(dataset.query), num_classes, cam_num, view_num 109 | -------------------------------------------------------------------------------- /datasets/make_dataloader_clipreid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as T 3 | from torch.utils.data import DataLoader 4 | 5 | from .bases import ImageDataset 6 | from timm.data.random_erasing import RandomErasing 7 | from .sampler import RandomIdentitySampler 8 | from .dukemtmcreid import DukeMTMCreID 9 | from .market1501 import Market1501 10 | from .msmt17 import MSMT17 11 | from .msmt17_v2 import MSMT17_V2 12 | from .sampler_ddp import RandomIdentitySampler_DDP 13 | import torch.distributed as dist 14 | from .occ_duke import OCC_DukeMTMCreID 15 | from .vehicleid import VehicleID 16 | from .veri import VeRi 17 | 18 | __factory = { 19 | 'market1501': Market1501, 20 | 'dukemtmc': DukeMTMCreID, 21 | 'msmt17': MSMT17, 22 | # 'msmt17': MSMT17_V2, 23 | 'occ_duke': OCC_DukeMTMCreID, 24 | 'veri': VeRi, 25 | 'VehicleID': VehicleID 26 | } 27 | 28 | def train_collate_fn(batch): 29 | """ 30 | # collate_fn这个函数的输入就是一个list,list的长度是一个batch size,list中的每个元素都是__getitem__得到的结果 31 | """ 32 | imgs, pids, camids, viewids , _ = zip(*batch) 33 | pids = torch.tensor(pids, dtype=torch.int64) 34 | viewids = torch.tensor(viewids, dtype=torch.int64) 35 | camids = torch.tensor(camids, dtype=torch.int64) 36 | return torch.stack(imgs, dim=0), pids, camids, viewids, 37 | 38 | def val_collate_fn(batch): 39 | imgs, pids, camids, viewids, img_paths = zip(*batch) 40 | viewids = torch.tensor(viewids, dtype=torch.int64) 41 | camids_batch = torch.tensor(camids, dtype=torch.int64) 42 | return torch.stack(imgs, dim=0), pids, camids, camids_batch, viewids, img_paths 43 | 44 | def make_dataloader(cfg): 45 | train_transforms = T.Compose([ 46 | T.Resize(cfg.INPUT.SIZE_TRAIN, interpolation=3), 47 | T.RandomHorizontalFlip(p=cfg.INPUT.PROB), 48 | T.Pad(cfg.INPUT.PADDING), 49 | T.RandomCrop(cfg.INPUT.SIZE_TRAIN), 50 | T.ToTensor(), 51 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD), 52 | RandomErasing(probability=cfg.INPUT.RE_PROB, mode='pixel', max_count=1, device='cpu'), 53 | # RandomErasing(probability=cfg.INPUT.RE_PROB, mean=cfg.INPUT.PIXEL_MEAN) 54 | ]) 55 | 56 | val_transforms = T.Compose([ 57 | T.Resize(cfg.INPUT.SIZE_TEST), 58 | T.ToTensor(), 59 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD) 60 | ]) 61 | 62 | num_workers = cfg.DATALOADER.NUM_WORKERS 63 | 64 | dataset = __factory[cfg.DATASETS.NAMES](root=cfg.DATASETS.ROOT_DIR) 65 | 66 | train_set = ImageDataset(dataset.train, train_transforms) 67 | train_set_normal = ImageDataset(dataset.train, val_transforms) 68 | num_classes = dataset.num_train_pids 69 | cam_num = dataset.num_train_cams 70 | view_num = dataset.num_train_vids 71 | 72 | if 'triplet' in cfg.DATALOADER.SAMPLER: 73 | if cfg.MODEL.DIST_TRAIN: 74 | print('DIST_TRAIN START') 75 | mini_batch_size = cfg.SOLVER.STAGE2.IMS_PER_BATCH // dist.get_world_size() 76 | data_sampler = RandomIdentitySampler_DDP(dataset.train, cfg.SOLVER.STAGE2.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE) 77 | batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size, True) 78 | train_loader_stage2 = torch.utils.data.DataLoader( 79 | train_set, 80 | num_workers=num_workers, 81 | batch_sampler=batch_sampler, 82 | collate_fn=train_collate_fn, 83 | pin_memory=True, 84 | ) 85 | else: 86 | train_loader_stage2 = DataLoader( 87 | train_set, batch_size=cfg.SOLVER.STAGE2.IMS_PER_BATCH, 88 | sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.STAGE2.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE), 89 | num_workers=num_workers, collate_fn=train_collate_fn 90 | ) 91 | elif cfg.DATALOADER.SAMPLER == 'softmax': 92 | print('using softmax sampler') 93 | train_loader_stage2 = DataLoader( 94 | train_set, batch_size=cfg.SOLVER.STAGE2.IMS_PER_BATCH, shuffle=True, num_workers=num_workers, 95 | collate_fn=train_collate_fn 96 | ) 97 | else: 98 | print('unsupported sampler! expected softmax or triplet but got {}'.format(cfg.SAMPLER)) 99 | 100 | val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms) 101 | 102 | val_loader = DataLoader( 103 | val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers, 104 | collate_fn=val_collate_fn 105 | ) 106 | train_loader_stage1 = DataLoader( 107 | train_set_normal, batch_size=cfg.SOLVER.STAGE1.IMS_PER_BATCH, shuffle=True, num_workers=num_workers, 108 | collate_fn=train_collate_fn 109 | ) 110 | return train_loader_stage2, train_loader_stage1, val_loader, len(dataset.query), num_classes, cam_num, view_num 111 | -------------------------------------------------------------------------------- /datasets/make_dataloader_clipreid_ccpa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as T 3 | from torch.utils.data import DataLoader 4 | 5 | from .bases import ImageDataset, AttributeImageDataset 6 | from timm.data.random_erasing import RandomErasing 7 | from .sampler import RandomIdentitySampler 8 | from .dukemtmcreid import DukeMTMCreID 9 | from .market1501 import Market1501 10 | from .msmt17 import MSMT17 11 | # from .msmt17_v2 import MSMT17_V2 12 | from .sampler_ddp import RandomIdentitySampler_DDP 13 | import torch.distributed as dist 14 | from .occ_duke import OCC_DukeMTMCreID 15 | from .vehicleid import VehicleID 16 | from .veri import VeRi 17 | 18 | __factory = { 19 | 'market1501': Market1501, 20 | 'dukemtmc': DukeMTMCreID, 21 | 'msmt17': MSMT17, 22 | # 'msmt17': MSMT17_V2, 23 | 'occ_duke': OCC_DukeMTMCreID, 24 | 'veri': VeRi, 25 | 'VehicleID': VehicleID 26 | } 27 | 28 | def train_collate_fn(batch): 29 | """ 30 | # collate_fn这个函数的输入就是一个list,list的长度是一个batch size,list中的每个元素都是__getitem__得到的结果 31 | """ 32 | imgs, pids, camids, viewids , _, attrs = zip(*batch) 33 | pids = torch.tensor(pids, dtype=torch.int64) 34 | viewids = torch.tensor(viewids, dtype=torch.int64) 35 | camids = torch.tensor(camids, dtype=torch.int64) 36 | attrs = torch.stack(attrs, dim=0) 37 | return torch.stack(imgs, dim=0), pids, camids, viewids, attrs 38 | 39 | def val_collate_fn(batch): 40 | imgs, pids, camids, viewids, img_paths = zip(*batch) 41 | viewids = torch.tensor(viewids, dtype=torch.int64) 42 | camids_batch = torch.tensor(camids, dtype=torch.int64) 43 | return torch.stack(imgs, dim=0), pids, camids, camids_batch, viewids, img_paths 44 | 45 | def make_dataloader(cfg): 46 | train_transforms = T.Compose([ 47 | T.Resize(cfg.INPUT.SIZE_TRAIN, interpolation=3), 48 | T.RandomHorizontalFlip(p=cfg.INPUT.PROB), 49 | T.Pad(cfg.INPUT.PADDING), 50 | T.RandomCrop(cfg.INPUT.SIZE_TRAIN), 51 | T.ToTensor(), 52 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD), 53 | RandomErasing(probability=cfg.INPUT.RE_PROB, mode='pixel', max_count=1, device='cpu'), 54 | # RandomErasing(probability=cfg.INPUT.RE_PROB, mean=cfg.INPUT.PIXEL_MEAN) 55 | ]) 56 | 57 | val_transforms = T.Compose([ 58 | T.Resize(cfg.INPUT.SIZE_TEST), 59 | T.ToTensor(), 60 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD) 61 | ]) 62 | 63 | num_workers = cfg.DATALOADER.NUM_WORKERS 64 | 65 | dataset = __factory[cfg.DATASETS.NAMES](root=cfg.DATASETS.ROOT_DIR) 66 | 67 | train_set = AttributeImageDataset(dataset.train, cfg.DATASETS.ATTR_PATH, dataset.pid2label, train_transforms) 68 | train_set_normal = AttributeImageDataset(dataset.train, cfg.DATASETS.ATTR_PATH, dataset.pid2label, val_transforms) 69 | num_classes = dataset.num_train_pids 70 | cam_num = dataset.num_train_cams 71 | view_num = dataset.num_train_vids 72 | 73 | if 'triplet' in cfg.DATALOADER.SAMPLER: 74 | if cfg.MODEL.DIST_TRAIN: 75 | print('DIST_TRAIN START') 76 | mini_batch_size = cfg.SOLVER.STAGE2.IMS_PER_BATCH // dist.get_world_size() 77 | data_sampler = RandomIdentitySampler_DDP(dataset.train, cfg.SOLVER.STAGE2.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE) 78 | batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size, True) 79 | train_loader_stage2 = torch.utils.data.DataLoader( 80 | train_set, 81 | num_workers=num_workers, 82 | batch_sampler=batch_sampler, 83 | collate_fn=train_collate_fn, 84 | pin_memory=True, 85 | ) 86 | else: 87 | train_loader_stage2 = DataLoader( 88 | train_set, batch_size=cfg.SOLVER.STAGE2.IMS_PER_BATCH, 89 | sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.STAGE2.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE), 90 | num_workers=num_workers, collate_fn=train_collate_fn 91 | ) 92 | elif cfg.DATALOADER.SAMPLER == 'softmax': 93 | print('using softmax sampler') 94 | train_loader_stage2 = DataLoader( 95 | train_set, batch_size=cfg.SOLVER.STAGE2.IMS_PER_BATCH, shuffle=True, num_workers=num_workers, 96 | collate_fn=train_collate_fn 97 | ) 98 | else: 99 | print('unsupported sampler! expected softmax or triplet but got {}'.format(cfg.SAMPLER)) 100 | 101 | val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms) 102 | 103 | val_loader = DataLoader( 104 | val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers, 105 | collate_fn=val_collate_fn 106 | ) 107 | train_loader_stage1 = DataLoader( 108 | train_set_normal, batch_size=cfg.SOLVER.STAGE1.IMS_PER_BATCH, shuffle=True, num_workers=num_workers, 109 | collate_fn=train_collate_fn 110 | ) 111 | return train_loader_stage2, train_loader_stage1, val_loader, len(dataset.query), num_classes, cam_num, view_num 112 | -------------------------------------------------------------------------------- /datasets/market1501.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import glob 8 | import re 9 | 10 | import os.path as osp 11 | 12 | from .bases import BaseImageDataset 13 | from collections import defaultdict 14 | import pickle 15 | class Market1501(BaseImageDataset): 16 | """ 17 | Market1501 18 | Reference: 19 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 20 | URL: http://www.liangzheng.org/Project/project_reid.html 21 | 22 | Dataset statistics: 23 | # identities: 1501 (+1 for background) 24 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 25 | """ 26 | dataset_dir = 'Market1501' 27 | 28 | def __init__(self, root='', verbose=True, pid_begin = 0, **kwargs): 29 | super(Market1501, self).__init__() 30 | # self.dataset_dir = osp.join(root, self.dataset_dir) 31 | root1 = '/dataset_cc/Mars/' 32 | root2 = '/media/ycy/ba8af05f-f397-4839-a318-f469b124cbab/data/Market-1501' 33 | root3 = '/YCY/dataset/Market' 34 | if osp.exists(root1): 35 | self.root = root1 36 | elif osp.exists(root2): 37 | self.root = root2 38 | elif osp.exists(root3): 39 | self.root = root3 40 | self.train_dir = osp.join(self.root, 'bounding_box_train') 41 | self.query_dir = osp.join(self.root, 'query') 42 | self.gallery_dir = osp.join(self.root, 'bounding_box_test') 43 | 44 | self._check_before_run() 45 | self.pid_begin = pid_begin 46 | train = self._process_dir(self.train_dir, relabel=True, bind_pid2label=True) 47 | query = self._process_dir(self.query_dir, relabel=False) 48 | gallery = self._process_dir(self.gallery_dir, relabel=False) 49 | 50 | if verbose: 51 | print("=> Market1501 loaded") 52 | self.print_dataset_statistics(train, query, gallery) 53 | 54 | self.train = train 55 | self.query = query 56 | self.gallery = gallery 57 | 58 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train) 59 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query) 60 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery) 61 | 62 | def _check_before_run(self): 63 | """Check if all files are available before going deeper""" 64 | if not osp.exists(self.root): 65 | raise RuntimeError("'{}' is not available".format(self.root)) 66 | if not osp.exists(self.train_dir): 67 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 68 | if not osp.exists(self.query_dir): 69 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 70 | if not osp.exists(self.gallery_dir): 71 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 72 | 73 | def _process_dir(self, dir_path, relabel=False, bind_pid2label=False): 74 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 75 | pattern = re.compile(r'([-\d]+)_c(\d)') 76 | 77 | pid_container = set() 78 | for img_path in sorted(img_paths): 79 | pid, _ = map(int, pattern.search(img_path).groups()) 80 | if pid == -1: continue # junk images are just ignored 81 | pid_container.add(pid) 82 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 83 | if bind_pid2label: 84 | self.pid2label = pid2label 85 | dataset = [] 86 | for img_path in sorted(img_paths): 87 | pid, camid = map(int, pattern.search(img_path).groups()) 88 | if pid == -1: continue # junk images are just ignored 89 | assert 0 <= pid <= 1501 # pid == 0 means background 90 | assert 1 <= camid <= 6 91 | camid -= 1 # index starts from 0 92 | if relabel: pid = pid2label[pid] 93 | 94 | dataset.append((img_path, self.pid_begin + pid, camid, 0)) 95 | return dataset -------------------------------------------------------------------------------- /datasets/msmt17.py: -------------------------------------------------------------------------------- 1 | 2 | import glob 3 | import re 4 | 5 | import os.path as osp 6 | 7 | from .bases import BaseImageDataset 8 | 9 | 10 | class MSMT17(BaseImageDataset): 11 | """ 12 | MSMT17 13 | 14 | Reference: 15 | Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018. 16 | 17 | URL: http://www.pkuvmc.com/publications/msmt17.html 18 | 19 | Dataset statistics: 20 | # identities: 4101 21 | # images: 32621 (train) + 11659 (query) + 82161 (gallery) 22 | # cameras: 15 23 | """ 24 | dataset_dir = 'MSMT17_v32' 25 | 26 | def __init__(self, root='', verbose=True, pid_begin=0, **kwargs): 27 | super(MSMT17, self).__init__() 28 | self.pid_begin = pid_begin 29 | # self.dataset_dir = osp.join(root, self.dataset_dir) 30 | root1 = '/dataset_cc/MSMT17_v3/' 31 | root2 = '/ycy/ba8af05f-f397-4839-a318-f469b124cbab/data/MSMT17_v3' 32 | root3 = '/YCY/dataset/MSMT17_v3/MSMT17_v3' 33 | if osp.exists(root1): 34 | self.root = root1 35 | elif osp.exists(root2): 36 | self.root = root2 37 | elif osp.exists(root3): 38 | self.root = root3 39 | 40 | self.train_dir = osp.join(self.root, 'train') 41 | self.test_dir = osp.join(self.root, 'test') 42 | self.list_train_path = osp.join(self.root, 'list_train.txt') 43 | self.list_val_path = osp.join(self.root, 'list_val.txt') 44 | self.list_query_path = osp.join(self.root, 'list_query.txt') 45 | self.list_gallery_path = osp.join(self.root, 'list_gallery.txt') 46 | 47 | self._check_before_run() 48 | train = self._process_dir(self.train_dir, self.list_train_path) 49 | val = self._process_dir(self.train_dir, self.list_val_path) 50 | train += val 51 | query = self._process_dir(self.test_dir, self.list_query_path) 52 | gallery = self._process_dir(self.test_dir, self.list_gallery_path) 53 | if verbose: 54 | print("=> MSMT17 loaded") 55 | self.print_dataset_statistics(train, query, gallery) 56 | 57 | self.train = train 58 | self.query = query 59 | self.gallery = gallery 60 | 61 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train) 62 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query) 63 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery) 64 | def _check_before_run(self): 65 | """Check if all files are available before going deeper""" 66 | if not osp.exists(self.root): 67 | raise RuntimeError("'{}' is not available".format(self.root)) 68 | if not osp.exists(self.train_dir): 69 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 70 | if not osp.exists(self.test_dir): 71 | raise RuntimeError("'{}' is not available".format(self.test_dir)) 72 | 73 | def _process_dir(self, dir_path, list_path): 74 | with open(list_path, 'r') as txt: 75 | lines = txt.readlines() 76 | dataset = [] 77 | pid_container = set() 78 | cam_container = set() 79 | for img_idx, img_info in enumerate(lines): 80 | img_path, pid = img_info.split(' ') 81 | pid = int(pid) # no need to relabel 82 | camid = int(img_path.split('_')[2]) 83 | img_path = osp.join(dir_path, img_path) 84 | dataset.append((img_path, self.pid_begin+pid, camid-1, 0)) 85 | pid_container.add(pid) 86 | cam_container.add(camid) 87 | print(cam_container, 'cam_container') 88 | # check if pid starts from 0 and increments with 1 89 | for idx, pid in enumerate(pid_container): 90 | assert idx == pid, "See code comment for explanation" 91 | return dataset -------------------------------------------------------------------------------- /datasets/msmt17_v2.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | 6 | from .bases import BaseImageDataset 7 | 8 | class MSMT17_V2(BaseImageDataset): 9 | 10 | dataset_dir = 'MSMT17' 11 | 12 | def __init__(self, root, verbose=True, **kwargs): 13 | super(MSMT17_V2, self).__init__() 14 | self.dataset_dir = osp.join(root, self.dataset_dir) 15 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 16 | self.query_dir = osp.join(self.dataset_dir, 'query') 17 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 18 | 19 | self._check_before_run() 20 | 21 | train = self._process_dir(self.train_dir, relabel=True) 22 | query = self._process_dir(self.query_dir, relabel=False) 23 | gallery = self._process_dir(self.gallery_dir, relabel=False) 24 | 25 | if verbose: 26 | print("=> MSMT17 loaded") 27 | self.print_dataset_statistics(train, query, gallery) 28 | 29 | self.train = train 30 | self.query = query 31 | self.gallery = gallery 32 | 33 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train) 34 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query) 35 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery) 36 | 37 | def _check_before_run(self): 38 | """Check if all files are available before going deeper""" 39 | if not osp.exists(self.dataset_dir): 40 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 41 | if not osp.exists(self.train_dir): 42 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 43 | if not osp.exists(self.query_dir): 44 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 45 | if not osp.exists(self.gallery_dir): 46 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 47 | 48 | def _process_dir(self, dir_path, relabel=False): 49 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 50 | #pattern = re.compile(r'([-\d]+)_c(\d)') # pattern for market and duke 51 | pattern = re.compile(r'([-\d]+)_c([-\d]+)') # pattern for msmt17 52 | pid_container = set() 53 | for img_path in img_paths: 54 | pid, _ = map(int, pattern.search(img_path).groups()) 55 | if pid == -1: continue # junk images are just ignored 56 | pid_container.add(pid) 57 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 58 | 59 | dataset = [] 60 | for img_path in img_paths: 61 | pid, camid = map(int, pattern.search(img_path).groups()) 62 | if pid == -1: continue # junk images are just ignored 63 | #assert 0 <= pid <= 1501 # pid == 0 means background 64 | assert 1 <= camid <= 15 65 | camid -= 1 # index starts from 0 66 | if relabel: pid = pid2label[pid] 67 | dataset.append((img_path, pid, camid, 0)) 68 | 69 | return dataset 70 | -------------------------------------------------------------------------------- /datasets/occ_duke.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import glob 8 | import re 9 | import urllib 10 | import zipfile 11 | 12 | import os.path as osp 13 | 14 | from utils.iotools import mkdir_if_missing 15 | from .bases import BaseImageDataset 16 | 17 | 18 | class OCC_DukeMTMCreID(BaseImageDataset): 19 | """ 20 | DukeMTMC-reID 21 | Reference: 22 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 23 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 24 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation 25 | 26 | Dataset statistics: 27 | # identities: 1404 (train + query) 28 | # images:16522 (train) + 2228 (query) + 17661 (gallery) 29 | # cameras: 8 30 | """ 31 | dataset_dir = 'dukemtmcreid' 32 | 33 | def __init__(self, root='', verbose=True, pid_begin=0, **kwargs): 34 | super(OCC_DukeMTMCreID, self).__init__() 35 | self.dataset_dir = osp.join(root, self.dataset_dir) 36 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 37 | self.train_dir = osp.join(self.dataset_dir, 'Occluded_Duke/bounding_box_train') 38 | self.query_dir = osp.join(self.dataset_dir, 'Occluded_Duke/query') 39 | self.gallery_dir = osp.join(self.dataset_dir, 'Occluded_Duke/bounding_box_test') 40 | self.pid_begin = pid_begin 41 | self._download_data() 42 | self._check_before_run() 43 | 44 | train = self._process_dir(self.train_dir, relabel=True) 45 | query = self._process_dir(self.query_dir, relabel=False) 46 | gallery = self._process_dir(self.gallery_dir, relabel=False) 47 | 48 | if verbose: 49 | print("=> DukeMTMC-reID loaded") 50 | self.print_dataset_statistics(train, query, gallery) 51 | 52 | self.train = train 53 | self.query = query 54 | self.gallery = gallery 55 | 56 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train) 57 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query) 58 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery) 59 | 60 | def _download_data(self): 61 | if osp.exists(self.dataset_dir): 62 | print("This dataset has been downloaded.") 63 | return 64 | 65 | print("Creating directory {}".format(self.dataset_dir)) 66 | mkdir_if_missing(self.dataset_dir) 67 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 68 | 69 | print("Downloading DukeMTMC-reID dataset") 70 | urllib.request.urlretrieve(self.dataset_url, fpath) 71 | 72 | print("Extracting files") 73 | zip_ref = zipfile.ZipFile(fpath, 'r') 74 | zip_ref.extractall(self.dataset_dir) 75 | zip_ref.close() 76 | 77 | def _check_before_run(self): 78 | """Check if all files are available before going deeper""" 79 | if not osp.exists(self.dataset_dir): 80 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 81 | if not osp.exists(self.train_dir): 82 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 83 | if not osp.exists(self.query_dir): 84 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 85 | if not osp.exists(self.gallery_dir): 86 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 87 | 88 | def _process_dir(self, dir_path, relabel=False): 89 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 90 | pattern = re.compile(r'([-\d]+)_c(\d)') 91 | 92 | pid_container = set() 93 | for img_path in img_paths: 94 | pid, _ = map(int, pattern.search(img_path).groups()) 95 | pid_container.add(pid) 96 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 97 | 98 | dataset = [] 99 | cam_container = set() 100 | for img_path in img_paths: 101 | pid, camid = map(int, pattern.search(img_path).groups()) 102 | assert 1 <= camid <= 8 103 | camid -= 1 # index starts from 0 104 | if relabel: pid = pid2label[pid] 105 | dataset.append((img_path, self.pid_begin + pid, camid, 1)) 106 | cam_container.add(camid) 107 | print(cam_container, 'cam_container') 108 | return dataset 109 | -------------------------------------------------------------------------------- /datasets/preprocessing.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | 4 | 5 | class RandomErasing(object): 6 | """ Randomly selects a rectangle region in an image and erases its pixels. 7 | 'Random Erasing Data Augmentation' by Zhong et al. 8 | See https://arxiv.org/pdf/1708.04896.pdf 9 | Args: 10 | probability: The probability that the Random Erasing operation will be performed. 11 | sl: Minimum proportion of erased area against input image. 12 | sh: Maximum proportion of erased area against input image. 13 | r1: Minimum aspect ratio of erased area. 14 | mean: Erasing value. 15 | """ 16 | 17 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 18 | self.probability = probability 19 | self.mean = mean 20 | self.sl = sl 21 | self.sh = sh 22 | self.r1 = r1 23 | 24 | def __call__(self, img): 25 | 26 | if random.uniform(0, 1) >= self.probability: 27 | return img 28 | 29 | for attempt in range(100): 30 | area = img.size()[1] * img.size()[2] 31 | 32 | target_area = random.uniform(self.sl, self.sh) * area 33 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 34 | 35 | h = int(round(math.sqrt(target_area * aspect_ratio))) 36 | w = int(round(math.sqrt(target_area / aspect_ratio))) 37 | 38 | if w < img.size()[2] and h < img.size()[1]: 39 | x1 = random.randint(0, img.size()[1] - h) 40 | y1 = random.randint(0, img.size()[2] - w) 41 | if img.size()[0] == 3: 42 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 43 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 44 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 45 | else: 46 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 47 | return img 48 | 49 | return img 50 | 51 | -------------------------------------------------------------------------------- /datasets/sampler.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.sampler import Sampler 2 | from collections import defaultdict 3 | import copy 4 | import random 5 | import numpy as np 6 | 7 | class RandomIdentitySampler(Sampler): 8 | """ 9 | Randomly sample N identities, then for each identity, 10 | randomly sample K instances, therefore batch size is N*K. 11 | Args: 12 | - data_source (list): list of (img_path, pid, camid). 13 | - num_instances (int): number of instances per identity in a batch. 14 | - batch_size (int): number of examples in a batch. 15 | """ 16 | 17 | def __init__(self, data_source, batch_size, num_instances): 18 | self.data_source = data_source 19 | self.batch_size = batch_size 20 | self.num_instances = num_instances 21 | self.num_pids_per_batch = self.batch_size // self.num_instances 22 | self.index_dic = defaultdict(list) #dict with list value 23 | #{783: [0, 5, 116, 876, 1554, 2041],...,} 24 | for index, (_, pid, _, _) in enumerate(self.data_source): 25 | self.index_dic[pid].append(index) 26 | self.pids = list(self.index_dic.keys()) 27 | 28 | # estimate number of examples in an epoch 29 | self.length = 0 30 | for pid in self.pids: 31 | idxs = self.index_dic[pid] 32 | num = len(idxs) 33 | if num < self.num_instances: 34 | num = self.num_instances 35 | self.length += num - num % self.num_instances 36 | 37 | def __iter__(self): 38 | batch_idxs_dict = defaultdict(list) 39 | 40 | for pid in self.pids: 41 | idxs = copy.deepcopy(self.index_dic[pid]) 42 | if len(idxs) < self.num_instances: 43 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 44 | random.shuffle(idxs) 45 | batch_idxs = [] 46 | for idx in idxs: 47 | batch_idxs.append(idx) 48 | if len(batch_idxs) == self.num_instances: 49 | batch_idxs_dict[pid].append(batch_idxs) 50 | batch_idxs = [] 51 | 52 | avai_pids = copy.deepcopy(self.pids) 53 | final_idxs = [] 54 | 55 | while len(avai_pids) >= self.num_pids_per_batch: 56 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 57 | for pid in selected_pids: 58 | batch_idxs = batch_idxs_dict[pid].pop(0) 59 | final_idxs.extend(batch_idxs) 60 | if len(batch_idxs_dict[pid]) == 0: 61 | avai_pids.remove(pid) 62 | 63 | return iter(final_idxs) 64 | 65 | def __len__(self): 66 | return self.length 67 | 68 | -------------------------------------------------------------------------------- /datasets/sampler_ddp.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.sampler import Sampler 2 | from collections import defaultdict 3 | import copy 4 | import random 5 | import numpy as np 6 | import math 7 | import torch.distributed as dist 8 | _LOCAL_PROCESS_GROUP = None 9 | import torch 10 | import pickle 11 | 12 | def _get_global_gloo_group(): 13 | """ 14 | Return a process group based on gloo backend, containing all the ranks 15 | The result is cached. 16 | """ 17 | if dist.get_backend() == "nccl": 18 | return dist.new_group(backend="gloo") 19 | else: 20 | return dist.group.WORLD 21 | 22 | def _serialize_to_tensor(data, group): 23 | backend = dist.get_backend(group) 24 | assert backend in ["gloo", "nccl"] 25 | device = torch.device("cpu" if backend == "gloo" else "cuda") 26 | 27 | buffer = pickle.dumps(data) 28 | if len(buffer) > 1024 ** 3: 29 | print( 30 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( 31 | dist.get_rank(), len(buffer) / (1024 ** 3), device 32 | ) 33 | ) 34 | storage = torch.ByteStorage.from_buffer(buffer) 35 | tensor = torch.ByteTensor(storage).to(device=device) 36 | return tensor 37 | 38 | def _pad_to_largest_tensor(tensor, group): 39 | """ 40 | Returns: 41 | list[int]: size of the tensor, on each rank 42 | Tensor: padded tensor that has the max size 43 | """ 44 | world_size = dist.get_world_size(group=group) 45 | assert ( 46 | world_size >= 1 47 | ), "comm.gather/all_gather must be called from ranks within the given group!" 48 | local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) 49 | size_list = [ 50 | torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) 51 | ] 52 | dist.all_gather(size_list, local_size, group=group) 53 | size_list = [int(size.item()) for size in size_list] 54 | 55 | max_size = max(size_list) 56 | 57 | # we pad the tensor because torch all_gather does not support 58 | # gathering tensors of different shapes 59 | if local_size != max_size: 60 | padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) 61 | tensor = torch.cat((tensor, padding), dim=0) 62 | return size_list, tensor 63 | 64 | def all_gather(data, group=None): 65 | """ 66 | Run all_gather on arbitrary picklable data (not necessarily tensors). 67 | Args: 68 | data: any picklable object 69 | group: a torch process group. By default, will use a group which 70 | contains all ranks on gloo backend. 71 | Returns: 72 | list[data]: list of data gathered from each rank 73 | """ 74 | if dist.get_world_size() == 1: 75 | return [data] 76 | if group is None: 77 | group = _get_global_gloo_group() 78 | if dist.get_world_size(group) == 1: 79 | return [data] 80 | 81 | tensor = _serialize_to_tensor(data, group) 82 | 83 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 84 | max_size = max(size_list) 85 | 86 | # receiving Tensor from all ranks 87 | tensor_list = [ 88 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list 89 | ] 90 | dist.all_gather(tensor_list, tensor, group=group) 91 | 92 | data_list = [] 93 | for size, tensor in zip(size_list, tensor_list): 94 | buffer = tensor.cpu().numpy().tobytes()[:size] 95 | data_list.append(pickle.loads(buffer)) 96 | 97 | return data_list 98 | 99 | def shared_random_seed(): 100 | """ 101 | Returns: 102 | int: a random number that is the same across all workers. 103 | If workers need a shared RNG, they can use this shared seed to 104 | create one. 105 | All workers must call this function, otherwise it will deadlock. 106 | """ 107 | ints = np.random.randint(2 ** 31) 108 | all_ints = all_gather(ints) 109 | return all_ints[0] 110 | 111 | class RandomIdentitySampler_DDP(Sampler): 112 | """ 113 | Randomly sample N identities, then for each identity, 114 | randomly sample K instances, therefore batch size is N*K. 115 | Args: 116 | - data_source (list): list of (img_path, pid, camid). 117 | - num_instances (int): number of instances per identity in a batch. 118 | - batch_size (int): number of examples in a batch. 119 | """ 120 | 121 | def __init__(self, data_source, batch_size, num_instances): 122 | self.data_source = data_source 123 | self.batch_size = batch_size 124 | self.world_size = dist.get_world_size() 125 | self.num_instances = num_instances 126 | self.mini_batch_size = self.batch_size // self.world_size 127 | self.num_pids_per_batch = self.mini_batch_size // self.num_instances 128 | self.index_dic = defaultdict(list) 129 | 130 | for index, (_, pid, _, _) in enumerate(self.data_source): 131 | self.index_dic[pid].append(index) 132 | self.pids = list(self.index_dic.keys()) 133 | 134 | # estimate number of examples in an epoch 135 | self.length = 0 136 | for pid in self.pids: 137 | idxs = self.index_dic[pid] 138 | num = len(idxs) 139 | if num < self.num_instances: 140 | num = self.num_instances 141 | self.length += num - num % self.num_instances 142 | 143 | self.rank = dist.get_rank() 144 | #self.world_size = dist.get_world_size() 145 | self.length //= self.world_size 146 | 147 | def __iter__(self): 148 | seed = shared_random_seed() 149 | np.random.seed(seed) 150 | self._seed = int(seed) 151 | final_idxs = self.sample_list() 152 | length = int(math.ceil(len(final_idxs) * 1.0 / self.world_size)) 153 | #final_idxs = final_idxs[self.rank * length:(self.rank + 1) * length] 154 | final_idxs = self.__fetch_current_node_idxs(final_idxs, length) 155 | self.length = len(final_idxs) 156 | return iter(final_idxs) 157 | 158 | 159 | def __fetch_current_node_idxs(self, final_idxs, length): 160 | total_num = len(final_idxs) 161 | block_num = (length // self.mini_batch_size) 162 | index_target = [] 163 | for i in range(0, block_num * self.world_size, self.world_size): 164 | index = range(self.mini_batch_size * self.rank + self.mini_batch_size * i, min(self.mini_batch_size * self.rank + self.mini_batch_size * (i+1), total_num)) 165 | index_target.extend(index) 166 | index_target_npy = np.array(index_target) 167 | final_idxs = list(np.array(final_idxs)[index_target_npy]) 168 | return final_idxs 169 | 170 | 171 | def sample_list(self): 172 | #np.random.seed(self._seed) 173 | avai_pids = copy.deepcopy(self.pids) 174 | batch_idxs_dict = {} 175 | 176 | batch_indices = [] 177 | while len(avai_pids) >= self.num_pids_per_batch: 178 | selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False).tolist() 179 | for pid in selected_pids: 180 | if pid not in batch_idxs_dict: 181 | idxs = copy.deepcopy(self.index_dic[pid]) 182 | if len(idxs) < self.num_instances: 183 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True).tolist() 184 | np.random.shuffle(idxs) 185 | batch_idxs_dict[pid] = idxs 186 | 187 | avai_idxs = batch_idxs_dict[pid] 188 | for _ in range(self.num_instances): 189 | batch_indices.append(avai_idxs.pop(0)) 190 | 191 | if len(avai_idxs) < self.num_instances: avai_pids.remove(pid) 192 | 193 | return batch_indices 194 | 195 | def __len__(self): 196 | return self.length 197 | -------------------------------------------------------------------------------- /mamba/AUTHORS: -------------------------------------------------------------------------------- 1 | Tri Dao, tri@tridao.me 2 | Albert Gu, agu@andrew.cmu.edu 3 | -------------------------------------------------------------------------------- /mamba/assets/selection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/mamba/assets/selection.png -------------------------------------------------------------------------------- /mamba/benchmarks/benchmark_generation_mamba_simple.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao, Albert Gu. 2 | 3 | import argparse 4 | import time 5 | import json 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from einops import rearrange 11 | 12 | from transformers import AutoTokenizer, AutoModelForCausalLM 13 | 14 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 15 | 16 | 17 | parser = argparse.ArgumentParser(description="Generation benchmarking") 18 | parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m") 19 | parser.add_argument("--prompt", type=str, default=None) 20 | parser.add_argument("--promptlen", type=int, default=100) 21 | parser.add_argument("--genlen", type=int, default=100) 22 | parser.add_argument("--temperature", type=float, default=1.0) 23 | parser.add_argument("--topk", type=int, default=1) 24 | parser.add_argument("--topp", type=float, default=1.0) 25 | parser.add_argument("--batch", type=int, default=1) 26 | args = parser.parse_args() 27 | 28 | repeats = 3 29 | device = "cuda" 30 | dtype = torch.float16 31 | 32 | print(f"Loading model {args.model_name}") 33 | is_mamba = args.model_name.startswith("state-spaces/mamba-") or "mamba" in args.model_name 34 | 35 | if is_mamba: 36 | tokenizer = AutoTokenizer.from_pretrained("/home/zhulianghui/VisionProjects/mamba/ckpts/gpt-neox-20b-tokenizer") 37 | model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype) 38 | else: 39 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 40 | model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype) 41 | model.eval() 42 | print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") 43 | 44 | torch.random.manual_seed(0) 45 | if args.prompt is None: 46 | input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda") 47 | attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda") 48 | else: 49 | tokens = tokenizer(args.prompt, return_tensors="pt") 50 | input_ids = tokens.input_ids.to(device=device) 51 | attn_mask = tokens.attention_mask.to(device=device) 52 | max_length = input_ids.shape[1] + args.genlen 53 | 54 | if is_mamba: 55 | fn = lambda: model.generate( 56 | input_ids=input_ids, 57 | max_length=max_length, 58 | cg=True, 59 | return_dict_in_generate=True, 60 | output_scores=True, 61 | enable_timing=False, 62 | temperature=args.temperature, 63 | top_k=args.topk, 64 | top_p=args.topp, 65 | ) 66 | else: 67 | fn = lambda: model.generate( 68 | input_ids=input_ids, 69 | attention_mask=attn_mask, 70 | max_length=max_length, 71 | return_dict_in_generate=True, 72 | pad_token_id=tokenizer.eos_token_id, 73 | do_sample=True, 74 | temperature=args.temperature, 75 | top_k=args.topk, 76 | top_p=args.topp, 77 | ) 78 | out = fn() 79 | if args.prompt is not None: 80 | print(tokenizer.batch_decode(out.sequences.tolist())) 81 | 82 | torch.cuda.synchronize() 83 | start = time.time() 84 | for _ in range(repeats): 85 | fn() 86 | torch.cuda.synchronize() 87 | print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}") 88 | print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms") 89 | -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/reverse_scan.cuh: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | // #include 13 | #include "uninitialized_copy.cuh" 14 | 15 | /** 16 | * Perform a reverse sequential reduction over \p LENGTH elements of the \p input array. The aggregate is returned. 17 | */ 18 | template < 19 | int LENGTH, 20 | typename T, 21 | typename ReductionOp> 22 | __device__ __forceinline__ T ThreadReverseReduce(const T (&input)[LENGTH], ReductionOp reduction_op) { 23 | static_assert(LENGTH > 0); 24 | T retval = input[LENGTH - 1]; 25 | #pragma unroll 26 | for (int i = LENGTH - 2; i >= 0; --i) { retval = reduction_op(retval, input[i]); } 27 | return retval; 28 | } 29 | 30 | /** 31 | * Perform a sequential inclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned. 32 | */ 33 | template < 34 | int LENGTH, 35 | typename T, 36 | typename ScanOp> 37 | __device__ __forceinline__ T ThreadReverseScanInclusive( 38 | const T (&input)[LENGTH], 39 | T (&output)[LENGTH], 40 | ScanOp scan_op, 41 | const T postfix) 42 | { 43 | T inclusive = postfix; 44 | #pragma unroll 45 | for (int i = LENGTH - 1; i >= 0; --i) { 46 | inclusive = scan_op(inclusive, input[i]); 47 | output[i] = inclusive; 48 | } 49 | } 50 | 51 | /** 52 | * Perform a sequential exclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned. 53 | */ 54 | template < 55 | int LENGTH, 56 | typename T, 57 | typename ScanOp> 58 | __device__ __forceinline__ T ThreadReverseScanExclusive( 59 | const T (&input)[LENGTH], 60 | T (&output)[LENGTH], 61 | ScanOp scan_op, 62 | const T postfix) 63 | { 64 | // Careful, output maybe be aliased to input 65 | T exclusive = postfix; 66 | T inclusive; 67 | #pragma unroll 68 | for (int i = LENGTH - 1; i >= 0; --i) { 69 | inclusive = scan_op(exclusive, input[i]); 70 | output[i] = exclusive; 71 | exclusive = inclusive; 72 | } 73 | return inclusive; 74 | } 75 | 76 | 77 | /** 78 | * \brief WarpReverseScan provides SHFL-based variants of parallel postfix scan of items partitioned across a CUDA thread warp. 79 | * 80 | * LOGICAL_WARP_THREADS must be a power-of-two 81 | */ 82 | template < 83 | typename T, ///< Data type being scanned 84 | int LOGICAL_WARP_THREADS ///< Number of threads per logical warp 85 | > 86 | struct WarpReverseScan { 87 | //--------------------------------------------------------------------- 88 | // Constants and type definitions 89 | //--------------------------------------------------------------------- 90 | 91 | /// Whether the logical warp size and the PTX warp size coincide 92 | static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0)); 93 | /// The number of warp scan steps 94 | static constexpr int STEPS = cub::Log2::VALUE; 95 | static_assert(LOGICAL_WARP_THREADS == 1 << STEPS); 96 | 97 | 98 | //--------------------------------------------------------------------- 99 | // Thread fields 100 | //--------------------------------------------------------------------- 101 | 102 | /// Lane index in logical warp 103 | unsigned int lane_id; 104 | 105 | /// Logical warp index in 32-thread physical warp 106 | unsigned int warp_id; 107 | 108 | /// 32-thread physical warp member mask of logical warp 109 | unsigned int member_mask; 110 | 111 | //--------------------------------------------------------------------- 112 | // Construction 113 | //--------------------------------------------------------------------- 114 | 115 | /// Constructor 116 | explicit __device__ __forceinline__ 117 | WarpReverseScan() 118 | : lane_id(cub::LaneId()) 119 | , warp_id(IS_ARCH_WARP ? 0 : (lane_id / LOGICAL_WARP_THREADS)) 120 | , member_mask(cub::WarpMask(warp_id)) 121 | { 122 | if (!IS_ARCH_WARP) { 123 | lane_id = lane_id % LOGICAL_WARP_THREADS; 124 | } 125 | } 126 | 127 | 128 | /// Broadcast 129 | __device__ __forceinline__ T Broadcast( 130 | T input, ///< [in] The value to broadcast 131 | int src_lane) ///< [in] Which warp lane is to do the broadcasting 132 | { 133 | return cub::ShuffleIndex(input, src_lane, member_mask); 134 | } 135 | 136 | 137 | /// Inclusive scan 138 | template 139 | __device__ __forceinline__ void InclusiveReverseScan( 140 | T input, ///< [in] Calling thread's input item. 141 | T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. 142 | ScanOpT scan_op) ///< [in] Binary scan operator 143 | { 144 | inclusive_output = input; 145 | #pragma unroll 146 | for (int STEP = 0; STEP < STEPS; STEP++) { 147 | int offset = 1 << STEP; 148 | T temp = cub::ShuffleDown( 149 | inclusive_output, offset, LOGICAL_WARP_THREADS - 1, member_mask 150 | ); 151 | // Perform scan op if from a valid peer 152 | inclusive_output = static_cast(lane_id) >= LOGICAL_WARP_THREADS - offset 153 | ? inclusive_output : scan_op(temp, inclusive_output); 154 | } 155 | } 156 | 157 | /// Exclusive scan 158 | // Get exclusive from inclusive 159 | template 160 | __device__ __forceinline__ void ExclusiveReverseScan( 161 | T input, ///< [in] Calling thread's input item. 162 | T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. 163 | ScanOpT scan_op, ///< [in] Binary scan operator 164 | T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items. 165 | { 166 | T inclusive_output; 167 | InclusiveReverseScan(input, inclusive_output, scan_op); 168 | warp_aggregate = cub::ShuffleIndex(inclusive_output, 0, member_mask); 169 | // initial value unknown 170 | exclusive_output = cub::ShuffleDown( 171 | inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask 172 | ); 173 | } 174 | 175 | /** 176 | * \brief Computes both inclusive and exclusive reverse scans using the specified binary scan functor across the calling warp. Because no initial value is supplied, the \p exclusive_output computed for the last warp-lane is undefined. 177 | */ 178 | template 179 | __device__ __forceinline__ void ReverseScan( 180 | T input, ///< [in] Calling thread's input item. 181 | T &inclusive_output, ///< [out] Calling thread's inclusive-scan output item. 182 | T &exclusive_output, ///< [out] Calling thread's exclusive-scan output item. 183 | ScanOpT scan_op) ///< [in] Binary scan operator 184 | { 185 | InclusiveReverseScan(input, inclusive_output, scan_op); 186 | // initial value unknown 187 | exclusive_output = cub::ShuffleDown( 188 | inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask 189 | ); 190 | } 191 | 192 | }; 193 | 194 | /** 195 | * \brief BlockReverseScan provides variants of raking-based parallel postfix scan across a CUDA thread block. 196 | */ 197 | template < 198 | typename T, ///< Data type being scanned 199 | int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension 200 | bool MEMOIZE=false ///< Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure 201 | > 202 | struct BlockReverseScan { 203 | //--------------------------------------------------------------------- 204 | // Types and constants 205 | //--------------------------------------------------------------------- 206 | 207 | /// Constants 208 | /// The thread block size in threads 209 | static constexpr int BLOCK_THREADS = BLOCK_DIM_X; 210 | 211 | /// Layout type for padded thread block raking grid 212 | using BlockRakingLayout = cub::BlockRakingLayout; 213 | // The number of reduction elements is not a multiple of the number of raking threads for now 214 | static_assert(BlockRakingLayout::UNGUARDED); 215 | 216 | /// Number of raking threads 217 | static constexpr int RAKING_THREADS = BlockRakingLayout::RAKING_THREADS; 218 | /// Number of raking elements per warp synchronous raking thread 219 | static constexpr int SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH; 220 | /// Cooperative work can be entirely warp synchronous 221 | static constexpr bool WARP_SYNCHRONOUS = (int(BLOCK_THREADS) == int(RAKING_THREADS)); 222 | 223 | /// WarpReverseScan utility type 224 | using WarpReverseScan = WarpReverseScan; 225 | 226 | /// Shared memory storage layout type 227 | struct _TempStorage { 228 | typename BlockRakingLayout::TempStorage raking_grid; ///< Padded thread block raking grid 229 | }; 230 | 231 | 232 | /// Alias wrapper allowing storage to be unioned 233 | struct TempStorage : cub::Uninitialized<_TempStorage> {}; 234 | 235 | 236 | //--------------------------------------------------------------------- 237 | // Per-thread fields 238 | //--------------------------------------------------------------------- 239 | 240 | // Thread fields 241 | _TempStorage &temp_storage; 242 | unsigned int linear_tid; 243 | T cached_segment[SEGMENT_LENGTH]; 244 | 245 | 246 | //--------------------------------------------------------------------- 247 | // Utility methods 248 | //--------------------------------------------------------------------- 249 | 250 | /// Performs upsweep raking reduction, returning the aggregate 251 | template 252 | __device__ __forceinline__ T Upsweep(ScanOp scan_op) { 253 | T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); 254 | // Read data into registers 255 | #pragma unroll 256 | for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; } 257 | T raking_partial = cached_segment[SEGMENT_LENGTH - 1]; 258 | #pragma unroll 259 | for (int i = SEGMENT_LENGTH - 2; i >= 0; --i) { 260 | raking_partial = scan_op(raking_partial, cached_segment[i]); 261 | } 262 | return raking_partial; 263 | } 264 | 265 | 266 | /// Performs exclusive downsweep raking scan 267 | template 268 | __device__ __forceinline__ void ExclusiveDownsweep( 269 | ScanOp scan_op, 270 | T raking_partial) 271 | { 272 | T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); 273 | // Read data back into registers 274 | if (!MEMOIZE) { 275 | #pragma unroll 276 | for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; } 277 | } 278 | ThreadReverseScanExclusive(cached_segment, cached_segment, scan_op, raking_partial); 279 | // Write data back to smem 280 | #pragma unroll 281 | for (int i = 0; i < SEGMENT_LENGTH; ++i) { smem_raking_ptr[i] = cached_segment[i]; } 282 | } 283 | 284 | 285 | //--------------------------------------------------------------------- 286 | // Constructors 287 | //--------------------------------------------------------------------- 288 | 289 | /// Constructor 290 | __device__ __forceinline__ BlockReverseScan( 291 | TempStorage &temp_storage) 292 | : 293 | temp_storage(temp_storage.Alias()), 294 | linear_tid(cub::RowMajorTid(BLOCK_DIM_X, 1, 1)) 295 | {} 296 | 297 | 298 | /// Computes an exclusive thread block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. 299 | template < 300 | typename ScanOp, 301 | typename BlockPostfixCallbackOp> 302 | __device__ __forceinline__ void ExclusiveReverseScan( 303 | T input, ///< [in] Calling thread's input item 304 | T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) 305 | ScanOp scan_op, ///< [in] Binary scan operator 306 | BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a thread block-wide postfix to be applied to all inputs. 307 | { 308 | if (WARP_SYNCHRONOUS) { 309 | // Short-circuit directly to warp-synchronous scan 310 | T block_aggregate; 311 | WarpReverseScan warp_scan; 312 | warp_scan.ExclusiveReverseScan(input, exclusive_output, scan_op, block_aggregate); 313 | // Obtain warp-wide postfix in lane0, then broadcast to other lanes 314 | T block_postfix = block_postfix_callback_op(block_aggregate); 315 | block_postfix = warp_scan.Broadcast(block_postfix, 0); 316 | exclusive_output = linear_tid == BLOCK_THREADS - 1 ? block_postfix : scan_op(block_postfix, exclusive_output); 317 | } else { 318 | // Place thread partial into shared memory raking grid 319 | T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid); 320 | detail::uninitialized_copy(placement_ptr, input); 321 | cub::CTA_SYNC(); 322 | // Reduce parallelism down to just raking threads 323 | if (linear_tid < RAKING_THREADS) { 324 | WarpReverseScan warp_scan; 325 | // Raking upsweep reduction across shared partials 326 | T upsweep_partial = Upsweep(scan_op); 327 | // Warp-synchronous scan 328 | T exclusive_partial, block_aggregate; 329 | warp_scan.ExclusiveReverseScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate); 330 | // Obtain block-wide postfix in lane0, then broadcast to other lanes 331 | T block_postfix = block_postfix_callback_op(block_aggregate); 332 | block_postfix = warp_scan.Broadcast(block_postfix, 0); 333 | // Update postfix with warpscan exclusive partial 334 | T downsweep_postfix = linear_tid == RAKING_THREADS - 1 335 | ? block_postfix : scan_op(block_postfix, exclusive_partial); 336 | // Exclusive raking downsweep scan 337 | ExclusiveDownsweep(scan_op, downsweep_postfix); 338 | } 339 | cub::CTA_SYNC(); 340 | // Grab thread postfix from shared memory 341 | exclusive_output = *placement_ptr; 342 | 343 | // // Compute warp scan in each warp. 344 | // // The exclusive output from the last lane in each warp is invalid. 345 | // T inclusive_output; 346 | // WarpReverseScan warp_scan; 347 | // warp_scan.ReverseScan(input, inclusive_output, exclusive_output, scan_op); 348 | 349 | // // Compute the warp-wide postfix and block-wide aggregate for each warp. Warp postfix for the last warp is invalid. 350 | // T block_aggregate; 351 | // T warp_postfix = ComputeWarpPostfix(scan_op, inclusive_output, block_aggregate); 352 | 353 | // // Apply warp postfix to our lane's partial 354 | // if (warp_id != 0) { 355 | // exclusive_output = scan_op(warp_postfix, exclusive_output); 356 | // if (lane_id == 0) { exclusive_output = warp_postfix; } 357 | // } 358 | 359 | // // Use the first warp to determine the thread block postfix, returning the result in lane0 360 | // if (warp_id == 0) { 361 | // T block_postfix = block_postfix_callback_op(block_aggregate); 362 | // if (lane_id == 0) { 363 | // // Share the postfix with all threads 364 | // detail::uninitialized_copy(&temp_storage.block_postfix, 365 | // block_postfix); 366 | 367 | // exclusive_output = block_postfix; // The block postfix is the exclusive output for tid0 368 | // } 369 | // } 370 | 371 | // cub::CTA_SYNC(); 372 | 373 | // // Incorporate thread block postfix into outputs 374 | // T block_postfix = temp_storage.block_postfix; 375 | // if (linear_tid > 0) { exclusive_output = scan_op(block_postfix, exclusive_output); } 376 | } 377 | } 378 | 379 | 380 | /** 381 | * \brief Computes an inclusive block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. 382 | */ 383 | template < 384 | int ITEMS_PER_THREAD, 385 | typename ScanOp, 386 | typename BlockPostfixCallbackOp> 387 | __device__ __forceinline__ void InclusiveReverseScan( 388 | T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items 389 | T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input) 390 | ScanOp scan_op, ///< [in] Binary scan functor 391 | BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a block-wide postfix to be applied to the logical input sequence. 392 | { 393 | // Reduce consecutive thread items in registers 394 | T thread_postfix = ThreadReverseReduce(input, scan_op); 395 | // Exclusive thread block-scan 396 | ExclusiveReverseScan(thread_postfix, thread_postfix, scan_op, block_postfix_callback_op); 397 | // Inclusive scan in registers with postfix as seed 398 | ThreadReverseScanInclusive(input, output, scan_op, thread_postfix); 399 | } 400 | 401 | }; -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | //////////////////////////////////////////////////////////////////////////////////////////////////// 8 | 9 | struct SSMScanParamsBase { 10 | using index_t = uint32_t; 11 | 12 | int batch, seqlen, n_chunks; 13 | index_t a_batch_stride; 14 | index_t b_batch_stride; 15 | index_t out_batch_stride; 16 | 17 | // Common data pointers. 18 | void *__restrict__ a_ptr; 19 | void *__restrict__ b_ptr; 20 | void *__restrict__ out_ptr; 21 | void *__restrict__ x_ptr; 22 | }; 23 | 24 | //////////////////////////////////////////////////////////////////////////////////////////////////// 25 | 26 | struct SSMParamsBase { 27 | using index_t = uint32_t; 28 | 29 | int batch, dim, seqlen, dstate, n_groups, n_chunks; 30 | int dim_ngroups_ratio; 31 | bool is_variable_B; 32 | bool is_variable_C; 33 | 34 | bool delta_softplus; 35 | 36 | index_t A_d_stride; 37 | index_t A_dstate_stride; 38 | index_t B_batch_stride; 39 | index_t B_d_stride; 40 | index_t B_dstate_stride; 41 | index_t B_group_stride; 42 | index_t C_batch_stride; 43 | index_t C_d_stride; 44 | index_t C_dstate_stride; 45 | index_t C_group_stride; 46 | index_t u_batch_stride; 47 | index_t u_d_stride; 48 | index_t delta_batch_stride; 49 | index_t delta_d_stride; 50 | index_t z_batch_stride; 51 | index_t z_d_stride; 52 | index_t out_batch_stride; 53 | index_t out_d_stride; 54 | index_t out_z_batch_stride; 55 | index_t out_z_d_stride; 56 | 57 | // Common data pointers. 58 | void *__restrict__ A_ptr; 59 | void *__restrict__ B_ptr; 60 | void *__restrict__ C_ptr; 61 | void *__restrict__ D_ptr; 62 | void *__restrict__ u_ptr; 63 | void *__restrict__ delta_ptr; 64 | void *__restrict__ delta_bias_ptr; 65 | void *__restrict__ out_ptr; 66 | void *__restrict__ x_ptr; 67 | void *__restrict__ z_ptr; 68 | void *__restrict__ out_z_ptr; 69 | }; 70 | 71 | struct SSMParamsBwd: public SSMParamsBase { 72 | index_t dout_batch_stride; 73 | index_t dout_d_stride; 74 | index_t dA_d_stride; 75 | index_t dA_dstate_stride; 76 | index_t dB_batch_stride; 77 | index_t dB_group_stride; 78 | index_t dB_d_stride; 79 | index_t dB_dstate_stride; 80 | index_t dC_batch_stride; 81 | index_t dC_group_stride; 82 | index_t dC_d_stride; 83 | index_t dC_dstate_stride; 84 | index_t du_batch_stride; 85 | index_t du_d_stride; 86 | index_t dz_batch_stride; 87 | index_t dz_d_stride; 88 | index_t ddelta_batch_stride; 89 | index_t ddelta_d_stride; 90 | 91 | // Common data pointers. 92 | void *__restrict__ dout_ptr; 93 | void *__restrict__ dA_ptr; 94 | void *__restrict__ dB_ptr; 95 | void *__restrict__ dC_ptr; 96 | void *__restrict__ dD_ptr; 97 | void *__restrict__ du_ptr; 98 | void *__restrict__ dz_ptr; 99 | void *__restrict__ ddelta_ptr; 100 | void *__restrict__ ddelta_bias_ptr; 101 | }; 102 | -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan_bwd_bf16_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan_bwd_fp16_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan_bwd_fp32_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/csrc/selective_scan/selective_scan_common.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | #include // For scalar_value_type 10 | 11 | #define MAX_DSTATE 256 12 | 13 | using complex_t = c10::complex; 14 | 15 | inline __device__ float2 operator+(const float2 & a, const float2 & b){ 16 | return {a.x + b.x, a.y + b.y}; 17 | } 18 | 19 | inline __device__ float3 operator+(const float3 &a, const float3 &b) { 20 | return {a.x + b.x, a.y + b.y, a.z + b.z}; 21 | } 22 | 23 | inline __device__ float4 operator+(const float4 & a, const float4 & b){ 24 | return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; 25 | } 26 | 27 | //////////////////////////////////////////////////////////////////////////////////////////////////// 28 | 29 | template struct BytesToType {}; 30 | 31 | template<> struct BytesToType<16> { 32 | using Type = uint4; 33 | static_assert(sizeof(Type) == 16); 34 | }; 35 | 36 | template<> struct BytesToType<8> { 37 | using Type = uint64_t; 38 | static_assert(sizeof(Type) == 8); 39 | }; 40 | 41 | template<> struct BytesToType<4> { 42 | using Type = uint32_t; 43 | static_assert(sizeof(Type) == 4); 44 | }; 45 | 46 | template<> struct BytesToType<2> { 47 | using Type = uint16_t; 48 | static_assert(sizeof(Type) == 2); 49 | }; 50 | 51 | template<> struct BytesToType<1> { 52 | using Type = uint8_t; 53 | static_assert(sizeof(Type) == 1); 54 | }; 55 | 56 | //////////////////////////////////////////////////////////////////////////////////////////////////// 57 | 58 | template 59 | struct Converter{ 60 | static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { 61 | #pragma unroll 62 | for (int i = 0; i < N; ++i) { dst[i] = src[i]; } 63 | } 64 | }; 65 | 66 | template 67 | struct Converter{ 68 | static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) { 69 | static_assert(N % 2 == 0); 70 | auto &src2 = reinterpret_cast(src); 71 | auto &dst2 = reinterpret_cast(dst); 72 | #pragma unroll 73 | for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } 74 | } 75 | }; 76 | 77 | #if __CUDA_ARCH__ >= 800 78 | template 79 | struct Converter{ 80 | static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) { 81 | static_assert(N % 2 == 0); 82 | auto &src2 = reinterpret_cast(src); 83 | auto &dst2 = reinterpret_cast(dst); 84 | #pragma unroll 85 | for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } 86 | } 87 | }; 88 | #endif 89 | 90 | //////////////////////////////////////////////////////////////////////////////////////////////////// 91 | 92 | // From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp 93 | // and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696 94 | __device__ __forceinline__ complex_t cexp2f(complex_t z) { 95 | float t = exp2f(z.real_); 96 | float c, s; 97 | sincosf(z.imag_, &s, &c); 98 | return complex_t(c * t, s * t); 99 | } 100 | 101 | __device__ __forceinline__ complex_t cexpf(complex_t z) { 102 | float t = expf(z.real_); 103 | float c, s; 104 | sincosf(z.imag_, &s, &c); 105 | return complex_t(c * t, s * t); 106 | } 107 | 108 | template struct SSMScanOp; 109 | 110 | template<> 111 | struct SSMScanOp { 112 | __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const { 113 | return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); 114 | } 115 | }; 116 | 117 | template<> 118 | struct SSMScanOp { 119 | __device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const { 120 | complex_t a0 = complex_t(ab0.x, ab0.y); 121 | complex_t b0 = complex_t(ab0.z, ab0.w); 122 | complex_t a1 = complex_t(ab1.x, ab1.y); 123 | complex_t b1 = complex_t(ab1.z, ab1.w); 124 | complex_t out_a = a1 * a0; 125 | complex_t out_b = a1 * b0 + b1; 126 | return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_); 127 | } 128 | }; 129 | 130 | // A stateful callback functor that maintains a running prefix to be applied 131 | // during consecutive scan operations. 132 | template struct SSMScanPrefixCallbackOp { 133 | using scan_t = std::conditional_t, float2, float4>; 134 | scan_t running_prefix; 135 | // Constructor 136 | __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {} 137 | // Callback operator to be entered by the first warp of threads in the block. 138 | // Thread-0 is responsible for returning a value for seeding the block-wide scan. 139 | __device__ scan_t operator()(scan_t block_aggregate) { 140 | scan_t old_prefix = running_prefix; 141 | running_prefix = SSMScanOp()(running_prefix, block_aggregate); 142 | return old_prefix; 143 | } 144 | }; 145 | 146 | //////////////////////////////////////////////////////////////////////////////////////////////////// 147 | 148 | template 149 | inline __device__ void load_input(typename Ktraits::input_t *u, 150 | typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], 151 | typename Ktraits::BlockLoadT::TempStorage &smem_load, 152 | int seqlen) { 153 | if constexpr (Ktraits::kIsEvenLen) { 154 | auto& smem_load_vec = reinterpret_cast(smem_load); 155 | using vec_t = typename Ktraits::vec_t; 156 | Ktraits::BlockLoadVecT(smem_load_vec).Load( 157 | reinterpret_cast(u), 158 | reinterpret_cast(u_vals) 159 | ); 160 | } else { 161 | Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); 162 | } 163 | } 164 | 165 | template 166 | inline __device__ void load_weight(typename Ktraits::input_t *Bvar, 167 | typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], 168 | typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight, 169 | int seqlen) { 170 | constexpr int kNItems = Ktraits::kNItems; 171 | if constexpr (!Ktraits::kIsComplex) { 172 | typename Ktraits::input_t B_vals_load[kNItems]; 173 | if constexpr (Ktraits::kIsEvenLen) { 174 | auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); 175 | using vec_t = typename Ktraits::vec_t; 176 | Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( 177 | reinterpret_cast(Bvar), 178 | reinterpret_cast(B_vals_load) 179 | ); 180 | } else { 181 | Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); 182 | } 183 | // #pragma unroll 184 | // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } 185 | Converter::to_float(B_vals_load, B_vals); 186 | } else { 187 | typename Ktraits::input_t B_vals_load[kNItems * 2]; 188 | if constexpr (Ktraits::kIsEvenLen) { 189 | auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); 190 | using vec_t = typename Ktraits::vec_t; 191 | Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( 192 | reinterpret_cast(Bvar), 193 | reinterpret_cast(B_vals_load) 194 | ); 195 | } else { 196 | Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); 197 | } 198 | #pragma unroll 199 | for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); } 200 | } 201 | } 202 | 203 | template 204 | inline __device__ void store_output(typename Ktraits::input_t *out, 205 | const float (&out_vals)[Ktraits::kNItems], 206 | typename Ktraits::BlockStoreT::TempStorage &smem_store, 207 | int seqlen) { 208 | typename Ktraits::input_t write_vals[Ktraits::kNItems]; 209 | #pragma unroll 210 | for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } 211 | if constexpr (Ktraits::kIsEvenLen) { 212 | auto& smem_store_vec = reinterpret_cast(smem_store); 213 | using vec_t = typename Ktraits::vec_t; 214 | Ktraits::BlockStoreVecT(smem_store_vec).Store( 215 | reinterpret_cast(out), 216 | reinterpret_cast(write_vals) 217 | ); 218 | } else { 219 | Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); 220 | } 221 | } 222 | -------------------------------------------------------------------------------- /solver/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/solver/__init__.py -------------------------------------------------------------------------------- /solver/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/solver/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /solver/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/solver/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /solver/__pycache__/lr_scheduler.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/solver/__pycache__/lr_scheduler.cpython-310.pyc -------------------------------------------------------------------------------- /solver/__pycache__/lr_scheduler.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/solver/__pycache__/lr_scheduler.cpython-39.pyc -------------------------------------------------------------------------------- /solver/cosine_lr.py: -------------------------------------------------------------------------------- 1 | """ Cosine Scheduler 2 | 3 | Cosine LR schedule with warmup, cycle/restarts, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import logging 8 | import math 9 | import torch 10 | 11 | from .scheduler import Scheduler 12 | 13 | 14 | _logger = logging.getLogger(__name__) 15 | 16 | 17 | class CosineLRScheduler(Scheduler): 18 | """ 19 | Cosine decay with restarts. 20 | This is described in the paper https://arxiv.org/abs/1608.03983. 21 | 22 | Inspiration from 23 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py 24 | """ 25 | 26 | def __init__(self, 27 | optimizer: torch.optim.Optimizer, 28 | t_initial: int, 29 | t_mul: float = 1., 30 | lr_min: float = 0., 31 | decay_rate: float = 1., 32 | warmup_t=0, 33 | warmup_lr_init=0, 34 | warmup_prefix=False, 35 | cycle_limit=0, 36 | t_in_epochs=True, 37 | noise_range_t=None, 38 | noise_pct=0.67, 39 | noise_std=1.0, 40 | noise_seed=42, 41 | initialize=True) -> None: 42 | super().__init__( 43 | optimizer, param_group_field="lr", 44 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 45 | initialize=initialize) 46 | 47 | assert t_initial > 0 48 | assert lr_min >= 0 49 | if t_initial == 1 and t_mul == 1 and decay_rate == 1: 50 | _logger.warning("Cosine annealing scheduler will have no effect on the learning " 51 | "rate since t_initial = t_mul = eta_mul = 1.") 52 | self.t_initial = t_initial 53 | self.t_mul = t_mul 54 | self.lr_min = lr_min 55 | self.decay_rate = decay_rate 56 | self.cycle_limit = cycle_limit 57 | self.warmup_t = warmup_t 58 | self.warmup_lr_init = warmup_lr_init 59 | self.warmup_prefix = warmup_prefix 60 | self.t_in_epochs = t_in_epochs 61 | if self.warmup_t: 62 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 63 | super().update_groups(self.warmup_lr_init) 64 | else: 65 | self.warmup_steps = [1 for _ in self.base_values] 66 | 67 | def _get_lr(self, t): 68 | if t < self.warmup_t: 69 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 70 | else: 71 | if self.warmup_prefix: 72 | t = t - self.warmup_t 73 | 74 | if self.t_mul != 1: 75 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 76 | t_i = self.t_mul ** i * self.t_initial 77 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 78 | else: 79 | i = t // self.t_initial 80 | t_i = self.t_initial 81 | t_curr = t - (self.t_initial * i) 82 | 83 | gamma = self.decay_rate ** i 84 | lr_min = self.lr_min * gamma 85 | lr_max_values = [v * gamma for v in self.base_values] 86 | 87 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 88 | lrs = [ 89 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values 90 | ] 91 | else: 92 | lrs = [self.lr_min for _ in self.base_values] 93 | 94 | return lrs 95 | 96 | def get_epoch_values(self, epoch: int): 97 | if self.t_in_epochs: 98 | return self._get_lr(epoch) 99 | else: 100 | return None 101 | 102 | def get_update_values(self, num_updates: int): 103 | if not self.t_in_epochs: 104 | return self._get_lr(num_updates) 105 | else: 106 | return None 107 | 108 | def get_cycle_length(self, cycles=0): 109 | if not cycles: 110 | cycles = self.cycle_limit 111 | cycles = max(1, cycles) 112 | if self.t_mul == 1.0: 113 | return self.t_initial * cycles 114 | else: 115 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 116 | -------------------------------------------------------------------------------- /solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from bisect import bisect_right 7 | import torch 8 | 9 | 10 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 11 | # separating MultiStepLR with WarmupLR 12 | # but the current LRScheduler design doesn't allow it 13 | 14 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 15 | def __init__( 16 | self, 17 | optimizer, 18 | milestones, # steps 19 | gamma=0.1, 20 | warmup_factor=1.0 / 3, 21 | warmup_iters=500, 22 | warmup_method="linear", 23 | last_epoch=-1, 24 | ): 25 | if not list(milestones) == sorted(milestones): 26 | raise ValueError( 27 | "Milestones should be a list of" " increasing integers. Got {}", 28 | milestones, 29 | ) 30 | 31 | if warmup_method not in ("constant", "linear"): 32 | raise ValueError( 33 | "Only 'constant' or 'linear' warmup_method accepted" 34 | "got {}".format(warmup_method) 35 | ) 36 | self.milestones = milestones 37 | self.gamma = gamma 38 | self.warmup_factor = warmup_factor 39 | self.warmup_iters = warmup_iters 40 | self.warmup_method = warmup_method 41 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 42 | 43 | def get_lr(self): 44 | warmup_factor = 1 45 | if self.last_epoch < self.warmup_iters: 46 | if self.warmup_method == "constant": 47 | warmup_factor = self.warmup_factor 48 | elif self.warmup_method == "linear": 49 | alpha = self.last_epoch / self.warmup_iters 50 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 51 | return [ 52 | base_lr 53 | * warmup_factor 54 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 55 | for base_lr in self.base_lrs 56 | ] 57 | -------------------------------------------------------------------------------- /solver/make_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def make_optimizer(cfg, model, center_criterion): 4 | params = [] 5 | for key, value in model.named_parameters(): 6 | if not value.requires_grad: 7 | continue 8 | lr = cfg.SOLVER.BASE_LR 9 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 10 | if "bias" in key: 11 | lr = cfg.SOLVER.BASE_LR * 2 12 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 13 | if cfg.SOLVER.LARGE_FC_LR: 14 | if "classifier" in key or "arcface" in key: 15 | lr = cfg.SOLVER.BASE_LR * 2 16 | print('Using two times learning rate for fc ') 17 | 18 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 19 | 20 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': 21 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) 22 | elif cfg.SOLVER.OPTIMIZER_NAME == 'AdamW': 23 | optimizer = torch.optim.AdamW(params, lr=cfg.SOLVER.BASE_LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 24 | else: 25 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) 26 | optimizer_center = torch.optim.SGD(center_criterion.parameters(), lr=cfg.SOLVER.CENTER_LR) 27 | 28 | return optimizer, optimizer_center -------------------------------------------------------------------------------- /solver/scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | 3 | import torch 4 | 5 | 6 | class Scheduler: 7 | """ Parameter Scheduler Base Class 8 | A scheduler base class that can be used to schedule any optimizer parameter groups. 9 | 10 | Unlike the builtin PyTorch schedulers, this is intended to be consistently called 11 | * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value 12 | * At the END of each optimizer update, after incrementing the update count, to calculate next update's value 13 | 14 | The schedulers built on this should try to remain as stateless as possible (for simplicity). 15 | 16 | This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' 17 | and -1 values for special behaviour. All epoch and update counts must be tracked in the training 18 | code and explicitly passed in to the schedulers on the corresponding step or step_update call. 19 | 20 | Based on ideas from: 21 | * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler 22 | * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers 23 | """ 24 | 25 | def __init__(self, 26 | optimizer: torch.optim.Optimizer, 27 | param_group_field: str, 28 | noise_range_t=None, 29 | noise_type='normal', 30 | noise_pct=0.67, 31 | noise_std=1.0, 32 | noise_seed=None, 33 | initialize: bool = True) -> None: 34 | self.optimizer = optimizer 35 | self.param_group_field = param_group_field 36 | self._initial_param_group_field = f"initial_{param_group_field}" 37 | if initialize: 38 | for i, group in enumerate(self.optimizer.param_groups): 39 | if param_group_field not in group: 40 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]") 41 | group.setdefault(self._initial_param_group_field, group[param_group_field]) 42 | else: 43 | for i, group in enumerate(self.optimizer.param_groups): 44 | if self._initial_param_group_field not in group: 45 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") 46 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] 47 | self.metric = None # any point to having this for all? 48 | self.noise_range_t = noise_range_t 49 | self.noise_pct = noise_pct 50 | self.noise_type = noise_type 51 | self.noise_std = noise_std 52 | self.noise_seed = noise_seed if noise_seed is not None else 42 53 | self.update_groups(self.base_values) 54 | 55 | def state_dict(self) -> Dict[str, Any]: 56 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 57 | 58 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 59 | self.__dict__.update(state_dict) 60 | 61 | def get_epoch_values(self, epoch: int): 62 | return None 63 | 64 | def get_update_values(self, num_updates: int): 65 | return None 66 | 67 | def step(self, epoch: int, metric: float = None) -> None: 68 | self.metric = metric 69 | values = self.get_epoch_values(epoch) 70 | if values is not None: 71 | values = self._add_noise(values, epoch) 72 | self.update_groups(values) 73 | 74 | def step_update(self, num_updates: int, metric: float = None): 75 | self.metric = metric 76 | values = self.get_update_values(num_updates) 77 | if values is not None: 78 | values = self._add_noise(values, num_updates) 79 | self.update_groups(values) 80 | 81 | def update_groups(self, values): 82 | if not isinstance(values, (list, tuple)): 83 | values = [values] * len(self.optimizer.param_groups) 84 | for param_group, value in zip(self.optimizer.param_groups, values): 85 | param_group[self.param_group_field] = value 86 | 87 | def _add_noise(self, lrs, t): 88 | if self.noise_range_t is not None: 89 | if isinstance(self.noise_range_t, (list, tuple)): 90 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] 91 | else: 92 | apply_noise = t >= self.noise_range_t 93 | if apply_noise: 94 | g = torch.Generator() 95 | g.manual_seed(self.noise_seed + t) 96 | if self.noise_type == 'normal': 97 | while True: 98 | # resample if noise out of percent limit, brute force but shouldn't spin much 99 | noise = torch.randn(1, generator=g).item() 100 | if abs(noise) < self.noise_pct: 101 | break 102 | else: 103 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct 104 | lrs = [v + v * noise for v in lrs] 105 | return lrs 106 | -------------------------------------------------------------------------------- /solver/scheduler_factory.py: -------------------------------------------------------------------------------- 1 | """ Scheduler Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from .cosine_lr import CosineLRScheduler 5 | 6 | 7 | def create_scheduler(optimizer, num_epochs, lr_min, warmup_lr_init, warmup_t, noise_range = None): 8 | 9 | lr_scheduler = CosineLRScheduler( 10 | optimizer, 11 | t_initial=num_epochs, 12 | lr_min=lr_min, 13 | t_mul= 1., 14 | decay_rate=0.1, 15 | warmup_lr_init=warmup_lr_init, 16 | warmup_t=warmup_t, 17 | cycle_limit=1, 18 | t_in_epochs=True, 19 | noise_range_t=noise_range, 20 | noise_pct= 0.67, 21 | noise_std= 1., 22 | noise_seed=42, 23 | ) 24 | 25 | return lr_scheduler 26 | -------------------------------------------------------------------------------- /train_climb.py: -------------------------------------------------------------------------------- 1 | from utils.logger import setup_logger 2 | import random 3 | import torch 4 | import numpy as np 5 | import os 6 | import argparse 7 | from config import cfg 8 | from solver.lr_scheduler import WarmupMultiStepLR 9 | from climb.dataloader import make_CLIMB_dataloader 10 | from climb.processor_climb import train_climb 11 | from climb.optimizer import make_CLIMB_optimizer 12 | from climb.model import make_model 13 | 14 | def set_seed(seed): 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | np.random.seed(seed) 19 | random.seed(seed) 20 | torch.backends.cudnn.deterministic = True 21 | torch.backends.cudnn.benchmark = False 22 | 23 | if __name__ == '__main__': 24 | 25 | parser = argparse.ArgumentParser(description="ReID Baseline Training") 26 | parser.add_argument( 27 | "--config_file", default="./config/climb-vit-msmt.yml", help="path to config file", type=str 28 | ) 29 | 30 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 31 | nargs=argparse.REMAINDER) 32 | parser.add_argument("--local_rank", default=0, type=int) 33 | args = parser.parse_args() 34 | 35 | if args.config_file != "": 36 | cfg.merge_from_file(args.config_file) 37 | cfg.merge_from_list(args.opts) 38 | cfg.freeze() 39 | 40 | set_seed(cfg.SOLVER.SEED) 41 | 42 | output_dir = cfg.OUTPUT_DIR 43 | if output_dir and not os.path.exists(output_dir): 44 | os.makedirs(output_dir) 45 | 46 | logger = setup_logger("CLIMB", output_dir, if_train=True) 47 | logger.info("Saving model in the path :{}".format(cfg.OUTPUT_DIR)) 48 | logger.info(args) 49 | 50 | if args.config_file != "": 51 | logger.info("Loaded configuration file {}".format(args.config_file)) 52 | with open(args.config_file, 'r') as cf: 53 | config_str = "\n" + cf.read() 54 | logger.info(config_str) 55 | logger.info("Running with config:\n{}".format(cfg)) 56 | 57 | train_loader, val_loader, cluster_loader, num_query, num_classes, camera_num, view_num = make_CLIMB_dataloader(cfg) 58 | model = make_model(cfg, num_classes, camera_num=camera_num, view_num=view_num) 59 | optimizer = make_CLIMB_optimizer(cfg, model) 60 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 61 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD) 62 | 63 | train_climb( 64 | cfg, 65 | model, 66 | train_loader, 67 | val_loader, 68 | cluster_loader, 69 | optimizer, 70 | scheduler, 71 | num_query, 72 | num_classes 73 | ) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/utils/__pycache__/logger.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/utils/__pycache__/logger.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/meter.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/utils/__pycache__/meter.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/meter.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CLIMB-ReID/42c5839e40eb63d3bd6ad0a8ff524f9d1fb9aaa7/utils/__pycache__/meter.cpython-39.pyc -------------------------------------------------------------------------------- /utils/iotools.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import errno 8 | import json 9 | import os 10 | 11 | import os.path as osp 12 | 13 | 14 | def mkdir_if_missing(directory): 15 | if not osp.exists(directory): 16 | try: 17 | os.makedirs(directory) 18 | except OSError as e: 19 | if e.errno != errno.EEXIST: 20 | raise 21 | 22 | 23 | def check_isfile(path): 24 | isfile = osp.isfile(path) 25 | if not isfile: 26 | print("=> Warning: no file found at '{}' (ignored)".format(path)) 27 | return isfile 28 | 29 | 30 | def read_json(fpath): 31 | with open(fpath, 'r') as f: 32 | obj = json.load(f) 33 | return obj 34 | 35 | 36 | def write_json(obj, fpath): 37 | mkdir_if_missing(osp.dirname(fpath)) 38 | with open(fpath, 'w') as f: 39 | json.dump(obj, f, indent=4, separators=(',', ': ')) 40 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import os.path as osp 5 | def setup_logger(name, save_dir, if_train): 6 | logger = logging.getLogger(name) 7 | logger.setLevel(logging.DEBUG) 8 | 9 | ch = logging.StreamHandler(stream=sys.stdout) 10 | ch.setLevel(logging.DEBUG) 11 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 12 | ch.setFormatter(formatter) 13 | logger.addHandler(ch) 14 | 15 | if save_dir: 16 | if not osp.exists(save_dir): 17 | os.makedirs(save_dir) 18 | if if_train: 19 | fh = logging.FileHandler(os.path.join(save_dir, "train_log.txt"), mode='w') 20 | else: 21 | fh = logging.FileHandler(os.path.join(save_dir, "test_log.txt"), mode='w') 22 | fh.setLevel(logging.DEBUG) 23 | fh.setFormatter(formatter) 24 | logger.addHandler(fh) 25 | 26 | return logger -------------------------------------------------------------------------------- /utils/meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | 4 | def __init__(self): 5 | self.val = 0 6 | self.avg = 0 7 | self.sum = 0 8 | self.count = 0 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from utils.reranking import re_ranking 5 | 6 | 7 | def euclidean_distance(qf, gf): 8 | m = qf.shape[0] 9 | n = gf.shape[0] 10 | dist_mat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 11 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 12 | dist_mat.addmm_(1, -2, qf, gf.t()) 13 | return dist_mat.cpu().numpy() 14 | 15 | def cosine_similarity(qf, gf): 16 | epsilon = 0.00001 17 | dist_mat = qf.mm(gf.t()) 18 | qf_norm = torch.norm(qf, p=2, dim=1, keepdim=True) # mx1 19 | gf_norm = torch.norm(gf, p=2, dim=1, keepdim=True) # nx1 20 | qg_normdot = qf_norm.mm(gf_norm.t()) 21 | 22 | dist_mat = dist_mat.mul(1 / qg_normdot).cpu().numpy() 23 | dist_mat = np.clip(dist_mat, -1 + epsilon, 1 - epsilon) 24 | dist_mat = np.arccos(dist_mat) 25 | return dist_mat 26 | 27 | 28 | def org_cosine_similarity(qf, gf): 29 | 30 | q_norm = torch.norm(qf, p=2, dim=1, keepdim=True) 31 | g_norm = torch.norm(gf, p=2, dim=1, keepdim=True) 32 | qf = qf.div(q_norm.expand_as(qf)) # torch.Size([3873, 2048]) 33 | gf = gf.div(g_norm.expand_as(gf)) # torch.Size([3384, 2048]) 34 | dist_mat = - torch.mm(qf, gf.t()) 35 | 36 | return dist_mat 37 | 38 | 39 | def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): 40 | """Evaluation with market1501 metric 41 | Key: for each query identity, its gallery images from the same camera view are discarded. 42 | """ 43 | num_q, num_g = distmat.shape 44 | # distmat g 45 | # q 1 3 2 4 46 | # 4 1 2 3 47 | if num_g < max_rank: 48 | max_rank = num_g 49 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 50 | indices = np.argsort(distmat, axis=1) 51 | # 0 2 1 3 52 | # 1 2 3 0 53 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 54 | # compute cmc curve for each query 55 | all_cmc = [] 56 | all_AP = [] 57 | num_valid_q = 0. # number of valid query 58 | for q_idx in range(num_q): 59 | # get query pid and camid 60 | q_pid = q_pids[q_idx] 61 | q_camid = q_camids[q_idx] 62 | 63 | # remove gallery samples that have the same pid and camid with query 64 | order = indices[q_idx] # select one row 65 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 66 | keep = np.invert(remove) 67 | 68 | # compute cmc curve 69 | # binary vector, positions with value 1 are correct matches 70 | orig_cmc = matches[q_idx][keep] 71 | if not np.any(orig_cmc): 72 | # this condition is true when query identity does not appear in gallery 73 | continue 74 | 75 | cmc = orig_cmc.cumsum() 76 | cmc[cmc > 1] = 1 77 | 78 | all_cmc.append(cmc[:max_rank]) 79 | num_valid_q += 1. 80 | 81 | # compute average precision 82 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 83 | num_rel = orig_cmc.sum() 84 | tmp_cmc = orig_cmc.cumsum() 85 | #tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 86 | y = np.arange(1, tmp_cmc.shape[0] + 1) * 1.0 87 | tmp_cmc = tmp_cmc / y 88 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 89 | AP = tmp_cmc.sum() / num_rel 90 | all_AP.append(AP) 91 | 92 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 93 | 94 | all_cmc = np.asarray(all_cmc).astype(np.float32) 95 | all_cmc = all_cmc.sum(0) / num_valid_q 96 | mAP = np.mean(all_AP) 97 | 98 | return all_cmc, mAP 99 | 100 | 101 | class R1_mAP_eval(): 102 | def __init__(self, num_query, max_rank=50, feat_norm=True, reranking=False): 103 | super(R1_mAP_eval, self).__init__() 104 | self.num_query = num_query 105 | self.max_rank = max_rank 106 | self.feat_norm = feat_norm 107 | self.reranking = reranking 108 | 109 | def reset(self): 110 | self.feats = [] 111 | self.feats0 = [] 112 | self.feats1 = [] 113 | self.pids = [] 114 | self.camids = [] 115 | 116 | def update(self, output): # called once for each batch 117 | feat, feat0, feat1, pid, camid = output 118 | self.feats.append(feat.cpu()) 119 | self.feats0.append(feat0.cpu()) 120 | self.feats1.append(feat1.cpu()) 121 | self.pids.extend(np.asarray(pid)) 122 | self.camids.extend(np.asarray(camid)) 123 | 124 | def compute(self): # called after each epoch 125 | feats = torch.cat(self.feats, dim=0) 126 | feats0 = torch.cat(self.feats0, dim=0) 127 | feats1 = torch.cat(self.feats1, dim=0) 128 | if self.feat_norm: 129 | print("The test feature is normalized") 130 | feats = torch.nn.functional.normalize(feats, dim=1, p=2) # along channel 131 | feats0 = torch.nn.functional.normalize(feats0, dim=1, p=2) # along channel 132 | feats1 = torch.nn.functional.normalize(feats1, dim=1, p=2) # along channel 133 | # query 134 | qf = feats[:self.num_query] 135 | qf0 = feats0[:self.num_query] 136 | qf1 = feats1[:self.num_query] 137 | q_pids = np.asarray(self.pids[:self.num_query]) 138 | q_camids = np.asarray(self.camids[:self.num_query]) 139 | # gallery 140 | gf = feats[self.num_query:] 141 | gf0 = feats0[self.num_query:] 142 | gf1 = feats1[self.num_query:] 143 | g_pids = np.asarray(self.pids[self.num_query:]) 144 | 145 | g_camids = np.asarray(self.camids[self.num_query:]) 146 | if self.reranking: 147 | print('=> Enter reranking') 148 | # distmat = re_ranking(qf, gf, k1=20, k2=6, lambda_value=0.3) 149 | distmat = re_ranking(qf, gf, k1=50, k2=15, lambda_value=0.3) 150 | 151 | else: 152 | print('=> Computing DistMat with euclidean_distance') 153 | distmat = euclidean_distance(qf, gf) 154 | # distmat = org_cosine_similarity(qf, gf) 155 | distmat01 = org_cosine_similarity(qf0, gf0) 156 | distmat02 = org_cosine_similarity(qf1, gf1) 157 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids) 158 | cmc01, mAP01 = eval_func(distmat01, q_pids, g_pids, q_camids, g_camids) 159 | cmc02, mAP02 = eval_func(distmat02, q_pids, g_pids, q_camids, g_camids) 160 | cmc03, mAP03 = eval_func(distmat01 + distmat02, q_pids, g_pids, q_camids, g_camids) 161 | return cmc, mAP, cmc01, mAP01, cmc02, mAP02, cmc03, mAP03 162 | 163 | 164 | 165 | -------------------------------------------------------------------------------- /utils/reranking.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri, 25 May 2018 20:29:09 5 | 6 | 7 | """ 8 | 9 | """ 10 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 11 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 12 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 13 | """ 14 | 15 | """ 16 | API 17 | 18 | probFea: all feature vectors of the query set (torch tensor) 19 | probFea: all feature vectors of the gallery set (torch tensor) 20 | k1,k2,lambda: parameters, the original paper is (k1=20,k2=6,lambda=0.3) 21 | MemorySave: set to 'True' when using MemorySave mode 22 | Minibatch: avaliable when 'MemorySave' is 'True' 23 | """ 24 | 25 | import numpy as np 26 | import torch 27 | 28 | 29 | def re_ranking(probFea, galFea, k1, k2, lambda_value, local_distmat=None, only_local=False): 30 | # if feature vector is numpy, you should use 'torch.tensor' transform it to tensor 31 | query_num = probFea.size(0) 32 | all_num = query_num + galFea.size(0) 33 | if only_local: 34 | original_dist = local_distmat 35 | else: 36 | feat = torch.cat([probFea, galFea]) 37 | # print('using GPU to compute original distance') 38 | distmat = torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num) + \ 39 | torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num).t() 40 | distmat.addmm_(1, -2, feat, feat.t()) 41 | original_dist = distmat.cpu().numpy() 42 | del feat 43 | if not local_distmat is None: 44 | original_dist = original_dist + local_distmat 45 | gallery_num = original_dist.shape[0] 46 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0)) 47 | V = np.zeros_like(original_dist).astype(np.float16) 48 | initial_rank = np.argsort(original_dist).astype(np.int32) 49 | 50 | # print('starting re_ranking') 51 | for i in range(all_num): 52 | # k-reciprocal neighbors 53 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 54 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 55 | fi = np.where(backward_k_neigh_index == i)[0] 56 | k_reciprocal_index = forward_k_neigh_index[fi] 57 | k_reciprocal_expansion_index = k_reciprocal_index 58 | for j in range(len(k_reciprocal_index)): 59 | candidate = k_reciprocal_index[j] 60 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1] 61 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, 62 | :int(np.around(k1 / 2)) + 1] 63 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 64 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 65 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len( 66 | candidate_k_reciprocal_index): 67 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index) 68 | 69 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 70 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) 71 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight) 72 | original_dist = original_dist[:query_num, ] 73 | if k2 != 1: 74 | V_qe = np.zeros_like(V, dtype=np.float16) 75 | for i in range(all_num): 76 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 77 | V = V_qe 78 | del V_qe 79 | del initial_rank 80 | invIndex = [] 81 | for i in range(gallery_num): 82 | invIndex.append(np.where(V[:, i] != 0)[0]) 83 | 84 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16) 85 | 86 | for i in range(query_num): 87 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16) 88 | indNonZero = np.where(V[i, :] != 0)[0] 89 | indImages = [invIndex[ind] for ind in indNonZero] 90 | for j in range(len(indNonZero)): 91 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]], 92 | V[indImages[j], indNonZero[j]]) 93 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 94 | 95 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value 96 | del original_dist 97 | del V 98 | del jaccard_dist 99 | final_dist = final_dist[:query_num, query_num:] 100 | return final_dist 101 | 102 | --------------------------------------------------------------------------------