├── tools ├── __init__.py ├── test.py └── main.py ├── utils ├── __init__.py ├── logger.py ├── iotools.py ├── lr_scheduler.py └── reid_metric.py ├── modeling ├── backbones │ └── __init__.py ├── layer │ ├── __init__.py │ ├── gem_pool.py │ ├── non_local.py │ └── center_loss.py └── __init__.py ├── Transformer-ReID-Survey ├── UnTransReID_USL_ReID │ ├── examples │ │ └── logs │ │ │ └── log.txt │ ├── solver │ │ ├── __init__.py │ │ ├── scheduler_factory.py │ │ ├── lr_scheduler.py │ │ ├── make_optimizer.py │ │ └── cosine_lr.py │ ├── train_market.sh │ ├── train_msmt.sh │ └── clustercontrast │ │ ├── evaluation_metrics │ │ ├── __init__.py │ │ └── classification.py │ │ ├── __init__.py │ │ ├── utils │ │ ├── osutils.py │ │ ├── meters.py │ │ ├── __init__.py │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── preprocessor.py │ │ │ ├── base_dataset.py │ │ │ └── transforms.py │ │ ├── infomap_utils.py │ │ ├── logging.py │ │ ├── serialization.py │ │ ├── faiss_utils.py │ │ └── pos_embed.py │ │ ├── models │ │ ├── contrastive_loss.py │ │ ├── kmeans.py │ │ ├── __init__.py │ │ ├── dsbn.py │ │ └── cm.py │ │ ├── datasets │ │ ├── __init__.py │ │ ├── dukemtmcreid.py │ │ ├── market1501.py │ │ └── msmt17.py │ │ └── trainers.py ├── UnTransReID_VI_ReID │ ├── clustercontrast │ │ ├── model_vit │ │ │ ├── backbones │ │ │ │ └── __init__.py │ │ │ └── __init__.py │ │ ├── model_vit_fft │ │ │ ├── backbones │ │ │ │ └── __init__.py │ │ │ └── __init__.py │ │ ├── model_vit_lptn │ │ │ ├── backbones │ │ │ │ └── __init__.py │ │ │ └── __init__.py │ │ ├── model_vit_token │ │ │ ├── backbones │ │ │ │ └── __init__.py │ │ │ └── __init__.py │ │ ├── model_vit_cmrefine │ │ │ ├── backbones │ │ │ │ └── __init__.py │ │ │ └── __init__.py │ │ ├── evaluation_metrics │ │ │ ├── __init__.py │ │ │ └── classification.py │ │ ├── __init__.py │ │ ├── utils │ │ │ ├── osutils.py │ │ │ ├── meters.py │ │ │ ├── __init__.py │ │ │ ├── data │ │ │ │ ├── __init__.py │ │ │ │ ├── base_dataset.py │ │ │ │ ├── preprocessor.py │ │ │ │ └── transforms.py │ │ │ ├── infomap_utils.py │ │ │ ├── logging.py │ │ │ ├── serialization.py │ │ │ └── faiss_utils.py │ │ ├── models │ │ │ ├── kmeans.py │ │ │ ├── __init__.py │ │ │ └── dsbn.py │ │ └── datasets │ │ │ ├── __init__.py │ │ │ ├── sysu_all.py │ │ │ ├── regdb_rgb.py │ │ │ ├── regdb_ir.py │ │ │ ├── sysu_rgb.py │ │ │ └── sysu_ir.py │ ├── solver │ │ ├── __init__.py │ │ ├── scheduler_factory.py │ │ ├── lr_scheduler.py │ │ ├── make_optimizer.py │ │ └── cosine_lr.py │ ├── config │ │ └── __init__.py │ ├── train_sysu.sh │ ├── setup.py │ ├── vit_small_ics_384.yml │ ├── vit_base_ics_288.yml │ ├── vit_base_ics_384.yml │ └── train_regdb.sh ├── Animal-Re-ID-main │ ├── Tools │ │ ├── market_change.py │ │ ├── ATRW_test.py │ │ └── split_train_test.py │ └── README.md └── README.md ├── config └── __init__.py ├── data ├── __init__.py ├── datasets │ ├── __init__.py │ ├── dataset_loader.py │ ├── partial_ilids.py │ ├── partial_reid.py │ ├── eval_reid.py │ ├── market1501.py │ ├── veri.py │ ├── msmt17.py │ ├── bases.py │ └── dukemtmcreid.py ├── triplet_sampler.py ├── build.py └── transforms.py ├── video-reid-AWG ├── train_video_agw_plus.sh ├── models │ ├── __init__.py │ ├── gem_pool.py │ └── non_local.py ├── README.md ├── samplers.py ├── lr_scheduler.py ├── eval_metrics.py └── transforms.py ├── Experiment-AGW-duke.sh ├── Experiment-AGW-market.sh ├── Experiment-AGW-partial.sh ├── Test-AGW-duke.sh ├── Test-AGW-market.sh ├── LICENSE ├── configs └── AGW_baseline.yml └── .gitignore /tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | -------------------------------------------------------------------------------- /modeling/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/examples/logs/log.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | from .defaults import _C as cfg 4 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | from .build import make_data_loader 4 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/model_vit/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/model_vit_fft/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/model_vit_lptn/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/model_vit_token/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/model_vit_cmrefine/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/model_vit/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_model import make_model -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/model_vit_fft/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_model import make_model -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/model_vit_lptn/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_model import make_model -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/model_vit_token/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_model import make_model -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/model_vit_cmrefine/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_model import make_model -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .lr_scheduler import WarmupMultiStepLR 2 | from .make_optimizer import make_optimizer -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .lr_scheduler import WarmupMultiStepLR 2 | from .make_optimizer import make_optimizer -------------------------------------------------------------------------------- /video-reid-AWG/train_video_agw_plus.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python3 ./main_video_person_reid.py --arch AGW_Plus_Baseline \ 2 | --train-dataset mars --test-dataset mars --save-dir ./mars_agw_plus -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/config/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .defaults import _C as cfg 8 | from .defaults import _C as cfg_test 9 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/train_market.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3 examples/cluster_contrast_train_usl.py -b 256 -a vit_base -d market1501 --iters 200 --eps 0.6 --self-norm --use-hard --hw-ratio 2 --num-instances 8 --conv-stem 2 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/train_msmt.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3 python examples/cluster_contrast_train_usl.py -b 256 -a vit_base -d msmt17 --self-norm --use-hard --hw-ratio 2 --num-instances 8 --conv-stem --iters 400 --eps 0.7 --k1 30 -------------------------------------------------------------------------------- /modeling/layer/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | from .center_loss import CenterLoss 4 | from .triplet_loss import CrossEntropyLabelSmooth, TripletLoss, WeightedRegularizedTriplet 5 | from .non_local import Non_local 6 | from .gem_pool import GeneralizedMeanPooling, GeneralizedMeanPoolingP -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/evaluation_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .classification import accuracy 4 | from .ranking import cmc, mean_ap 5 | 6 | __all__ = [ 7 | 'accuracy', 8 | 'cmc', 9 | 'mean_ap' 10 | ] 11 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/clustercontrast/evaluation_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .classification import accuracy 4 | from .ranking import cmc, mean_ap 5 | 6 | __all__ = [ 7 | 'accuracy', 8 | 'cmc', 9 | 'mean_ap' 10 | ] 11 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/clustercontrast/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import datasets 4 | from . import evaluation_metrics 5 | from . import models 6 | from . import utils 7 | from . import evaluators 8 | from . import trainers 9 | 10 | __version__ = '0.1.0' 11 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import datasets 4 | from . import evaluation_metrics 5 | from . import models 6 | from . import utils 7 | from . import evaluators 8 | from . import trainers 9 | 10 | __version__ = '0.1.0' 11 | -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | from .baseline import Baseline 4 | 5 | def build_model(cfg, num_classes): 6 | model = Baseline(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NAME, 7 | cfg.MODEL.GENERALIZED_MEAN_POOL, cfg.MODEL.PRETRAIN_CHOICE) 8 | return model -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/clustercontrast/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /video-reid-AWG/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .models import * 4 | 5 | __factory = { 6 | 'AGW_Plus_Baseline': AGW_Plus_Baseline, 7 | } 8 | 9 | def get_names(): 10 | return __factory.keys() 11 | 12 | def init_model(name, *args, **kwargs): 13 | if name not in __factory.keys(): 14 | raise KeyError("Unknown model: {}".format(name)) 15 | return __factory[name](*args, **kwargs) 16 | -------------------------------------------------------------------------------- /Experiment-AGW-duke.sh: -------------------------------------------------------------------------------- 1 | # Dataset: dukemtmc 2 | # imagesize: 256x128 3 | # batchsize: 16x4 4 | # warmup_step 10 5 | # random erase prob 0.5 6 | # last stride 1 7 | # with center loss 8 | # weight regularized triplet loss 9 | # generalized mean pooling 10 | # non local blocks 11 | python3 tools/main.py --config_file='configs/AGW_baseline.yml' MODEL.DEVICE_ID "('2')" \ 12 | DATASETS.NAMES "('dukemtmc')" OUTPUT_DIR "('./log/dukemtmc/Experiment-AGW-baseline')" -------------------------------------------------------------------------------- /Experiment-AGW-market.sh: -------------------------------------------------------------------------------- 1 | # Dataset: market1501 2 | # imagesize: 256x128 3 | # batchsize: 16x4 4 | # warmup_step 10 5 | # random erase prob 0.5 6 | # last stride 1 7 | # with center loss 8 | # weight regularized triplet loss 9 | # generalized mean pooling 10 | # non local blocks 11 | python3 tools/main.py --config_file='configs/AGW_baseline.yml' MODEL.DEVICE_ID "('3')" \ 12 | DATASETS.NAMES "('market1501')" OUTPUT_DIR "('./log/market1501/Experiment-AGW-baseline')" -------------------------------------------------------------------------------- /Experiment-AGW-partial.sh: -------------------------------------------------------------------------------- 1 | # Dataset: train on market1501, eval on partial_reid and partial_ilids 2 | # imagesize: 256x128 3 | # batchsize: 16x4 4 | # warmup_step 10 5 | # random erase prob 0.5 6 | # last stride 1 7 | # with center loss 8 | # weight regularized triplet loss 9 | # generalized mean pooling 10 | # non local blocks 11 | python3 tools/main.py --config_file='configs/AGW_baseline.yml' MODEL.DEVICE_ID "('1')" \ 12 | DATASETS.NAMES "('market1501')" TEST.PARTIAL_REID "('on')" OUTPUT_DIR "('./log/market1501/Experiment-AGW-baseline-partial')" 13 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/Animal-Re-ID-main/Tools/market_change.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | 5 | directory = '' 6 | new_name_format = '{:>4}_c{}s{}_{}' 7 | 8 | 9 | for filename in os.listdir(directory): 10 | if filename.endswith(".jpg") and "_" in filename: 11 | id, eid = filename.split("_")[:2] 12 | 13 | new_name = new_name_format.format(id, random.randint(1, 6), random.randint(1, 6), eid) 14 | 15 | old_path = os.path.join(directory, filename) 16 | new_path = os.path.join(directory, new_name) 17 | 18 | os.rename(old_path, new_path) -------------------------------------------------------------------------------- /Test-AGW-duke.sh: -------------------------------------------------------------------------------- 1 | # Dataset: dukemtmc 2 | # imagesize: 256x128 3 | # batchsize: 16x4 4 | # warmup_step 10 5 | # random erase prob 0.5 6 | # last stride 1 7 | # with center loss 8 | # weight regularized triplet loss 9 | # generalized mean pooling 10 | # non local blocks 11 | # without re-ranking: add TEST.RE_RANKING "('on')" for re-ranking 12 | python3 tools/main.py --config_file='configs/AGW_baseline.yml' MODEL.DEVICE_ID "('1')" \ 13 | DATASETS.NAMES "('dukemtmc')" MODEL.PRETRAIN_CHOICE "('self')" \ 14 | TEST.WEIGHT "('./pretrained/dukemtmc_AGW.pth')" TEST.EVALUATE_ONLY "('on')" OUTPUT_DIR "('./log/Test')" 15 | -------------------------------------------------------------------------------- /Test-AGW-market.sh: -------------------------------------------------------------------------------- 1 | # Dataset: market1501 2 | # imagesize: 256x128 3 | # batchsize: 16x4 4 | # warmup_step 10 5 | # random erase prob 0.5 6 | # last stride 1 7 | # with center loss 8 | # weight regularized triplet loss 9 | # generalized mean pooling 10 | # non local blocks 11 | # without re-ranking: add TEST.RE_RANKING "('on')" for re-ranking 12 | python3 tools/main.py --config_file='configs/AGW_baseline.yml' MODEL.DEVICE_ID "('1')" \ 13 | DATASETS.NAMES "('market1501')" MODEL.PRETRAIN_CHOICE "('self')" \ 14 | TEST.WEIGHT "('./pretrained/market1501_AGW.pth')" TEST.EVALUATE_ONLY "('on')" OUTPUT_DIR "('./log/Test')" 15 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/clustercontrast/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/train_sysu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -A chenjun3 3 | #SBATCH -p a100x4 4 | #SBATCH -N 1 5 | #SBATCH --ntasks=1 6 | #SBATCH --ntasks-per-node=4 7 | #SBATCH --cpus-per-task=4 8 | #SBATCH --gres=gpu:2 9 | #SBATCH --exclude=g0154,g0150,g0158 10 | #SBATCH -o sysu_2p_384_g.log 11 | # module load scl/gcc5.3 12 | module load nvidia/cuda/11.6 13 | # rm -rf x.txt 14 | # CUDA_VISIBLE_DEVICES=0,1,2,3 python cluster_contrast_camera_cmrouting.py -b 256 -a agw -d sysu_all --iters 200 --momentum 0.1 --eps 0.6 --num-instances 16 15 | CUDA_VISIBLE_DEVICES=0,1 python sysu_train.py -b 256 -a agw -d sysu_all --iters 200 --momentum 0.1 --eps 0.6 --num-instances 16 16 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/clustercontrast/models/contrastive_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ConLoss(object): 6 | """ 7 | Contrastive loss. 8 | """ 9 | 10 | def __init__(self, temperature=0.8): 11 | self.temperature = temperature 12 | 13 | def __call__(self, z1, z2): 14 | z1 = torch.nn.functional.normalize(z1, dim=1) 15 | z2 = torch.nn.functional.normalize(z2, dim=1) 16 | 17 | logits = z1 @ z2.T 18 | logits /= self.temperature 19 | n = z2.shape[0] 20 | labels = torch.arange(0, n, dtype=torch.long).cuda() 21 | loss = torch.nn.functional.cross_entropy(logits, labels) 22 | return loss -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | setup(name='ClusterContrast', 5 | version='1.0.0', 6 | description='Cluster Contrast for Unsupervised Person Re-Identification', 7 | author='GuangYuan wang', 8 | author_email='yixuan.wgy@alibaba-inc.com', 9 | # url='', 10 | install_requires=[ 11 | 'numpy', 'torch', 'torchvision', 12 | 'six', 'h5py', 'Pillow', 'scipy', 13 | 'scikit-learn', 'metric-learn', 'faiss_gpu'], 14 | packages=find_packages(), 15 | keywords=[ 16 | 'Unsupervised Learning', 17 | 'Contrastive Learning', 18 | 'Object Re-identification' 19 | ]) 20 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/clustercontrast/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def to_numpy(tensor): 7 | if torch.is_tensor(tensor): 8 | return tensor.cpu().numpy() 9 | elif type(tensor).__module__ != 'numpy': 10 | raise ValueError("Cannot convert {} to numpy array" 11 | .format(type(tensor))) 12 | return tensor 13 | 14 | 15 | def to_torch(ndarray): 16 | if type(ndarray).__module__ == 'numpy': 17 | return torch.from_numpy(ndarray) 18 | elif not torch.is_tensor(ndarray): 19 | raise ValueError("Cannot convert {} to torch tensor" 20 | .format(type(ndarray))) 21 | return ndarray 22 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def to_numpy(tensor): 7 | if torch.is_tensor(tensor): 8 | return tensor.cpu().numpy() 9 | elif type(tensor).__module__ != 'numpy': 10 | raise ValueError("Cannot convert {} to numpy array" 11 | .format(type(tensor))) 12 | return tensor 13 | 14 | 15 | def to_torch(ndarray): 16 | if type(ndarray).__module__ == 'numpy': 17 | return torch.from_numpy(ndarray) 18 | elif not torch.is_tensor(ndarray): 19 | raise ValueError("Cannot convert {} to torch tensor" 20 | .format(type(ndarray))) 21 | return ndarray 22 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/clustercontrast/evaluation_metrics/classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from ..utils import to_torch 5 | 6 | 7 | def accuracy(output, target, topk=(1,)): 8 | with torch.no_grad(): 9 | output, target = to_torch(output), to_torch(target) 10 | maxk = max(topk) 11 | batch_size = target.size(0) 12 | 13 | _, pred = output.topk(maxk, 1, True, True) 14 | pred = pred.t() 15 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 16 | 17 | ret = [] 18 | for k in topk: 19 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) 20 | ret.append(correct_k.mul_(1. / batch_size)) 21 | return ret 22 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/evaluation_metrics/classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from ..utils import to_torch 5 | 6 | 7 | def accuracy(output, target, topk=(1,)): 8 | with torch.no_grad(): 9 | output, target = to_torch(output), to_torch(target) 10 | maxk = max(topk) 11 | batch_size = target.size(0) 12 | 13 | _, pred = output.topk(maxk, 1, True, True) 14 | pred = pred.t() 15 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 16 | 17 | ret = [] 18 | for k in topk: 19 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) 20 | ret.append(correct_k.mul_(1. / batch_size)) 21 | return ret 22 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/clustercontrast/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .base_dataset import BaseDataset, BaseImageDataset 4 | from .preprocessor import Preprocessor 5 | 6 | 7 | class IterLoader: 8 | def __init__(self, loader, length=None): 9 | self.loader = loader 10 | self.length = length 11 | self.iter = None 12 | 13 | def __len__(self): 14 | if self.length is not None: 15 | return self.length 16 | 17 | return len(self.loader) 18 | 19 | def new_epoch(self): 20 | self.iter = iter(self.loader) 21 | 22 | def next(self): 23 | try: 24 | return next(self.iter) 25 | except: 26 | self.iter = iter(self.loader) 27 | return next(self.iter) 28 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .base_dataset import BaseDataset, BaseImageDataset 4 | from .preprocessor import Preprocessor 5 | 6 | 7 | class IterLoader: 8 | def __init__(self, loader, length=None): 9 | self.loader = loader 10 | self.length = length 11 | self.iter = None 12 | 13 | def __len__(self): 14 | if self.length is not None: 15 | return self.length 16 | 17 | return len(self.loader) 18 | 19 | def new_epoch(self): 20 | self.iter = iter(self.loader) 21 | 22 | def next(self): 23 | try: 24 | return next(self.iter) 25 | except: 26 | self.iter = iter(self.loader) 27 | return next(self.iter) 28 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import logging 4 | import os 5 | import sys 6 | 7 | 8 | def setup_logger(name, save_dir, distributed_rank): 9 | logger = logging.getLogger(name) 10 | logger.setLevel(logging.DEBUG) 11 | # don't log results for the non-master process 12 | if distributed_rank > 0: 13 | return logger 14 | ch = logging.StreamHandler(stream=sys.stdout) 15 | ch.setLevel(logging.DEBUG) 16 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 17 | ch.setFormatter(formatter) 18 | logger.addHandler(ch) 19 | 20 | if save_dir: 21 | fh = logging.FileHandler(os.path.join(save_dir, "log.txt"), mode='w') 22 | fh.setLevel(logging.DEBUG) 23 | fh.setFormatter(formatter) 24 | logger.addHandler(fh) 25 | 26 | return logger 27 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/clustercontrast/utils/infomap_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class TextColors: 5 | HEADER = '\033[35m' 6 | OKBLUE = '\033[34m' 7 | OKGREEN = '\033[32m' 8 | WARNING = '\033[33m' 9 | FATAL = '\033[31m' 10 | ENDC = '\033[0m' 11 | BOLD = '\033[1m' 12 | UNDERLINE = '\033[4m' 13 | 14 | 15 | class Timer(): 16 | def __init__(self, name='task', verbose=True): 17 | self.name = name 18 | self.verbose = verbose 19 | 20 | def __enter__(self): 21 | self.start = time.time() 22 | return self 23 | 24 | def __exit__(self, exc_type, exc_val, exc_tb): 25 | if self.verbose: 26 | print('[Time] {} consumes {:.4f} s'.format( 27 | self.name, 28 | time.time() - self.start)) 29 | return exc_type is None -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/utils/infomap_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class TextColors: 5 | HEADER = '\033[35m' 6 | OKBLUE = '\033[34m' 7 | OKGREEN = '\033[32m' 8 | WARNING = '\033[33m' 9 | FATAL = '\033[31m' 10 | ENDC = '\033[0m' 11 | BOLD = '\033[1m' 12 | UNDERLINE = '\033[4m' 13 | 14 | 15 | class Timer(): 16 | def __init__(self, name='task', verbose=True): 17 | self.name = name 18 | self.verbose = verbose 19 | 20 | def __enter__(self): 21 | self.start = time.time() 22 | return self 23 | 24 | def __exit__(self, exc_type, exc_val, exc_tb): 25 | if self.verbose: 26 | print('[Time] {} consumes {:.4f} s'.format( 27 | self.name, 28 | time.time() - self.start)) 29 | return exc_type is None -------------------------------------------------------------------------------- /data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | from .cuhk03 import CUHK03 4 | from .dukemtmcreid import DukeMTMCreID 5 | from .market1501 import Market1501 6 | from .msmt17 import MSMT17 7 | from .veri import VeRi 8 | from .partial_ilids import PartialILIDS 9 | from .partial_reid import PartialREID 10 | from .dataset_loader import ImageDataset 11 | 12 | __factory = { 13 | 'market1501': Market1501, 14 | 'cuhk03': CUHK03, 15 | 'dukemtmc': DukeMTMCreID, 16 | 'msmt17': MSMT17, 17 | 'veri': VeRi, 18 | 'partial_reid' : PartialREID, 19 | 'partial_ilids' : PartialILIDS, 20 | } 21 | 22 | 23 | def get_names(): 24 | return __factory.keys() 25 | 26 | 27 | def init_dataset(name, *args, **kwargs): 28 | if name not in __factory.keys(): 29 | raise KeyError("Unknown datasets: {}".format(name)) 30 | return __factory[name](*args, **kwargs) 31 | -------------------------------------------------------------------------------- /utils/iotools.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import errno 4 | import json 5 | import os 6 | 7 | import os.path as osp 8 | 9 | 10 | def mkdir_if_missing(directory): 11 | if not osp.exists(directory): 12 | try: 13 | os.makedirs(directory) 14 | except OSError as e: 15 | if e.errno != errno.EEXIST: 16 | raise 17 | 18 | 19 | def check_isfile(path): 20 | isfile = osp.isfile(path) 21 | if not isfile: 22 | print("=> Warning: no file found at '{}' (ignored)".format(path)) 23 | return isfile 24 | 25 | 26 | def read_json(fpath): 27 | with open(fpath, 'r') as f: 28 | obj = json.load(f) 29 | return obj 30 | 31 | 32 | def write_json(obj, fpath): 33 | mkdir_if_missing(osp.dirname(fpath)) 34 | with open(fpath, 'w') as f: 35 | json.dump(obj, f, indent=4, separators=(',', ': ')) 36 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/models/kmeans.py: -------------------------------------------------------------------------------- 1 | # Written by Yixiao Ge 2 | 3 | import warnings 4 | 5 | import faiss 6 | import torch 7 | 8 | from ..utils import to_numpy, to_torch 9 | 10 | __all__ = ["label_generator_kmeans"] 11 | 12 | 13 | @torch.no_grad() 14 | def label_generator_kmeans(features, num_classes=500, cuda=True): 15 | 16 | assert num_classes, "num_classes for kmeans is null" 17 | 18 | # k-means cluster by faiss 19 | cluster = faiss.Kmeans( 20 | features.size(-1), num_classes, niter=300, verbose=True, gpu=cuda 21 | ) 22 | 23 | cluster.train(to_numpy(features)) 24 | 25 | _, labels = cluster.index.search(to_numpy(features), 1) 26 | labels = labels.reshape(-1) 27 | 28 | centers = to_torch(cluster.centroids).float() 29 | # labels = to_torch(labels).long() 30 | 31 | # k-means does not have outlier points 32 | assert not (-1 in labels) 33 | 34 | return labels, centers, num_classes, None 35 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/clustercontrast/models/kmeans.py: -------------------------------------------------------------------------------- 1 | # Written by Yixiao Ge 2 | 3 | import warnings 4 | 5 | import faiss 6 | import torch 7 | 8 | from ..utils import to_numpy, to_torch 9 | 10 | __all__ = ["label_generator_kmeans"] 11 | 12 | 13 | @torch.no_grad() 14 | def label_generator_kmeans(features, num_classes=500, cuda=True): 15 | 16 | assert num_classes, "num_classes for kmeans is null" 17 | 18 | # k-means cluster by faiss 19 | cluster = faiss.Kmeans( 20 | features.size(-1), num_classes, niter=300, verbose=True, gpu=cuda 21 | ) 22 | 23 | cluster.train(to_numpy(features)) 24 | 25 | _, labels = cluster.index.search(to_numpy(features), 1) 26 | labels = labels.reshape(-1) 27 | 28 | centers = to_torch(cluster.centroids).float() 29 | # labels = to_torch(labels).long() 30 | 31 | # k-means does not have outlier points 32 | assert not (-1 in labels) 33 | 34 | return labels, centers, num_classes, None 35 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/Animal-Re-ID-main/Tools/ATRW_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import shutil 4 | 5 | json_file_path = "gt_test_plain.json" 6 | image_folder_path = "" 7 | output_folder_path = "" 8 | 9 | os.makedirs(output_folder_path, exist_ok=True) 10 | 11 | with open(json_file_path, "r") as file: 12 | data = json.load(file) 13 | 14 | for item in data: 15 | entityid = int(item["entityid"]) 16 | imgid = int(item["imgid"]) 17 | query = item["query"] 18 | if query == "multi": 19 | img_files = [file_name for file_name in os.listdir(image_folder_path) if file_name.endswith(".jpg")] 20 | matching_files = [file_name for file_name in img_files if int(os.path.splitext(file_name)[0]) == imgid] 21 | 22 | if len(matching_files) > 0: 23 | old_file_name = os.path.join(image_folder_path, matching_files[0]) 24 | new_file_name = os.path.join(output_folder_path, f"{entityid}_{imgid}.jpg") 25 | 26 | shutil.copy2(old_file_name, new_file_name) -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/clustercontrast/utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 mangye16 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/clustercontrast/utils/data/preprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import os.path as osp 4 | from torch.utils.data import DataLoader, Dataset 5 | import numpy as np 6 | import random 7 | import math 8 | from PIL import Image 9 | 10 | 11 | class Preprocessor(Dataset): 12 | def __init__(self, dataset, root=None, transform=None): 13 | super(Preprocessor, self).__init__() 14 | self.dataset = dataset 15 | self.root = root 16 | self.transform = transform 17 | 18 | def __len__(self): 19 | return len(self.dataset) 20 | 21 | def __getitem__(self, indices): 22 | return self._get_single_item(indices) 23 | 24 | def _get_single_item(self, index): 25 | fname, pid, camid = self.dataset[index] 26 | fpath = fname 27 | if self.root is not None: 28 | fpath = osp.join(self.root, fname) 29 | 30 | img = Image.open(fpath).convert('RGB') 31 | 32 | if self.transform is not None: 33 | img = self.transform(img) 34 | 35 | return img, fname, pid, camid, index 36 | -------------------------------------------------------------------------------- /configs/AGW_baseline.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/cgv841/.cache/torch/checkpoints/resnet50-19c8e357.pth' 4 | CENTER_LOSS: 'on' 5 | CENTER_FEAT_DIM: 2048 6 | NAME: 'resnet50_nl' 7 | WEIGHT_REGULARIZED_TRIPLET: 'on' 8 | GENERALIZED_MEAN_POOL: 'on' 9 | 10 | INPUT: 11 | IMG_SIZE: [256, 128] 12 | PROB: 0.5 # random horizontal flip 13 | RE_PROB: 0.5 # random erasing 14 | PADDING: 10 15 | 16 | DATASETS: 17 | NAMES: ('market1501') 18 | 19 | DATALOADER: 20 | PK_SAMPLER: 'on' 21 | NUM_INSTANCE: 4 22 | NUM_WORKERS: 8 23 | 24 | SOLVER: 25 | OPTIMIZER_NAME: 'Adam' 26 | MAX_EPOCHS: 120 27 | BASE_LR: 0.00035 28 | 29 | CENTER_LR: 0.5 30 | CENTER_LOSS_WEIGHT: 0.0005 31 | 32 | WEIGHT_DECAY: 0.0005 33 | IMS_PER_BATCH: 64 34 | 35 | STEPS: [40, 70] 36 | GAMMA: 0.1 37 | 38 | WARMUP_FACTOR: 0.01 39 | WARMUP_ITERS: 10 40 | WARMUP_METHOD: 'linear' 41 | 42 | CHECKPOINT_PERIOD: 40 43 | LOG_PERIOD: 20 44 | EVAL_PERIOD: 40 45 | 46 | TEST: 47 | IMS_PER_BATCH: 128 48 | RE_RANKING: 'off' 49 | WEIGHT: "path" 50 | FEAT_NORM: 'on' 51 | EVALUATE_ONLY: 'off' 52 | PARTIAL_REID: 'off' 53 | 54 | OUTPUT_DIR: "./log/market1501/Experiment-AGW-baseline" 55 | 56 | 57 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/solver/scheduler_factory.py: -------------------------------------------------------------------------------- 1 | """ Scheduler Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from .cosine_lr import CosineLRScheduler 5 | 6 | 7 | def create_scheduler(cfg, optimizer): 8 | num_epochs = cfg.SOLVER.MAX_EPOCHS 9 | # type 1 10 | # lr_min = 0.01 * cfg.SOLVER.BASE_LR 11 | # warmup_lr_init = 0.001 * cfg.SOLVER.BASE_LR 12 | # type 2 13 | lr_min = 0.002 * cfg.SOLVER.BASE_LR 14 | warmup_lr_init = 0.01 * cfg.SOLVER.BASE_LR 15 | # type 3 16 | # lr_min = 0.001 * cfg.SOLVER.BASE_LR 17 | # warmup_lr_init = 0.01 * cfg.SOLVER.BASE_LR 18 | 19 | warmup_t = cfg.SOLVER.WARMUP_EPOCHS 20 | noise_range = None 21 | 22 | lr_scheduler = CosineLRScheduler( 23 | optimizer, 24 | t_initial=num_epochs, 25 | lr_min=lr_min, 26 | t_mul= 1., 27 | decay_rate=0.1, 28 | warmup_lr_init=warmup_lr_init, 29 | warmup_t=warmup_t, 30 | cycle_limit=1, 31 | t_in_epochs=True, 32 | noise_range_t=noise_range, 33 | noise_pct= 0.67, 34 | noise_std= 1., 35 | noise_seed=42, 36 | ) 37 | 38 | return lr_scheduler 39 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/solver/scheduler_factory.py: -------------------------------------------------------------------------------- 1 | """ Scheduler Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from .cosine_lr import CosineLRScheduler 5 | 6 | 7 | def create_scheduler(cfg, optimizer): 8 | num_epochs = cfg.SOLVER.MAX_EPOCHS 9 | # type 1 10 | # lr_min = 0.01 * cfg.SOLVER.BASE_LR 11 | # warmup_lr_init = 0.001 * cfg.SOLVER.BASE_LR 12 | # type 2 13 | lr_min = 0.002 * cfg.SOLVER.BASE_LR 14 | warmup_lr_init = 0.01 * cfg.SOLVER.BASE_LR 15 | # type 3 16 | # lr_min = 0.001 * cfg.SOLVER.BASE_LR 17 | # warmup_lr_init = 0.01 * cfg.SOLVER.BASE_LR 18 | 19 | warmup_t = cfg.SOLVER.WARMUP_EPOCHS 20 | noise_range = None 21 | 22 | lr_scheduler = CosineLRScheduler( 23 | optimizer, 24 | t_initial=num_epochs, 25 | lr_min=lr_min, 26 | t_mul= 1., 27 | decay_rate=0.1, 28 | warmup_lr_init=warmup_lr_init, 29 | warmup_t=warmup_t, 30 | cycle_limit=1, 31 | t_in_epochs=True, 32 | noise_range_t=noise_range, 33 | noise_pct= 0.67, 34 | noise_std= 1., 35 | noise_seed=42, 36 | ) 37 | 38 | return lr_scheduler 39 | -------------------------------------------------------------------------------- /data/datasets/dataset_loader.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | 4 | import os.path as osp 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | 9 | def read_image(img_path): 10 | """Keep reading image until succeed. 11 | This can avoid IOError incurred by heavy IO process.""" 12 | got_img = False 13 | if not osp.exists(img_path): 14 | raise IOError("{} does not exist".format(img_path)) 15 | while not got_img: 16 | try: 17 | img = Image.open(img_path).convert('RGB') 18 | got_img = True 19 | except IOError: 20 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 21 | pass 22 | return img 23 | 24 | 25 | class ImageDataset(Dataset): 26 | """Image Person ReID Dataset""" 27 | 28 | def __init__(self, dataset, transform=None): 29 | self.dataset = dataset 30 | self.transform = transform 31 | 32 | def __len__(self): 33 | return len(self.dataset) 34 | 35 | def __getitem__(self, index): 36 | img_path, pid, camid = self.dataset[index] 37 | img = read_image(img_path) 38 | 39 | if self.transform is not None: 40 | img = self.transform(img) 41 | 42 | return img, pid, camid, img_path 43 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/README.md: -------------------------------------------------------------------------------- 1 | # IJCV 2024: Transformer for Object Re-Identification: A Survey 2 | 3 | ## Highlights 4 | 5 | - An in-depth analysis of Transformer's strengths, highlighting its impact across four key Re-ID directions: image/video-based, limited data/annotations, cross-modal, and special scenarios. 6 | 7 | - A new Transformer-based unsupervised baseline, UntransReID, achieving state-of-the-art performance on both single/cross modal Re-ID. 8 | 9 | - A unified experimental standard for animal Re-ID, designed to address its unique challenges and evaluate the potential of Transformer-based approaches. 10 | 11 | 12 | ### Citation 13 | 14 | Please kindly cite this paper in your publications if it helps your research: 15 | ``` 16 | @article{ye2024transformer, 17 | title={Transformer for Object Re-Identification: A Survey}, 18 | author={Ye, Mang and Chen, Shuoyi and Li, Chenyue and Zheng, Wei-Shi and Crandall, David and Du, Bo}, 19 | journal={arXiv preprint arXiv:2401.06960}, 20 | year={2024} 21 | } 22 | ``` 23 | # Acknowledgements 24 | Our implementation is mainly based on [cluster-contrast-reid](https://github.com/alibaba/cluster-contrast-reid) and [TransReID-SSL](https://github.com/damo-cv/TransReID-SSL) open source codebases. We sincerely thank the authors for their excellent work. 25 | The pre-trained model can be downloaded from [TransReID-SSL](https://github.com/damo-cv/TransReID-SSL). 26 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/clustercontrast/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import warnings 3 | 4 | from .market1501 import Market1501 5 | from .msmt17 import MSMT17 6 | from .dukemtmcreid import DukeMTMCreID 7 | 8 | 9 | 10 | __factory = { 11 | 'market1501': Market1501, 12 | 'msmt17': MSMT17, 13 | 'dukemtmcreid': DukeMTMCreID 14 | } 15 | 16 | 17 | def names(): 18 | return sorted(__factory.keys()) 19 | 20 | 21 | def create(name, root, *args, **kwargs): 22 | """ 23 | Create a dataset instance. 24 | 25 | Parameters 26 | ---------- 27 | name : str 28 | The dataset name. 29 | root : str 30 | The path to the dataset directory. 31 | split_id : int, optional 32 | The index of data split. Default: 0 33 | num_val : int or float, optional 34 | When int, it means the number of validation identities. When float, 35 | it means the proportion of validation to all the trainval. Default: 100 36 | download : bool, optional 37 | If True, will download the dataset. Default: False 38 | """ 39 | if name not in __factory: 40 | raise KeyError("Unknown dataset:", name) 41 | return __factory[name](root, *args, **kwargs) 42 | 43 | 44 | def get_dataset(name, root, *args, **kwargs): 45 | warnings.warn("get_dataset is deprecated. Use create instead.") 46 | return create(name, root, *args, **kwargs) 47 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import warnings 3 | 4 | from .sysu_all import sysu_all 5 | from .sysu_ir import sysu_ir 6 | from .sysu_rgb import sysu_rgb 7 | from .regdb_ir import regdb_ir 8 | from .regdb_rgb import regdb_rgb 9 | __factory = { 10 | 'sysu_all': sysu_all, 11 | 'sysu_ir':sysu_ir, 12 | 'sysu_rgb':sysu_rgb, 13 | 'regdb_ir':regdb_ir, 14 | 'regdb_rgb':regdb_rgb 15 | } 16 | 17 | 18 | def names(): 19 | return sorted(__factory.keys()) 20 | 21 | 22 | def create(name, root,trial=0, *args, **kwargs): 23 | """ 24 | Create a dataset instance. 25 | 26 | Parameters 27 | ---------- 28 | name : str 29 | The dataset name. 30 | root : str 31 | The path to the dataset directory. 32 | split_id : int, optional 33 | The index of data split. Default: 0 34 | num_val : int or float, optional 35 | When int, it means the number of validation identities. When float, 36 | it means the proportion of validation to all the trainval. Default: 100 37 | download : bool, optional 38 | If True, will download the dataset. Default: False 39 | """ 40 | if name not in __factory: 41 | raise KeyError("Unknown dataset:", name) 42 | return __factory[name](root, trial=trial, *args, **kwargs) 43 | 44 | 45 | def get_dataset(name, root, *args, **kwargs): 46 | warnings.warn("get_dataset is deprecated. Use create instead.") 47 | return create(name, root, *args, **kwargs) 48 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/vit_small_ics_384.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_PATH: '/home/chenshuoyi/.cache/torch/checkpoints/vit_base_ics_cfs_lup.pth' 3 | PRETRAIN_HW_RATIO: 2 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | # DEVICE_ID: ('3') 10 | TRANSFORMER_TYPE: 'vit_small_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | STEM_CONV: True # False for vanilla ViT-S 13 | GEM_POOLING: False 14 | # DIST_TRAIN: True 15 | 16 | INPUT: 17 | SIZE_TRAIN: [288, 144] #[384, 128] 18 | SIZE_TEST: [288, 144] 19 | PROB: 0.5 # random horizontal flip 20 | RE_PROB: 0.5 # random erasing 21 | PADDING: 10 22 | PIXEL_MEAN: [0.5, 0.5, 0.5] 23 | PIXEL_STD: [0.5, 0.5, 0.5] 24 | 25 | # DATASETS: 26 | # NAMES: ('market1501') 27 | # ROOT_DIR: ('/home/michuan.lh/datasets') 28 | 29 | # DATALOADER: 30 | # SAMPLER: 'softmax_triplet' 31 | # NUM_INSTANCE: 4 32 | # NUM_WORKERS: 8 33 | 34 | SOLVER: 35 | OPTIMIZER_NAME: 'SGD' 36 | MAX_EPOCHS: 120 37 | BASE_LR: 0.0004 38 | WARMUP_EPOCHS: 20 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'cosine' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 20 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | # OUTPUT_DIR: '../../log/transreid/market/vit_base_ics_cfs_lup_384' 58 | -------------------------------------------------------------------------------- /video-reid-AWG/README.md: -------------------------------------------------------------------------------- 1 | # AGW Baseline for Video Person ReID 2 | 3 | This repository contains PyTorch implementations of AGW Baseline for video-based person reID. 4 | The code is mainly based on [Video-Person-ReID](https://github.com/jiyanggao/Video-Person-ReID). 5 | 6 | ## Quick Start 7 | 8 | ### 1. Prepare dataset 9 | ### Dataset 10 | 11 | 1. Create a directory named `mars/` under `data/`. 12 | 2. Download dataset to `data/mars/` from http://www.liangzheng.com.cn/Project/project_mars.html. 13 | 3. Extract `bbox_train.zip` and `bbox_test.zip`. 14 | 4. Download split information from https://github.com/liangzheng06/MARS-evaluation/tree/master/info and put `info/` in `data/mars` (we want to follow the standard split in [8]). The data structure would look like: 15 | ``` 16 | mars/ 17 | bbox_test/ 18 | bbox_train/ 19 | info/ 20 | ``` 21 | ### 2. Train 22 | 23 | To train a AGW+ model on MARS with GPU device 0, run similarly: 24 | ``` 25 | CUDA_VISIBLE_DEVICES=0 python ./main_video_person_reid.py --arch AGW_Plus_Baseline \ 26 | --train-dataset mars --test-dataset mars --save-dir ./mars_agw_plus 27 | ``` 28 | 29 | ## Citation 30 | 31 | Please kindly cite this paper in your publications if it helps your research: 32 | ``` 33 | @article{arxiv20reidsurvey, 34 | title={Deep Learning for Person Re-identification: A Survey and Outlook}, 35 | author={Ye, Mang and Shen, Jianbing and Lin, Gaojie and Xiang, Tao and Shao, Ling and Hoi, Steven C. H.}, 36 | journal={arXiv preprint arXiv:2001.04193}, 37 | year={2020}, 38 | } 39 | ``` 40 | 41 | Contact: mangye16@gmail.com 42 | -------------------------------------------------------------------------------- /video-reid-AWG/samplers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | import numpy as np 4 | import torch 5 | 6 | class RandomIdentitySampler(torch.utils.data.Sampler): 7 | # class RandomIdentitySampler(torch.utils.data.sampler.Sampler): 8 | """ 9 | Randomly sample N identities, then for each identity, 10 | randomly sample K instances, therefore batch size is N*K. 11 | 12 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py. 13 | 14 | Args: 15 | data_source (Dataset): dataset to sample from. 16 | num_instances (int): number of instances per identity. 17 | """ 18 | def __init__(self, data_source, num_instances=4): 19 | self.data_source = data_source 20 | self.num_instances = num_instances 21 | self.index_dic = defaultdict(list) 22 | for index, (_, pid, _) in enumerate(data_source): 23 | self.index_dic[pid].append(index) 24 | self.pids = list(self.index_dic.keys()) 25 | self.num_identities = len(self.pids) 26 | 27 | def __iter__(self): 28 | indices = torch.randperm(self.num_identities) 29 | ret = [] 30 | for i in indices: 31 | pid = self.pids[i] 32 | t = self.index_dic[pid] 33 | replace = False if len(t) >= self.num_instances else True 34 | t = np.random.choice(t, size=self.num_instances, replace=replace) 35 | ret.extend(t) 36 | return iter(ret) 37 | 38 | def __len__(self): 39 | return self.num_identities * self.num_instances 40 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/vit_base_ics_288.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_PATH: '/home/chenshuoyi/.cache/torch/checkpoints/vit_base_ics_cfs_lup.pth' 3 | PRETRAIN_HW_RATIO: 2 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | # DEVICE_ID: ('3') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | STEM_CONV: True # False for vanilla ViT-S 13 | GEM_POOLING: False 14 | # DIST_TRAIN: True 15 | SIE_COE: 3.0 16 | INPUT: 17 | SIZE_TRAIN: [288, 144] #[384, 128] #[288, 144] # 18 | SIZE_TEST: [288, 144] # [384, 128] # [288, 144] # 19 | PROB: 0.5 # random horizontal flip 20 | RE_PROB: 0.5 # random erasing 21 | PADDING: 10 22 | PIXEL_MEAN: [0.5, 0.5, 0.5] 23 | PIXEL_STD: [0.5, 0.5, 0.5] 24 | 25 | # DATASETS: 26 | # NAMES: ('market1501') 27 | # ROOT_DIR: ('/home/michuan.lh/datasets') 28 | 29 | # DATALOADER: 30 | # SAMPLER: 'softmax_triplet' 31 | # NUM_INSTANCE: 4 32 | # NUM_WORKERS: 8 33 | 34 | SOLVER: 35 | OPTIMIZER_NAME: 'SGD' 36 | MAX_EPOCHS: 120 37 | BASE_LR: 0.0004 38 | WARMUP_EPOCHS: 20 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'cosine' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 20 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | # OUTPUT_DIR: '../../log/transreid/market/vit_base_ics_cfs_lup_384' 58 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/vit_base_ics_384.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_PATH: '/home/chenshuoyi/.cache/torch/checkpoints/vit_base_ics_cfs_lup.pth' 3 | PRETRAIN_HW_RATIO: 2 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | # DEVICE_ID: ('3') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | STEM_CONV: True # False for vanilla ViT-S 13 | GEM_POOLING: False 14 | # DIST_TRAIN: True 15 | SIE_COE: 3.0 16 | INPUT: 17 | SIZE_TRAIN: [384, 128] #[384, 128] #[288, 144] # 18 | SIZE_TEST: [384, 128] # [384, 128] # [288, 144] # 19 | PROB: 0.5 # random horizontal flip 20 | RE_PROB: 0.5 # random erasing 21 | PADDING: 10 22 | PIXEL_MEAN: [0.5, 0.5, 0.5] 23 | PIXEL_STD: [0.5, 0.5, 0.5] 24 | 25 | # DATASETS: 26 | # NAMES: ('market1501') 27 | # ROOT_DIR: ('/home/michuan.lh/datasets') 28 | 29 | # DATALOADER: 30 | # SAMPLER: 'softmax_triplet' 31 | # NUM_INSTANCE: 4 32 | # NUM_WORKERS: 8 33 | 34 | SOLVER: 35 | OPTIMIZER_NAME: 'SGD' 36 | MAX_EPOCHS: 120 37 | BASE_LR: 0.0004 38 | WARMUP_EPOCHS: 20 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'cosine' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 20 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | # OUTPUT_DIR: '../../log/transreid/market/vit_base_ics_cfs_lup_384' 58 | -------------------------------------------------------------------------------- /modeling/layer/gem_pool.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class GeneralizedMeanPooling(nn.Module): 8 | r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes. 9 | The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)` 10 | - At p = infinity, one gets Max Pooling 11 | - At p = 1, one gets Average Pooling 12 | The output is of size H x W, for any input size. 13 | The number of output features is equal to the number of input planes. 14 | Args: 15 | output_size: the target output size of the image of the form H x W. 16 | Can be a tuple (H, W) or a single H for a square image H x H 17 | H and W can be either a ``int``, or ``None`` which means the size will 18 | be the same as that of the input. 19 | """ 20 | 21 | def __init__(self, norm, output_size=1, eps=1e-6): 22 | super(GeneralizedMeanPooling, self).__init__() 23 | assert norm > 0 24 | self.p = float(norm) 25 | self.output_size = output_size 26 | self.eps = eps 27 | 28 | def forward(self, x): 29 | x = x.clamp(min=self.eps).pow(self.p) 30 | return torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p) 31 | 32 | def __repr__(self): 33 | return self.__class__.__name__ + '(' \ 34 | + str(self.p) + ', ' \ 35 | + 'output_size=' + str(self.output_size) + ')' 36 | 37 | 38 | class GeneralizedMeanPoolingP(GeneralizedMeanPooling): 39 | """ Same, but norm is trainable 40 | """ 41 | def __init__(self, norm=3, output_size=1, eps=1e-6): 42 | super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps) 43 | self.p = nn.Parameter(torch.ones(1) * norm) -------------------------------------------------------------------------------- /video-reid-AWG/models/gem_pool.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class GeneralizedMeanPooling(nn.Module): 8 | r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes. 9 | The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)` 10 | - At p = infinity, one gets Max Pooling 11 | - At p = 1, one gets Average Pooling 12 | The output is of size H x W, for any input size. 13 | The number of output features is equal to the number of input planes. 14 | Args: 15 | output_size: the target output size of the image of the form H x W. 16 | Can be a tuple (H, W) or a single H for a square image H x H 17 | H and W can be either a ``int``, or ``None`` which means the size will 18 | be the same as that of the input. 19 | """ 20 | 21 | def __init__(self, norm, output_size=1, eps=1e-6): 22 | super(GeneralizedMeanPooling, self).__init__() 23 | assert norm > 0 24 | self.p = float(norm) 25 | self.output_size = output_size 26 | self.eps = eps 27 | 28 | def forward(self, x): 29 | x = x.clamp(min=self.eps).pow(self.p) 30 | return torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p) 31 | 32 | def __repr__(self): 33 | return self.__class__.__name__ + '(' \ 34 | + str(self.p) + ', ' \ 35 | + 'output_size=' + str(self.output_size) + ')' 36 | 37 | 38 | class GeneralizedMeanPoolingP(GeneralizedMeanPooling): 39 | """ Same, but norm is trainable 40 | """ 41 | def __init__(self, norm=3, output_size=1, eps=1e-6): 42 | super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps) 43 | self.p = nn.Parameter(torch.ones(1) * norm) -------------------------------------------------------------------------------- /Transformer-ReID-Survey/Animal-Re-ID-main/Tools/split_train_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import shutil 4 | from collections import defaultdict 5 | 6 | def split_dataset_by_id(dataset_folder, train_folder, test_folder, train_ratio=0.7): 7 | 8 | image_files = [f for f in os.listdir(dataset_folder) if f.endswith('.jpg')] 9 | 10 | 11 | id_groups = defaultdict(list) 12 | for image_file in image_files: 13 | id_part = image_file.split('_')[0] 14 | id_groups[id_part].append(image_file) 15 | 16 | 17 | id_list = list(id_groups.keys()) 18 | random.shuffle(id_list) 19 | split_index = int(len(id_list) * train_ratio) 20 | train_ids = id_list[:split_index] 21 | test_ids = id_list[split_index:] 22 | 23 | 24 | train_images = set() 25 | test_images = set() 26 | 27 | for id_part in train_ids: 28 | train_images.update(id_groups[id_part]) 29 | 30 | for id_part in test_ids: 31 | test_images.update(id_groups[id_part]) 32 | 33 | 34 | for image_file in train_images: 35 | source_path = os.path.join(dataset_folder, image_file) 36 | destination_path = os.path.join(train_folder, image_file) 37 | shutil.move(source_path, destination_path) 38 | 39 | 40 | for image_file in test_images: 41 | source_path = os.path.join(dataset_folder, image_file) 42 | destination_path = os.path.join(test_folder, image_file) 43 | shutil.move(source_path, destination_path) 44 | 45 | if __name__ == "__main__": 46 | dataset_folder = "dataset" # dataset_folder 47 | train_folder = "train" # train_folder 48 | test_folder = "test" # test_folder 49 | 50 | os.makedirs(train_folder, exist_ok=True) 51 | os.makedirs(test_folder, exist_ok=True) 52 | train_ratio = 0.7 # train_set ratio 53 | split_dataset_by_id(dataset_folder, train_folder, test_folder, train_ratio) -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/clustercontrast/utils/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import numpy as np 3 | 4 | 5 | class BaseDataset(object): 6 | """ 7 | Base class of reid dataset 8 | """ 9 | 10 | def get_imagedata_info(self, data): 11 | pids, cams = [], [] 12 | for _, pid, camid in data: 13 | pids += [pid] 14 | cams += [camid] 15 | pids = set(pids) 16 | cams = set(cams) 17 | num_pids = len(pids) 18 | num_cams = len(cams) 19 | num_imgs = len(data) 20 | return num_pids, num_imgs, num_cams 21 | 22 | def print_dataset_statistics(self): 23 | raise NotImplementedError 24 | 25 | @property 26 | def images_dir(self): 27 | return None 28 | 29 | 30 | class BaseImageDataset(BaseDataset): 31 | """ 32 | Base class of image reid dataset 33 | """ 34 | 35 | def print_dataset_statistics(self, train, query, gallery): 36 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) 37 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) 38 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) 39 | 40 | print("Dataset statistics:") 41 | print(" ----------------------------------------") 42 | print(" subset | # ids | # images | # cameras") 43 | print(" ----------------------------------------") 44 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 45 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 46 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 47 | print(" ----------------------------------------") 48 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/utils/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import numpy as np 3 | 4 | 5 | class BaseDataset(object): 6 | """ 7 | Base class of reid dataset 8 | """ 9 | 10 | def get_imagedata_info(self, data): 11 | pids, cams = [], [] 12 | for _, pid, camid in data: 13 | pids += [pid] 14 | cams += [camid] 15 | pids = set(pids) 16 | cams = set(cams) 17 | num_pids = len(pids) 18 | num_cams = len(cams) 19 | num_imgs = len(data) 20 | return num_pids, num_imgs, num_cams 21 | 22 | def print_dataset_statistics(self): 23 | raise NotImplementedError 24 | 25 | @property 26 | def images_dir(self): 27 | return None 28 | 29 | 30 | class BaseImageDataset(BaseDataset): 31 | """ 32 | Base class of image reid dataset 33 | """ 34 | 35 | def print_dataset_statistics(self, train, query, gallery): 36 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) 37 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) 38 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) 39 | 40 | print("Dataset statistics:") 41 | print(" ----------------------------------------") 42 | print(" subset | # ids | # images | # cameras") 43 | print(" ----------------------------------------") 44 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 45 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 46 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 47 | print(" ----------------------------------------") 48 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/Animal-Re-ID-main/README.md: -------------------------------------------------------------------------------- 1 | # Wildlife-dataset 2 | | Dataset | Dataset Download | Annotations | 3 | | ---- | ---- | ---- | 4 | | iPanda-50 | [iPanda-50](https://github.com/iPandaDateset/iPanda-50) | [label link](https://drive.google.com/drive/folders/1jhk8qgyWMbL1Ykd_GlAjh2Vn2e_wMJmc?usp=sharing) | 5 | | ATRW(Amur Tiger Re-identification) | [ATRW](https://www.kaggle.com/datasets/quadeer15sh/amur-tiger-reidentification) | [label link](https://drive.google.com/drive/folders/1HlFVl5SPcKFWElo9cwq7eTyL1qwEeSSD?usp=sharing) | 6 | | ELPephants | [ELPephants](https://cornell.app.box.com/s/qh9clpzm5e2vgsjmcaca0kqasj2vt1f6.)| [label link](https://drive.google.com/drive/folders/ELPephants) | 7 | | SealID | [SealID](https://etsin.fairdata.fi/dataset/22b5191e-f24b-4457-93d3-95797c900fc0.)| [label link](https://drive.google.com/drive/folders/SealID) | 8 | | GZGC-G | [ GZGC-G](https://lila.science/datasets/great-zebra-giraffe-id.)| [label link](https://drive.google.com/drive/folders/GZGC-G) | 9 | | GZGC-Z | [ GZGC-Z](https://lila.science/datasets/great-zebra-giraffe-id.)| [label link](https://drive.google.com/drive/folders/GZGC-Z) | 10 | | LeopardID | [ LeopardID](https://lila.science/datasets/leopard-id-2022/.)| [label link](https://drive.google.com/drive/folders/LeopardID) | 11 | 12 | 13 | ``` 14 | Wildlife dataset 15 | ├─Tools 16 | | (tools to process datasets, changing datasets formats to Market1501 format) 17 | ├─README.md 18 | ``` 19 | 20 | # Tools 21 | Run **ATRW_test.py** to process the ATRW test set data format. You need to first download the ATRW test set and the gt_test_plain.json 22 | Run **market_change.py** to change the format to Market1501 format 23 | Run **split_train_test.py** to split the dataset to train_set and test_set 24 | 25 | 26 | ## Our project will be constantly updated [HERE](https://github.com/JigglypuffStitch/Animal-Re-ID). 27 | -------------------------------------------------------------------------------- /utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | from bisect import bisect_right 3 | import torch 4 | 5 | 6 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 7 | # separating MultiStepLR with WarmupLR 8 | # but the current LRScheduler design doesn't allow it 9 | 10 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 11 | def __init__( 12 | self, 13 | optimizer, 14 | milestones, 15 | gamma=0.1, 16 | warmup_factor=1.0 / 3, 17 | warmup_iters=500, 18 | warmup_method="linear", 19 | last_epoch=-1, 20 | ): 21 | if not list(milestones) == sorted(milestones): 22 | raise ValueError( 23 | "Milestones should be a list of" " increasing integers. Got {}", 24 | milestones, 25 | ) 26 | 27 | if warmup_method not in ("constant", "linear"): 28 | raise ValueError( 29 | "Only 'constant' or 'linear' warmup_method accepted" 30 | "got {}".format(warmup_method) 31 | ) 32 | self.milestones = milestones 33 | self.gamma = gamma 34 | self.warmup_factor = warmup_factor 35 | self.warmup_iters = warmup_iters 36 | self.warmup_method = warmup_method 37 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 38 | 39 | def get_lr(self): 40 | warmup_factor = 1 41 | if self.last_epoch < self.warmup_iters: 42 | if self.warmup_method == "constant": 43 | warmup_factor = self.warmup_factor 44 | elif self.warmup_method == "linear": 45 | alpha = self.last_epoch / self.warmup_iters 46 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 47 | return [ 48 | base_lr 49 | * warmup_factor 50 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 51 | for base_lr in self.base_lrs 52 | ] 53 | -------------------------------------------------------------------------------- /video-reid-AWG/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | from bisect import bisect_right 3 | import torch 4 | 5 | 6 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 7 | # separating MultiStepLR with WarmupLR 8 | # but the current LRScheduler design doesn't allow it 9 | 10 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 11 | def __init__( 12 | self, 13 | optimizer, 14 | milestones, 15 | gamma=0.1, 16 | warmup_factor=1.0 / 3, 17 | warmup_iters=500, 18 | warmup_method="linear", 19 | last_epoch=-1, 20 | ): 21 | if not list(milestones) == sorted(milestones): 22 | raise ValueError( 23 | "Milestones should be a list of" " increasing integers. Got {}", 24 | milestones, 25 | ) 26 | 27 | if warmup_method not in ("constant", "linear"): 28 | raise ValueError( 29 | "Only 'constant' or 'linear' warmup_method accepted" 30 | "got {}".format(warmup_method) 31 | ) 32 | self.milestones = milestones 33 | self.gamma = gamma 34 | self.warmup_factor = warmup_factor 35 | self.warmup_iters = warmup_iters 36 | self.warmup_method = warmup_method 37 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 38 | 39 | def get_lr(self): 40 | warmup_factor = 1 41 | if self.last_epoch < self.warmup_iters: 42 | if self.warmup_method == "constant": 43 | warmup_factor = self.warmup_factor 44 | elif self.warmup_method == "linear": 45 | alpha = self.last_epoch / self.warmup_iters 46 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 47 | return [ 48 | base_lr 49 | * warmup_factor 50 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 51 | for base_lr in self.base_lrs 52 | ] 53 | -------------------------------------------------------------------------------- /modeling/layer/non_local.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class Non_local(nn.Module): 8 | def __init__(self, in_channels, reduc_ratio=2): 9 | super(Non_local, self).__init__() 10 | 11 | self.in_channels = in_channels 12 | self.inter_channels = reduc_ratio//reduc_ratio 13 | 14 | self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 15 | kernel_size=1, stride=1, padding=0) 16 | 17 | self.W = nn.Sequential( 18 | nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, 19 | kernel_size=1, stride=1, padding=0), 20 | nn.BatchNorm2d(self.in_channels), 21 | ) 22 | nn.init.constant_(self.W[1].weight, 0.0) 23 | nn.init.constant_(self.W[1].bias, 0.0) 24 | 25 | self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 26 | kernel_size=1, stride=1, padding=0) 27 | 28 | self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 29 | kernel_size=1, stride=1, padding=0) 30 | 31 | def forward(self, x): 32 | ''' 33 | :param x: (b, t, h, w) 34 | :return x: (b, t, h, w) 35 | ''' 36 | batch_size = x.size(0) 37 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 38 | g_x = g_x.permute(0, 2, 1) 39 | 40 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 41 | theta_x = theta_x.permute(0, 2, 1) 42 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 43 | f = torch.matmul(theta_x, phi_x) 44 | N = f.size(-1) 45 | f_div_C = f / N 46 | 47 | y = torch.matmul(f_div_C, g_x) 48 | y = y.permute(0, 2, 1).contiguous() 49 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 50 | W_y = self.W(y) 51 | z = W_y + x 52 | return z 53 | -------------------------------------------------------------------------------- /video-reid-AWG/models/non_local.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class Non_local(nn.Module): 8 | def __init__(self, in_channels, reduc_ratio=2): 9 | super(Non_local, self).__init__() 10 | 11 | self.in_channels = in_channels 12 | self.inter_channels = reduc_ratio // reduc_ratio 13 | 14 | self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 15 | kernel_size=1, stride=1, padding=0) 16 | 17 | self.W = nn.Sequential( 18 | nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, 19 | kernel_size=1, stride=1, padding=0), 20 | nn.BatchNorm2d(self.in_channels), 21 | ) 22 | nn.init.constant_(self.W[1].weight, 0.0) 23 | nn.init.constant_(self.W[1].bias, 0.0) 24 | 25 | self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 26 | kernel_size=1, stride=1, padding=0) 27 | 28 | self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 29 | kernel_size=1, stride=1, padding=0) 30 | 31 | def forward(self, x): 32 | ''' 33 | :param x: (b, c, h, w) 34 | :return x: (b, c, h, w) 35 | ''' 36 | batch_size = x.size(0) 37 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 38 | g_x = g_x.permute(0, 2, 1) 39 | 40 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 41 | theta_x = theta_x.permute(0, 2, 1) 42 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 43 | f = torch.matmul(theta_x, phi_x) 44 | N = f.size(-1) 45 | f_div_C = f / N 46 | 47 | y = torch.matmul(f_div_C, g_x) 48 | y = y.permute(0, 2, 1).contiguous() 49 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 50 | W_y = self.W(y) 51 | z = W_y + x 52 | return z 53 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/clustercontrast/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os.path as osp 4 | import shutil 5 | 6 | import torch 7 | from torch.nn import Parameter 8 | 9 | from .osutils import mkdir_if_missing 10 | 11 | 12 | def read_json(fpath): 13 | with open(fpath, 'r') as f: 14 | obj = json.load(f) 15 | return obj 16 | 17 | 18 | def write_json(obj, fpath): 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | with open(fpath, 'w') as f: 21 | json.dump(obj, f, indent=4, separators=(',', ': ')) 22 | 23 | 24 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 25 | mkdir_if_missing(osp.dirname(fpath)) 26 | torch.save(state, fpath) 27 | if is_best: 28 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 29 | 30 | 31 | def load_checkpoint(fpath): 32 | if osp.isfile(fpath): 33 | # checkpoint = torch.load(fpath) 34 | checkpoint = torch.load(fpath, map_location=torch.device('cpu')) 35 | print("=> Loaded checkpoint '{}'".format(fpath)) 36 | return checkpoint 37 | else: 38 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 39 | 40 | 41 | def copy_state_dict(state_dict, model, strip=None): 42 | tgt_state = model.state_dict() 43 | copied_names = set() 44 | for name, param in state_dict.items(): 45 | if strip is not None and name.startswith(strip): 46 | name = name[len(strip):] 47 | if name not in tgt_state: 48 | continue 49 | if isinstance(param, Parameter): 50 | param = param.data 51 | if param.size() != tgt_state[name].size(): 52 | print('mismatch:', name, param.size(), tgt_state[name].size()) 53 | continue 54 | tgt_state[name].copy_(param) 55 | copied_names.add(name) 56 | 57 | missing = set(tgt_state.keys()) - copied_names 58 | if len(missing) > 0: 59 | print("missing keys in state_dict:", missing) 60 | 61 | return model 62 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/utils/data/preprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import os.path as osp 4 | from torch.utils.data import DataLoader, Dataset 5 | import numpy as np 6 | import random 7 | import math 8 | from PIL import Image 9 | 10 | 11 | class Preprocessor(Dataset): 12 | def __init__(self, dataset, root=None, transform=None): 13 | super(Preprocessor, self).__init__() 14 | self.dataset = dataset 15 | self.root = root 16 | self.transform = transform 17 | 18 | def __len__(self): 19 | return len(self.dataset) 20 | 21 | def __getitem__(self, indices): 22 | return self._get_single_item(indices) 23 | 24 | def _get_single_item(self, index): 25 | fname, pid, camid = self.dataset[index] 26 | fpath = fname 27 | if self.root is not None: 28 | fpath = osp.join(self.root, fname) 29 | 30 | img = Image.open(fpath).convert('RGB') 31 | 32 | if self.transform is not None: 33 | img = self.transform(img) 34 | 35 | return img, fname, pid, camid, index 36 | 37 | class Preprocessor_color(Dataset): 38 | def __init__(self, dataset, root=None, transform=None,transform1=None): 39 | super(Preprocessor_color, self).__init__() 40 | self.dataset = dataset 41 | self.root = root 42 | self.transform = transform 43 | self.transform1 = transform1 44 | def __len__(self): 45 | return len(self.dataset) 46 | 47 | def __getitem__(self, indices): 48 | return self._get_single_item(indices) 49 | 50 | def _get_single_item(self, index): 51 | fname, pid, camid = self.dataset[index] 52 | fpath = fname 53 | if self.root is not None: 54 | fpath = osp.join(self.root, fname) 55 | 56 | img_ori = Image.open(fpath).convert('RGB') 57 | 58 | if self.transform is not None: 59 | img = self.transform(img_ori) 60 | img1 = self.transform1(img_ori) 61 | return img, img1,fname, pid, camid, index 62 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os.path as osp 4 | import shutil 5 | 6 | import torch 7 | from torch.nn import Parameter 8 | 9 | from .osutils import mkdir_if_missing 10 | 11 | 12 | def read_json(fpath): 13 | with open(fpath, 'r') as f: 14 | obj = json.load(f) 15 | return obj 16 | 17 | 18 | def write_json(obj, fpath): 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | with open(fpath, 'w') as f: 21 | json.dump(obj, f, indent=4, separators=(',', ': ')) 22 | 23 | 24 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 25 | mkdir_if_missing(osp.dirname(fpath)) 26 | torch.save(state, fpath) 27 | if is_best: 28 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 29 | 30 | 31 | 32 | 33 | def load_checkpoint(fpath): 34 | if osp.isfile(fpath): 35 | # checkpoint = torch.load(fpath) 36 | checkpoint = torch.load(fpath, map_location=torch.device('cpu')) 37 | print("=> Loaded checkpoint '{}'".format(fpath)) 38 | return checkpoint 39 | else: 40 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 41 | 42 | 43 | def copy_state_dict(state_dict, model, strip=None): 44 | tgt_state = model.state_dict() 45 | copied_names = set() 46 | for name, param in state_dict.items(): 47 | if strip is not None and name.startswith(strip): 48 | name = name[len(strip):] 49 | if name not in tgt_state: 50 | continue 51 | if isinstance(param, Parameter): 52 | param = param.data 53 | if param.size() != tgt_state[name].size(): 54 | print('mismatch:', name, param.size(), tgt_state[name].size()) 55 | continue 56 | tgt_state[name].copy_(param) 57 | copied_names.add(name) 58 | 59 | missing = set(tgt_state.keys()) - copied_names 60 | if len(missing) > 0: 61 | print("missing keys in state_dict:", missing) 62 | 63 | return model 64 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from bisect import bisect_right 7 | import torch 8 | 9 | 10 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 11 | # separating MultiStepLR with WarmupLR 12 | # but the current LRScheduler design doesn't allow it 13 | 14 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 15 | def __init__( 16 | self, 17 | optimizer, 18 | milestones, # steps 19 | gamma=0.1, 20 | warmup_factor=1.0 / 3, 21 | warmup_iters=500, 22 | warmup_method="linear", 23 | last_epoch=-1, 24 | ): 25 | if not list(milestones) == sorted(milestones): 26 | raise ValueError( 27 | "Milestones should be a list of" " increasing integers. Got {}", 28 | milestones, 29 | ) 30 | 31 | if warmup_method not in ("constant", "linear"): 32 | raise ValueError( 33 | "Only 'constant' or 'linear' warmup_method accepted" 34 | "got {}".format(warmup_method) 35 | ) 36 | self.milestones = milestones 37 | self.gamma = gamma 38 | self.warmup_factor = warmup_factor 39 | self.warmup_iters = warmup_iters 40 | self.warmup_method = warmup_method 41 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 42 | 43 | def get_lr(self): 44 | warmup_factor = 1 45 | if self.last_epoch < self.warmup_iters: 46 | if self.warmup_method == "constant": 47 | warmup_factor = self.warmup_factor 48 | elif self.warmup_method == "linear": 49 | alpha = self.last_epoch / self.warmup_iters 50 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 51 | return [ 52 | base_lr 53 | * warmup_factor 54 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 55 | for base_lr in self.base_lrs 56 | ] 57 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from bisect import bisect_right 7 | import torch 8 | 9 | 10 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 11 | # separating MultiStepLR with WarmupLR 12 | # but the current LRScheduler design doesn't allow it 13 | 14 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 15 | def __init__( 16 | self, 17 | optimizer, 18 | milestones, # steps 19 | gamma=0.1, 20 | warmup_factor=1.0 / 3, 21 | warmup_iters=500, 22 | warmup_method="linear", 23 | last_epoch=-1, 24 | ): 25 | if not list(milestones) == sorted(milestones): 26 | raise ValueError( 27 | "Milestones should be a list of" " increasing integers. Got {}", 28 | milestones, 29 | ) 30 | 31 | if warmup_method not in ("constant", "linear"): 32 | raise ValueError( 33 | "Only 'constant' or 'linear' warmup_method accepted" 34 | "got {}".format(warmup_method) 35 | ) 36 | self.milestones = milestones 37 | self.gamma = gamma 38 | self.warmup_factor = warmup_factor 39 | self.warmup_iters = warmup_iters 40 | self.warmup_method = warmup_method 41 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 42 | 43 | def get_lr(self): 44 | warmup_factor = 1 45 | if self.last_epoch < self.warmup_iters: 46 | if self.warmup_method == "constant": 47 | warmup_factor = self.warmup_factor 48 | elif self.warmup_method == "linear": 49 | alpha = self.last_epoch / self.warmup_iters 50 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 51 | return [ 52 | base_lr 53 | * warmup_factor 54 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 55 | for base_lr in self.base_lrs 56 | ] 57 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/clustercontrast/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .resnet import * 4 | from .resnet_ibn import * 5 | from .vision_transformer import * 6 | __factory = { 7 | 'resnet18': resnet18, 8 | 'resnet34': resnet34, 9 | 'resnet50': resnet50, 10 | 'resnet101': resnet101, 11 | 'resnet152': resnet152, 12 | 'resnet_ibn50a': resnet_ibn50a, 13 | 'resnet_ibn101a': resnet_ibn101a, 14 | 'vit_small': vit_small, 15 | 'vit_base': vit_base, 16 | } 17 | 18 | 19 | def names(): 20 | return sorted(__factory.keys()) 21 | 22 | 23 | def create(name, *args, **kwargs): 24 | """ 25 | Create a model instance. 26 | 27 | Parameters 28 | ---------- 29 | name : str 30 | Model name. Can be one of 'inception', 'resnet18', 'resnet34', 31 | 'resnet50', 'resnet101', and 'resnet152'. 32 | pretrained : bool, optional 33 | Only applied for 'resnet*' models. If True, will use ImageNet pretrained 34 | model. Default: True 35 | cut_at_pooling : bool, optional 36 | If True, will cut the model before the last global pooling layer and 37 | ignore the remaining kwargs. Default: False 38 | num_features : int, optional 39 | If positive, will append a Linear layer after the global pooling layer, 40 | with this number of output units, followed by a BatchNorm layer. 41 | Otherwise these layers will not be appended. Default: 256 for 42 | 'inception', 0 for 'resnet*' 43 | norm : bool, optional 44 | If True, will normalize the feature to be unit L2-norm for each sample. 45 | Otherwise will append a ReLU layer after the above Linear layer if 46 | num_features > 0. Default: False 47 | dropout : float, optional 48 | If positive, will append a Dropout layer with this dropout rate. 49 | Default: 0 50 | num_classes : int, optional 51 | If positive, will append a Linear layer at the end as the classifier 52 | with this number of output units. Default: 0 53 | """ 54 | if name not in __factory: 55 | raise KeyError("Unknown model:", name) 56 | return __factory[name](*args, **kwargs) 57 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .resnet import * 4 | from .resnet_ibn import * 5 | from .resnet_ibn import * 6 | from .agw import * 7 | from .agw_part import * 8 | __factory = { 9 | 'resnet18': resnet18, 10 | 'resnet34': resnet34, 11 | 'resnet50': resnet50, 12 | 'resnet101': resnet101, 13 | 'resnet152': resnet152, 14 | 'resnet_ibn50a': resnet_ibn50a, 15 | 'resnet_ibn101a': resnet_ibn101a, 16 | 'agw':agw, 17 | 'agw_part':agw 18 | } 19 | 20 | 21 | def names(): 22 | return sorted(__factory.keys()) 23 | 24 | 25 | def create(name, *args, **kwargs): 26 | """ 27 | Create a model instance. 28 | 29 | Parameters 30 | ---------- 31 | name : str 32 | Model name. Can be one of 'inception', 'resnet18', 'resnet34', 33 | 'resnet50', 'resnet101', and 'resnet152'. 34 | pretrained : bool, optional 35 | Only applied for 'resnet*' models. If True, will use ImageNet pretrained 36 | model. Default: True 37 | cut_at_pooling : bool, optional 38 | If True, will cut the model before the last global pooling layer and 39 | ignore the remaining kwargs. Default: False 40 | num_features : int, optional 41 | If positive, will append a Linear layer after the global pooling layer, 42 | with this number of output units, followed by a BatchNorm layer. 43 | Otherwise these layers will not be appended. Default: 256 for 44 | 'inception', 0 for 'resnet*' 45 | norm : bool, optional 46 | If True, will normalize the feature to be unit L2-norm for each sample. 47 | Otherwise will append a ReLU layer after the above Linear layer if 48 | num_features > 0. Default: False 49 | dropout : float, optional 50 | If positive, will append a Dropout layer with this dropout rate. 51 | Default: 0 52 | num_classes : int, optional 53 | If positive, will append a Linear layer at the end as the classifier 54 | with this number of output units. Default: 0 55 | """ 56 | if name not in __factory: 57 | raise KeyError("Unknown model:", name) 58 | return __factory[name](*args, **kwargs) 59 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/clustercontrast/trainers.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | from .utils.meters import AverageMeter 4 | from .models.contrastive_loss import ConLoss 5 | import torch 6 | 7 | 8 | class ClusterContrastTrainer(object): 9 | def __init__(self, encoder, memory=None): 10 | super(ClusterContrastTrainer, self).__init__() 11 | self.encoder = encoder 12 | self.memory = memory 13 | 14 | 15 | def train(self, epoch, data_loader, optimizer, print_freq=10, train_iters=400): 16 | self.encoder.train() 17 | 18 | batch_time = AverageMeter() 19 | data_time = AverageMeter() 20 | 21 | losses = AverageMeter() 22 | con_loss = ConLoss() 23 | 24 | end = time.time() 25 | for i in range(train_iters): 26 | # load data 27 | inputs = data_loader.next() 28 | data_time.update(time.time() - end) 29 | 30 | # process inputs 31 | inputs, labels, indexes = self._parse_data(inputs) 32 | 33 | # forward 34 | f_out = self._forward(inputs, False) 35 | f_out_mask = self._forward(inputs, True) 36 | 37 | loss = self.memory(f_out, labels) + 0.1*con_loss(f_out_mask, f_out) 38 | 39 | optimizer.zero_grad() 40 | loss.backward() 41 | optimizer.step() 42 | 43 | 44 | losses.update(loss.item()) 45 | 46 | # print log 47 | batch_time.update(time.time() - end) 48 | end = time.time() 49 | 50 | if (i + 1) % print_freq == 0: 51 | print('Epoch: [{}][{}/{}]\t' 52 | 'Time {:.3f} ({:.3f})\t' 53 | 'Data {:.3f} ({:.3f})\t' 54 | 'Loss {:.3f} ({:.3f})' 55 | .format(epoch, i + 1, len(data_loader), 56 | batch_time.val, batch_time.avg, 57 | data_time.val, data_time.avg, 58 | losses.val, losses.avg)) 59 | 60 | def _parse_data(self, inputs): 61 | imgs, _, pids, _, indexes = inputs 62 | return imgs.cuda(), pids.cuda(), indexes.cuda() 63 | 64 | def _forward(self, inputs, mask): 65 | return self.encoder(inputs, mask) 66 | 67 | -------------------------------------------------------------------------------- /data/datasets/partial_ilids.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import glob 4 | import re 5 | 6 | import os.path as osp 7 | 8 | from .bases import BaseImageDataset 9 | 10 | 11 | class PartialILIDS(BaseImageDataset): 12 | 13 | dataset_dir = 'partial_ilids' 14 | 15 | def __init__(self, root='./toDataset', verbose=True, **kwargs): 16 | super(PartialILIDS, self).__init__() 17 | self.dataset_dir = osp.join(root, self.dataset_dir) 18 | self.query_dir = osp.join(self.dataset_dir, 'query') 19 | self.gallery_dir = osp.join(self.dataset_dir, 'gallery') 20 | 21 | self._check_before_run() 22 | 23 | query, gallery= self._process(self.query_dir, self.gallery_dir) 24 | 25 | if verbose: 26 | print("=> partial_ilids loaded") 27 | self.print_dataset_statistics(query, query, gallery) 28 | 29 | self.query = query 30 | self.gallery = gallery 31 | 32 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 33 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 34 | 35 | def _check_before_run(self): 36 | """Check if all files are available before going deeper""" 37 | if not osp.exists(self.dataset_dir): 38 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 39 | if not osp.exists(self.query_dir): 40 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 41 | if not osp.exists(self.gallery_dir): 42 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 43 | 44 | def _process(self, query_path, gallery_path): 45 | query_img_paths = glob.glob(osp.join(query_path, '*.jpg')) 46 | gallery_img_paths = glob.glob(osp.join(gallery_path, '*.jpg')) 47 | query_paths = [] 48 | pattern = re.compile(r'([-\d]+)_(\d*)') 49 | for img_path in query_img_paths: 50 | pid, camid = map(int, pattern.search(img_path).groups()) 51 | query_paths.append([img_path, pid, camid]) 52 | gallery_paths = [] 53 | for img_path in gallery_img_paths: 54 | pid, camid = map(int, pattern.search(img_path).groups()) 55 | gallery_paths.append([img_path, pid, camid]) 56 | return query_paths, gallery_paths -------------------------------------------------------------------------------- /data/datasets/partial_reid.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import glob 4 | import re 5 | 6 | import os.path as osp 7 | 8 | from .bases import BaseImageDataset 9 | 10 | 11 | class PartialREID(BaseImageDataset): 12 | 13 | dataset_dir = 'partial_reid' 14 | 15 | def __init__(self, root='./toDataset', verbose=True, **kwargs): 16 | super(PartialREID, self).__init__() 17 | self.dataset_dir = osp.join(root, self.dataset_dir) 18 | self.query_dir = osp.join(self.dataset_dir, 'partial_body_images') 19 | self.gallery_dir = osp.join(self.dataset_dir, 'whole_body_images') 20 | 21 | self._check_before_run() 22 | 23 | query, gallery = self._process(self.query_dir, self.gallery_dir) 24 | 25 | if verbose: 26 | print("=> partial_reid loaded") 27 | self.print_dataset_statistics(query, query, gallery) 28 | 29 | self.query = query 30 | self.gallery = gallery 31 | 32 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 33 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 34 | 35 | def _check_before_run(self): 36 | """Check if all files are available before going deeper""" 37 | if not osp.exists(self.dataset_dir): 38 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 39 | if not osp.exists(self.query_dir): 40 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 41 | if not osp.exists(self.gallery_dir): 42 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 43 | 44 | def _process(self, query_path, gallery_path): 45 | query_img_paths = glob.glob(osp.join(query_path, '*.jpg')) 46 | gallery_img_paths = glob.glob(osp.join(gallery_path, '*.jpg')) 47 | query_paths = [] 48 | pattern = re.compile(r'([-\d]+)_(\d*)') 49 | for img_path in query_img_paths: 50 | pid, camid = map(int, pattern.search(img_path).groups()) 51 | query_paths.append([img_path, pid, camid]) 52 | gallery_paths = [] 53 | for img_path in gallery_img_paths: 54 | pid, camid = map(int, pattern.search(img_path).groups()) 55 | gallery_paths.append([img_path, pid, camid]) 56 | return query_paths, gallery_paths -------------------------------------------------------------------------------- /video-reid-AWG/eval_metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import numpy as np 3 | import copy 4 | 5 | 6 | def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): 7 | num_q, num_g = distmat.shape 8 | if num_g < max_rank: 9 | max_rank = num_g 10 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 11 | indices = np.argsort(distmat, axis=1) 12 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 13 | 14 | # compute cmc curve for each query 15 | all_cmc = [] 16 | all_AP = [] 17 | all_INP = [] 18 | num_valid_q = 0. 19 | for q_idx in range(num_q): 20 | # get query pid and camid 21 | q_pid = q_pids[q_idx] 22 | q_camid = q_camids[q_idx] 23 | 24 | # remove gallery samples that have the same pid and camid with query 25 | order = indices[q_idx] 26 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 27 | keep = np.invert(remove) 28 | 29 | # compute cmc curve 30 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 31 | if not np.any(orig_cmc): 32 | # this condition is true when query identity does not appear in gallery 33 | continue 34 | 35 | cmc = orig_cmc.cumsum() 36 | 37 | pos_idx = np.where(orig_cmc == 1) 38 | max_pos_idx = np.max(pos_idx) 39 | inp = cmc[max_pos_idx]/ (max_pos_idx + 1.0) 40 | all_INP.append(inp) 41 | 42 | cmc[cmc > 1] = 1 43 | 44 | all_cmc.append(cmc[:max_rank]) 45 | num_valid_q += 1. 46 | 47 | # compute average precision 48 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 49 | num_rel = orig_cmc.sum() 50 | tmp_cmc = orig_cmc.cumsum() 51 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 52 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 53 | AP = tmp_cmc.sum() / num_rel 54 | all_AP.append(AP) 55 | 56 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 57 | 58 | all_cmc = np.asarray(all_cmc).astype(np.float32) 59 | all_cmc = all_cmc.sum(0) / num_valid_q 60 | mAP = np.mean(all_AP) 61 | mINP = np.mean(all_INP) 62 | 63 | return all_cmc, mAP, mINP 64 | 65 | 66 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/train_regdb.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -A chenjun3 3 | #SBATCH -p a100x4 4 | #SBATCH -N 1 5 | #SBATCH --ntasks=4 6 | #SBATCH --ntasks-per-node=4 7 | #SBATCH --cpus-per-task=4 8 | #SBATCH --gres=gpu:2 9 | #SBATCH -o regdb_2p_384_g.log 10 | module load nvidia/cuda/11.6 11 | 12 | for trial in 1 2 3 4 5 6 7 8 9 10 13 | do 14 | CUDA_VISIBLE_DEVICES=4,5,6,7 python regdb_train.py -b 256 -a agw -d regdb_rgb --iters 50 --momentum 0.95 --eps 0.6 --num-instances 16 --trial $trial 15 | done 16 | echo 'Done!' 17 | #cluster_contrast_camera_cmsub_andcmass_regdb.py 18 | 19 | # CUDA_VISIBLE_DEVICES=0,1,2,3 python cluster_contrast_camera_cmsub_andcmass_regdb.py -b 256 -a agw -d regdb_all --iters 200 --momentum 0.1 --eps 0.6 --num-instances 16 20 | # CUDA_VISIBLE_DEVICES=0,1,2,3 python cluster_contrast_camera_cmsub_s3mergecamera.py -b 256 -a agw -d sysu_all --iters 200 --momentum 0.1 --eps 0.6 --num-instances 16 21 | 22 | 23 | # CUDA_VISIBLE_DEVICES=0,1,2,3 python cluster_contrast_camera.py -b 256 -a agw -d sysu_all --iters 200 --momentum 0.1 --eps 0.6 --num-instances 16 24 | 25 | # CUDA_VISIBLE_DEVICES=0,1,2,3 python examples/cluster_contrast_train_usl_infomap.py -b 256 -a resnet50 -d market1501 --iters 200 --momentum 0.1 --eps 0.5 --k1 15 --k2 4 --num-instances 16 26 | 27 | # CUDA_VISIBLE_DEVICES=0,1,2,3 python examples/cluster_contrast_train_usl.py -b 256 -a resnet50 -d msmt17 --iters 400 --momentum 0.1 --eps 0.6 --num-instances 16 28 | # CUDA_VISIBLE_DEVICES=0,1,2,3 python examples/cluster_contrast_train_usl_infomap.py -b 256 -a resnet50 -d msmt17 --iters 400 --momentum 0.1 --eps 0.5 --k1 15 --k2 4 --num-instances 16 29 | 30 | 31 | # CUDA_VISIBLE_DEVICES=0,1,2,3 python examples/cluster_contrast_train_usl.py -b 256 -a resnet50 -d dukemtmcreid --iters 200 --momentum 0.1 --eps 0.6 --num-instances 16 32 | # CUDA_VISIBLE_DEVICES=0,1,2,3 python examples/cluster_contrast_train_usl_infomap.py -b 256 -a resnet50 -d dukemtmcreid --iters 200 --momentum 0.1 --eps 0.5 --k1 15 --k2 4 --num-instances 16 33 | 34 | 35 | # CUDA_VISIBLE_DEVICES=0,1,2,3 python examples/cluster_contrast_train_usl.py -b 256 -a resnet50 -d veri --iters 400 --momentum 0.1 --eps 0.6 --num-instances 16 --height 224 --width 224 36 | # CUDA_VISIBLE_DEVICES=0,1,2,3 python examples/cluster_contrast_train_usl_infomap.py -b 256 -a resnet50 -d veri --iters 400 --momentum 0.1 --eps 0.5 --k1 15 --k2 4 --num-instances 16 --height 224 --width 224 37 | 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .idea/ 131 | -------------------------------------------------------------------------------- /data/datasets/eval_reid.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import numpy as np 4 | 5 | 6 | def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): 7 | """Evaluation with market1501 metric 8 | Key: for each query identity, its gallery images from the same camera view are discarded. 9 | """ 10 | num_q, num_g = distmat.shape 11 | if num_g < max_rank: 12 | max_rank = num_g 13 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 14 | indices = np.argsort(distmat, axis=1) 15 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 16 | 17 | # compute cmc curve for each query 18 | all_cmc = [] 19 | all_AP = [] 20 | all_INP = [] 21 | num_valid_q = 0. # number of valid query 22 | for q_idx in range(num_q): 23 | # get query pid and camid 24 | q_pid = q_pids[q_idx] 25 | q_camid = q_camids[q_idx] 26 | 27 | # remove gallery samples that have the same pid and camid with query 28 | order = indices[q_idx] 29 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 30 | keep = np.invert(remove) 31 | 32 | # compute cmc curve 33 | # binary vector, positions with value 1 are correct matches 34 | orig_cmc = matches[q_idx][keep] 35 | if not np.any(orig_cmc): 36 | # this condition is true when query identity does not appear in gallery 37 | continue 38 | 39 | cmc = orig_cmc.cumsum() 40 | 41 | pos_idx = np.where(orig_cmc == 1) 42 | max_pos_idx = np.max(pos_idx) 43 | inp = cmc[max_pos_idx]/ (max_pos_idx + 1.0) 44 | all_INP.append(inp) 45 | 46 | cmc[cmc > 1] = 1 47 | 48 | all_cmc.append(cmc[:max_rank]) 49 | num_valid_q += 1. 50 | 51 | # compute average precision 52 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 53 | num_rel = orig_cmc.sum() 54 | tmp_cmc = orig_cmc.cumsum() 55 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 56 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 57 | AP = tmp_cmc.sum() / num_rel 58 | all_AP.append(AP) 59 | 60 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 61 | 62 | all_cmc = np.asarray(all_cmc).astype(np.float32) 63 | all_cmc = all_cmc.sum(0) / num_valid_q 64 | mAP = np.mean(all_AP) 65 | mINP = np.mean(all_INP) 66 | 67 | return all_cmc, mAP, mINP 68 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/solver/make_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def make_optimizer(cfg, model): 5 | params = [] 6 | for key, value in model.named_parameters(): 7 | if not value.requires_grad: 8 | continue 9 | lr = cfg.SOLVER.BASE_LR 10 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 11 | if "bias" in key: 12 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 13 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 14 | if cfg.SOLVER.LARGE_FC_LR: 15 | if "classifier" in key or "arcface" in key: 16 | lr = cfg.SOLVER.BASE_LR * 2 17 | print('Using two times learning rate for fc ') 18 | 19 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 20 | 21 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': 22 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) 23 | elif cfg.SOLVER.OPTIMIZER_NAME == 'AdamW': 24 | optimizer = torch.optim.AdamW(params, lr=cfg.SOLVER.BASE_LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 25 | else: 26 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) 27 | # optimizer_center = torch.optim.SGD(center_criterion.parameters(), lr=cfg.SOLVER.CENTER_LR) 28 | 29 | return optimizer 30 | # def make_optimizer(cfg, model, center_criterion): 31 | # params = [] 32 | # for key, value in model.named_parameters(): 33 | # if not value.requires_grad: 34 | # continue 35 | # lr = cfg.SOLVER.BASE_LR 36 | # weight_decay = cfg.SOLVER.WEIGHT_DECAY 37 | # if "bias" in key: 38 | # lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 39 | # weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 40 | # if cfg.SOLVER.LARGE_FC_LR: 41 | # if "classifier" in key or "arcface" in key: 42 | # lr = cfg.SOLVER.BASE_LR * 2 43 | # print('Using two times learning rate for fc ') 44 | 45 | # params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 46 | 47 | # if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': 48 | # optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) 49 | # elif cfg.SOLVER.OPTIMIZER_NAME == 'AdamW': 50 | # optimizer = torch.optim.AdamW(params, lr=cfg.SOLVER.BASE_LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 51 | # else: 52 | # optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) 53 | # optimizer_center = torch.optim.SGD(center_criterion.parameters(), lr=cfg.SOLVER.CENTER_LR) 54 | 55 | # return optimizer, optimizer_center 56 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/solver/make_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def make_optimizer(cfg, model): 5 | params = [] 6 | for key, value in model.named_parameters(): 7 | if not value.requires_grad: 8 | continue 9 | lr = cfg.SOLVER.BASE_LR 10 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 11 | if "bias" in key: 12 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 13 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 14 | if cfg.SOLVER.LARGE_FC_LR: 15 | if "classifier" in key or "arcface" in key: 16 | lr = cfg.SOLVER.BASE_LR * 2 17 | print('Using two times learning rate for fc ') 18 | 19 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 20 | 21 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': 22 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) 23 | elif cfg.SOLVER.OPTIMIZER_NAME == 'AdamW': 24 | optimizer = torch.optim.AdamW(params, lr=cfg.SOLVER.BASE_LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 25 | else: 26 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) 27 | # optimizer_center = torch.optim.SGD(center_criterion.parameters(), lr=cfg.SOLVER.CENTER_LR) 28 | 29 | return optimizer 30 | # def make_optimizer(cfg, model, center_criterion): 31 | # params = [] 32 | # for key, value in model.named_parameters(): 33 | # if not value.requires_grad: 34 | # continue 35 | # lr = cfg.SOLVER.BASE_LR 36 | # weight_decay = cfg.SOLVER.WEIGHT_DECAY 37 | # if "bias" in key: 38 | # lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 39 | # weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 40 | # if cfg.SOLVER.LARGE_FC_LR: 41 | # if "classifier" in key or "arcface" in key: 42 | # lr = cfg.SOLVER.BASE_LR * 2 43 | # print('Using two times learning rate for fc ') 44 | 45 | # params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 46 | 47 | # if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': 48 | # optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) 49 | # elif cfg.SOLVER.OPTIMIZER_NAME == 'AdamW': 50 | # optimizer = torch.optim.AdamW(params, lr=cfg.SOLVER.BASE_LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 51 | # else: 52 | # optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) 53 | # optimizer_center = torch.optim.SGD(center_criterion.parameters(), lr=cfg.SOLVER.CENTER_LR) 54 | 55 | # return optimizer, optimizer_center 56 | -------------------------------------------------------------------------------- /data/triplet_sampler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import copy 4 | import random 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | from torch.utils.data.sampler import Sampler 9 | 10 | class RandomIdentitySampler(Sampler): 11 | """ 12 | Randomly sample N identities, then for each identity, 13 | randomly sample K instances, therefore batch size is N*K. 14 | Args: 15 | - data_source (list): list of (img_path, pid, camid). 16 | - num_instances (int): number of instances per identity in a batch. 17 | - batch_size (int): number of examples in a batch. 18 | """ 19 | 20 | def __init__(self, data_source, batch_size, num_instances): 21 | self.data_source = data_source 22 | self.batch_size = batch_size 23 | self.num_instances = num_instances 24 | self.num_pids_per_batch = self.batch_size // self.num_instances 25 | self.index_dic = defaultdict(list) 26 | for index, (_, pid, _) in enumerate(self.data_source): 27 | self.index_dic[pid].append(index) 28 | self.pids = list(self.index_dic.keys()) 29 | 30 | # estimate number of examples in an epoch 31 | self.length = 0 32 | for pid in self.pids: 33 | idxs = self.index_dic[pid] 34 | num = len(idxs) 35 | if num < self.num_instances: 36 | num = self.num_instances 37 | self.length += num - num % self.num_instances 38 | 39 | def __iter__(self): 40 | batch_idxs_dict = defaultdict(list) 41 | 42 | for pid in self.pids: 43 | idxs = copy.deepcopy(self.index_dic[pid]) 44 | if len(idxs) < self.num_instances: 45 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 46 | random.shuffle(idxs) 47 | batch_idxs = [] 48 | for idx in idxs: 49 | batch_idxs.append(idx) 50 | if len(batch_idxs) == self.num_instances: 51 | batch_idxs_dict[pid].append(batch_idxs) 52 | batch_idxs = [] 53 | 54 | avai_pids = copy.deepcopy(self.pids) 55 | final_idxs = [] 56 | 57 | while len(avai_pids) >= self.num_pids_per_batch: 58 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 59 | for pid in selected_pids: 60 | batch_idxs = batch_idxs_dict[pid].pop(0) 61 | final_idxs.extend(batch_idxs) 62 | if len(batch_idxs_dict[pid]) == 0: 63 | avai_pids.remove(pid) 64 | 65 | self.length = len(final_idxs) 66 | return iter(final_idxs) 67 | 68 | def __len__(self): 69 | return self.length 70 | -------------------------------------------------------------------------------- /data/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | from .datasets import init_dataset, ImageDataset 7 | from .triplet_sampler import RandomIdentitySampler 8 | from .transforms import build_transforms 9 | 10 | 11 | def train_collate_fn(batch): 12 | imgs, pids, _, _, = zip(*batch) 13 | pids = torch.tensor(pids, dtype=torch.int64) 14 | return torch.stack(imgs, dim=0), pids 15 | 16 | 17 | def val_collate_fn(batch): 18 | imgs, pids, camids, _ = zip(*batch) 19 | return torch.stack(imgs, dim=0), pids, camids 20 | 21 | 22 | def make_data_loader(cfg): 23 | transforms = build_transforms(cfg) 24 | dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR) 25 | 26 | num_classes = dataset.num_train_pids 27 | num_workers = cfg.DATALOADER.NUM_WORKERS 28 | train_set = ImageDataset(dataset.train, transforms['train']) 29 | data_loader={} 30 | if cfg.DATALOADER.PK_SAMPLER == 'on': 31 | data_loader['train'] = DataLoader( 32 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, 33 | sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE), 34 | num_workers=num_workers, collate_fn=train_collate_fn 35 | ) 36 | else: 37 | data_loader['train'] = DataLoader( 38 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers, 39 | collate_fn=train_collate_fn 40 | ) 41 | 42 | if cfg.TEST.PARTIAL_REID == 'off': 43 | eval_set = ImageDataset(dataset.query + dataset.gallery, transforms['eval']) 44 | data_loader['eval'] = DataLoader( 45 | eval_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers, 46 | collate_fn=val_collate_fn 47 | ) 48 | else: 49 | dataset_reid = init_dataset('partial_reid', root=cfg.DATASETS.ROOT_DIR) 50 | dataset_ilids = init_dataset('partial_ilids', root=cfg.DATASETS.ROOT_DIR) 51 | eval_set_reid = ImageDataset(dataset_reid.query + dataset_reid.gallery, transforms['eval']) 52 | data_loader['eval_reid'] = DataLoader( 53 | eval_set_reid, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers, 54 | collate_fn=val_collate_fn 55 | ) 56 | eval_set_ilids = ImageDataset(dataset_ilids.query + dataset_ilids.gallery, transforms['eval']) 57 | data_loader['eval_ilids'] = DataLoader( 58 | eval_set_ilids, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers, 59 | collate_fn=val_collate_fn 60 | ) 61 | return data_loader, len(dataset.query), num_classes 62 | -------------------------------------------------------------------------------- /modeling/layer/center_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class CenterLoss(nn.Module): 8 | """Center loss. 9 | 10 | Reference: 11 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 12 | 13 | Args: 14 | num_classes (int): number of classes. 15 | feat_dim (int): feature dimension. 16 | """ 17 | 18 | def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True): 19 | super(CenterLoss, self).__init__() 20 | self.num_classes = num_classes 21 | self.feat_dim = feat_dim 22 | self.use_gpu = use_gpu 23 | 24 | if self.use_gpu: 25 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 26 | else: 27 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 28 | 29 | def forward(self, x, labels): 30 | """ 31 | Args: 32 | x: feature matrix with shape (batch_size, feat_dim). 33 | labels: ground truth labels with shape (num_classes). 34 | """ 35 | assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)" 36 | 37 | batch_size = x.size(0) 38 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 39 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 40 | distmat.addmm_(1, -2, x, self.centers.t()) 41 | 42 | classes = torch.arange(self.num_classes).long() 43 | if self.use_gpu: classes = classes.cuda() 44 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 45 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 46 | 47 | dist = distmat * mask.float() 48 | loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size 49 | #dist = [] 50 | #for i in range(batch_size): 51 | # value = distmat[i][mask[i]] 52 | # value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 53 | # dist.append(value) 54 | #dist = torch.cat(dist) 55 | #loss = dist.mean() 56 | return loss 57 | 58 | 59 | if __name__ == '__main__': 60 | use_gpu = False 61 | center_loss = CenterLoss(use_gpu=use_gpu) 62 | features = torch.rand(16, 2048) 63 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).long() 64 | if use_gpu: 65 | features = torch.rand(16, 2048).cuda() 66 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).cuda() 67 | 68 | loss = center_loss(features, targets) 69 | print(loss) 70 | -------------------------------------------------------------------------------- /tools/test.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import logging 3 | 4 | import torch 5 | import torch.nn as nn 6 | from ignite.engine import Engine 7 | 8 | from utils.reid_metric import r1_mAP_mINP, r1_mAP_mINP_reranking 9 | 10 | 11 | def create_supervised_evaluator(model, metrics, device=None): 12 | """ 13 | Factory function for creating an evaluator for supervised models 14 | 15 | Args: 16 | model (`torch.nn.Module`): the model to evaluate 17 | metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics 18 | device (str, optional): device type specification (default: None). 19 | Applies to both model and batches. 20 | Returns: 21 | Engine: an evaluator engine with supervised inference function 22 | """ 23 | 24 | def _inference(engine, batch): 25 | model.eval() 26 | with torch.no_grad(): 27 | data, pids, camids = batch 28 | data = data.to(device) if torch.cuda.device_count() >= 1 else data 29 | feat = model(data) 30 | return feat, pids, camids 31 | 32 | engine = Engine(_inference) 33 | 34 | for name, metric in metrics.items(): 35 | metric.attach(engine, name) 36 | 37 | return engine 38 | 39 | 40 | def do_test( 41 | cfg, 42 | model, 43 | data_loader, 44 | num_query 45 | ): 46 | device = cfg.MODEL.DEVICE 47 | 48 | logger = logging.getLogger("reid_baseline") 49 | logger.info("Enter inferencing") 50 | if cfg.TEST.RE_RANKING == 'off': 51 | print("Create evaluator") 52 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP_mINP': r1_mAP_mINP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, 53 | device=device) 54 | elif cfg.TEST.RE_RANKING == 'on': 55 | print("Create evaluator for reranking") 56 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP_mINP': r1_mAP_mINP_reranking(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, 57 | device=device) 58 | else: 59 | print("Unsupported re_ranking config. Only support for on or off, but got {}.".format(cfg.TEST.RE_RANKING)) 60 | 61 | evaluator.run(data_loader['eval']) 62 | cmc, mAP, mINP = evaluator.state.metrics['r1_mAP_mINP'] 63 | logger.info('Validation Results') 64 | logger.info("mINP: {:.1%}".format(mINP)) 65 | logger.info("mAP: {:.1%}".format(mAP)) 66 | if cfg.TEST.PARTIAL_REID == 'off': 67 | for r in [1, 5, 10]: 68 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 69 | else: 70 | for r in [1, 3, 5, 10]: 71 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 72 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/clustercontrast/models/dsbn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # Domain-specific BatchNorm 5 | 6 | class DSBN2d(nn.Module): 7 | def __init__(self, planes): 8 | super(DSBN2d, self).__init__() 9 | self.num_features = planes 10 | self.BN_S = nn.BatchNorm2d(planes) 11 | self.BN_T = nn.BatchNorm2d(planes) 12 | 13 | def forward(self, x): 14 | if (not self.training): 15 | return self.BN_T(x) 16 | 17 | bs = x.size(0) 18 | assert (bs%2==0) 19 | split = torch.split(x, int(bs/2), 0) 20 | out1 = self.BN_S(split[0].contiguous()) 21 | out2 = self.BN_T(split[1].contiguous()) 22 | out = torch.cat((out1, out2), 0) 23 | return out 24 | 25 | class DSBN1d(nn.Module): 26 | def __init__(self, planes): 27 | super(DSBN1d, self).__init__() 28 | self.num_features = planes 29 | self.BN_S = nn.BatchNorm1d(planes) 30 | self.BN_T = nn.BatchNorm1d(planes) 31 | 32 | def forward(self, x): 33 | if (not self.training): 34 | return self.BN_T(x) 35 | 36 | bs = x.size(0) 37 | assert (bs%2==0) 38 | split = torch.split(x, int(bs/2), 0) 39 | out1 = self.BN_S(split[0].contiguous()) 40 | out2 = self.BN_T(split[1].contiguous()) 41 | out = torch.cat((out1, out2), 0) 42 | return out 43 | 44 | def convert_dsbn(model): 45 | for _, (child_name, child) in enumerate(model.named_children()): 46 | assert(not next(model.parameters()).is_cuda) 47 | if isinstance(child, nn.BatchNorm2d): 48 | m = DSBN2d(child.num_features) 49 | m.BN_S.load_state_dict(child.state_dict()) 50 | m.BN_T.load_state_dict(child.state_dict()) 51 | setattr(model, child_name, m) 52 | elif isinstance(child, nn.BatchNorm1d): 53 | m = DSBN1d(child.num_features) 54 | m.BN_S.load_state_dict(child.state_dict()) 55 | m.BN_T.load_state_dict(child.state_dict()) 56 | setattr(model, child_name, m) 57 | else: 58 | convert_dsbn(child) 59 | 60 | def convert_bn(model, use_target=True): 61 | for _, (child_name, child) in enumerate(model.named_children()): 62 | assert(not next(model.parameters()).is_cuda) 63 | if isinstance(child, DSBN2d): 64 | m = nn.BatchNorm2d(child.num_features) 65 | if use_target: 66 | m.load_state_dict(child.BN_T.state_dict()) 67 | else: 68 | m.load_state_dict(child.BN_S.state_dict()) 69 | setattr(model, child_name, m) 70 | elif isinstance(child, DSBN1d): 71 | m = nn.BatchNorm1d(child.num_features) 72 | if use_target: 73 | m.load_state_dict(child.BN_T.state_dict()) 74 | else: 75 | m.load_state_dict(child.BN_S.state_dict()) 76 | setattr(model, child_name, m) 77 | else: 78 | convert_bn(child, use_target=use_target) 79 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/models/dsbn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # Domain-specific BatchNorm 5 | 6 | class DSBN2d(nn.Module): 7 | def __init__(self, planes): 8 | super(DSBN2d, self).__init__() 9 | self.num_features = planes 10 | self.BN_S = nn.BatchNorm2d(planes) 11 | self.BN_T = nn.BatchNorm2d(planes) 12 | 13 | def forward(self, x): 14 | if (not self.training): 15 | return self.BN_T(x) 16 | 17 | bs = x.size(0) 18 | assert (bs%2==0) 19 | split = torch.split(x, int(bs/2), 0) 20 | out1 = self.BN_S(split[0].contiguous()) 21 | out2 = self.BN_T(split[1].contiguous()) 22 | out = torch.cat((out1, out2), 0) 23 | return out 24 | 25 | class DSBN1d(nn.Module): 26 | def __init__(self, planes): 27 | super(DSBN1d, self).__init__() 28 | self.num_features = planes 29 | self.BN_S = nn.BatchNorm1d(planes) 30 | self.BN_T = nn.BatchNorm1d(planes) 31 | 32 | def forward(self, x): 33 | if (not self.training): 34 | return self.BN_T(x) 35 | 36 | bs = x.size(0) 37 | assert (bs%2==0) 38 | split = torch.split(x, int(bs/2), 0) 39 | out1 = self.BN_S(split[0].contiguous()) 40 | out2 = self.BN_T(split[1].contiguous()) 41 | out = torch.cat((out1, out2), 0) 42 | return out 43 | 44 | def convert_dsbn(model): 45 | for _, (child_name, child) in enumerate(model.named_children()): 46 | assert(not next(model.parameters()).is_cuda) 47 | if isinstance(child, nn.BatchNorm2d): 48 | m = DSBN2d(child.num_features) 49 | m.BN_S.load_state_dict(child.state_dict()) 50 | m.BN_T.load_state_dict(child.state_dict()) 51 | setattr(model, child_name, m) 52 | elif isinstance(child, nn.BatchNorm1d): 53 | m = DSBN1d(child.num_features) 54 | m.BN_S.load_state_dict(child.state_dict()) 55 | m.BN_T.load_state_dict(child.state_dict()) 56 | setattr(model, child_name, m) 57 | else: 58 | convert_dsbn(child) 59 | 60 | def convert_bn(model, use_target=True): 61 | for _, (child_name, child) in enumerate(model.named_children()): 62 | assert(not next(model.parameters()).is_cuda) 63 | if isinstance(child, DSBN2d): 64 | m = nn.BatchNorm2d(child.num_features) 65 | if use_target: 66 | m.load_state_dict(child.BN_T.state_dict()) 67 | else: 68 | m.load_state_dict(child.BN_S.state_dict()) 69 | setattr(model, child_name, m) 70 | elif isinstance(child, DSBN1d): 71 | m = nn.BatchNorm1d(child.num_features) 72 | if use_target: 73 | m.load_state_dict(child.BN_T.state_dict()) 74 | else: 75 | m.load_state_dict(child.BN_S.state_dict()) 76 | setattr(model, child_name, m) 77 | else: 78 | convert_bn(child, use_target=use_target) 79 | -------------------------------------------------------------------------------- /data/transforms.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import math 4 | import random 5 | import torchvision.transforms as T 6 | 7 | 8 | def build_transforms(cfg): 9 | normalize_transform = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD) 10 | transforms = {} 11 | if cfg.TEST.PARTIAL_REID == 'off': 12 | transforms['train'] = T.Compose([ 13 | T.Resize(cfg.INPUT.IMG_SIZE), 14 | T.RandomHorizontalFlip(p=cfg.INPUT.PROB), 15 | T.Pad(cfg.INPUT.PADDING), 16 | T.RandomCrop(cfg.INPUT.IMG_SIZE), 17 | T.ToTensor(), 18 | normalize_transform, 19 | RandomErasing(probability=cfg.INPUT.RE_PROB, mean=cfg.INPUT.PIXEL_MEAN) 20 | ]) 21 | else: 22 | transforms['train'] = T.Compose([ 23 | T.Resize(cfg.INPUT.IMG_SIZE), 24 | T.RandomHorizontalFlip(p=cfg.INPUT.PROB), 25 | T.RandomResizedCrop(size=256, scale=(0.5, 1.0), ratio=(1.0, 3.0)), 26 | T.Resize(cfg.INPUT.IMG_SIZE), 27 | T.ToTensor(), 28 | normalize_transform, 29 | RandomErasing(probability=cfg.INPUT.RE_PROB, mean=cfg.INPUT.PIXEL_MEAN) 30 | ]) 31 | 32 | transforms['eval'] = T.Compose([ 33 | T.Resize(cfg.INPUT.IMG_SIZE), 34 | T.ToTensor(), 35 | normalize_transform 36 | ]) 37 | 38 | return transforms 39 | 40 | 41 | class RandomErasing(object): 42 | """ Randomly selects a rectangle region in an image and erases its pixels. 43 | 'Random Erasing Data Augmentation' by Zhong et al. 44 | See https://arxiv.org/pdf/1708.04896.pdf 45 | Args: 46 | probability: The probability that the Random Erasing operation will be performed. 47 | sl: Minimum proportion of erased area against input image. 48 | sh: Maximum proportion of erased area against input image. 49 | r1: Minimum aspect ratio of erased area. 50 | mean: Erasing value. 51 | """ 52 | 53 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 54 | self.probability = probability 55 | self.mean = mean 56 | self.sl = sl 57 | self.sh = sh 58 | self.r1 = r1 59 | 60 | def __call__(self, img): 61 | 62 | if random.uniform(0, 1) >= self.probability: 63 | return img 64 | 65 | for attempt in range(100): 66 | area = img.size()[1] * img.size()[2] 67 | 68 | target_area = random.uniform(self.sl, self.sh) * area 69 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 70 | 71 | h = int(round(math.sqrt(target_area * aspect_ratio))) 72 | w = int(round(math.sqrt(target_area / aspect_ratio))) 73 | 74 | if w < img.size()[2] and h < img.size()[1]: 75 | x1 = random.randint(0, img.size()[1] - h) 76 | y1 = random.randint(0, img.size()[2] - w) 77 | if img.size()[0] == 3: 78 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 79 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 80 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 81 | else: 82 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 83 | return img 84 | 85 | return img 86 | -------------------------------------------------------------------------------- /data/datasets/market1501.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import glob 4 | import re 5 | 6 | import os.path as osp 7 | 8 | from .bases import BaseImageDataset 9 | 10 | 11 | class Market1501(BaseImageDataset): 12 | """ 13 | Market1501 14 | Reference: 15 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 16 | URL: http://www.liangzheng.org/Project/project_reid.html 17 | 18 | Dataset statistics: 19 | # identities: 1501 (+1 for background) 20 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 21 | """ 22 | dataset_dir = 'market1501' 23 | 24 | def __init__(self, root='./toDataset', verbose=True, **kwargs): 25 | super(Market1501, self).__init__() 26 | self.dataset_dir = osp.join(root, self.dataset_dir) 27 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 28 | self.query_dir = osp.join(self.dataset_dir, 'query') 29 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 30 | 31 | self._check_before_run() 32 | 33 | train = self._process_dir(self.train_dir, relabel=True) 34 | query = self._process_dir(self.query_dir, relabel=False) 35 | gallery = self._process_dir(self.gallery_dir, relabel=False) 36 | 37 | if verbose: 38 | print("=> Market1501 loaded") 39 | self.print_dataset_statistics(train, query, gallery) 40 | 41 | self.train = train 42 | self.query = query 43 | self.gallery = gallery 44 | 45 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 46 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 47 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 48 | 49 | def _check_before_run(self): 50 | """Check if all files are available before going deeper""" 51 | if not osp.exists(self.dataset_dir): 52 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 53 | if not osp.exists(self.train_dir): 54 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 55 | if not osp.exists(self.query_dir): 56 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 57 | if not osp.exists(self.gallery_dir): 58 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 59 | 60 | def _process_dir(self, dir_path, relabel=False): 61 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 62 | pattern = re.compile(r'([-\d]+)_c(\d)') 63 | 64 | pid_container = set() 65 | for img_path in img_paths: 66 | pid, _ = map(int, pattern.search(img_path).groups()) 67 | if pid == -1: continue # junk images are just ignored 68 | pid_container.add(pid) 69 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 70 | 71 | dataset = [] 72 | for img_path in img_paths: 73 | pid, camid = map(int, pattern.search(img_path).groups()) 74 | if pid == -1: continue # junk images are just ignored 75 | assert 0 <= pid <= 1501 # pid == 0 means background 76 | assert 1 <= camid <= 6 77 | camid -= 1 # index starts from 0 78 | if relabel: pid = pid2label[pid] 79 | dataset.append((img_path, pid, camid)) 80 | 81 | return dataset 82 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/clustercontrast/datasets/dukemtmcreid.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os.path as osp 3 | import re 4 | from ..utils.data import BaseImageDataset 5 | 6 | 7 | def process_dir(dir_path, relabel=False): 8 | img_paths = glob.glob(osp.join(dir_path, "*.jpg")) 9 | pattern = re.compile(r"([-\d]+)_c(\d)") 10 | 11 | # get all identities 12 | pid_container = set() 13 | for img_path in img_paths: 14 | pid, _ = map(int, pattern.search(img_path).groups()) 15 | if pid == -1: 16 | continue 17 | pid_container.add(pid) 18 | 19 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 20 | 21 | data = [] 22 | for img_path in img_paths: 23 | pid, camid = map(int, pattern.search(img_path).groups()) 24 | if (pid not in pid_container) or (pid == -1): 25 | continue 26 | 27 | assert 1 <= camid <= 8 28 | camid -= 1 29 | 30 | if relabel: 31 | pid = pid2label[pid] 32 | data.append((img_path, pid, camid)) 33 | 34 | return data 35 | 36 | 37 | class DukeMTMCreID(BaseImageDataset): 38 | 39 | """DukeMTMC-reID. 40 | Reference: 41 | - Ristani et al. Performance Measures and a Data Set for Multi-Target, 42 | Multi-Camera Tracking. ECCVW 2016. 43 | - Zheng et al. Unlabeled Samples Generated by GAN Improve the Person 44 | Re-identification Baseline in vitro. ICCV 2017. 45 | URL: ``_ 46 | 47 | Dataset statistics: 48 | - identities: 1404 (train + query). 49 | - images:16522 (train) + 2228 (query) + 17661 (gallery). 50 | - cameras: 8. 51 | """ 52 | 53 | dataset_dir = "DukeMTMC-reID" 54 | 55 | def __init__(self, root, verbose=True): 56 | super(DukeMTMCreID, self).__init__() 57 | self.root = osp.abspath(osp.expanduser(root)) 58 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 59 | 60 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 61 | self.query_dir = osp.join(self.dataset_dir, 'query') 62 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 63 | 64 | train = process_dir(dir_path=self.train_dir, relabel=True) 65 | query = process_dir(dir_path=self.query_dir, relabel=False) 66 | gallery = process_dir(dir_path=self.gallery_dir, relabel=False) 67 | 68 | self.train = train 69 | self.query = query 70 | self.gallery = gallery 71 | 72 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 73 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 74 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 75 | 76 | def _check_before_run(self): 77 | """Check if all files are available before going deeper""" 78 | if not osp.exists(self.dataset_dir): 79 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 80 | if not osp.exists(self.train_dir): 81 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 82 | if not osp.exists(self.query_dir): 83 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 84 | if not osp.exists(self.gallery_dir): 85 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 86 | -------------------------------------------------------------------------------- /video-reid-AWG/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | from PIL import Image 5 | import random 6 | import numpy as np 7 | 8 | import math 9 | class RandomErasing(object): 10 | """ Randomly selects a rectangle region in an image and erases its pixels. 11 | 'Random Erasing Data Augmentation' by Zhong et al. 12 | See https://arxiv.org/pdf/1708.04896.pdf 13 | Args: 14 | probability: The probability that the Random Erasing operation will be performed. 15 | sl: Minimum proportion of erased area against input image. 16 | sh: Maximum proportion of erased area against input image. 17 | r1: Minimum aspect ratio of erased area. 18 | mean: Erasing value. 19 | """ 20 | 21 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 22 | self.probability = probability 23 | self.mean = mean 24 | self.sl = sl 25 | self.sh = sh 26 | self.r1 = r1 27 | 28 | def __call__(self, img): 29 | 30 | if random.uniform(0, 1) >= self.probability: 31 | return img 32 | 33 | for attempt in range(100): 34 | area = img.size()[1] * img.size()[2] 35 | 36 | target_area = random.uniform(self.sl, self.sh) * area 37 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 38 | 39 | h = int(round(math.sqrt(target_area * aspect_ratio))) 40 | w = int(round(math.sqrt(target_area / aspect_ratio))) 41 | 42 | if w < img.size()[2] and h < img.size()[1]: 43 | x1 = random.randint(0, img.size()[1] - h) 44 | y1 = random.randint(0, img.size()[2] - w) 45 | if img.size()[0] == 3: 46 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 47 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 48 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 49 | else: 50 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 51 | return img 52 | 53 | return img 54 | 55 | class Random2DTranslation(object): 56 | """ 57 | With a probability, first increase image size to (1 + 1/8), and then perform random crop. 58 | 59 | Args: 60 | height (int): target height. 61 | width (int): target width. 62 | p (float): probability of performing this transformation. Default: 0.5. 63 | """ 64 | def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR): 65 | self.height = height 66 | self.width = width 67 | self.p = p 68 | self.interpolation = interpolation 69 | 70 | def __call__(self, img): 71 | """ 72 | Args: 73 | img (PIL Image): Image to be cropped. 74 | 75 | Returns: 76 | PIL Image: Cropped image. 77 | """ 78 | if random.random() < self.p: 79 | return img.resize((self.width, self.height), self.interpolation) 80 | new_width, new_height = int(round(self.width * 1.125)), int(round(self.height * 1.125)) 81 | resized_img = img.resize((new_width, new_height), self.interpolation) 82 | x_maxrange = new_width - self.width 83 | y_maxrange = new_height - self.height 84 | x1 = int(round(random.uniform(0, x_maxrange))) 85 | y1 = int(round(random.uniform(0, y_maxrange))) 86 | croped_img = resized_img.crop((x1, y1, x1 + self.width, y1 + self.height)) 87 | return croped_img 88 | 89 | if __name__ == '__main__': 90 | pass 91 | -------------------------------------------------------------------------------- /data/datasets/veri.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import glob 3 | import re 4 | 5 | import os.path as osp 6 | 7 | from .bases import BaseImageDataset 8 | 9 | 10 | class VeRi(BaseImageDataset): 11 | """ 12 | VeRi-776 13 | Reference: 14 | Liu, Xinchen, et al. "Large-scale vehicle re-identification in urban surveillance videos." ICME 2016. 15 | 16 | URL:https://vehiclereid.github.io/VeRi/ 17 | 18 | Dataset statistics: 19 | # identities: 776 20 | # images: 37778 (train) + 1678 (query) + 11579 (gallery) 21 | # cameras: 20 22 | """ 23 | 24 | dataset_dir = 'veri' 25 | 26 | def __init__(self, root='./toDataset', verbose=True, **kwargs): 27 | super(VeRi, self).__init__() 28 | self.dataset_dir = osp.join(root, self.dataset_dir) 29 | self.train_dir = osp.join(self.dataset_dir, 'image_train') 30 | self.query_dir = osp.join(self.dataset_dir, 'image_query') 31 | self.gallery_dir = osp.join(self.dataset_dir, 'image_test') 32 | 33 | self._check_before_run() 34 | 35 | train = self._process_dir(self.train_dir, relabel=True) 36 | query = self._process_dir(self.query_dir, relabel=False) 37 | gallery = self._process_dir(self.gallery_dir, relabel=False) 38 | 39 | if verbose: 40 | print("=> VeRi-776 loaded") 41 | self.print_dataset_statistics(train, query, gallery) 42 | 43 | self.train = train 44 | self.query = query 45 | self.gallery = gallery 46 | 47 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 48 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 49 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 50 | 51 | def _check_before_run(self): 52 | """Check if all files are available before going deeper""" 53 | if not osp.exists(self.dataset_dir): 54 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 55 | if not osp.exists(self.train_dir): 56 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 57 | if not osp.exists(self.query_dir): 58 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 59 | if not osp.exists(self.gallery_dir): 60 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 61 | 62 | def _process_dir(self, dir_path, relabel=False): 63 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 64 | pattern = re.compile(r'([-\d]+)_c(\d+)') 65 | 66 | pid_container = set() 67 | for img_path in img_paths: 68 | pid, _ = map(int, pattern.search(img_path).groups()) 69 | if pid == -1: continue # junk images are just ignored 70 | pid_container.add(pid) 71 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 72 | 73 | dataset = [] 74 | for img_path in img_paths: 75 | pid, camid = map(int, pattern.search(img_path).groups()) 76 | if pid == -1: continue # junk images are just ignored 77 | assert 0 <= pid <= 776 # pid == 0 means background 78 | assert 1 <= camid <= 20 79 | camid -= 1 # index starts from 0 80 | if relabel: pid = pid2label[pid] 81 | dataset.append((img_path, pid, camid)) 82 | 83 | return dataset 84 | 85 | -------------------------------------------------------------------------------- /utils/reid_metric.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import numpy as np 4 | import torch 5 | from ignite.metrics import Metric 6 | 7 | from data.datasets.eval_reid import eval_func 8 | from .re_ranking import re_ranking 9 | 10 | 11 | class r1_mAP_mINP(Metric): 12 | def __init__(self, num_query, max_rank=50, feat_norm='on'): 13 | super(r1_mAP_mINP, self).__init__() 14 | self.num_query = num_query 15 | self.max_rank = max_rank 16 | self.feat_norm = feat_norm 17 | 18 | def reset(self): 19 | self.feats = [] 20 | self.pids = [] 21 | self.camids = [] 22 | 23 | def update(self, output): 24 | feat, pid, camid = output 25 | self.feats.append(feat) 26 | self.pids.extend(np.asarray(pid)) 27 | self.camids.extend(np.asarray(camid)) 28 | 29 | def compute(self): 30 | feats = torch.cat(self.feats, dim=0) 31 | if self.feat_norm == 'on': 32 | print("The test feature is normalized") 33 | feats = torch.nn.functional.normalize(feats, dim=1, p=2) 34 | # query 35 | qf = feats[:self.num_query] 36 | q_pids = np.asarray(self.pids[:self.num_query]) 37 | q_camids = np.asarray(self.camids[:self.num_query]) 38 | # gallery 39 | gf = feats[self.num_query:] 40 | g_pids = np.asarray(self.pids[self.num_query:]) 41 | g_camids = np.asarray(self.camids[self.num_query:]) 42 | m, n = qf.shape[0], gf.shape[0] 43 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 44 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 45 | distmat.addmm_(1, -2, qf, gf.t()) 46 | distmat = distmat.cpu().numpy() 47 | cmc, mAP, mINP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids) 48 | 49 | return cmc, mAP, mINP 50 | 51 | 52 | class r1_mAP_mINP_reranking(Metric): 53 | def __init__(self, num_query, max_rank=50, feat_norm='on'): 54 | super(r1_mAP_mINP_reranking, self).__init__() 55 | self.num_query = num_query 56 | self.max_rank = max_rank 57 | self.feat_norm = feat_norm 58 | 59 | def reset(self): 60 | self.feats = [] 61 | self.pids = [] 62 | self.camids = [] 63 | 64 | def update(self, output): 65 | feat, pid, camid = output 66 | self.feats.append(feat) 67 | self.pids.extend(np.asarray(pid)) 68 | self.camids.extend(np.asarray(camid)) 69 | 70 | def compute(self): 71 | feats = torch.cat(self.feats, dim=0) 72 | if self.feat_norm == 'on': 73 | print("The test feature is normalized") 74 | feats = torch.nn.functional.normalize(feats, dim=1, p=2) 75 | 76 | # query 77 | qf = feats[:self.num_query] 78 | q_pids = np.asarray(self.pids[:self.num_query]) 79 | q_camids = np.asarray(self.camids[:self.num_query]) 80 | # gallery 81 | gf = feats[self.num_query:] 82 | g_pids = np.asarray(self.pids[self.num_query:]) 83 | g_camids = np.asarray(self.camids[self.num_query:]) 84 | # m, n = qf.shape[0], gf.shape[0] 85 | # distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 86 | # torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 87 | # distmat.addmm_(1, -2, qf, gf.t()) 88 | # distmat = distmat.cpu().numpy() 89 | print("Enter reranking") 90 | distmat = re_ranking(qf, gf, k1=20, k2=6, lambda_value=0.3) 91 | cmc, mAP, mINP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids) 92 | 93 | return cmc, mAP, mINP -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/clustercontrast/datasets/market1501.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | from ..utils.data import BaseImageDataset 6 | 7 | 8 | class Market1501(BaseImageDataset): 9 | """ 10 | Market1501 11 | Reference: 12 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 13 | URL: http://www.liangzheng.org/Project/project_reid.html 14 | 15 | Dataset statistics: 16 | # identities: 1501 (+1 for background) 17 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 18 | """ 19 | dataset_dir = 'Market-1501-v15.09.15' 20 | 21 | def __init__(self, root, verbose=True, **kwargs): 22 | super(Market1501, self).__init__() 23 | self.dataset_dir = '/data0/ReIDData/market1501/' 24 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 25 | self.query_dir = osp.join(self.dataset_dir, 'query') 26 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 27 | 28 | self._check_before_run() 29 | 30 | train = self._process_dir(self.train_dir, relabel=True) 31 | query = self._process_dir(self.query_dir, relabel=False) 32 | gallery = self._process_dir(self.gallery_dir, relabel=False) 33 | 34 | if verbose: 35 | print("=> Market1501 loaded") 36 | self.print_dataset_statistics(train, query, gallery) 37 | 38 | self.train = train 39 | self.query = query 40 | self.gallery = gallery 41 | 42 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 43 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 44 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 45 | 46 | def _check_before_run(self): 47 | """Check if all files are available before going deeper""" 48 | if not osp.exists(self.dataset_dir): 49 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 50 | if not osp.exists(self.train_dir): 51 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 52 | if not osp.exists(self.query_dir): 53 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 54 | if not osp.exists(self.gallery_dir): 55 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 56 | 57 | def _process_dir(self, dir_path, relabel=False): 58 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 59 | pattern = re.compile(r'([-\d]+)_c(\d)') 60 | 61 | pid_container = set() 62 | for img_path in img_paths: 63 | pid, _ = map(int, pattern.search(img_path).groups()) 64 | if pid == -1: 65 | continue # junk images are just ignored 66 | pid_container.add(pid) 67 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 68 | 69 | dataset = [] 70 | for img_path in img_paths: 71 | pid, camid = map(int, pattern.search(img_path).groups()) 72 | if pid == -1: 73 | continue # junk images are just ignored 74 | assert 0 <= pid <= 1501 # pid == 0 means background 75 | assert 1 <= camid <= 6 76 | camid -= 1 # index starts from 0 77 | if relabel: 78 | pid = pid2label[pid] 79 | dataset.append((img_path, pid, camid)) 80 | 81 | return dataset 82 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/datasets/sysu_all.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | from ..utils.data import BaseImageDataset 6 | 7 | 8 | class sysu_all(BaseImageDataset): 9 | """ 10 | Market1501 11 | Reference: 12 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 13 | URL: http://www.liangzheng.org/Project/project_reid.html 14 | 15 | Dataset statistics: 16 | # identities: 1501 (+1 for background) 17 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 18 | """ 19 | dataset_dir = 'sysu/all_modify/' 20 | 21 | def __init__(self, root, verbose=True, **kwargs): 22 | super(sysu_all, self).__init__() 23 | root='/data0/ReIDData/' 24 | self.dataset_dir = osp.join(root, self.dataset_dir) 25 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 26 | self.query_dir = osp.join(self.dataset_dir, 'query') 27 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 28 | 29 | self._check_before_run() 30 | 31 | train = self._process_dir(self.train_dir, relabel=True) 32 | query = self._process_dir(self.query_dir, relabel=False) 33 | gallery = self._process_dir(self.gallery_dir, relabel=False) 34 | 35 | if verbose: 36 | print("=> Market1501 loaded") 37 | self.print_dataset_statistics(train, query, gallery) 38 | 39 | self.train = train 40 | self.query = query 41 | self.gallery = gallery 42 | 43 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 44 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 45 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 46 | 47 | def _check_before_run(self): 48 | """Check if all files are available before going deeper""" 49 | if not osp.exists(self.dataset_dir): 50 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 51 | if not osp.exists(self.train_dir): 52 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 53 | if not osp.exists(self.query_dir): 54 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 55 | if not osp.exists(self.gallery_dir): 56 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 57 | 58 | def _process_dir(self, dir_path, relabel=False): 59 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 60 | pattern = re.compile(r'([-\d]+)_c(\d)') 61 | 62 | pid_container = set() 63 | for img_path in img_paths: 64 | pid, _ = map(int, pattern.search(img_path).groups()) 65 | if pid == -1: 66 | continue # junk images are just ignored 67 | pid_container.add(pid) 68 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 69 | 70 | dataset = [] 71 | for img_path in img_paths: 72 | pid, camid = map(int, pattern.search(img_path).groups()) 73 | if pid == -1: 74 | continue # junk images are just ignored 75 | assert 0 <= pid <= 1501 # pid == 0 means background 76 | assert 1 <= camid <= 6 77 | camid -= 1 # index starts from 0 78 | if relabel: 79 | pid = pid2label[pid] 80 | dataset.append((img_path, pid, camid)) 81 | 82 | return dataset 83 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/clustercontrast/datasets/msmt17.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | 4 | import glob 5 | import re 6 | from ..utils.data import BaseImageDataset 7 | 8 | 9 | class MSMT17(BaseImageDataset): 10 | """ 11 | MSMT17 12 | Reference: 13 | Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018. 14 | URL: http://www.pkuvmc.com/publications/msmt17.html 15 | Dataset statistics: 16 | # identities: 4101 17 | # images: 32621 (train) + 11659 (query) + 82161 (gallery) 18 | # cameras: 15 19 | """ 20 | dataset_dir = 'MSMT17' 21 | 22 | def __init__(self, root='', verbose=True, pid_begin=0, **kwargs): 23 | super(MSMT17, self).__init__() 24 | self.pid_begin = pid_begin 25 | self.dataset_dir = osp.join(root, self.dataset_dir) 26 | self.dataset_dir = '/data0/ReIDData/MSMT17_V2/' 27 | self.train_dir = osp.join(self.dataset_dir, 'train') 28 | self.test_dir = osp.join(self.dataset_dir, 'test') 29 | self.list_train_path = osp.join(self.dataset_dir, 'list_train.txt') 30 | self.list_val_path = osp.join(self.dataset_dir, 'list_val.txt') 31 | self.list_query_path = osp.join(self.dataset_dir, 'list_query.txt') 32 | self.list_gallery_path = osp.join(self.dataset_dir, 'list_gallery.txt') 33 | 34 | self._check_before_run() 35 | train = self._process_dir(self.train_dir, self.list_train_path) 36 | val = self._process_dir(self.train_dir, self.list_val_path) 37 | train += val 38 | query = self._process_dir(self.test_dir, self.list_query_path) 39 | gallery = self._process_dir(self.test_dir, self.list_gallery_path) 40 | if verbose: 41 | print("=> MSMT17 loaded") 42 | self.print_dataset_statistics(train, query, gallery) 43 | 44 | self.train = train 45 | self.query = query 46 | self.gallery = gallery 47 | 48 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 49 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 50 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 51 | def _check_before_run(self): 52 | """Check if all files are available before going deeper""" 53 | if not osp.exists(self.dataset_dir): 54 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 55 | if not osp.exists(self.train_dir): 56 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 57 | if not osp.exists(self.test_dir): 58 | raise RuntimeError("'{}' is not available".format(self.test_dir)) 59 | 60 | def _process_dir(self, dir_path, list_path): 61 | with open(list_path, 'r') as txt: 62 | lines = txt.readlines() 63 | dataset = [] 64 | pid_container = set() 65 | cam_container = set() 66 | for img_idx, img_info in enumerate(lines): 67 | img_path, pid = img_info.split(' ') 68 | pid = int(pid) # no need to relabel 69 | camid = int(img_path.split('_')[2]) 70 | img_path = osp.join(dir_path, img_path) 71 | dataset.append((img_path, self.pid_begin +pid, camid-1)) 72 | pid_container.add(pid) 73 | cam_container.add(camid) 74 | # check if pid starts from 0 and increments with 1 75 | for idx, pid in enumerate(pid_container): 76 | assert idx == pid, "See code comment for explanation" 77 | return dataset 78 | -------------------------------------------------------------------------------- /data/datasets/msmt17.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import glob 3 | import re 4 | 5 | import os.path as osp 6 | 7 | from .bases import BaseImageDataset 8 | 9 | 10 | class MSMT17(BaseImageDataset): 11 | """ 12 | MSMT17 13 | 14 | Reference: 15 | Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018. 16 | 17 | URL: http://www.pkuvmc.com/publications/msmt17.html 18 | 19 | Dataset statistics: 20 | # identities: 4101 21 | # images: 32621 (train) + 11659 (query) + 82161 (gallery) 22 | # cameras: 15 23 | """ 24 | dataset_dir = 'msmt17' 25 | 26 | def __init__(self,root='./toDataset', verbose=True, **kwargs): 27 | super(MSMT17, self).__init__() 28 | self.dataset_dir = osp.join(root, self.dataset_dir) 29 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 30 | self.query_dir = osp.join(self.dataset_dir, 'query') 31 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 32 | 33 | self._check_before_run() 34 | 35 | train = self._process_dir(self.train_dir, relabel=True) 36 | query = self._process_dir(self.query_dir, relabel=False) 37 | gallery = self._process_dir(self.gallery_dir, relabel=False) 38 | 39 | if verbose: 40 | print("=> MSMT17 loaded") 41 | self.print_dataset_statistics(train, query, gallery) 42 | 43 | self.train = train 44 | self.query = query 45 | self.gallery = gallery 46 | 47 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 48 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 49 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 50 | 51 | def _check_before_run(self): 52 | """Check if all files are available before going deeper""" 53 | if not osp.exists(self.dataset_dir): 54 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 55 | if not osp.exists(self.train_dir): 56 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 57 | if not osp.exists(self.query_dir): 58 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 59 | if not osp.exists(self.query_dir): 60 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 61 | if not osp.exists(self.gallery_dir): 62 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 63 | 64 | def _process_dir(self, dir_path, relabel=False): 65 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 66 | pattern = re.compile(r'([-\d]+)_c(\d)') 67 | 68 | pid_container = set() 69 | for img_path in img_paths: 70 | pid, _ = map(int, pattern.search(img_path).groups()) 71 | if pid == -1: continue # junk images are just ignored 72 | pid_container.add(pid) 73 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 74 | 75 | dataset = [] 76 | for img_path in img_paths: 77 | pid, camid = map(int, pattern.search(img_path).groups()) 78 | # if pid == -1: continue # junk images are just ignored 79 | # assert 0 <= pid <= 1501 # pid == 0 means background 80 | # assert 1 <= camid <= 6 81 | camid -= 1 # index starts from 0 82 | if relabel: pid = pid2label[pid] 83 | dataset.append((img_path, pid, camid)) 84 | 85 | return dataset -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/clustercontrast/utils/faiss_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import faiss 4 | import torch 5 | 6 | def swig_ptr_from_FloatTensor(x): 7 | assert x.is_contiguous() 8 | assert x.dtype == torch.float32 9 | return faiss.cast_integer_to_float_ptr( 10 | x.storage().data_ptr() + x.storage_offset() * 4) 11 | 12 | def swig_ptr_from_LongTensor(x): 13 | assert x.is_contiguous() 14 | assert x.dtype == torch.int64, 'dtype=%s' % x.dtype 15 | 16 | return faiss.cast_integer_to_idx_t_ptr( 17 | x.storage().data_ptr() + x.storage_offset() * 8) 18 | 19 | def search_index_pytorch(index, x, k, D=None, I=None): 20 | """call the search function of an index with pytorch tensor I/O (CPU 21 | and GPU supported)""" 22 | assert x.is_contiguous() 23 | n, d = x.size() 24 | assert d == index.d 25 | 26 | if D is None: 27 | D = torch.empty((n, k), dtype=torch.float32, device=x.device) 28 | else: 29 | assert D.size() == (n, k) 30 | 31 | if I is None: 32 | I = torch.empty((n, k), dtype=torch.int64, device=x.device) 33 | else: 34 | assert I.size() == (n, k) 35 | torch.cuda.synchronize() 36 | xptr = swig_ptr_from_FloatTensor(x) 37 | Iptr = swig_ptr_from_LongTensor(I) 38 | Dptr = swig_ptr_from_FloatTensor(D) 39 | index.search_c(n, xptr, 40 | k, Dptr, Iptr) 41 | torch.cuda.synchronize() 42 | return D, I 43 | 44 | def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None, 45 | metric=faiss.METRIC_L2): 46 | assert xb.device == xq.device 47 | 48 | nq, d = xq.size() 49 | if xq.is_contiguous(): 50 | xq_row_major = True 51 | elif xq.t().is_contiguous(): 52 | xq = xq.t() # I initially wrote xq:t(), Lua is still haunting me :-) 53 | xq_row_major = False 54 | else: 55 | raise TypeError('matrix should be row or column-major') 56 | 57 | xq_ptr = swig_ptr_from_FloatTensor(xq) 58 | 59 | nb, d2 = xb.size() 60 | assert d2 == d 61 | if xb.is_contiguous(): 62 | xb_row_major = True 63 | elif xb.t().is_contiguous(): 64 | xb = xb.t() 65 | xb_row_major = False 66 | else: 67 | raise TypeError('matrix should be row or column-major') 68 | xb_ptr = swig_ptr_from_FloatTensor(xb) 69 | 70 | if D is None: 71 | D = torch.empty(nq, k, device=xb.device, dtype=torch.float32) 72 | else: 73 | assert D.shape == (nq, k) 74 | assert D.device == xb.device 75 | 76 | if I is None: 77 | I = torch.empty(nq, k, device=xb.device, dtype=torch.int64) 78 | else: 79 | assert I.shape == (nq, k) 80 | assert I.device == xb.device 81 | 82 | D_ptr = swig_ptr_from_FloatTensor(D) 83 | I_ptr = swig_ptr_from_LongTensor(I) 84 | 85 | faiss.bruteForceKnn(res, metric, 86 | xb_ptr, xb_row_major, nb, 87 | xq_ptr, xq_row_major, nq, 88 | d, k, D_ptr, I_ptr) 89 | 90 | return D, I 91 | 92 | def index_init_gpu(ngpus, feat_dim): 93 | flat_config = [] 94 | for i in range(ngpus): 95 | cfg = faiss.GpuIndexFlatConfig() 96 | cfg.useFloat16 = False 97 | cfg.device = i 98 | flat_config.append(cfg) 99 | 100 | res = [faiss.StandardGpuResources() for i in range(ngpus)] 101 | indexes = [faiss.GpuIndexFlatL2(res[i], feat_dim, flat_config[i]) for i in range(ngpus)] 102 | index = faiss.IndexShards(feat_dim) 103 | for sub_index in indexes: 104 | index.add_shard(sub_index) 105 | index.reset() 106 | return index 107 | 108 | def index_init_cpu(feat_dim): 109 | return faiss.IndexFlatL2(feat_dim) 110 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/utils/faiss_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import faiss 4 | import torch 5 | 6 | def swig_ptr_from_FloatTensor(x): 7 | assert x.is_contiguous() 8 | assert x.dtype == torch.float32 9 | return faiss.cast_integer_to_float_ptr( 10 | x.storage().data_ptr() + x.storage_offset() * 4) 11 | 12 | def swig_ptr_from_LongTensor(x): 13 | assert x.is_contiguous() 14 | assert x.dtype == torch.int64, 'dtype=%s' % x.dtype 15 | 16 | return faiss.cast_integer_to_idx_t_ptr( 17 | x.storage().data_ptr() + x.storage_offset() * 8) 18 | 19 | def search_index_pytorch(index, x, k, D=None, I=None): 20 | """call the search function of an index with pytorch tensor I/O (CPU 21 | and GPU supported)""" 22 | assert x.is_contiguous() 23 | n, d = x.size() 24 | assert d == index.d 25 | 26 | if D is None: 27 | D = torch.empty((n, k), dtype=torch.float32, device=x.device) 28 | else: 29 | assert D.size() == (n, k) 30 | 31 | if I is None: 32 | I = torch.empty((n, k), dtype=torch.int64, device=x.device) 33 | else: 34 | assert I.size() == (n, k) 35 | torch.cuda.synchronize() 36 | xptr = swig_ptr_from_FloatTensor(x) 37 | Iptr = swig_ptr_from_LongTensor(I) 38 | Dptr = swig_ptr_from_FloatTensor(D) 39 | index.search_c(n, xptr, 40 | k, Dptr, Iptr) 41 | torch.cuda.synchronize() 42 | return D, I 43 | 44 | def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None, 45 | metric=faiss.METRIC_L2): 46 | assert xb.device == xq.device 47 | 48 | nq, d = xq.size() 49 | if xq.is_contiguous(): 50 | xq_row_major = True 51 | elif xq.t().is_contiguous(): 52 | xq = xq.t() # I initially wrote xq:t(), Lua is still haunting me :-) 53 | xq_row_major = False 54 | else: 55 | raise TypeError('matrix should be row or column-major') 56 | 57 | xq_ptr = swig_ptr_from_FloatTensor(xq) 58 | 59 | nb, d2 = xb.size() 60 | assert d2 == d 61 | if xb.is_contiguous(): 62 | xb_row_major = True 63 | elif xb.t().is_contiguous(): 64 | xb = xb.t() 65 | xb_row_major = False 66 | else: 67 | raise TypeError('matrix should be row or column-major') 68 | xb_ptr = swig_ptr_from_FloatTensor(xb) 69 | 70 | if D is None: 71 | D = torch.empty(nq, k, device=xb.device, dtype=torch.float32) 72 | else: 73 | assert D.shape == (nq, k) 74 | assert D.device == xb.device 75 | 76 | if I is None: 77 | I = torch.empty(nq, k, device=xb.device, dtype=torch.int64) 78 | else: 79 | assert I.shape == (nq, k) 80 | assert I.device == xb.device 81 | 82 | D_ptr = swig_ptr_from_FloatTensor(D) 83 | I_ptr = swig_ptr_from_LongTensor(I) 84 | 85 | faiss.bruteForceKnn(res, metric, 86 | xb_ptr, xb_row_major, nb, 87 | xq_ptr, xq_row_major, nq, 88 | d, k, D_ptr, I_ptr) 89 | 90 | return D, I 91 | 92 | def index_init_gpu(ngpus, feat_dim): 93 | flat_config = [] 94 | for i in range(ngpus): 95 | cfg = faiss.GpuIndexFlatConfig() 96 | cfg.useFloat16 = False 97 | cfg.device = i 98 | flat_config.append(cfg) 99 | 100 | res = [faiss.StandardGpuResources() for i in range(ngpus)] 101 | indexes = [faiss.GpuIndexFlatL2(res[i], feat_dim, flat_config[i]) for i in range(ngpus)] 102 | index = faiss.IndexShards(feat_dim) 103 | for sub_index in indexes: 104 | index.add_shard(sub_index) 105 | index.reset() 106 | return index 107 | 108 | def index_init_cpu(feat_dim): 109 | return faiss.IndexFlatL2(feat_dim) 110 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/clustercontrast/utils/data/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | from PIL import Image 5 | import random 6 | import math 7 | import numpy as np 8 | 9 | class RectScale(object): 10 | def __init__(self, height, width, interpolation=Image.BILINEAR): 11 | self.height = height 12 | self.width = width 13 | self.interpolation = interpolation 14 | 15 | def __call__(self, img): 16 | w, h = img.size 17 | if h == self.height and w == self.width: 18 | return img 19 | return img.resize((self.width, self.height), self.interpolation) 20 | 21 | 22 | class RandomSizedRectCrop(object): 23 | def __init__(self, height, width, interpolation=Image.BILINEAR): 24 | self.height = height 25 | self.width = width 26 | self.interpolation = interpolation 27 | 28 | def __call__(self, img): 29 | for attempt in range(10): 30 | area = img.size[0] * img.size[1] 31 | target_area = random.uniform(0.64, 1.0) * area 32 | aspect_ratio = random.uniform(2, 3) 33 | 34 | h = int(round(math.sqrt(target_area * aspect_ratio))) 35 | w = int(round(math.sqrt(target_area / aspect_ratio))) 36 | 37 | if w <= img.size[0] and h <= img.size[1]: 38 | x1 = random.randint(0, img.size[0] - w) 39 | y1 = random.randint(0, img.size[1] - h) 40 | 41 | img = img.crop((x1, y1, x1 + w, y1 + h)) 42 | assert(img.size == (w, h)) 43 | 44 | return img.resize((self.width, self.height), self.interpolation) 45 | 46 | # Fallback 47 | scale = RectScale(self.height, self.width, 48 | interpolation=self.interpolation) 49 | return scale(img) 50 | 51 | 52 | class RandomErasing(object): 53 | """ Randomly selects a rectangle region in an image and erases its pixels. 54 | 'Random Erasing Data Augmentation' by Zhong et al. 55 | See https://arxiv.org/pdf/1708.04896.pdf 56 | Args: 57 | probability: The probability that the Random Erasing operation will be performed. 58 | sl: Minimum proportion of erased area against input image. 59 | sh: Maximum proportion of erased area against input image. 60 | r1: Minimum aspect ratio of erased area. 61 | mean: Erasing value. 62 | """ 63 | 64 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 65 | self.probability = probability 66 | self.mean = mean 67 | self.sl = sl 68 | self.sh = sh 69 | self.r1 = r1 70 | 71 | def __call__(self, img): 72 | 73 | if random.uniform(0, 1) >= self.probability: 74 | return img 75 | 76 | for attempt in range(100): 77 | area = img.size()[1] * img.size()[2] 78 | 79 | target_area = random.uniform(self.sl, self.sh) * area 80 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 81 | 82 | h = int(round(math.sqrt(target_area * aspect_ratio))) 83 | w = int(round(math.sqrt(target_area / aspect_ratio))) 84 | 85 | if w < img.size()[2] and h < img.size()[1]: 86 | x1 = random.randint(0, img.size()[1] - h) 87 | y1 = random.randint(0, img.size()[2] - w) 88 | if img.size()[0] == 3: 89 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 90 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 91 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 92 | else: 93 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 94 | return img 95 | 96 | return img 97 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/datasets/regdb_rgb.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | from ..utils.data import BaseImageDataset 6 | 7 | 8 | class regdb_rgb(BaseImageDataset): 9 | """ 10 | Market1501 11 | Reference: 12 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 13 | URL: http://www.liangzheng.org/Project/project_reid.html 14 | 15 | Dataset statistics: 16 | # identities: 1501 (+1 for background) 17 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 18 | """ 19 | dataset_dir = 'RegDB/rgb_modify/' 20 | 21 | def __init__(self, root,trial=0, verbose=True, **kwargs): 22 | super(regdb_rgb, self).__init__() 23 | root='/data0/ReIDData/' 24 | # print('regdb_rgb',trial) 25 | self.dataset_dir = osp.join(root, self.dataset_dir) 26 | self.train_dir = osp.join(self.dataset_dir, str(trial)+'/'+'bounding_box_train') 27 | 28 | 29 | 30 | self.query_dir = osp.join(self.dataset_dir, str(trial)+'/'+'query')#osp.join(self.dataset_dir, 'query') 31 | self.gallery_dir = osp.join(self.dataset_dir, str(trial)+'/'+'bounding_box_test') 32 | 33 | self._check_before_run() 34 | 35 | train = self._process_dir(self.train_dir, relabel=True) 36 | query = self._process_dir(self.query_dir, relabel=False) 37 | gallery = self._process_dir(self.gallery_dir, relabel=False) 38 | 39 | if verbose: 40 | print("=> regdb_rgb loaded",trial) 41 | self.print_dataset_statistics(train, query, gallery) 42 | 43 | self.train = train 44 | self.query = query 45 | self.gallery = gallery 46 | 47 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 48 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 49 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 50 | 51 | def _check_before_run(self): 52 | """Check if all files are available before going deeper""" 53 | if not osp.exists(self.dataset_dir): 54 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 55 | if not osp.exists(self.train_dir): 56 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 57 | if not osp.exists(self.query_dir): 58 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 59 | if not osp.exists(self.gallery_dir): 60 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 61 | 62 | def _process_dir(self, dir_path, relabel=False): 63 | img_paths = glob.glob(osp.join(dir_path, '*.bmp')) 64 | pattern = re.compile(r'([-\d]+)_c(\d)') 65 | 66 | pid_container = set() 67 | for img_path in img_paths: 68 | pid, _ = map(int, pattern.search(img_path).groups()) 69 | if pid == -1: 70 | continue # junk images are just ignored 71 | pid_container.add(pid) 72 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 73 | 74 | dataset = [] 75 | for img_path in img_paths: 76 | pid, camid = map(int, pattern.search(img_path).groups()) 77 | if pid == -1: 78 | continue # junk images are just ignored 79 | assert 0 <= pid <= 1501 # pid == 0 means background 80 | assert 1 <= camid <= 6 81 | camid -= 1 # index starts from 0 82 | if relabel: 83 | pid = pid2label[pid] 84 | dataset.append((img_path, pid, camid)) 85 | 86 | return dataset 87 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/utils/data/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | from PIL import Image 5 | import random 6 | import math 7 | import numpy as np 8 | 9 | class RectScale(object): 10 | def __init__(self, height, width, interpolation=Image.BILINEAR): 11 | self.height = height 12 | self.width = width 13 | self.interpolation = interpolation 14 | 15 | def __call__(self, img): 16 | w, h = img.size 17 | if h == self.height and w == self.width: 18 | return img 19 | return img.resize((self.width, self.height), self.interpolation) 20 | 21 | 22 | class RandomSizedRectCrop(object): 23 | def __init__(self, height, width, interpolation=Image.BILINEAR): 24 | self.height = height 25 | self.width = width 26 | self.interpolation = interpolation 27 | 28 | def __call__(self, img): 29 | for attempt in range(10): 30 | area = img.size[0] * img.size[1] 31 | target_area = random.uniform(0.64, 1.0) * area 32 | aspect_ratio = random.uniform(2, 3) 33 | 34 | h = int(round(math.sqrt(target_area * aspect_ratio))) 35 | w = int(round(math.sqrt(target_area / aspect_ratio))) 36 | 37 | if w <= img.size[0] and h <= img.size[1]: 38 | x1 = random.randint(0, img.size[0] - w) 39 | y1 = random.randint(0, img.size[1] - h) 40 | 41 | img = img.crop((x1, y1, x1 + w, y1 + h)) 42 | assert(img.size == (w, h)) 43 | 44 | return img.resize((self.width, self.height), self.interpolation) 45 | 46 | # Fallback 47 | scale = RectScale(self.height, self.width, 48 | interpolation=self.interpolation) 49 | return scale(img) 50 | 51 | 52 | class RandomErasing(object): 53 | """ Randomly selects a rectangle region in an image and erases its pixels. 54 | 'Random Erasing Data Augmentation' by Zhong et al. 55 | See https://arxiv.org/pdf/1708.04896.pdf 56 | Args: 57 | probability: The probability that the Random Erasing operation will be performed. 58 | sl: Minimum proportion of erased area against input image. 59 | sh: Maximum proportion of erased area against input image. 60 | r1: Minimum aspect ratio of erased area. 61 | mean: Erasing value. 62 | """ 63 | 64 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 65 | self.probability = probability 66 | self.mean = mean 67 | self.sl = sl 68 | self.sh = sh 69 | self.r1 = r1 70 | 71 | def __call__(self, img): 72 | 73 | if random.uniform(0, 1) >= self.probability: 74 | return img 75 | 76 | for attempt in range(100): 77 | area = img.size()[1] * img.size()[2] 78 | 79 | target_area = random.uniform(self.sl, self.sh) * area 80 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 81 | 82 | h = int(round(math.sqrt(target_area * aspect_ratio))) 83 | w = int(round(math.sqrt(target_area / aspect_ratio))) 84 | 85 | if w < img.size()[2] and h < img.size()[1]: 86 | x1 = random.randint(0, img.size()[1] - h) 87 | y1 = random.randint(0, img.size()[2] - w) 88 | if img.size()[0] == 3: 89 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 90 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 91 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 92 | else: 93 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 94 | return img 95 | 96 | return img 97 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/datasets/regdb_ir.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | from ..utils.data import BaseImageDataset 6 | 7 | 8 | class regdb_ir(BaseImageDataset): 9 | """ 10 | Market1501 11 | Reference: 12 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 13 | URL: http://www.liangzheng.org/Project/project_reid.html 14 | 15 | Dataset statistics: 16 | # identities: 1501 (+1 for background) 17 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 18 | """ 19 | dataset_dir = 'RegDB/ir_modify/' 20 | 21 | def __init__(self, root, trial= 0,verbose=True, **kwargs): 22 | super(regdb_ir, self).__init__() 23 | # print('regdb_ir',trial) 24 | root='/data0/ReIDData/' 25 | self.dataset_dir = osp.join(root, self.dataset_dir) 26 | self.train_dir = osp.join(self.dataset_dir, str(trial)+'/'+'bounding_box_train') 27 | 28 | self.query_dir = osp.join(self.dataset_dir, str(trial)+'/'+'query')#osp.join(self.dataset_dir, 'query') '/dat01/yangbin/data/sysu/all_modify/' 29 | self.gallery_dir = osp.join(self.dataset_dir, str(trial)+'/'+'bounding_box_test') 30 | 31 | 32 | self._check_before_run() 33 | 34 | train = self._process_dir(self.train_dir, relabel=True) 35 | query = self._process_dir(self.query_dir, relabel=False) 36 | gallery = self._process_dir(self.gallery_dir, relabel=False) 37 | 38 | if verbose: 39 | print("=> regdb_ir loaded",trial) 40 | self.print_dataset_statistics(train, query, gallery) 41 | 42 | self.train = train 43 | self.query = query 44 | self.gallery = gallery 45 | 46 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 47 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 48 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 49 | 50 | def _check_before_run(self): 51 | """Check if all files are available before going deeper""" 52 | if not osp.exists(self.dataset_dir): 53 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 54 | if not osp.exists(self.train_dir): 55 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 56 | if not osp.exists(self.query_dir): 57 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 58 | if not osp.exists(self.gallery_dir): 59 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 60 | 61 | def _process_dir(self, dir_path, relabel=False): 62 | img_paths = glob.glob(osp.join(dir_path, '*.bmp')) 63 | pattern = re.compile(r'([-\d]+)_c(\d)') 64 | 65 | pid_container = set() 66 | for img_path in img_paths: 67 | pid, _ = map(int, pattern.search(img_path).groups()) 68 | if pid == -1: 69 | continue # junk images are just ignored 70 | pid_container.add(pid) 71 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 72 | 73 | dataset = [] 74 | for img_path in img_paths: 75 | pid, camid = map(int, pattern.search(img_path).groups()) 76 | if pid == -1: 77 | continue # junk images are just ignored 78 | assert 0 <= pid <= 1501 # pid == 0 means background 79 | assert 1 <= camid <= 6 80 | camid -= 1 # index starts from 0 81 | if relabel: 82 | pid = pid2label[pid] 83 | dataset.append((img_path, pid, camid)) 84 | 85 | return dataset 86 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/datasets/sysu_rgb.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | from ..utils.data import BaseImageDataset 6 | 7 | 8 | class sysu_rgb(BaseImageDataset): 9 | """ 10 | Market1501 11 | Reference: 12 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 13 | URL: http://www.liangzheng.org/Project/project_reid.html 14 | 15 | Dataset statistics: 16 | # identities: 1501 (+1 for background) 17 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 18 | """ 19 | dataset_dir = 'sysu/rgb_modify/' 20 | 21 | def __init__(self, root, verbose=True, **kwargs): 22 | super(sysu_rgb, self).__init__() 23 | root='/data0/ReIDData/' 24 | self.dataset_dir = osp.join(root, self.dataset_dir) 25 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 26 | 27 | 28 | 29 | self.query_dir = osp.join(self.dataset_dir, 'query')#osp.join(self.dataset_dir, 'query') 30 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 31 | 32 | self._check_before_run() 33 | 34 | train = self._process_dir(self.train_dir, relabel=True) 35 | query = self._process_dir(self.query_dir, relabel=False) 36 | gallery = self._process_dir(self.gallery_dir, relabel=False) 37 | 38 | if verbose: 39 | print("=> Market1501 loaded") 40 | self.print_dataset_statistics(train, query, gallery) 41 | 42 | self.train = train 43 | self.query = query 44 | self.gallery = gallery 45 | 46 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 47 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 48 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 49 | 50 | def _check_before_run(self): 51 | """Check if all files are available before going deeper""" 52 | if not osp.exists(self.dataset_dir): 53 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 54 | if not osp.exists(self.train_dir): 55 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 56 | if not osp.exists(self.query_dir): 57 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 58 | if not osp.exists(self.gallery_dir): 59 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 60 | 61 | def _process_dir(self, dir_path, relabel=False): 62 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 63 | pattern = re.compile(r'([-\d]+)_c(\d)') 64 | 65 | pid_container = set() 66 | for img_path in img_paths: 67 | pid, _ = map(int, pattern.search(img_path).groups()) 68 | if pid == -1: 69 | continue # junk images are just ignored 70 | pid_container.add(pid) 71 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 72 | cid_container = [0,1,3,4] 73 | cid2label = {cid: label for label, cid in enumerate(cid_container)} 74 | print("cid2label",cid2label) 75 | dataset = [] 76 | for img_path in img_paths: 77 | pid, camid = map(int, pattern.search(img_path).groups()) 78 | if pid == -1: 79 | continue # junk images are just ignored 80 | assert 0 <= pid <= 1501 # pid == 0 means background 81 | assert 1 <= camid <= 6 82 | camid -= 1 # index starts from 0 83 | if relabel: 84 | pid = pid2label[pid] 85 | camid = cid2label[camid] 86 | dataset.append((img_path, pid, camid)) 87 | 88 | return dataset 89 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/clustercontrast/datasets/sysu_ir.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | from ..utils.data import BaseImageDataset 6 | 7 | 8 | class sysu_ir(BaseImageDataset): 9 | """ 10 | Market1501 11 | Reference: 12 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 13 | URL: http://www.liangzheng.org/Project/project_reid.html 14 | 15 | Dataset statistics: 16 | # identities: 1501 (+1 for background) 17 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 18 | """ 19 | dataset_dir = 'sysu/ir_modify/' 20 | 21 | def __init__(self, root, verbose=True, **kwargs): 22 | super(sysu_ir, self).__init__() 23 | root='/data0/ReIDData/' 24 | self.dataset_dir = osp.join(root, self.dataset_dir) 25 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 26 | 27 | self.query_dir = osp.join(self.dataset_dir, 'query')#osp.join(self.dataset_dir, 'query') '/dat01/yangbin/data/sysu/all_modify/' 28 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 29 | 30 | 31 | self._check_before_run() 32 | 33 | train = self._process_dir(self.train_dir, relabel=True) 34 | query = self._process_dir(self.query_dir, relabel=False) 35 | gallery = self._process_dir(self.gallery_dir, relabel=False) 36 | 37 | if verbose: 38 | print("=> Market1501 loaded") 39 | self.print_dataset_statistics(train, query, gallery) 40 | 41 | self.train = train 42 | self.query = query 43 | self.gallery = gallery 44 | 45 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 46 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 47 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 48 | 49 | def _check_before_run(self): 50 | """Check if all files are available before going deeper""" 51 | if not osp.exists(self.dataset_dir): 52 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 53 | if not osp.exists(self.train_dir): 54 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 55 | if not osp.exists(self.query_dir): 56 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 57 | if not osp.exists(self.gallery_dir): 58 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 59 | 60 | def _process_dir(self, dir_path, relabel=False): 61 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 62 | pattern = re.compile(r'([-\d]+)_c(\d)') 63 | 64 | pid_container = set() 65 | for img_path in img_paths: 66 | pid, _ = map(int, pattern.search(img_path).groups()) 67 | if pid == -1: 68 | continue # junk images are just ignored 69 | pid_container.add(pid) 70 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 71 | cid_container = [2,5] 72 | cid2label = {cid: label for label, cid in enumerate(cid_container)} 73 | print("cid2label",cid2label) 74 | dataset = [] 75 | for img_path in img_paths: 76 | pid, camid = map(int, pattern.search(img_path).groups()) 77 | if pid == -1: 78 | continue # junk images are just ignored 79 | assert 0 <= pid <= 1501 # pid == 0 means background 80 | assert 1 <= camid <= 6 81 | camid -= 1 # index starts from 0 82 | if relabel: 83 | pid = pid2label[pid] 84 | camid = cid2label[camid] 85 | dataset.append((img_path, pid, camid)) 86 | 87 | return dataset 88 | -------------------------------------------------------------------------------- /data/datasets/bases.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import numpy as np 4 | 5 | 6 | class BaseDataset(object): 7 | """ 8 | Base class of reid dataset 9 | """ 10 | 11 | def get_imagedata_info(self, data): 12 | pids, cams = [], [] 13 | for _, pid, camid in data: 14 | pids += [pid] 15 | cams += [camid] 16 | pids = set(pids) 17 | cams = set(cams) 18 | num_pids = len(pids) 19 | num_cams = len(cams) 20 | num_imgs = len(data) 21 | return num_pids, num_imgs, num_cams 22 | 23 | def get_videodata_info(self, data, return_tracklet_stats=False): 24 | pids, cams, tracklet_stats = [], [], [] 25 | for img_paths, pid, camid in data: 26 | pids += [pid] 27 | cams += [camid] 28 | tracklet_stats += [len(img_paths)] 29 | pids = set(pids) 30 | cams = set(cams) 31 | num_pids = len(pids) 32 | num_cams = len(cams) 33 | num_tracklets = len(data) 34 | if return_tracklet_stats: 35 | return num_pids, num_tracklets, num_cams, tracklet_stats 36 | return num_pids, num_tracklets, num_cams 37 | 38 | def print_dataset_statistics(self): 39 | raise NotImplementedError 40 | 41 | 42 | class BaseImageDataset(BaseDataset): 43 | """ 44 | Base class of image reid dataset 45 | """ 46 | 47 | def print_dataset_statistics(self, train, query, gallery): 48 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) 49 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) 50 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) 51 | 52 | print("Dataset statistics:") 53 | print(" ----------------------------------------") 54 | print(" subset | # ids | # images | # cameras") 55 | print(" ----------------------------------------") 56 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 57 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 58 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 59 | print(" ----------------------------------------") 60 | 61 | 62 | class BaseVideoDataset(BaseDataset): 63 | """ 64 | Base class of video reid dataset 65 | """ 66 | 67 | def print_dataset_statistics(self, train, query, gallery): 68 | num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_stats = \ 69 | self.get_videodata_info(train, return_tracklet_stats=True) 70 | 71 | num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_stats = \ 72 | self.get_videodata_info(query, return_tracklet_stats=True) 73 | 74 | num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_stats = \ 75 | self.get_videodata_info(gallery, return_tracklet_stats=True) 76 | 77 | tracklet_stats = train_tracklet_stats + query_tracklet_stats + gallery_tracklet_stats 78 | min_num = np.min(tracklet_stats) 79 | max_num = np.max(tracklet_stats) 80 | avg_num = np.mean(tracklet_stats) 81 | 82 | print("Dataset statistics:") 83 | print(" -------------------------------------------") 84 | print(" subset | # ids | # tracklets | # cameras") 85 | print(" -------------------------------------------") 86 | print(" train | {:5d} | {:11d} | {:9d}".format(num_train_pids, num_train_tracklets, num_train_cams)) 87 | print(" query | {:5d} | {:11d} | {:9d}".format(num_query_pids, num_query_tracklets, num_query_cams)) 88 | print(" gallery | {:5d} | {:11d} | {:9d}".format(num_gallery_pids, num_gallery_tracklets, num_gallery_cams)) 89 | print(" -------------------------------------------") 90 | print(" number of images per tracklet: {} ~ {}, average {:.2f}".format(min_num, max_num, avg_num)) 91 | print(" -------------------------------------------") 92 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/clustercontrast/models/cm.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import numpy as np 3 | from abc import ABC 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn, autograd 7 | 8 | 9 | class CM(autograd.Function): 10 | 11 | @staticmethod 12 | def forward(ctx, inputs, targets, features, momentum): 13 | ctx.features = features 14 | ctx.momentum = momentum 15 | ctx.save_for_backward(inputs, targets) 16 | outputs = inputs.mm(ctx.features.t()) 17 | 18 | return outputs 19 | 20 | @staticmethod 21 | def backward(ctx, grad_outputs): 22 | inputs, targets = ctx.saved_tensors 23 | grad_inputs = None 24 | if ctx.needs_input_grad[0]: 25 | grad_inputs = grad_outputs.mm(ctx.features) 26 | 27 | # momentum update 28 | for x, y in zip(inputs, targets): 29 | ctx.features[y] = ctx.momentum * ctx.features[y] + (1. - ctx.momentum) * x 30 | ctx.features[y] /= ctx.features[y].norm() 31 | 32 | return grad_inputs, None, None, None 33 | 34 | 35 | def cm(inputs, indexes, features, momentum=0.5): 36 | return CM.apply(inputs, indexes, features, torch.Tensor([momentum]).to(inputs.device)) 37 | 38 | 39 | class CM_Hard(autograd.Function): 40 | 41 | @staticmethod 42 | def forward(ctx, inputs, targets, features, momentum): 43 | ctx.features = features 44 | ctx.momentum = momentum 45 | ctx.save_for_backward(inputs, targets) 46 | outputs = inputs.mm(ctx.features.t()) 47 | 48 | return outputs 49 | 50 | @staticmethod 51 | def backward(ctx, grad_outputs): 52 | inputs, targets = ctx.saved_tensors 53 | grad_inputs = None 54 | if ctx.needs_input_grad[0]: 55 | grad_inputs = grad_outputs.mm(ctx.features) 56 | 57 | batch_centers = collections.defaultdict(list) 58 | for instance_feature, index in zip(inputs, targets.tolist()): 59 | batch_centers[index].append(instance_feature) 60 | 61 | for index, features in batch_centers.items(): 62 | distances = [] 63 | for feature in features: 64 | distance = feature.unsqueeze(0).mm(ctx.features[index].unsqueeze(0).t())[0][0] 65 | distances.append(distance.cpu().numpy()) 66 | 67 | median = np.argmin(np.array(distances)) 68 | ctx.features[index] = ctx.features[index] * ctx.momentum + (1 - ctx.momentum) * features[median] 69 | ctx.features[index] /= ctx.features[index].norm() 70 | 71 | return grad_inputs, None, None, None 72 | 73 | 74 | def cm_hard(inputs, indexes, features, momentum=0.5): 75 | return CM_Hard.apply(inputs, indexes, features, torch.Tensor([momentum]).to(inputs.device)) 76 | 77 | 78 | class Barlow(autograd.Function): 79 | 80 | @staticmethod 81 | def forward(ctx, inputs, targets, features, momentum): 82 | ctx.features = features 83 | ctx.momentum = momentum 84 | ctx.save_for_backward(inputs, targets) 85 | outputs = inputs.mm(ctx.features.t()) 86 | 87 | return outputs 88 | 89 | def barlow(inputs, indexes, features): 90 | return Barlow.apply(inputs, indexes, features) 91 | 92 | class ClusterMemory(nn.Module, ABC): 93 | def __init__(self, num_features, num_samples, temp=0.05, momentum=0.2, use_hard=False): 94 | super(ClusterMemory, self).__init__() 95 | self.num_features = num_features 96 | self.num_samples = num_samples 97 | 98 | self.momentum = momentum 99 | self.temp = temp 100 | self.use_hard = use_hard 101 | 102 | self.register_buffer('features', torch.zeros(num_samples, num_features)) 103 | 104 | def forward(self, inputs, targets): 105 | 106 | inputs = F.normalize(inputs, dim=1).cuda() 107 | if self.use_hard: 108 | outputs = cm_hard(inputs, targets, self.features, self.momentum) 109 | else: 110 | outputs = cm(inputs, targets, self.features, self.momentum) 111 | 112 | outputs /= self.temp 113 | loss = F.cross_entropy(outputs, targets) 114 | 115 | return loss 116 | -------------------------------------------------------------------------------- /data/datasets/dukemtmcreid.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import glob 4 | import re 5 | import urllib 6 | import zipfile 7 | 8 | import os.path as osp 9 | 10 | from utils.iotools import mkdir_if_missing 11 | from .bases import BaseImageDataset 12 | 13 | 14 | class DukeMTMCreID(BaseImageDataset): 15 | """ 16 | DukeMTMC-reID 17 | Reference: 18 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 19 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 20 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation 21 | 22 | Dataset statistics: 23 | # identities: 1404 (train + query) 24 | # images:16522 (train) + 2228 (query) + 17661 (gallery) 25 | # cameras: 8 26 | """ 27 | dataset_dir = 'dukemtmc-reid' 28 | 29 | def __init__(self, root='./toDataset', verbose=True, **kwargs): 30 | super(DukeMTMCreID, self).__init__() 31 | self.dataset_dir = osp.join(root, self.dataset_dir) 32 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 33 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 34 | self.query_dir = osp.join(self.dataset_dir, 'query') 35 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 36 | 37 | self._download_data() 38 | self._check_before_run() 39 | 40 | train = self._process_dir(self.train_dir, relabel=True) 41 | query = self._process_dir(self.query_dir, relabel=False) 42 | gallery = self._process_dir(self.gallery_dir, relabel=False) 43 | 44 | if verbose: 45 | print("=> DukeMTMC-reID loaded") 46 | self.print_dataset_statistics(train, query, gallery) 47 | 48 | self.train = train 49 | self.query = query 50 | self.gallery = gallery 51 | 52 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 53 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 54 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 55 | 56 | def _download_data(self): 57 | if osp.exists(self.dataset_dir): 58 | print("This dataset has been downloaded.") 59 | return 60 | 61 | print("Creating directory {}".format(self.dataset_dir)) 62 | mkdir_if_missing(self.dataset_dir) 63 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 64 | 65 | print("Downloading DukeMTMC-reID dataset") 66 | urllib.request.urlretrieve(self.dataset_url, fpath) 67 | 68 | print("Extracting files") 69 | zip_ref = zipfile.ZipFile(fpath, 'r') 70 | zip_ref.extractall(self.dataset_dir) 71 | zip_ref.close() 72 | 73 | def _check_before_run(self): 74 | """Check if all files are available before going deeper""" 75 | if not osp.exists(self.dataset_dir): 76 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 77 | if not osp.exists(self.train_dir): 78 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 79 | if not osp.exists(self.query_dir): 80 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 81 | if not osp.exists(self.gallery_dir): 82 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 83 | 84 | def _process_dir(self, dir_path, relabel=False): 85 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 86 | pattern = re.compile(r'([-\d]+)_c(\d)') 87 | 88 | pid_container = set() 89 | for img_path in img_paths: 90 | pid, _ = map(int, pattern.search(img_path).groups()) 91 | pid_container.add(pid) 92 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 93 | 94 | dataset = [] 95 | for img_path in img_paths: 96 | pid, camid = map(int, pattern.search(img_path).groups()) 97 | assert 1 <= camid <= 8 98 | camid -= 1 # index starts from 0 99 | if relabel: pid = pid2label[pid] 100 | dataset.append((img_path, pid, camid)) 101 | 102 | return dataset 103 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/clustercontrast/utils/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_VI_ReID/solver/cosine_lr.py: -------------------------------------------------------------------------------- 1 | """ Cosine Scheduler 2 | 3 | Cosine LR schedule with warmup, cycle/restarts, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import logging 8 | import math 9 | import torch 10 | 11 | from .scheduler import Scheduler 12 | 13 | 14 | _logger = logging.getLogger(__name__) 15 | 16 | 17 | class CosineLRScheduler(Scheduler): 18 | """ 19 | Cosine decay with restarts. 20 | This is described in the paper https://arxiv.org/abs/1608.03983. 21 | 22 | Inspiration from 23 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py 24 | """ 25 | 26 | def __init__(self, 27 | optimizer: torch.optim.Optimizer, 28 | t_initial: int, 29 | t_mul: float = 1., 30 | lr_min: float = 0., 31 | decay_rate: float = 1., 32 | warmup_t=0, 33 | warmup_lr_init=0, 34 | warmup_prefix=False, 35 | cycle_limit=0, 36 | t_in_epochs=True, 37 | noise_range_t=None, 38 | noise_pct=0.67, 39 | noise_std=1.0, 40 | noise_seed=42, 41 | initialize=True) -> None: 42 | super().__init__( 43 | optimizer, param_group_field="lr", 44 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 45 | initialize=initialize) 46 | 47 | assert t_initial > 0 48 | assert lr_min >= 0 49 | if t_initial == 1 and t_mul == 1 and decay_rate == 1: 50 | _logger.warning("Cosine annealing scheduler will have no effect on the learning " 51 | "rate since t_initial = t_mul = eta_mul = 1.") 52 | self.t_initial = t_initial 53 | self.t_mul = t_mul 54 | self.lr_min = lr_min 55 | self.decay_rate = decay_rate 56 | self.cycle_limit = cycle_limit 57 | self.warmup_t = warmup_t 58 | self.warmup_lr_init = warmup_lr_init 59 | self.warmup_prefix = warmup_prefix 60 | self.t_in_epochs = t_in_epochs 61 | if self.warmup_t: 62 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 63 | super().update_groups(self.warmup_lr_init) 64 | else: 65 | self.warmup_steps = [1 for _ in self.base_values] 66 | 67 | def _get_lr(self, t): 68 | if t < self.warmup_t: 69 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 70 | else: 71 | if self.warmup_prefix: 72 | t = t - self.warmup_t 73 | 74 | if self.t_mul != 1: 75 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 76 | t_i = self.t_mul ** i * self.t_initial 77 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 78 | else: 79 | i = t // self.t_initial 80 | t_i = self.t_initial 81 | t_curr = t - (self.t_initial * i) 82 | 83 | gamma = self.decay_rate ** i 84 | lr_min = self.lr_min * gamma 85 | lr_max_values = [v * gamma for v in self.base_values] 86 | 87 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 88 | lrs = [ 89 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values 90 | ] 91 | else: 92 | lrs = [self.lr_min for _ in self.base_values] 93 | 94 | return lrs 95 | 96 | def get_epoch_values(self, epoch: int): 97 | if self.t_in_epochs: 98 | return self._get_lr(epoch) 99 | else: 100 | return None 101 | 102 | def get_update_values(self, num_updates: int): 103 | if not self.t_in_epochs: 104 | return self._get_lr(num_updates) 105 | else: 106 | return None 107 | 108 | def get_cycle_length(self, cycles=0): 109 | if not cycles: 110 | cycles = self.cycle_limit 111 | cycles = max(1, cycles) 112 | if self.t_mul == 1.0: 113 | return self.t_initial * cycles 114 | else: 115 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 116 | -------------------------------------------------------------------------------- /Transformer-ReID-Survey/UnTransReID_USL_ReID/solver/cosine_lr.py: -------------------------------------------------------------------------------- 1 | """ Cosine Scheduler 2 | 3 | Cosine LR schedule with warmup, cycle/restarts, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import logging 8 | import math 9 | import torch 10 | 11 | from .scheduler import Scheduler 12 | 13 | 14 | _logger = logging.getLogger(__name__) 15 | 16 | 17 | class CosineLRScheduler(Scheduler): 18 | """ 19 | Cosine decay with restarts. 20 | This is described in the paper https://arxiv.org/abs/1608.03983. 21 | 22 | Inspiration from 23 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py 24 | """ 25 | 26 | def __init__(self, 27 | optimizer: torch.optim.Optimizer, 28 | t_initial: int, 29 | t_mul: float = 1., 30 | lr_min: float = 0., 31 | decay_rate: float = 1., 32 | warmup_t=0, 33 | warmup_lr_init=0, 34 | warmup_prefix=False, 35 | cycle_limit=0, 36 | t_in_epochs=True, 37 | noise_range_t=None, 38 | noise_pct=0.67, 39 | noise_std=1.0, 40 | noise_seed=42, 41 | initialize=True) -> None: 42 | super().__init__( 43 | optimizer, param_group_field="lr", 44 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 45 | initialize=initialize) 46 | 47 | assert t_initial > 0 48 | assert lr_min >= 0 49 | if t_initial == 1 and t_mul == 1 and decay_rate == 1: 50 | _logger.warning("Cosine annealing scheduler will have no effect on the learning " 51 | "rate since t_initial = t_mul = eta_mul = 1.") 52 | self.t_initial = t_initial 53 | self.t_mul = t_mul 54 | self.lr_min = lr_min 55 | self.decay_rate = decay_rate 56 | self.cycle_limit = cycle_limit 57 | self.warmup_t = warmup_t 58 | self.warmup_lr_init = warmup_lr_init 59 | self.warmup_prefix = warmup_prefix 60 | self.t_in_epochs = t_in_epochs 61 | if self.warmup_t: 62 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 63 | super().update_groups(self.warmup_lr_init) 64 | else: 65 | self.warmup_steps = [1 for _ in self.base_values] 66 | 67 | def _get_lr(self, t): 68 | if t < self.warmup_t: 69 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 70 | else: 71 | if self.warmup_prefix: 72 | t = t - self.warmup_t 73 | 74 | if self.t_mul != 1: 75 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 76 | t_i = self.t_mul ** i * self.t_initial 77 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 78 | else: 79 | i = t // self.t_initial 80 | t_i = self.t_initial 81 | t_curr = t - (self.t_initial * i) 82 | 83 | gamma = self.decay_rate ** i 84 | lr_min = self.lr_min * gamma 85 | lr_max_values = [v * gamma for v in self.base_values] 86 | 87 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 88 | lrs = [ 89 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values 90 | ] 91 | else: 92 | lrs = [self.lr_min for _ in self.base_values] 93 | 94 | return lrs 95 | 96 | def get_epoch_values(self, epoch: int): 97 | if self.t_in_epochs: 98 | return self._get_lr(epoch) 99 | else: 100 | return None 101 | 102 | def get_update_values(self, num_updates: int): 103 | if not self.t_in_epochs: 104 | return self._get_lr(num_updates) 105 | else: 106 | return None 107 | 108 | def get_cycle_length(self, cycles=0): 109 | if not cycles: 110 | cycles = self.cycle_limit 111 | cycles = max(1, cycles) 112 | if self.t_mul == 1.0: 113 | return self.t_initial * cycles 114 | else: 115 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 116 | -------------------------------------------------------------------------------- /tools/main.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import argparse 4 | import os 5 | import sys 6 | import torch 7 | 8 | from torch.backends import cudnn 9 | 10 | sys.path.append('.') 11 | from config import cfg 12 | from data import make_data_loader 13 | from modeling import build_model 14 | from utils.lr_scheduler import WarmupMultiStepLR 15 | from utils.logger import setup_logger 16 | from tools.train import do_train 17 | from tools.test import do_test 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser(description="AGW Re-ID Baseline") 22 | parser.add_argument( 23 | "--config_file", default="", help="path to config file", type=str 24 | ) 25 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 26 | nargs=argparse.REMAINDER) 27 | 28 | args = parser.parse_args() 29 | 30 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 31 | 32 | if args.config_file != "": 33 | cfg.merge_from_file(args.config_file) 34 | cfg.merge_from_list(args.opts) 35 | cfg.freeze() 36 | 37 | output_dir = cfg.OUTPUT_DIR 38 | if output_dir and not os.path.exists(output_dir): 39 | os.makedirs(output_dir) 40 | 41 | logger = setup_logger("reid_baseline", output_dir, 0) 42 | logger.info("Using {} GPUS".format(num_gpus)) 43 | logger.info(args) 44 | 45 | if args.config_file != "": 46 | logger.info("Loaded configuration file {}".format(args.config_file)) 47 | with open(args.config_file, 'r') as cf: 48 | config_str = "\n" + cf.read() 49 | logger.info(config_str) 50 | logger.info("Running with config:\n{}".format(cfg)) 51 | 52 | if cfg.MODEL.DEVICE == "cuda": 53 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID # new add by gu 54 | cudnn.benchmark = True 55 | 56 | data_loader, num_query, num_classes = make_data_loader(cfg) 57 | model = build_model(cfg, num_classes) 58 | 59 | if 'cpu' not in cfg.MODEL.DEVICE: 60 | if torch.cuda.device_count() > 1: 61 | model = torch.nn.DataParallel(model) 62 | model.to(device=cfg.MODEL.DEVICE) 63 | 64 | if cfg.TEST.EVALUATE_ONLY == 'on': 65 | logger.info("Evaluate Only") 66 | model.load_param(cfg.TEST.WEIGHT) 67 | do_test(cfg, model, data_loader, num_query) 68 | return 69 | 70 | criterion = model.get_creterion(cfg, num_classes) 71 | optimizer = model.get_optimizer(cfg, criterion) 72 | 73 | # Add for using self trained model 74 | if cfg.MODEL.PRETRAIN_CHOICE == 'self': 75 | start_epoch = eval(cfg.MODEL.PRETRAIN_PATH.split('/')[-1].split('.')[0].split('_')[-1]) 76 | print('Start epoch:', start_epoch) 77 | path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer') 78 | print('Path to the checkpoint of optimizer:', path_to_optimizer) 79 | path_to_center_param = cfg.MODEL.PRETRAIN_PATH.replace('model', 'center_param') 80 | print('Path to the checkpoint of center_param:', path_to_center_param) 81 | path_to_optimizer_center = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer_center') 82 | print('Path to the checkpoint of optimizer_center:', path_to_optimizer_center) 83 | model.load_state_dict(torch.load(cfg.MODEL.PRETRAIN_PATH)) 84 | optimizer['model'].load_state_dict(torch.load(path_to_optimizer)) 85 | criterion['center'].load_state_dict(torch.load(path_to_center_param)) 86 | optimizer['center'].load_state_dict(torch.load(path_to_optimizer_center)) 87 | scheduler = WarmupMultiStepLR(optimizer['model'], cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 88 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD, start_epoch) 89 | elif cfg.MODEL.PRETRAIN_CHOICE == 'imagenet': 90 | start_epoch = 0 91 | scheduler = WarmupMultiStepLR(optimizer['model'], cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 92 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD) 93 | 94 | else: 95 | print('Only support pretrain_choice for imagenet and self, but got {}'.format(cfg.MODEL.PRETRAIN_CHOICE)) 96 | 97 | 98 | 99 | do_train(cfg, 100 | model, 101 | data_loader, 102 | optimizer, 103 | scheduler, 104 | criterion, 105 | num_query, 106 | start_epoch 107 | ) 108 | 109 | if __name__ == '__main__': 110 | main() 111 | --------------------------------------------------------------------------------