├── LICENSE ├── README.md ├── configs ├── Pirt.yml └── eval.yml ├── fastreid ├── __init__.py ├── config │ ├── __init__.py │ ├── config.py │ └── defaults.py ├── data │ ├── __init__.py │ ├── build.py │ ├── common.py │ ├── data_utils.py │ ├── datasets │ │ ├── __init__.py │ │ ├── bases.py │ │ ├── cuhk03.py │ │ ├── dukemtmcreid.py │ │ ├── market1501.py │ │ ├── msmt17.py │ │ └── occluded_duke.py │ ├── samplers │ │ ├── __init__.py │ │ ├── data_sampler.py │ │ └── triplet_sampler.py │ └── transforms │ │ ├── __init__.py │ │ ├── autoaugment.py │ │ ├── build.py │ │ ├── functional.py │ │ └── transforms.py ├── engine │ ├── __init__.py │ ├── defaults.py │ ├── hooks.py │ ├── launch.py │ └── train_loop.py ├── evaluation │ ├── __init__.py │ ├── build.py │ ├── evaluator.py │ ├── rank.py │ ├── rank_cylib │ │ ├── Makefile │ │ ├── __init__.py │ │ ├── rank_cy.pyx │ │ ├── roc_cy.pyx │ │ ├── setup.py │ │ └── test_cython.py │ ├── reid_evaluation.py │ ├── rerank.py │ ├── roc.py │ └── testing.py ├── layers │ ├── __init__.py │ ├── batch_norm.py │ ├── gather_layer.py │ ├── mlp.py │ ├── non_local.py │ ├── pooling.py │ └── stripe_layer.py ├── modeling │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── build.py │ │ └── resnet.py │ ├── heads │ │ ├── __init__.py │ │ ├── agg_head.py │ │ ├── build.py │ │ └── embedding_head.py │ ├── losses │ │ ├── __init__.py │ │ ├── cross_entroy_loss.py │ │ ├── triplet_loss.py │ │ └── utils.py │ ├── meta_arch │ │ ├── __init__.py │ │ ├── baseline.py │ │ ├── build.py │ │ └── pirt.py │ └── posenets │ │ ├── __init__.py │ │ ├── build.py │ │ └── pose_hrnet.py ├── solver │ ├── __init__.py │ ├── build.py │ ├── lr_scheduler.py │ └── optim │ │ ├── __init__.py │ │ ├── adam.py │ │ └── sgd.py └── utils │ ├── __init__.py │ ├── checkpoint.py │ ├── collect_env.py │ ├── comm.py │ ├── compute_dist.py │ ├── env.py │ ├── events.py │ ├── faiss_utils.py │ ├── file_io.py │ ├── history_buffer.py │ ├── logger.py │ ├── precision_bn.py │ ├── registry.py │ ├── summary.py │ ├── timer.py │ ├── visualizer.py │ └── weight_init.py └── tools └── train_net.py /README.md: -------------------------------------------------------------------------------- 1 | # Pirt 2 | Pose-guided Inter- and Intra-part Relational Transformer for Occluded Person Re-Identification official implement 3 | 4 | ## Introduction 5 | 6 | This repository contains the code for the paper: 7 | [**Pose-guided Inter- and Intra-part Relational Transformer for Occluded Person Re-Identification**](https://arxiv.org/abs/2109.03483) 8 | Zhongxing Ma, Yifan Zhao, Jia Li 9 | ACM Conference on Multimedia (ACM MM), 2021 10 | 11 | ## Environments 12 | 13 | 1. pytorch 1.6.0 14 | 2. python 3.8 15 | 3. pyyaml | yacs | termcolor | tqdm | faiss-cpu | tabulate | tabulate | matplotlib | tensorboard 16 | 4. sklearn | enopis 17 | 5. GTX 2080 Ti * 2 18 | 6. CUDA 10.1 19 | 20 | ## Getting Started 21 | 22 | Working directory: **/your/path/to/fast-reid/** 23 | 24 | ### Traning 25 | 26 | ```bash 27 | python -u tools/train_net.py --config-file configs/Pirt.yml --num-gpus 2 OUTPUT_DIR logs/your/customed/path 28 | ``` 29 | 30 | ### Evaluation 31 | 32 | ```bash 33 | python -u tools/train_net.py --eval-only --config-file configs/eval.yml --num-gpus 2 OUTPUT_DIR logs/your/customed/path 34 | ``` 35 | 36 | The config file of the model are placed at `./configs/Pirt.yml` 37 | 38 | ### Datasets 39 | 40 | OccludedDuke or Market-1501 datasets shoule be placed at `./datasets/OccludedDuke` 41 | 42 | See the `./fastreid/data/datasets` folder for detailed configuration 43 | 44 | ### Pretrained Models 45 | 46 | The pose-guided and resnet50 models should be placed at `../models_zoo/` 47 | 48 | [**resnet50**](https://download.pytorch.org/models/resnet50-19c8e357.pth). 49 | [**resnet50-ibn**](https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet50_ibn_a-d9d0bb7b.pth). 50 | [**pose_hrnet_w48_256x192**](https://drive.google.com/file/d/1GwNajHuDAXa61ioDw0d6F_ic7WhlELDb/view?usp=sharing). 51 | 52 | ### Citing 53 | 54 | ```bash 55 | @misc{ma2021poseguided, 56 | title={Pose-guided Inter- and Intra-part Relational Transformer for Occluded Person Re-Identification}, 57 | author={Zhongxing Ma and Yifan Zhao and Jia Li}, 58 | year={2021}, 59 | eprint={2109.03483}, 60 | archivePrefix={arXiv}, 61 | primaryClass={cs.CV} 62 | } 63 | ``` 64 | 65 | ## Acknowledgments 66 | 67 | Our code is based on the early version of [**FAST-REID**](https://github.com/JDAI-CV/fast-reid). 68 | 69 | A awesome Repo for beginners to learn, you can find more details of the framework in it. 70 | -------------------------------------------------------------------------------- /configs/Pirt.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "Pirt" 3 | TEL: 3 4 | 5 | BACKBONE: 6 | NAME: "build_resnet_backbone" 7 | NORM: "BN" 8 | DEPTH: "50x" 9 | LAST_STRIDE: 1 10 | FEAT_DIM: 2048 11 | WITH_IBN: True 12 | WITH_NL: False 13 | PRETRAIN: True 14 | PRETRAIN_PATH: "../models_zoo/resnet50_ibn_a-d9d0bb7b.pth" 15 | 16 | POSENET: 17 | NAME: "build_pose_hrnet" 18 | PRETRAIN: True 19 | PRETRAIN_PATH: '../models_zoo/pose_hrnet_w48_256x192.pth' 20 | PRETRAINED_LAYERS: ['*'] 21 | STEM_INPLANES: 64 22 | FINAL_CONV_KERNEL: 1 23 | NUM_JOINTS: 17 24 | 25 | JOINTS_GROUPS: [ 26 | [5, 7, 9, 6, 8, 10], [11, 13, 15, 12, 14, 16], [0, 1, 2, 3, 4, 5], 27 | ] 28 | WIDTH: '48x' 29 | 30 | HEADS: 31 | NAME: "AGGHead" 32 | NORM: "BN" 33 | WITH_BNNECK: True 34 | POOL_LAYER: "avgpool" 35 | NECK_FEAT: "before" 36 | CLS_LAYER: "linear" 37 | 38 | LOSSES: 39 | NAME: ("CrossEntropyLoss", "TripletLoss", ) 40 | 41 | CE: 42 | EPSILON: 0.1 43 | SCALE: 1. 44 | 45 | TRI: 46 | MARGIN: 0. 47 | HARD_MINING: True 48 | NORM_FEAT: False 49 | SCALE: 1. 50 | 51 | 52 | INPUT: 53 | SIZE_TRAIN: [384, 128] 54 | SIZE_TEST: [384, 128] 55 | REA: 56 | ENABLED: True 57 | PROB: 0.5 58 | MEAN: [123.675, 116.28, 103.53] 59 | DO_PAD: True 60 | PADDING: 10 61 | CJ: 62 | ENABLED: False 63 | 64 | DATALOADER: 65 | PK_SAMPLER: True 66 | NAIVE_WAY: True 67 | NUM_INSTANCE: 4 68 | NUM_WORKERS: 8 69 | 70 | SOLVER: 71 | OPT: "Adam" 72 | MAX_ITER: 60 73 | BASE_LR: 0.00035 74 | BIAS_LR_FACTOR: 1. 75 | WEIGHT_DECAY: 0.0005 76 | WEIGHT_DECAY_BIAS: 0.0005 77 | IMS_PER_BATCH: 64 78 | 79 | SCHED: "WarmupCosineAnnealingLR" 80 | DELAY_ITERS: 30 81 | ETA_MIN_LR: 0.000001 82 | 83 | WARMUP_FACTOR: 0.01 84 | WARMUP_ITERS: 10 85 | 86 | CHECKPOINT_PERIOD: 30 87 | 88 | TEST: 89 | EVALUATOR: "PirtEvaluator" 90 | EVAL_PERIOD: 10 91 | IMS_PER_BATCH: 128 92 | METRIC: "cosine" 93 | 94 | CUDNN_BENCHMARK: True 95 | 96 | DATASETS: 97 | NAMES: ("OccludedDuke", ) 98 | TESTS: ("OccludedDuke", ) 99 | 100 | OUTPUT_DIR: "logs/test" 101 | -------------------------------------------------------------------------------- /configs/eval.yml: -------------------------------------------------------------------------------- 1 | _BASE_: "Pirt.yml" 2 | 3 | MODEL: 4 | WEIGHTS: "path/to/your/model/weights.pth" 5 | DEVICE: "cuda" 6 | 7 | DATALOADER: 8 | PK_SAMPLER: True 9 | NAIVE_WAY: True 10 | NUM_INSTANCE: 4 11 | NUM_WORKERS: 8 12 | 13 | CUDNN_BENCHMARK: True 14 | 15 | TEST: 16 | IMS_PER_BATCH: 64 17 | EVALUATOR: "PirtEvaluator" 18 | 19 | DATASETS: 20 | NAMES: ("OccludedDuke", ) 21 | TESTS: ("OccludedDuke", ) 22 | 23 | OUTPUT_DIR: "logs/eval" -------------------------------------------------------------------------------- /fastreid/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | 8 | __version__ = "0.2.0" 9 | -------------------------------------------------------------------------------- /fastreid/config/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: l1aoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .config import CfgNode, get_cfg 8 | from .defaults import _C as cfg 9 | -------------------------------------------------------------------------------- /fastreid/config/config.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: l1aoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import logging 8 | import os 9 | from typing import Any 10 | 11 | import yaml 12 | from yacs.config import CfgNode as _CfgNode 13 | 14 | from ..utils.file_io import PathManager 15 | 16 | BASE_KEY = "_BASE_" 17 | 18 | 19 | class CfgNode(_CfgNode): 20 | """ 21 | Our own extended version of :class:`yacs.config.CfgNode`. 22 | It contains the following extra features: 23 | 1. The :meth:`merge_from_file` method supports the "_BASE_" key, 24 | which allows the new CfgNode to inherit all the attributes from the 25 | base configuration file. 26 | 2. Keys that start with "COMPUTED_" are treated as insertion-only 27 | "computed" attributes. They can be inserted regardless of whether 28 | the CfgNode is frozen or not. 29 | 3. With "allow_unsafe=True", it supports pyyaml tags that evaluate 30 | expressions in config. See examples in 31 | https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types 32 | Note that this may lead to arbitrary code execution: you must not 33 | load a config file from untrusted sources before manually inspecting 34 | the content of the file. 35 | """ 36 | 37 | @staticmethod 38 | def load_yaml_with_base(filename: str, allow_unsafe: bool = False): 39 | """ 40 | Just like `yaml.load(open(filename))`, but inherit attributes from its 41 | `_BASE_`. 42 | Args: 43 | filename (str): the file name of the current config. Will be used to 44 | find the base config file. 45 | allow_unsafe (bool): whether to allow loading the config file with 46 | `yaml.unsafe_load`. 47 | Returns: 48 | (dict): the loaded yaml 49 | """ 50 | with PathManager.open(filename, "r") as f: 51 | try: 52 | cfg = yaml.safe_load(f) 53 | except yaml.constructor.ConstructorError: 54 | if not allow_unsafe: 55 | raise 56 | logger = logging.getLogger(__name__) 57 | logger.warning( 58 | "Loading config {} with yaml.unsafe_load. Your machine may " 59 | "be at risk if the file contains malicious content.".format( 60 | filename 61 | ) 62 | ) 63 | f.close() 64 | with open(filename, "r") as f: 65 | cfg = yaml.unsafe_load(f) 66 | 67 | def merge_a_into_b(a, b): 68 | # merge dict a into dict b. values in a will overwrite b. 69 | for k, v in a.items(): 70 | if isinstance(v, dict) and k in b: 71 | assert isinstance( 72 | b[k], dict 73 | ), "Cannot inherit key '{}' from base!".format(k) 74 | merge_a_into_b(v, b[k]) 75 | else: 76 | b[k] = v 77 | 78 | if BASE_KEY in cfg: 79 | base_cfg_file = cfg[BASE_KEY] 80 | if base_cfg_file.startswith("~"): 81 | base_cfg_file = os.path.expanduser(base_cfg_file) 82 | if not any( 83 | map(base_cfg_file.startswith, ["/", "https://", "http://"]) 84 | ): 85 | # the path to base cfg is relative to the config file itself. 86 | base_cfg_file = os.path.join( 87 | os.path.dirname(filename), base_cfg_file 88 | ) 89 | base_cfg = CfgNode.load_yaml_with_base( 90 | base_cfg_file, allow_unsafe=allow_unsafe 91 | ) 92 | del cfg[BASE_KEY] 93 | 94 | merge_a_into_b(cfg, base_cfg) 95 | return base_cfg 96 | return cfg 97 | 98 | def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = False): 99 | """ 100 | Merge configs from a given yaml file. 101 | Args: 102 | cfg_filename: the file name of the yaml config. 103 | allow_unsafe: whether to allow loading the config file with 104 | `yaml.unsafe_load`. 105 | """ 106 | loaded_cfg = CfgNode.load_yaml_with_base( 107 | cfg_filename, allow_unsafe=allow_unsafe 108 | ) 109 | loaded_cfg = type(self)(loaded_cfg) 110 | self.merge_from_other_cfg(loaded_cfg) 111 | 112 | # Forward the following calls to base, but with a check on the BASE_KEY. 113 | def merge_from_other_cfg(self, cfg_other): 114 | """ 115 | Args: 116 | cfg_other (CfgNode): configs to merge from. 117 | """ 118 | assert ( 119 | BASE_KEY not in cfg_other 120 | ), "The reserved key '{}' can only be used in files!".format(BASE_KEY) 121 | return super().merge_from_other_cfg(cfg_other) 122 | 123 | def merge_from_list(self, cfg_list: list): 124 | """ 125 | Args: 126 | cfg_list (list): list of configs to merge from. 127 | """ 128 | keys = set(cfg_list[0::2]) 129 | assert ( 130 | BASE_KEY not in keys 131 | ), "The reserved key '{}' can only be used in files!".format(BASE_KEY) 132 | return super().merge_from_list(cfg_list) 133 | 134 | def __setattr__(self, name: str, val: Any): 135 | if name.startswith("COMPUTED_"): 136 | if name in self: 137 | old_val = self[name] 138 | if old_val == val: 139 | return 140 | raise KeyError( 141 | "Computed attributed '{}' already exists " 142 | "with a different value! old={}, new={}.".format( 143 | name, old_val, val 144 | ) 145 | ) 146 | self[name] = val 147 | else: 148 | super().__setattr__(name, val) 149 | 150 | 151 | def get_cfg() -> CfgNode: 152 | """ 153 | Get a copy of the default config. 154 | Returns: 155 | a fastreid CfgNode instance. 156 | """ 157 | from .defaults import _C 158 | 159 | return _C.clone() 160 | -------------------------------------------------------------------------------- /fastreid/data/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .build import build_reid_train_loader, build_reid_test_loader 8 | -------------------------------------------------------------------------------- /fastreid/data/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: l1aoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import os 8 | import torch 9 | from torch._six import container_abcs, string_classes, int_classes 10 | from torch.utils.data import DataLoader 11 | from fastreid.utils import comm 12 | 13 | from . import samplers 14 | from .common import CommDataset 15 | from .datasets import DATASET_REGISTRY 16 | from .transforms import build_transforms 17 | 18 | _root = os.getenv("FASTREID_DATASETS", "datasets") 19 | 20 | 21 | def build_reid_train_loader(cfg): 22 | cfg = cfg.clone() 23 | cfg.defrost() 24 | 25 | train_items = list() 26 | for d in cfg.DATASETS.NAMES: 27 | dataset = DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL) 28 | if comm.is_main_process(): 29 | dataset.show_train() 30 | train_items.extend(dataset.train) 31 | 32 | iters_per_epoch = len(train_items) // cfg.SOLVER.IMS_PER_BATCH 33 | cfg.SOLVER.MAX_ITER *= iters_per_epoch 34 | train_transforms = build_transforms(cfg, is_train=True) 35 | train_set = CommDataset( 36 | train_items, 37 | train_transforms, 38 | relabel=True, 39 | total_iter=cfg.SOLVER.MAX_ITER 40 | ) 41 | 42 | num_workers = cfg.DATALOADER.NUM_WORKERS 43 | num_instance = cfg.DATALOADER.NUM_INSTANCE 44 | mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size() 45 | 46 | if cfg.DATALOADER.PK_SAMPLER: 47 | if cfg.DATALOADER.NAIVE_WAY: 48 | data_sampler = samplers.NaiveIdentitySampler(train_set.img_items, 49 | cfg.SOLVER.IMS_PER_BATCH, num_instance) 50 | else: 51 | data_sampler = samplers.BalancedIdentitySampler(train_set.img_items, 52 | cfg.SOLVER.IMS_PER_BATCH, num_instance) 53 | else: 54 | data_sampler = samplers.TrainingSampler(len(train_set)) 55 | batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size, True) 56 | 57 | train_loader = torch.utils.data.DataLoader( 58 | train_set, 59 | num_workers=num_workers, 60 | batch_sampler=batch_sampler, 61 | collate_fn=fast_batch_collator, 62 | pin_memory=True, 63 | ) 64 | return train_loader 65 | 66 | 67 | def build_reid_test_loader(cfg, dataset_name): 68 | cfg = cfg.clone() 69 | cfg.defrost() 70 | 71 | dataset = DATASET_REGISTRY.get(dataset_name)(root=_root) 72 | if comm.is_main_process(): 73 | dataset.show_test() 74 | test_items = dataset.query + dataset.gallery 75 | 76 | test_transforms = build_transforms(cfg, is_train=False) 77 | test_set = CommDataset(test_items, test_transforms, relabel=False) 78 | 79 | mini_batch_size = cfg.TEST.IMS_PER_BATCH // comm.get_world_size() 80 | data_sampler = samplers.InferenceSampler(len(test_set)) 81 | batch_sampler = torch.utils.data.BatchSampler(data_sampler, mini_batch_size, False) 82 | test_loader = DataLoader( 83 | test_set, 84 | batch_sampler=batch_sampler, 85 | num_workers=0, # save some memory 86 | collate_fn=fast_batch_collator, 87 | pin_memory=True, 88 | ) 89 | return test_loader, len(dataset.query) 90 | 91 | 92 | def trivial_batch_collator(batch): 93 | """ 94 | A batch collator that does nothing. 95 | """ 96 | return batch 97 | 98 | 99 | def fast_batch_collator(batched_inputs): 100 | """ 101 | A simple batch collator for most common reid tasks 102 | """ 103 | elem = batched_inputs[0] 104 | if isinstance(elem, torch.Tensor): 105 | out = torch.zeros((len(batched_inputs), *elem.size()), dtype=elem.dtype) 106 | for i, tensor in enumerate(batched_inputs): 107 | out[i] += tensor 108 | return out 109 | 110 | elif isinstance(elem, container_abcs.Mapping): 111 | return {key: fast_batch_collator([d[key] for d in batched_inputs]) for key in elem} 112 | 113 | elif isinstance(elem, float): 114 | return torch.tensor(batched_inputs, dtype=torch.float64) 115 | elif isinstance(elem, int_classes): 116 | return torch.tensor(batched_inputs) 117 | elif isinstance(elem, string_classes): 118 | return batched_inputs 119 | -------------------------------------------------------------------------------- /fastreid/data/common.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from torch.utils.data import Dataset 8 | import torch 9 | import torchvision.transforms as T 10 | from fastreid.data.transforms.transforms import * 11 | from fastreid.data.transforms.autoaugment import AutoAugment 12 | import random 13 | 14 | from .data_utils import read_image 15 | 16 | 17 | class CommDataset(Dataset): 18 | """Image Person ReID Dataset""" 19 | 20 | def __init__(self, img_items, transform=None, relabel=True, total_iter=None): 21 | self.img_items = img_items 22 | self.transform = transform 23 | self.relabel = relabel 24 | self.total_iter = total_iter 25 | 26 | pid_set = set() 27 | cam_set = set() 28 | for i in img_items: 29 | pid_set.add(i[1]) 30 | cam_set.add(i[2]) 31 | 32 | self.pids = sorted(list(pid_set)) 33 | self.cams = sorted(list(cam_set)) 34 | if relabel: 35 | self.pid_dict = dict([(p, i) for i, p in enumerate(self.pids)]) 36 | self.cam_dict = dict([(p, i) for i, p in enumerate(self.cams)]) 37 | 38 | def __len__(self): 39 | return len(self.img_items) 40 | 41 | def __getitem__(self, index): 42 | img_path, pid, camid = self.img_items[index] 43 | img = read_image(img_path) 44 | if self.transform is not None: img = self.transform(img) 45 | img = ToTensor()(img) 46 | if self.relabel: 47 | pid = self.pid_dict[pid] 48 | camid = self.cam_dict[camid] 49 | 50 | info = { 51 | "images": img, 52 | "targets": pid, 53 | "camids": camid, 54 | "img_paths": img_path, 55 | } 56 | 57 | return info 58 | 59 | 60 | @property 61 | def num_classes(self): 62 | return len(self.pids) 63 | 64 | @property 65 | def num_cameras(self): 66 | return len(self.cams) 67 | -------------------------------------------------------------------------------- /fastreid/data/data_utils.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | import numpy as np 7 | from PIL import Image, ImageOps 8 | 9 | from fastreid.utils.file_io import PathManager 10 | 11 | 12 | def read_image(file_name, format=None): 13 | """ 14 | Read an image into the given format. 15 | Will apply rotation and flipping if the image has such exif information. 16 | Args: 17 | file_name (str): image file path 18 | format (str): one of the supported image modes in PIL, or "BGR" 19 | Returns: 20 | image (np.ndarray): an HWC image 21 | """ 22 | with PathManager.open(file_name, "rb") as f: 23 | image = Image.open(f) 24 | 25 | # capture and ignore this bug: https://github.com/python-pillow/Pillow/issues/3973 26 | try: 27 | image = ImageOps.exif_transpose(image) 28 | except Exception: 29 | pass 30 | 31 | if format is not None: 32 | # PIL only supports RGB, so convert to RGB and flip channels over below 33 | conversion_format = format 34 | if format == "BGR": 35 | conversion_format = "RGB" 36 | image = image.convert(conversion_format) 37 | image = np.asarray(image) 38 | if format == "BGR": 39 | # flip channels if needed 40 | image = image[:, :, ::-1] 41 | # PIL squeezes out the channel dimension for "L", so make it HWC 42 | if format == "L": 43 | image = np.expand_dims(image, -1) 44 | image = Image.fromarray(image) 45 | return image 46 | -------------------------------------------------------------------------------- /fastreid/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from ...utils.registry import Registry 8 | 9 | DATASET_REGISTRY = Registry("DATASET") 10 | DATASET_REGISTRY.__doc__ = """ 11 | Registry for datasets 12 | It must returns an instance of :class:`Backbone`. 13 | """ 14 | 15 | # Person re-id datasets 16 | from .cuhk03 import CUHK03 17 | from .dukemtmcreid import DukeMTMC 18 | from .market1501 import Market1501 19 | from .msmt17 import MSMT17 20 | from .occluded_duke import * 21 | 22 | __all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")] 23 | -------------------------------------------------------------------------------- /fastreid/data/datasets/bases.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import copy 8 | import logging 9 | import os 10 | from tabulate import tabulate 11 | from termcolor import colored 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class Dataset(object): 17 | """An abstract class representing a Dataset. 18 | This is the base class for ``ImageDataset`` and ``VideoDataset``. 19 | Args: 20 | train (list): contains tuples of (img_path(s), pid, camid). 21 | query (list): contains tuples of (img_path(s), pid, camid). 22 | gallery (list): contains tuples of (img_path(s), pid, camid). 23 | transform: transform function. 24 | mode (str): 'train', 'query' or 'gallery'. 25 | combineall (bool): combines train, query and gallery in a 26 | dataset for training. 27 | verbose (bool): show information. 28 | """ 29 | _junk_pids = [] # contains useless person IDs, e.g. background, false detections 30 | 31 | def __init__(self, train, query, gallery, transform=None, mode='train', 32 | combineall=False, verbose=True, **kwargs): 33 | self.train = train 34 | self.query = query 35 | self.gallery = gallery 36 | self.transform = transform 37 | self.mode = mode 38 | self.combineall = combineall 39 | self.verbose = verbose 40 | 41 | self.num_train_pids = self.get_num_pids(self.train) 42 | self.num_train_cams = self.get_num_cams(self.train) 43 | 44 | if self.combineall: 45 | self.combine_all() 46 | 47 | if self.mode == 'train': 48 | self.data = self.train 49 | elif self.mode == 'query': 50 | self.data = self.query 51 | elif self.mode == 'gallery': 52 | self.data = self.gallery 53 | else: 54 | raise ValueError('Invalid mode. Got {}, but expected to be ' 55 | 'one of [train | query | gallery]'.format(self.mode)) 56 | 57 | def __getitem__(self, index): 58 | raise NotImplementedError 59 | 60 | def __len__(self): 61 | return len(self.data) 62 | 63 | def __radd__(self, other): 64 | """Supports sum([dataset1, dataset2, dataset3]).""" 65 | if other == 0: 66 | return self 67 | else: 68 | return self.__add__(other) 69 | 70 | def parse_data(self, data): 71 | """Parses data list and returns the number of person IDs 72 | and the number of camera views. 73 | Args: 74 | data (list): contains tuples of (img_path(s), pid, camid) 75 | """ 76 | pids = set() 77 | cams = set() 78 | for _, pid, camid in data: 79 | pids.add(pid) 80 | cams.add(camid) 81 | return len(pids), len(cams) 82 | 83 | def get_num_pids(self, data): 84 | """Returns the number of training person identities.""" 85 | return self.parse_data(data)[0] 86 | 87 | def get_num_cams(self, data): 88 | """Returns the number of training cameras.""" 89 | return self.parse_data(data)[1] 90 | 91 | def show_summary(self): 92 | """Shows dataset statistics.""" 93 | pass 94 | 95 | def combine_all(self): 96 | """Combines train, query and gallery in a dataset for training.""" 97 | combined = copy.deepcopy(self.train) 98 | 99 | def _combine_data(data): 100 | for img_path, pid, camid in data: 101 | if pid in self._junk_pids: 102 | continue 103 | pid = self.dataset_name + "_" + str(pid) 104 | camid = self.dataset_name + "_" + str(camid) 105 | combined.append((img_path, pid, camid)) 106 | 107 | _combine_data(self.query) 108 | _combine_data(self.gallery) 109 | 110 | self.train = combined 111 | self.num_train_pids = self.get_num_pids(self.train) 112 | 113 | def check_before_run(self, required_files): 114 | """Checks if required files exist before going deeper. 115 | Args: 116 | required_files (str or list): string file name(s). 117 | """ 118 | if isinstance(required_files, str): 119 | required_files = [required_files] 120 | 121 | for fpath in required_files: 122 | if not os.path.exists(fpath): 123 | raise RuntimeError('"{}" is not found'.format(fpath)) 124 | 125 | 126 | class ImageDataset(Dataset): 127 | """A base class representing ImageDataset. 128 | All other image datasets should subclass it. 129 | ``__getitem__`` returns an image given index. 130 | It will return ``img``, ``pid``, ``camid`` and ``img_path`` 131 | where ``img`` has shape (channel, height, width). As a result, 132 | data in each batch has shape (batch_size, channel, height, width). 133 | """ 134 | 135 | def __init__(self, train, query, gallery, **kwargs): 136 | super(ImageDataset, self).__init__(train, query, gallery, **kwargs) 137 | 138 | def show_train(self): 139 | num_train_pids, num_train_cams = self.parse_data(self.train) 140 | 141 | headers = ['subset', '# ids', '# images', '# cameras'] 142 | csv_results = [['train', num_train_pids, len(self.train), num_train_cams]] 143 | 144 | # tabulate it 145 | table = tabulate( 146 | csv_results, 147 | tablefmt="pipe", 148 | headers=headers, 149 | numalign="left", 150 | ) 151 | logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan")) 152 | 153 | def show_test(self): 154 | num_query_pids, num_query_cams = self.parse_data(self.query) 155 | num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery) 156 | 157 | headers = ['subset', '# ids', '# images', '# cameras'] 158 | csv_results = [ 159 | ['query', num_query_pids, len(self.query), num_query_cams], 160 | ['gallery', num_gallery_pids, len(self.gallery), num_gallery_cams], 161 | ] 162 | 163 | # tabulate it 164 | table = tabulate( 165 | csv_results, 166 | tablefmt="pipe", 167 | headers=headers, 168 | numalign="left", 169 | ) 170 | logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan")) 171 | -------------------------------------------------------------------------------- /fastreid/data/datasets/dukemtmcreid.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import glob 8 | import os.path as osp 9 | import re 10 | 11 | from .bases import ImageDataset 12 | from ..datasets import DATASET_REGISTRY 13 | 14 | 15 | @DATASET_REGISTRY.register() 16 | class DukeMTMC(ImageDataset): 17 | """DukeMTMC-reID. 18 | 19 | Reference: 20 | - Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 21 | - Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 22 | 23 | URL: ``_ 24 | 25 | Dataset statistics: 26 | - identities: 1404 (train + query). 27 | - images:16522 (train) + 2228 (query) + 17661 (gallery). 28 | - cameras: 8. 29 | """ 30 | dataset_dir = 'DukeMTMC-reID' 31 | dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 32 | dataset_name = "dukemtmc" 33 | 34 | def __init__(self, root='datasets', **kwargs): 35 | # self.root = osp.abspath(osp.expanduser(root)) 36 | self.root = root 37 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 38 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 39 | self.query_dir = osp.join(self.dataset_dir, 'query') 40 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 41 | 42 | required_files = [ 43 | self.dataset_dir, 44 | self.train_dir, 45 | self.query_dir, 46 | self.gallery_dir, 47 | ] 48 | self.check_before_run(required_files) 49 | 50 | train = self.process_dir(self.train_dir) 51 | query = self.process_dir(self.query_dir, is_train=False) 52 | gallery = self.process_dir(self.gallery_dir, is_train=False) 53 | 54 | super(DukeMTMC, self).__init__(train, query, gallery, **kwargs) 55 | 56 | def process_dir(self, dir_path, is_train=True): 57 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 58 | pattern = re.compile(r'([-\d]+)_c(\d)') 59 | 60 | data = [] 61 | for img_path in img_paths: 62 | pid, camid = map(int, pattern.search(img_path).groups()) 63 | assert 1 <= camid <= 8 64 | camid -= 1 # index starts from 0 65 | if is_train: 66 | pid = self.dataset_name + "_" + str(pid) 67 | camid = self.dataset_name + "_" + str(camid) 68 | data.append((img_path, pid, camid)) 69 | 70 | return data 71 | -------------------------------------------------------------------------------- /fastreid/data/datasets/market1501.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import glob 8 | import os.path as osp 9 | import re 10 | import warnings 11 | 12 | from .bases import ImageDataset 13 | from ..datasets import DATASET_REGISTRY 14 | 15 | 16 | @DATASET_REGISTRY.register() 17 | class Market1501(ImageDataset): 18 | """Market1501. 19 | 20 | Reference: 21 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 22 | 23 | URL: ``_ 24 | 25 | Dataset statistics: 26 | - identities: 1501 (+1 for background). 27 | - images: 12936 (train) + 3368 (query) + 15913 (gallery). 28 | """ 29 | _junk_pids = [0, -1] 30 | dataset_dir = '' 31 | dataset_url = 'http://188.138.127.15:81/Datasets/Market-1501-v15.09.15.zip' 32 | dataset_name = "market1501" 33 | 34 | def __init__(self, root='datasets', market1501_500k=False, **kwargs): 35 | # self.root = osp.abspath(osp.expanduser(root)) 36 | self.root = root 37 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 38 | 39 | # allow alternative directory structure 40 | self.data_dir = self.dataset_dir 41 | data_dir = osp.join(self.data_dir, 'Market-1501-v15.09.15') 42 | if osp.isdir(data_dir): 43 | self.data_dir = data_dir 44 | else: 45 | warnings.warn('The current data structure is deprecated. Please ' 46 | 'put data folders such as "bounding_box_train" under ' 47 | '"Market-1501-v15.09.15".') 48 | 49 | self.train_dir = osp.join(self.data_dir, 'bounding_box_train') 50 | self.query_dir = osp.join(self.data_dir, 'query') 51 | self.gallery_dir = osp.join(self.data_dir, 'bounding_box_test') 52 | self.extra_gallery_dir = osp.join(self.data_dir, 'images') 53 | self.market1501_500k = market1501_500k 54 | 55 | required_files = [ 56 | self.data_dir, 57 | self.train_dir, 58 | self.query_dir, 59 | self.gallery_dir, 60 | ] 61 | if self.market1501_500k: 62 | required_files.append(self.extra_gallery_dir) 63 | self.check_before_run(required_files) 64 | 65 | train = self.process_dir(self.train_dir) 66 | query = self.process_dir(self.query_dir, is_train=False) 67 | gallery = self.process_dir(self.gallery_dir, is_train=False) 68 | if self.market1501_500k: 69 | gallery += self.process_dir(self.extra_gallery_dir, is_train=False) 70 | 71 | super(Market1501, self).__init__(train, query, gallery, **kwargs) 72 | 73 | def process_dir(self, dir_path, is_train=True): 74 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 75 | pattern = re.compile(r'([-\d]+)_c(\d)') 76 | 77 | data = [] 78 | for img_path in img_paths: 79 | pid, camid = map(int, pattern.search(img_path).groups()) 80 | if pid == -1: 81 | continue # junk images are just ignored 82 | assert 0 <= pid <= 1501 # pid == 0 means background 83 | assert 1 <= camid <= 6 84 | camid -= 1 # index starts from 0 85 | if is_train: 86 | pid = self.dataset_name + "_" + str(pid) 87 | camid = self.dataset_name + "_" + str(camid) 88 | data.append((img_path, pid, camid)) 89 | 90 | return data 91 | -------------------------------------------------------------------------------- /fastreid/data/datasets/msmt17.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: l1aoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import sys 8 | import os 9 | import glob 10 | import re 11 | import os.path as osp 12 | 13 | from .bases import ImageDataset 14 | from ..datasets import DATASET_REGISTRY 15 | ##### Log ##### 16 | 17 | 18 | @DATASET_REGISTRY.register() 19 | class MSMT17(ImageDataset): 20 | """Market1501. 21 | 22 | Reference: 23 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 24 | 25 | URL: ``_ 26 | 27 | Dataset statistics: 28 | - identities: 1501 (+1 for background). 29 | - images: 12936 (train) + 3368 (query) + 15913 (gallery). 30 | """ 31 | dataset_dir = '' 32 | dataset_name = "msmt17" 33 | 34 | def __init__(self, root='datasets', market1501_500k=False, **kwargs): 35 | # self.root = osp.abspath(osp.expanduser(root)) 36 | self.root = root 37 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 38 | 39 | # allow alternative directory structure 40 | self.data_dir = self.dataset_dir 41 | data_dir = osp.join(self.data_dir, 'MSMT17') 42 | if osp.isdir(data_dir): 43 | self.data_dir = data_dir 44 | 45 | self.train_dir = osp.join(self.data_dir, 'bounding_box_train') 46 | self.query_dir = osp.join(self.data_dir, 'query') 47 | self.gallery_dir = osp.join(self.data_dir, 'bounding_box_test') 48 | required_files = [ 49 | self.data_dir, 50 | self.train_dir, 51 | self.query_dir, 52 | self.gallery_dir, 53 | ] 54 | self.check_before_run(required_files) 55 | 56 | train = self.process_dir(self.train_dir) 57 | query = self.process_dir(self.query_dir, is_train=False) 58 | gallery = self.process_dir(self.gallery_dir, is_train=False) 59 | 60 | super(MSMT17, self).__init__(train, query, gallery, **kwargs) 61 | 62 | def process_dir(self, dir_path, is_train=True): 63 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 64 | pattern = re.compile(r'([\d]+)_c(\d)') 65 | 66 | data = [] 67 | for img_path in img_paths: 68 | pid, camid = map(int, pattern.search(img_path).groups()) 69 | if pid == -1: 70 | continue # junk images are just ignored 71 | camid -= 1 # index starts from 0 72 | if is_train: 73 | pid = self.dataset_name + "_" + str(pid) 74 | camid = self.dataset_name + "_" + str(camid) 75 | data.append((img_path, pid, camid)) 76 | 77 | return data 78 | -------------------------------------------------------------------------------- /fastreid/data/datasets/occluded_duke.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import glob 8 | import os.path as osp 9 | import re 10 | 11 | from .bases import ImageDataset 12 | from ..datasets import DATASET_REGISTRY 13 | 14 | 15 | @DATASET_REGISTRY.register() 16 | class OccludedDuke(ImageDataset): 17 | """ 18 | Occluded Duke 19 | """ 20 | dataset_dir = 'OccludedDuke' 21 | dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 22 | dataset_name = "OccludedDuke" 23 | 24 | def __init__(self, root='datasets', **kwargs): 25 | # self.root = osp.abspath(osp.expanduser(root)) 26 | self.root = root 27 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 28 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 29 | self.query_dir = osp.join(self.dataset_dir, 'query') 30 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 31 | 32 | required_files = [ 33 | self.dataset_dir, 34 | self.train_dir, 35 | self.query_dir, 36 | self.gallery_dir, 37 | ] 38 | self.check_before_run(required_files) 39 | 40 | train = self.process_dir(self.train_dir) 41 | query = self.process_dir(self.query_dir, is_train=False) 42 | gallery = self.process_dir(self.gallery_dir, is_train=False) 43 | 44 | super(OccludedDuke, self).__init__(train, query, gallery, **kwargs) 45 | 46 | def process_dir(self, dir_path, is_train=True): 47 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 48 | pattern = re.compile(r'([-\d]+)_c(\d)') 49 | 50 | data = [] 51 | for img_path in img_paths: 52 | pid, camid = map(int, pattern.search(img_path).groups()) 53 | assert 1 <= camid <= 8 54 | camid -= 1 # index starts from 0 55 | if is_train: 56 | pid = self.dataset_name + "_" + str(pid) 57 | camid = self.dataset_name + "_" + str(camid) 58 | data.append((img_path, pid, camid)) 59 | 60 | return data 61 | -------------------------------------------------------------------------------- /fastreid/data/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .triplet_sampler import BalancedIdentitySampler, NaiveIdentitySampler 8 | from .data_sampler import TrainingSampler, InferenceSampler 9 | -------------------------------------------------------------------------------- /fastreid/data/samplers/data_sampler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: l1aoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | import itertools 7 | from typing import Optional 8 | 9 | import numpy as np 10 | from torch.utils.data import Sampler 11 | 12 | from fastreid.utils import comm 13 | 14 | 15 | class TrainingSampler(Sampler): 16 | """ 17 | In training, we only care about the "infinite stream" of training data. 18 | So this sampler produces an infinite stream of indices and 19 | all workers cooperate to correctly shuffle the indices and sample different indices. 20 | The samplers in each worker effectively produces `indices[worker_id::num_workers]` 21 | where `indices` is an infinite stream of indices consisting of 22 | `shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True) 23 | or `range(size) + range(size) + ...` (if shuffle is False) 24 | """ 25 | 26 | def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = None): 27 | """ 28 | Args: 29 | size (int): the total number of data of the underlying dataset to sample from 30 | shuffle (bool): whether to shuffle the indices or not 31 | seed (int): the initial seed of the shuffle. Must be the same 32 | across all workers. If None, will use a random seed shared 33 | among workers (require synchronization among all workers). 34 | """ 35 | self._size = size 36 | assert size > 0 37 | self._shuffle = shuffle 38 | if seed is None: 39 | seed = comm.shared_random_seed() 40 | self._seed = int(seed) 41 | 42 | self._rank = comm.get_rank() 43 | self._world_size = comm.get_world_size() 44 | 45 | def __iter__(self): 46 | start = self._rank 47 | yield from itertools.islice(self._infinite_indices(), start, None, self._world_size) 48 | 49 | def _infinite_indices(self): 50 | np.random.seed(self._seed) 51 | while True: 52 | if self._shuffle: 53 | yield from np.random.permutation(self._size) 54 | else: 55 | yield from np.arange(self._size) 56 | 57 | 58 | class InferenceSampler(Sampler): 59 | """ 60 | Produce indices for inference. 61 | Inference needs to run on the __exact__ set of samples, 62 | therefore when the total number of samples is not divisible by the number of workers, 63 | this sampler produces different number of samples on different workers. 64 | """ 65 | 66 | def __init__(self, size: int): 67 | """ 68 | Args: 69 | size (int): the total number of data of the underlying dataset to sample from 70 | """ 71 | self._size = size 72 | assert size > 0 73 | self._rank = comm.get_rank() 74 | self._world_size = comm.get_world_size() 75 | 76 | shard_size = (self._size - 1) // self._world_size + 1 77 | begin = shard_size * self._rank 78 | end = min(shard_size * (self._rank + 1), self._size) 79 | self._local_indices = range(begin, end) 80 | 81 | def __iter__(self): 82 | yield from self._local_indices 83 | 84 | def __len__(self): 85 | return len(self._local_indices) 86 | -------------------------------------------------------------------------------- /fastreid/data/samplers/triplet_sampler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import copy 8 | import itertools 9 | from collections import defaultdict 10 | from typing import Optional 11 | 12 | import numpy as np 13 | from torch.utils.data.sampler import Sampler 14 | 15 | from fastreid.utils import comm 16 | 17 | 18 | def no_index(a, b): 19 | assert isinstance(a, list) 20 | return [i for i, j in enumerate(a) if j != b] 21 | 22 | 23 | class BalancedIdentitySampler(Sampler): 24 | def __init__(self, data_source: str, batch_size: int, num_instances: int, seed: Optional[int] = None): 25 | self.data_source = data_source 26 | self.batch_size = batch_size 27 | self.num_instances = num_instances 28 | self.num_pids_per_batch = batch_size // self.num_instances 29 | 30 | self.index_pid = defaultdict(list) 31 | self.pid_cam = defaultdict(list) 32 | self.pid_index = defaultdict(list) 33 | 34 | for index, info in enumerate(data_source): 35 | pid = info[1] 36 | camid = info[2] 37 | self.index_pid[index] = pid 38 | self.pid_cam[pid].append(camid) 39 | self.pid_index[pid].append(index) 40 | 41 | self.pids = sorted(list(self.pid_index.keys())) 42 | self.num_identities = len(self.pids) 43 | 44 | if seed is None: 45 | seed = comm.shared_random_seed() 46 | self._seed = int(seed) 47 | 48 | self._rank = comm.get_rank() 49 | self._world_size = comm.get_world_size() 50 | 51 | def __iter__(self): 52 | start = self._rank 53 | yield from itertools.islice(self._infinite_indices(), start, None, self._world_size) 54 | 55 | def _infinite_indices(self): 56 | np.random.seed(self._seed) 57 | while True: 58 | # Shuffle identity list 59 | identities = np.random.permutation(self.num_identities) 60 | 61 | # If remaining identities cannot be enough for a batch, 62 | # just drop the remaining parts 63 | drop_indices = self.num_identities % self.num_pids_per_batch 64 | if drop_indices: identities = identities[:-drop_indices] 65 | 66 | ret = [] 67 | for kid in identities: 68 | i = np.random.choice(self.pid_index[self.pids[kid]]) 69 | _, i_pid, i_cam = self.data_source[i] 70 | ret.append(i) 71 | pid_i = self.index_pid[i] 72 | cams = self.pid_cam[pid_i] 73 | index = self.pid_index[pid_i] 74 | select_cams = no_index(cams, i_cam) 75 | 76 | if select_cams: 77 | if len(select_cams) >= self.num_instances: 78 | cam_indexes = np.random.choice(select_cams, size=self.num_instances - 1, replace=False) 79 | else: 80 | cam_indexes = np.random.choice(select_cams, size=self.num_instances - 1, replace=True) 81 | for kk in cam_indexes: 82 | ret.append(index[kk]) 83 | else: 84 | select_indexes = no_index(index, i) 85 | if not select_indexes: 86 | # Only one image for this identity 87 | ind_indexes = [0] * (self.num_instances - 1) 88 | elif len(select_indexes) >= self.num_instances: 89 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=False) 90 | else: 91 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=True) 92 | 93 | for kk in ind_indexes: 94 | ret.append(index[kk]) 95 | 96 | if len(ret) == self.batch_size: 97 | yield from ret 98 | ret = [] 99 | 100 | 101 | class NaiveIdentitySampler(Sampler): 102 | """ 103 | Randomly sample N identities, then for each identity, 104 | randomly sample K instances, therefore batch size is N*K. 105 | Args: 106 | - data_source (list): list of (img_path, pid, camid). 107 | - num_instances (int): number of instances per identity in a batch. 108 | - batch_size (int): number of examples in a batch. 109 | """ 110 | 111 | def __init__(self, data_source: str, batch_size: int, num_instances: int, seed: Optional[int] = None): 112 | self.data_source = data_source 113 | self.batch_size = batch_size 114 | self.num_instances = num_instances 115 | self.num_pids_per_batch = batch_size // self.num_instances 116 | 117 | self.index_pid = defaultdict(list) 118 | self.pid_cam = defaultdict(list) 119 | self.pid_index = defaultdict(list) 120 | 121 | for index, info in enumerate(data_source): 122 | pid = info[1] 123 | camid = info[2] 124 | self.index_pid[index] = pid 125 | self.pid_cam[pid].append(camid) 126 | self.pid_index[pid].append(index) 127 | 128 | self.pids = sorted(list(self.pid_index.keys())) 129 | self.num_identities = len(self.pids) 130 | 131 | if seed is None: 132 | seed = comm.shared_random_seed() 133 | self._seed = int(seed) 134 | 135 | self._rank = comm.get_rank() 136 | self._world_size = comm.get_world_size() 137 | 138 | def __iter__(self): 139 | start = self._rank 140 | yield from itertools.islice(self._infinite_indices(), start, None, self._world_size) 141 | 142 | def _infinite_indices(self): 143 | np.random.seed(self._seed) 144 | while True: 145 | avai_pids = copy.deepcopy(self.pids) 146 | batch_idxs_dict = {} 147 | 148 | batch_indices = [] 149 | while len(avai_pids) >= self.num_pids_per_batch: 150 | selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False).tolist() 151 | for pid in selected_pids: 152 | # Register pid in batch_idxs_dict if not 153 | if pid not in batch_idxs_dict: 154 | idxs = copy.deepcopy(self.pid_index[pid]) 155 | if len(idxs) < self.num_instances: 156 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True).tolist() 157 | np.random.shuffle(idxs) 158 | batch_idxs_dict[pid] = idxs 159 | 160 | avai_idxs = batch_idxs_dict[pid] 161 | for _ in range(self.num_instances): 162 | batch_indices.append(avai_idxs.pop(0)) 163 | 164 | if len(avai_idxs) < self.num_instances: avai_pids.remove(pid) 165 | 166 | assert len(batch_indices) == self.batch_size, f"batch indices have wrong " \ 167 | f"length with {len(batch_indices)}!" 168 | yield from batch_indices 169 | batch_indices = [] 170 | -------------------------------------------------------------------------------- /fastreid/data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | 8 | from .build import build_transforms 9 | from .transforms import * 10 | from .autoaugment import * 11 | -------------------------------------------------------------------------------- /fastreid/data/transforms/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torchvision.transforms as T 8 | 9 | from .transforms import * 10 | from .autoaugment import AutoAugment 11 | 12 | 13 | def build_transforms(cfg, is_train=True): 14 | res = [] 15 | 16 | if is_train: 17 | size_train = cfg.INPUT.SIZE_TRAIN 18 | 19 | # auto augmentation 20 | do_autoaug = cfg.INPUT.DO_AUTOAUG 21 | total_iter = cfg.SOLVER.MAX_ITER 22 | 23 | # horizontal filp 24 | do_flip = cfg.INPUT.DO_FLIP 25 | flip_prob = cfg.INPUT.FLIP_PROB 26 | 27 | # padding 28 | do_pad = cfg.INPUT.DO_PAD 29 | padding = cfg.INPUT.PADDING 30 | padding_mode = cfg.INPUT.PADDING_MODE 31 | 32 | # color jitter 33 | do_cj = cfg.INPUT.CJ.ENABLED 34 | cj_prob = cfg.INPUT.CJ.PROB 35 | cj_brightness = cfg.INPUT.CJ.BRIGHTNESS 36 | cj_contrast = cfg.INPUT.CJ.CONTRAST 37 | cj_saturation = cfg.INPUT.CJ.SATURATION 38 | cj_hue = cfg.INPUT.CJ.HUE 39 | 40 | # random erasing 41 | do_rea = cfg.INPUT.REA.ENABLED 42 | rea_prob = cfg.INPUT.REA.PROB 43 | rea_mean = cfg.INPUT.REA.MEAN 44 | 45 | 46 | if do_autoaug: 47 | res.append(AutoAugment(total_iter)) 48 | res.append(T.Resize(size_train, interpolation=3)) 49 | if do_flip: 50 | res.append(T.RandomHorizontalFlip(p=flip_prob)) 51 | if do_pad: 52 | res.extend([T.Pad(padding, padding_mode=padding_mode), 53 | T.RandomCrop(size_train)]) 54 | if do_cj: 55 | res.append(T.RandomApply([T.ColorJitter(cj_brightness, cj_contrast, cj_saturation, cj_hue)], p=cj_prob)) 56 | if do_rea: 57 | res.append(RandomErasing(probability=rea_prob, mean=rea_mean)) 58 | 59 | else: 60 | size_test = cfg.INPUT.SIZE_TEST 61 | 62 | res.append(T.Resize(size_test, interpolation=3)) 63 | return T.Compose(res) 64 | -------------------------------------------------------------------------------- /fastreid/data/transforms/functional.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import numpy as np 8 | import torch 9 | from PIL import Image, ImageOps, ImageEnhance 10 | 11 | 12 | def to_tensor(pic): 13 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 14 | 15 | See ``ToTensor`` for more details. 16 | 17 | Args: 18 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 19 | 20 | Returns: 21 | Tensor: Converted image. 22 | """ 23 | if isinstance(pic, np.ndarray): 24 | assert len(pic.shape) in (2, 3) 25 | # handle numpy array 26 | if pic.ndim == 2: 27 | pic = pic[:, :, None] 28 | 29 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 30 | # backward compatibility 31 | if isinstance(img, torch.ByteTensor): 32 | return img.float() 33 | else: 34 | return img 35 | 36 | # handle PIL Image 37 | if pic.mode == 'I': 38 | img = torch.from_numpy(np.array(pic, np.int32, copy=False)) 39 | elif pic.mode == 'I;16': 40 | img = torch.from_numpy(np.array(pic, np.int16, copy=False)) 41 | elif pic.mode == 'F': 42 | img = torch.from_numpy(np.array(pic, np.float32, copy=False)) 43 | elif pic.mode == '1': 44 | img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False)) 45 | else: 46 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 47 | # PIL image mode: L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK 48 | if pic.mode == 'YCbCr': 49 | nchannel = 3 50 | elif pic.mode == 'I;16': 51 | nchannel = 1 52 | else: 53 | nchannel = len(pic.mode) 54 | img = img.view(pic.size[1], pic.size[0], nchannel) 55 | # put it from HWC to CHW format 56 | # yikes, this transpose takes 80% of the loading time/CPU 57 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 58 | if isinstance(img, torch.ByteTensor): 59 | return img.float() 60 | else: 61 | return img 62 | 63 | 64 | def int_parameter(level, maxval): 65 | """Helper function to scale `val` between 0 and maxval . 66 | Args: 67 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 68 | maxval: Maximum value that the operation can have. This will be scaled to 69 | level/PARAMETER_MAX. 70 | Returns: 71 | An int that results from scaling `maxval` according to `level`. 72 | """ 73 | return int(level * maxval / 10) 74 | 75 | 76 | def float_parameter(level, maxval): 77 | """Helper function to scale `val` between 0 and maxval. 78 | Args: 79 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 80 | maxval: Maximum value that the operation can have. This will be scaled to 81 | level/PARAMETER_MAX. 82 | Returns: 83 | A float that results from scaling `maxval` according to `level`. 84 | """ 85 | return float(level) * maxval / 10. 86 | 87 | 88 | def sample_level(n): 89 | return np.random.uniform(low=0.1, high=n) 90 | 91 | 92 | def autocontrast(pil_img, *args): 93 | return ImageOps.autocontrast(pil_img) 94 | 95 | 96 | def equalize(pil_img, *args): 97 | return ImageOps.equalize(pil_img) 98 | 99 | 100 | def posterize(pil_img, level, *args): 101 | level = int_parameter(sample_level(level), 4) 102 | return ImageOps.posterize(pil_img, 4 - level) 103 | 104 | 105 | def rotate(pil_img, level, *args): 106 | degrees = int_parameter(sample_level(level), 30) 107 | if np.random.uniform() > 0.5: 108 | degrees = -degrees 109 | return pil_img.rotate(degrees, resample=Image.BILINEAR) 110 | 111 | 112 | def solarize(pil_img, level, *args): 113 | level = int_parameter(sample_level(level), 256) 114 | return ImageOps.solarize(pil_img, 256 - level) 115 | 116 | 117 | def shear_x(pil_img, level, image_size): 118 | level = float_parameter(sample_level(level), 0.3) 119 | if np.random.uniform() > 0.5: 120 | level = -level 121 | return pil_img.transform(image_size, 122 | Image.AFFINE, (1, level, 0, 0, 1, 0), 123 | resample=Image.BILINEAR) 124 | 125 | 126 | def shear_y(pil_img, level, image_size): 127 | level = float_parameter(sample_level(level), 0.3) 128 | if np.random.uniform() > 0.5: 129 | level = -level 130 | return pil_img.transform(image_size, 131 | Image.AFFINE, (1, 0, 0, level, 1, 0), 132 | resample=Image.BILINEAR) 133 | 134 | 135 | def translate_x(pil_img, level, image_size): 136 | level = int_parameter(sample_level(level), image_size[0] / 3) 137 | if np.random.random() > 0.5: 138 | level = -level 139 | return pil_img.transform(image_size, 140 | Image.AFFINE, (1, 0, level, 0, 1, 0), 141 | resample=Image.BILINEAR) 142 | 143 | 144 | def translate_y(pil_img, level, image_size): 145 | level = int_parameter(sample_level(level), image_size[1] / 3) 146 | if np.random.random() > 0.5: 147 | level = -level 148 | return pil_img.transform(image_size, 149 | Image.AFFINE, (1, 0, 0, 0, 1, level), 150 | resample=Image.BILINEAR) 151 | 152 | 153 | # operation that overlaps with ImageNet-C's test set 154 | def color(pil_img, level, *args): 155 | level = float_parameter(sample_level(level), 1.8) + 0.1 156 | return ImageEnhance.Color(pil_img).enhance(level) 157 | 158 | 159 | # operation that overlaps with ImageNet-C's test set 160 | def contrast(pil_img, level, *args): 161 | level = float_parameter(sample_level(level), 1.8) + 0.1 162 | return ImageEnhance.Contrast(pil_img).enhance(level) 163 | 164 | 165 | # operation that overlaps with ImageNet-C's test set 166 | def brightness(pil_img, level, *args): 167 | level = float_parameter(sample_level(level), 1.8) + 0.1 168 | return ImageEnhance.Brightness(pil_img).enhance(level) 169 | 170 | 171 | # operation that overlaps with ImageNet-C's test set 172 | def sharpness(pil_img, level, *args): 173 | level = float_parameter(sample_level(level), 1.8) + 0.1 174 | return ImageEnhance.Sharpness(pil_img).enhance(level) 175 | 176 | 177 | augmentations_reid = [ 178 | autocontrast, equalize, posterize, shear_x, shear_y, 179 | color, contrast, brightness, sharpness 180 | ] 181 | 182 | augmentations = [ 183 | autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, 184 | translate_x, translate_y 185 | ] 186 | 187 | augmentations_all = [ 188 | autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, 189 | translate_x, translate_y, color, contrast, brightness, sharpness 190 | ] 191 | -------------------------------------------------------------------------------- /fastreid/data/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | __all__ = ['ToTensor', 'RandomErasing',] 8 | 9 | import math 10 | import random 11 | from collections import deque 12 | 13 | import torch 14 | import numpy as np 15 | from PIL import Image 16 | import torchvision.transforms as T 17 | 18 | from .functional import to_tensor, augmentations_reid 19 | 20 | 21 | class ToTensor(object): 22 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 23 | 24 | Converts a PIL Image or numpy.ndarray (H x W x C) in the range 25 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 255.0] 26 | if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) 27 | or if the numpy.ndarray has dtype = np.uint8 28 | 29 | In the other cases, tensors are returned without scaling. 30 | """ 31 | 32 | def __call__(self, pic): 33 | """ 34 | Args: 35 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 36 | 37 | Returns: 38 | Tensor: Converted image. 39 | """ 40 | return to_tensor(pic) 41 | 42 | def __repr__(self): 43 | return self.__class__.__name__ + '()' 44 | 45 | 46 | class RandomErasing(object): 47 | """ Randomly selects a rectangle region in an image and erases its pixels. 48 | 'Random Erasing Data Augmentation' by Zhong et al. 49 | See https://arxiv.org/pdf/1708.04896.pdf 50 | Args: 51 | probability: The probability that the Random Erasing operation will be performed. 52 | sl: Minimum proportion of erased area against input image. 53 | sh: Maximum proportion of erased area against input image. 54 | r1: Minimum aspect ratio of erased area. 55 | mean: Erasing value. 56 | """ 57 | 58 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=255 * (0.49735, 0.4822, 0.4465)): 59 | self.probability = probability 60 | self.mean = mean 61 | self.sl = sl 62 | self.sh = sh 63 | self.r1 = r1 64 | 65 | def __call__(self, img): 66 | img = np.asarray(img, dtype=np.float32).copy() 67 | if random.uniform(0, 1) > self.probability: 68 | return img 69 | 70 | for attempt in range(100): 71 | area = img.shape[0] * img.shape[1] 72 | target_area = random.uniform(self.sl, self.sh) * area 73 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 74 | 75 | h = int(round(math.sqrt(target_area * aspect_ratio))) 76 | w = int(round(math.sqrt(target_area / aspect_ratio))) 77 | 78 | if w < img.shape[1] and h < img.shape[0]: 79 | x1 = random.randint(0, img.shape[0] - h) 80 | y1 = random.randint(0, img.shape[1] - w) 81 | if img.shape[2] == 3: 82 | img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0] 83 | img[x1:x1 + h, y1:y1 + w, 1] = self.mean[1] 84 | img[x1:x1 + h, y1:y1 + w, 2] = self.mean[2] 85 | else: 86 | img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0] 87 | return img 88 | return img 89 | 90 | -------------------------------------------------------------------------------- /fastreid/engine/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from .train_loop import * 7 | 8 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 9 | 10 | 11 | # prefer to let hooks and defaults live in separate namespaces (therefore not in __all__) 12 | # but still make them available here 13 | from .hooks import * 14 | from .defaults import * 15 | from .launch import * 16 | -------------------------------------------------------------------------------- /fastreid/engine/launch.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | # based on: 8 | # https://github.com/facebookresearch/detectron2/blob/master/detectron2/engine/launch.py 9 | 10 | 11 | import logging 12 | 13 | import torch 14 | import torch.distributed as dist 15 | import torch.multiprocessing as mp 16 | 17 | from fastreid.utils import comm 18 | 19 | __all__ = ["launch"] 20 | 21 | 22 | def _find_free_port(): 23 | import socket 24 | 25 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 26 | # Binding to port 0 will cause the OS to find an available port for us 27 | sock.bind(("", 0)) 28 | port = sock.getsockname()[1] 29 | sock.close() 30 | # NOTE: there is still a chance the port could be taken by other processes. 31 | return port 32 | 33 | 34 | def launch(main_func, num_gpus_per_machine, num_machines=1, machine_rank=0, dist_url=None, args=()): 35 | """ 36 | Launch multi-gpu or distributed training. 37 | This function must be called on all machines involved in the training. 38 | It will spawn child processes (defined by ``num_gpus_per_machine`) on each machine. 39 | Args: 40 | main_func: a function that will be called by `main_func(*args)` 41 | num_gpus_per_machine (int): number of GPUs per machine 42 | num_machines (int): the total number of machines 43 | machine_rank (int): the rank of this machine 44 | dist_url (str): url to connect to for distributed jobs, including protocol 45 | e.g. "tcp://127.0.0.1:8686". 46 | Can be set to "auto" to automatically select a free port on localhost 47 | args (tuple): arguments passed to main_func 48 | """ 49 | world_size = num_machines * num_gpus_per_machine 50 | if world_size > 1: 51 | # https://github.com/pytorch/pytorch/pull/14391 52 | # TODO prctl in spawned processes 53 | 54 | if dist_url == "auto": 55 | assert num_machines == 1, "dist_url=auto not supported in multi-machine jobs." 56 | port = _find_free_port() 57 | dist_url = f"tcp://127.0.0.1:{port}" 58 | if num_machines > 1 and dist_url.startswith("file://"): 59 | logger = logging.getLogger(__name__) 60 | logger.warning( 61 | "file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://" 62 | ) 63 | 64 | mp.spawn( 65 | _distributed_worker, 66 | nprocs=num_gpus_per_machine, 67 | args=(main_func, world_size, num_gpus_per_machine, machine_rank, dist_url, args), 68 | daemon=False, 69 | ) 70 | else: 71 | main_func(*args) 72 | 73 | 74 | def _distributed_worker( 75 | local_rank, main_func, world_size, num_gpus_per_machine, machine_rank, dist_url, args 76 | ): 77 | assert torch.cuda.is_available(), "cuda is not available. Please check your installation." 78 | global_rank = machine_rank * num_gpus_per_machine + local_rank 79 | try: 80 | dist.init_process_group( 81 | backend="NCCL", init_method=dist_url, world_size=world_size, rank=global_rank 82 | ) 83 | except Exception as e: 84 | logger = logging.getLogger(__name__) 85 | logger.error("Process group URL: {}".format(dist_url)) 86 | raise e 87 | # synchronize is needed here to prevent a possible timeout after calling init_process_group 88 | # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172 89 | comm.synchronize() 90 | 91 | assert num_gpus_per_machine <= torch.cuda.device_count() 92 | torch.cuda.set_device(local_rank) 93 | 94 | # Setup the local process group (which contains ranks within the same machine) 95 | assert comm._LOCAL_PROCESS_GROUP is None 96 | num_machines = world_size // num_gpus_per_machine 97 | for i in range(num_machines): 98 | ranks_on_i = list(range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine)) 99 | pg = dist.new_group(ranks_on_i) 100 | if i == machine_rank: 101 | comm._LOCAL_PROCESS_GROUP = pg 102 | 103 | main_func(*args) 104 | -------------------------------------------------------------------------------- /fastreid/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from .evaluator import DatasetEvaluator, inference_context, inference_on_dataset 3 | from .rank import evaluate_rank 4 | from .roc import evaluate_roc 5 | from .reid_evaluation import * 6 | from .testing import print_csv_format, verify_results 7 | 8 | from .build import build_sp_evaluator, EVALUATOR_REGISTRY 9 | 10 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 11 | -------------------------------------------------------------------------------- /fastreid/evaluation/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from ..utils.registry import Registry 8 | 9 | EVALUATOR_REGISTRY = Registry("EVALUATOR") 10 | EVALUATOR_REGISTRY.__doc__ = """ 11 | Registry for evaluators, which are used for different metric & methods 12 | The registered object must be a callable that accepts two arguments: 13 | 1. A :class:`detectron2.config.CfgNode` 14 | It must returns an instance of :class:`DefaultEvaluator`. 15 | """ 16 | 17 | 18 | def build_sp_evaluator(cfg, num_query, output_dir=None): 19 | """ 20 | Build a evaluator from `cfg.TEST.EVALUATOR`. 21 | Returns: 22 | an instance of :class:`DefaultEvaluator` 23 | """ 24 | 25 | evaluator_name = cfg.TEST.EVALUATOR 26 | evaluator = EVALUATOR_REGISTRY.get(evaluator_name)(cfg, num_query, output_dir) 27 | return evaluator 28 | -------------------------------------------------------------------------------- /fastreid/evaluation/evaluator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import datetime 3 | import logging 4 | import time 5 | from contextlib import contextmanager 6 | from copy import deepcopy 7 | 8 | import torch 9 | 10 | from fastreid.utils.logger import log_every_n_seconds 11 | 12 | 13 | class DatasetEvaluator: 14 | """ 15 | Base class for a dataset evaluator. 16 | The function :func:`inference_on_dataset` runs the model over 17 | all samples in the dataset, and have a DatasetEvaluator to process the inputs/outputs. 18 | This class will accumulate information of the inputs/outputs (by :meth:`process`), 19 | and produce evaluation results in the end (by :meth:`evaluate`). 20 | """ 21 | def __init__(self): 22 | self.model = None 23 | 24 | def load_module(self, model): 25 | """ 26 | Preparation for generation module like "Graph Matching", "DSR distance" 27 | model: could be a dict or list or nn.Module 28 | """ 29 | if isinstance(model, tuple) or isinstance(model, list): 30 | model = [module.cpu() for module in model] 31 | elif isinstance(model, dict): 32 | for k, v in model.items(): 33 | model[k] = v.cpu() 34 | elif model is not None: 35 | model = model.cpu() 36 | else: 37 | pass 38 | self.model = model 39 | 40 | def reset(self): 41 | """ 42 | Preparation for a new round of evaluation. 43 | Should be called before starting a round of evaluation. 44 | """ 45 | pass 46 | 47 | def preprocess_inputs(self, inputs): 48 | pass 49 | 50 | def process(self, inputs, outputs): 51 | """ 52 | Process an input/output pair. 53 | Args: 54 | inputs: the inputs that's used to call the model. 55 | outputs: the return value of `model(input)` 56 | """ 57 | pass 58 | 59 | def evaluate(self): 60 | """ 61 | Evaluate/summarize the performance, after processing all input/output pairs. 62 | Returns: 63 | dict: 64 | A new evaluator class can return a dict of arbitrary format 65 | as long as the user can process the results. 66 | In our train_net.py, we expect the following format: 67 | * key: the name of the task (e.g., bbox) 68 | * value: a dict of {metric name: score}, e.g.: {"AP50": 80} 69 | """ 70 | pass 71 | 72 | def inference_on_dataset(model, data_loader, evaluator): 73 | """ 74 | Run model on the data_loader and evaluate the metrics with evaluator. 75 | The model will be used in eval mode. 76 | Args: 77 | model (nn.Module): a module which accepts an object from 78 | `data_loader` and returns some outputs. It will be temporarily set to `eval` mode. 79 | If you wish to evaluate a model in `training` mode instead, you can 80 | wrap the given model and override its behavior of `.eval()` and `.train()`. 81 | data_loader: an iterable object with a length. 82 | The elements it generates will be the inputs to the model. 83 | evaluator (DatasetEvaluator): the evaluator to run. Use 84 | :class:`DatasetEvaluators([])` if you only want to benchmark, but 85 | don't want to do any evaluation. 86 | Returns: 87 | The return value of `evaluator.evaluate()` 88 | """ 89 | logger = logging.getLogger(__name__) 90 | logger.info("Start inference on {} images".format(len(data_loader.dataset))) 91 | 92 | total = len(data_loader) # inference data loader must have a fixed length 93 | evaluator.reset() 94 | 95 | num_warmup = min(5, total - 1) 96 | start_time = time.perf_counter() 97 | total_compute_time = 0 98 | with inference_context(model), torch.no_grad(): 99 | for idx, inputs in enumerate(data_loader): 100 | if idx == num_warmup: 101 | start_time = time.perf_counter() 102 | total_compute_time = 0 103 | 104 | start_compute_time = time.perf_counter() 105 | outputs = model(inputs) 106 | total_compute_time += time.perf_counter() - start_compute_time 107 | evaluator.process(inputs, outputs) 108 | 109 | idx += 1 110 | iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup) 111 | seconds_per_batch = total_compute_time / iters_after_start 112 | if idx >= num_warmup * 2 or seconds_per_batch > 30: 113 | total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start 114 | eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1))) 115 | log_every_n_seconds( 116 | logging.INFO, 117 | "Inference done {}/{}. {:.4f} s / batch. ETA={}".format( 118 | idx + 1, total, seconds_per_batch, str(eta) 119 | ), 120 | n=30, 121 | ) 122 | # break 123 | 124 | # Measure the time only for this worker (before the synchronization barrier) 125 | total_time = time.perf_counter() - start_time 126 | total_time_str = str(datetime.timedelta(seconds=total_time)) 127 | # NOTE this format is parsed by grep 128 | logger.info( 129 | "Total inference time: {} ({:.6f} s / batch per device)".format( 130 | total_time_str, total_time / (total - num_warmup) 131 | ) 132 | ) 133 | total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time))) 134 | logger.info( 135 | "Total inference pure compute time: {} ({:.6f} s / batch per device)".format( 136 | total_compute_time_str, total_compute_time / (total - num_warmup) 137 | ) 138 | ) 139 | results = evaluator.evaluate() 140 | # An evaluator may return None when not in main process. 141 | # Replace it by an empty dict instead to make it easier for downstream code to handle 142 | if results is None: 143 | results = {} 144 | return results 145 | 146 | 147 | @contextmanager 148 | def inference_context(model): 149 | """ 150 | A context where the model is temporarily changed to eval mode, 151 | and restored to previous mode afterwards. 152 | Args: 153 | model: a torch Module 154 | """ 155 | training_mode = model.training 156 | model.eval() 157 | yield 158 | model.train(training_mode) 159 | -------------------------------------------------------------------------------- /fastreid/evaluation/rank.py: -------------------------------------------------------------------------------- 1 | # credits: https://github.com/KaiyangZhou/deep-person-reid/blob/master/torchreid/metrics/rank.py 2 | 3 | import warnings 4 | from collections import defaultdict 5 | 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | try: 10 | from .rank_cylib.rank_cy import evaluate_cy 11 | 12 | IS_CYTHON_AVAI = True 13 | except ImportError: 14 | IS_CYTHON_AVAI = False 15 | # warnings.warn( 16 | # 'Cython rank evaluation (very fast so highly recommended) is ' 17 | # 'unavailable, now use python evaluation.' 18 | # ) 19 | 20 | 21 | def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank): 22 | """Evaluation with cuhk03 metric 23 | Key: one image for each gallery identity is randomly sampled for each query identity. 24 | Random sampling is performed num_repeats times. 25 | """ 26 | num_repeats = 10 27 | 28 | num_q, num_g = distmat.shape 29 | 30 | indices = np.argsort(distmat, axis=1) 31 | 32 | if num_g < max_rank: 33 | max_rank = num_g 34 | print( 35 | 'Note: number of gallery samples is quite small, got {}'. 36 | format(num_g) 37 | ) 38 | 39 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 40 | 41 | # compute cmc curve for each query 42 | all_cmc = [] 43 | all_AP = [] 44 | num_valid_q = 0. # number of valid query 45 | 46 | for q_idx in tqdm(range(num_q)): 47 | # get query pid and camid 48 | q_pid = q_pids[q_idx] 49 | q_camid = q_camids[q_idx] 50 | 51 | # remove gallery samples that have the same pid and camid with query 52 | order = indices[q_idx] 53 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 54 | keep = np.invert(remove) 55 | 56 | # compute cmc curve 57 | raw_cmc = matches[q_idx][ 58 | keep] # binary vector, positions with value 1 are correct matches 59 | if not np.any(raw_cmc): 60 | # this condition is true when query identity does not appear in gallery 61 | continue 62 | 63 | kept_g_pids = g_pids[order][keep] 64 | g_pids_dict = defaultdict(list) 65 | for idx, pid in enumerate(kept_g_pids): 66 | g_pids_dict[pid].append(idx) 67 | 68 | cmc = 0. 69 | for repeat_idx in range(num_repeats): 70 | mask = np.zeros(len(raw_cmc), dtype=np.bool) 71 | for _, idxs in g_pids_dict.items(): 72 | # randomly sample one image for each gallery person 73 | rnd_idx = np.random.choice(idxs) 74 | mask[rnd_idx] = True 75 | masked_raw_cmc = raw_cmc[mask] 76 | _cmc = masked_raw_cmc.cumsum() 77 | _cmc[_cmc > 1] = 1 78 | cmc += _cmc[:max_rank].astype(np.float32) 79 | 80 | cmc /= num_repeats 81 | all_cmc.append(cmc) 82 | # compute AP 83 | num_rel = raw_cmc.sum() 84 | tmp_cmc = raw_cmc.cumsum() 85 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 86 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 87 | AP = tmp_cmc.sum() / num_rel 88 | all_AP.append(AP) 89 | num_valid_q += 1. 90 | 91 | assert num_valid_q > 0, 'Error: all query identities do not appear in gallery' 92 | 93 | all_cmc = np.asarray(all_cmc).astype(np.float32) 94 | all_cmc = all_cmc.sum(0) / num_valid_q 95 | mAP = np.mean(all_AP) 96 | 97 | return all_cmc, mAP 98 | 99 | 100 | def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank): 101 | """Evaluation with market1501 metric 102 | Key: for each query identity, its gallery images from the same camera view are discarded. 103 | """ 104 | num_q, num_g = distmat.shape 105 | 106 | if num_g < max_rank: 107 | max_rank = num_g 108 | print('Note: number of gallery samples is quite small, got {}'.format(num_g)) 109 | 110 | indices = np.argsort(distmat, axis=1) 111 | 112 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 113 | 114 | # compute cmc curve for each query 115 | all_cmc = [] 116 | all_AP = [] 117 | all_INP = [] 118 | num_valid_q = 0. # number of valid query 119 | 120 | for q_idx in range(num_q): 121 | # get query pid and camid 122 | q_pid = q_pids[q_idx] 123 | q_camid = q_camids[q_idx] 124 | 125 | # remove gallery samples that have the same pid and camid with query 126 | order = indices[q_idx] 127 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 128 | keep = np.invert(remove) 129 | 130 | # compute cmc curve 131 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 132 | if not np.any(raw_cmc): 133 | # this condition is true when query identity does not appear in gallery 134 | continue 135 | 136 | cmc = raw_cmc.cumsum() 137 | 138 | pos_idx = np.where(raw_cmc == 1) 139 | max_pos_idx = np.max(pos_idx) 140 | inp = cmc[max_pos_idx] / (max_pos_idx + 1.0) 141 | all_INP.append(inp) 142 | 143 | cmc[cmc > 1] = 1 144 | 145 | all_cmc.append(cmc[:max_rank]) 146 | num_valid_q += 1. 147 | 148 | # compute average precision 149 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 150 | num_rel = raw_cmc.sum() 151 | tmp_cmc = raw_cmc.cumsum() 152 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 153 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 154 | AP = tmp_cmc.sum() / num_rel 155 | all_AP.append(AP) 156 | 157 | assert num_valid_q > 0, 'Error: all query identities do not appear in gallery' 158 | 159 | all_cmc = np.asarray(all_cmc).astype(np.float32) 160 | all_cmc = all_cmc.sum(0) / num_valid_q 161 | 162 | return all_cmc, all_AP, all_INP 163 | 164 | 165 | def evaluate_py(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03): 166 | if use_metric_cuhk03: 167 | return eval_cuhk03(distmat, g_pids, q_camids, g_camids, max_rank) 168 | else: 169 | return eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) 170 | 171 | 172 | def evaluate_rank( 173 | distmat, 174 | q_pids, 175 | g_pids, 176 | q_camids, 177 | g_camids, 178 | max_rank=50, 179 | use_metric_cuhk03=False, 180 | use_cython=True 181 | ): 182 | """Evaluates CMC rank. 183 | Args: 184 | distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery). 185 | q_pids (numpy.ndarray): 1-D array containing person identities 186 | of each query instance. 187 | g_pids (numpy.ndarray): 1-D array containing person identities 188 | of each gallery instance. 189 | q_camids (numpy.ndarray): 1-D array containing camera views under 190 | which each query instance is captured. 191 | g_camids (numpy.ndarray): 1-D array containing camera views under 192 | which each gallery instance is captured. 193 | max_rank (int, optional): maximum CMC rank to be computed. Default is 50. 194 | use_metric_cuhk03 (bool, optional): use single-gallery-shot setting for cuhk03. 195 | Default is False. This should be enabled when using cuhk03 classic split. 196 | use_cython (bool, optional): use cython code for evaluation. Default is True. 197 | This is highly recommended as the cython code can speed up the cmc computation 198 | by more than 10x. This requires Cython to be installed. 199 | """ 200 | if use_cython and IS_CYTHON_AVAI: 201 | return evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03) 202 | else: 203 | return evaluate_py(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03) 204 | -------------------------------------------------------------------------------- /fastreid/evaluation/rank_cylib/Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | python3 setup.py build_ext --inplace 3 | rm -rf build 4 | python3 test_cython.py 5 | clean: 6 | rm -rf build 7 | rm -f rank_cy.c *.so 8 | -------------------------------------------------------------------------------- /fastreid/evaluation/rank_cylib/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ -------------------------------------------------------------------------------- /fastreid/evaluation/rank_cylib/roc_cy.pyx: -------------------------------------------------------------------------------- 1 | # cython: boundscheck=False, wraparound=False, nonecheck=False, cdivision=True 2 | # credits: https://github.com/KaiyangZhou/deep-person-reid/blob/master/torchreid/metrics/rank_cylib/rank_cy.pyx 3 | 4 | import cython 5 | import faiss 6 | import numpy as np 7 | cimport numpy as np 8 | 9 | 10 | """ 11 | Compiler directives: 12 | https://github.com/cython/cython/wiki/enhancements-compilerdirectives 13 | Cython tutorial: 14 | https://cython.readthedocs.io/en/latest/src/userguide/numpy_tutorial.html 15 | Credit to https://github.com/luzai 16 | """ 17 | 18 | 19 | # Main interface 20 | cpdef evaluate_roc_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids, 21 | long[:]q_camids, long[:]g_camids): 22 | 23 | distmat = np.asarray(distmat, dtype=np.float32) 24 | q_pids = np.asarray(q_pids, dtype=np.int64) 25 | g_pids = np.asarray(g_pids, dtype=np.int64) 26 | q_camids = np.asarray(q_camids, dtype=np.int64) 27 | g_camids = np.asarray(g_camids, dtype=np.int64) 28 | 29 | cdef long num_q = distmat.shape[0] 30 | cdef long num_g = distmat.shape[1] 31 | 32 | cdef: 33 | long[:,:] indices = np.argsort(distmat, axis=1) 34 | long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64) 35 | 36 | float[:] pos = np.zeros(num_q*num_g, dtype=np.float32) 37 | float[:] neg = np.zeros(num_q*num_g, dtype=np.float32) 38 | 39 | long valid_pos = 0 40 | long valid_neg = 0 41 | long ind 42 | 43 | long q_idx, q_pid, q_camid, g_idx 44 | long[:] order = np.zeros(num_g, dtype=np.int64) 45 | 46 | float[:] raw_cmc = np.zeros(num_g, dtype=np.float32) # binary vector, positions with value 1 are correct matches 47 | long[:] sort_idx = np.zeros(num_g, dtype=np.int64) 48 | 49 | long idx 50 | 51 | for q_idx in range(num_q): 52 | # get query pid and camid 53 | q_pid = q_pids[q_idx] 54 | q_camid = q_camids[q_idx] 55 | 56 | for g_idx in range(num_g): 57 | order[g_idx] = indices[q_idx, g_idx] 58 | num_g_real = 0 59 | 60 | # remove gallery samples that have the same pid and camid with query 61 | for g_idx in range(num_g): 62 | if (g_pids[order[g_idx]] != q_pid) or (g_camids[order[g_idx]] != q_camid): 63 | raw_cmc[num_g_real] = matches[q_idx][g_idx] 64 | sort_idx[num_g_real] = order[g_idx] 65 | num_g_real += 1 66 | 67 | q_dist = distmat[q_idx] 68 | 69 | for valid_idx in range(num_g_real): 70 | if raw_cmc[valid_idx] == 1: 71 | pos[valid_pos] = q_dist[sort_idx[valid_idx]] 72 | valid_pos += 1 73 | elif raw_cmc[valid_idx] == 0: 74 | neg[valid_neg] = q_dist[sort_idx[valid_idx]] 75 | valid_neg += 1 76 | 77 | cdef float[:] scores = np.hstack((pos[:valid_pos], neg[:valid_neg])) 78 | cdef float[:] labels = np.hstack((np.zeros(valid_pos, dtype=np.float32), 79 | np.ones(valid_neg, dtype=np.float32))) 80 | return np.asarray(scores), np.asarray(labels) 81 | 82 | 83 | # Compute the cumulative sum 84 | cdef void function_cumsum(cython.numeric[:] src, cython.numeric[:] dst, long n): 85 | cdef long i 86 | dst[0] = src[0] 87 | for i in range(1, n): 88 | dst[i] = src[i] + dst[i - 1] -------------------------------------------------------------------------------- /fastreid/evaluation/rank_cylib/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from distutils.extension import Extension 3 | 4 | import numpy as np 5 | from Cython.Build import cythonize 6 | 7 | 8 | def numpy_include(): 9 | try: 10 | numpy_include = np.get_include() 11 | except AttributeError: 12 | numpy_include = np.get_numpy_include() 13 | return numpy_include 14 | 15 | 16 | ext_modules = [ 17 | Extension( 18 | 'rank_cy', 19 | ['rank_cy.pyx'], 20 | include_dirs=[numpy_include()], 21 | ), 22 | Extension( 23 | 'roc_cy', 24 | ['roc_cy.pyx'], 25 | include_dirs=[numpy_include()], 26 | ) 27 | ] 28 | 29 | setup( 30 | name='Cython-based reid evaluation code', 31 | ext_modules=cythonize(ext_modules) 32 | ) 33 | -------------------------------------------------------------------------------- /fastreid/evaluation/rank_cylib/test_cython.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import timeit 3 | import numpy as np 4 | import os.path as osp 5 | 6 | sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..') 7 | 8 | from fastreid.evaluation import evaluate_rank 9 | from fastreid.evaluation import evaluate_roc 10 | 11 | """ 12 | Test the speed of cython-based evaluation code. The speed improvements 13 | can be much bigger when using the real reid data, which contains a larger 14 | amount of query and gallery images. 15 | Note: you might encounter the following error: 16 | 'AssertionError: Error: all query identities do not appear in gallery'. 17 | This is normal because the inputs are random numbers. Just try again. 18 | """ 19 | 20 | print('*** Compare running time ***') 21 | 22 | setup = ''' 23 | import sys 24 | import os.path as osp 25 | import numpy as np 26 | sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..') 27 | from fastreid.evaluation import evaluate_rank 28 | from fastreid.evaluation import evaluate_roc 29 | num_q = 30 30 | num_g = 300 31 | dim = 512 32 | max_rank = 5 33 | q_feats = np.random.rand(num_q, dim).astype(np.float32) * 20 34 | q_feats = q_feats / np.linalg.norm(q_feats, ord=2, axis=1, keepdims=True) 35 | g_feats = np.random.rand(num_g, dim).astype(np.float32) * 20 36 | g_feats = g_feats / np.linalg.norm(g_feats, ord=2, axis=1, keepdims=True) 37 | distmat = 1 - np.dot(q_feats, g_feats.transpose()) 38 | q_pids = np.random.randint(0, num_q, size=num_q) 39 | g_pids = np.random.randint(0, num_g, size=num_g) 40 | q_camids = np.random.randint(0, 5, size=num_q) 41 | g_camids = np.random.randint(0, 5, size=num_g) 42 | ''' 43 | 44 | print('=> Using CMC metric') 45 | pytime = timeit.timeit( 46 | 'evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)', 47 | setup=setup, 48 | number=20 49 | ) 50 | cytime = timeit.timeit( 51 | 'evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)', 52 | setup=setup, 53 | number=20 54 | ) 55 | print('Python time: {} s'.format(pytime)) 56 | print('Cython time: {} s'.format(cytime)) 57 | print('CMC Cython is {} times faster than python\n'.format(pytime / cytime)) 58 | 59 | print('=> Using ROC metric') 60 | pytime = timeit.timeit( 61 | 'evaluate_roc(distmat, q_pids, g_pids, q_camids, g_camids, use_cython=False)', 62 | setup=setup, 63 | number=20 64 | ) 65 | cytime = timeit.timeit( 66 | 'evaluate_roc(distmat, q_pids, g_pids, q_camids, g_camids, use_cython=True)', 67 | setup=setup, 68 | number=20 69 | ) 70 | print('Python time: {} s'.format(pytime)) 71 | print('Cython time: {} s'.format(cytime)) 72 | print('ROC Cython is {} times faster than python\n'.format(pytime / cytime)) 73 | 74 | print("=> Check precision") 75 | num_q = 30 76 | num_g = 300 77 | dim = 512 78 | max_rank = 5 79 | q_feats = np.random.rand(num_q, dim).astype(np.float32) * 20 80 | q_feats = q_feats / np.linalg.norm(q_feats, ord=2, axis=1, keepdims=True) 81 | g_feats = np.random.rand(num_g, dim).astype(np.float32) * 20 82 | g_feats = g_feats / np.linalg.norm(g_feats, ord=2, axis=1, keepdims=True) 83 | distmat = 1 - np.dot(q_feats, g_feats.transpose()) 84 | q_pids = np.random.randint(0, num_q, size=num_q) 85 | g_pids = np.random.randint(0, num_g, size=num_g) 86 | q_camids = np.random.randint(0, 5, size=num_q) 87 | g_camids = np.random.randint(0, 5, size=num_g) 88 | 89 | cmc_py, mAP_py, mINP_py = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False) 90 | 91 | cmc_cy, mAP_cy, mINP_cy = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True) 92 | 93 | np.testing.assert_allclose(cmc_py, cmc_cy, rtol=1e-3, atol=1e-6) 94 | np.testing.assert_allclose(mAP_py, mAP_cy, rtol=1e-3, atol=1e-6) 95 | np.testing.assert_allclose(mINP_py, mINP_cy, rtol=1e-3, atol=1e-6) 96 | print('Rank results between python and cython are the same!') 97 | 98 | scores_cy, labels_cy = evaluate_roc(distmat, q_pids, g_pids, q_camids, g_camids, use_cython=True) 99 | scores_py, labels_py = evaluate_roc(distmat, q_pids, g_pids, q_camids, g_camids, use_cython=False) 100 | 101 | np.testing.assert_allclose(scores_cy, scores_py, rtol=1e-3, atol=1e-6) 102 | np.testing.assert_allclose(labels_cy, labels_py, rtol=1e-3, atol=1e-6) 103 | print('ROC results between python and cython are the same!\n') 104 | 105 | print("=> Check exact values") 106 | print("mAP = {} \ncmc = {}\nmINP = {}\nScores = {}".format(np.array(mAP_cy), cmc_cy, np.array(mINP_cy), scores_cy)) 107 | -------------------------------------------------------------------------------- /fastreid/evaluation/rerank.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # based on: 4 | # https://github.com/zhunzhong07/person-re-ranking 5 | 6 | __all__ = ['re_ranking'] 7 | 8 | import numpy as np 9 | 10 | 11 | def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1: int = 20, k2: int = 6, lambda_value: float = 0.3): 12 | original_dist = np.concatenate( 13 | [np.concatenate([q_q_dist, q_g_dist], axis=1), 14 | np.concatenate([q_g_dist.T, g_g_dist], axis=1)], 15 | axis=0) 16 | original_dist = np.power(original_dist, 2).astype(np.float32) 17 | original_dist = np.transpose(1. * original_dist / np.max(original_dist, axis=0)) 18 | V = np.zeros_like(original_dist).astype(np.float32) 19 | initial_rank = np.argsort(original_dist).astype(np.int32) 20 | 21 | query_num = q_g_dist.shape[0] 22 | gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1] 23 | all_num = gallery_num 24 | 25 | for i in range(all_num): 26 | # k-reciprocal neighbors 27 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 28 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 29 | fi = np.where(backward_k_neigh_index == i)[0] 30 | k_reciprocal_index = forward_k_neigh_index[fi] 31 | k_reciprocal_expansion_index = k_reciprocal_index 32 | for j in range(len(k_reciprocal_index)): 33 | candidate = k_reciprocal_index[j] 34 | candidate_forward_k_neigh_index = initial_rank[candidate, 35 | :int(np.around(k1 / 2.)) + 1] 36 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, 37 | :int(np.around(k1 / 2.)) + 1] 38 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 39 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 40 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2. / 3 * len( 41 | candidate_k_reciprocal_index): 42 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index) 43 | 44 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 45 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) 46 | V[i, k_reciprocal_expansion_index] = 1. * weight / np.sum(weight) 47 | original_dist = original_dist[:query_num, ] 48 | if k2 != 1: 49 | V_qe = np.zeros_like(V, dtype=np.float32) 50 | for i in range(all_num): 51 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 52 | V = V_qe 53 | del V_qe 54 | del initial_rank 55 | invIndex = [] 56 | for i in range(gallery_num): 57 | invIndex.append(np.where(V[:, i] != 0)[0]) 58 | 59 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float32) 60 | 61 | for i in range(query_num): 62 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float32) 63 | indNonZero = np.where(V[i, :] != 0)[0] 64 | indImages = [invIndex[ind] for ind in indNonZero] 65 | for j in range(len(indNonZero)): 66 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]], 67 | V[indImages[j], indNonZero[j]]) 68 | jaccard_dist[i] = 1 - temp_min / (2. - temp_min) 69 | 70 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value 71 | del original_dist, V, jaccard_dist 72 | final_dist = final_dist[:query_num, query_num:] 73 | return final_dist 74 | -------------------------------------------------------------------------------- /fastreid/evaluation/roc.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: l1aoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import warnings 8 | 9 | import faiss 10 | import numpy as np 11 | 12 | try: 13 | from .rank_cylib.roc_cy import evaluate_roc_cy 14 | 15 | IS_CYTHON_AVAI = True 16 | except ImportError: 17 | IS_CYTHON_AVAI = False 18 | # warnings.warn( 19 | # 'Cython roc evaluation (very fast so highly recommended) is ' 20 | # 'unavailable, now use python evaluation.' 21 | # ) 22 | 23 | 24 | def evaluate_roc_py(distmat, q_pids, g_pids, q_camids, g_camids): 25 | r"""Evaluation with ROC curve. 26 | Key: for each query identity, its gallery images from the same camera view are discarded. 27 | 28 | Args: 29 | distmat (np.ndarray): cosine distance matrix 30 | """ 31 | num_q, num_g = distmat.shape 32 | 33 | indices = np.argsort(distmat, axis=1) 34 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 35 | 36 | pos = [] 37 | neg = [] 38 | for q_idx in range(num_q): 39 | # get query pid and camid 40 | q_pid = q_pids[q_idx] 41 | q_camid = q_camids[q_idx] 42 | 43 | # Remove gallery samples that have the same pid and camid with query 44 | order = indices[q_idx] 45 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 46 | keep = np.invert(remove) 47 | raw_cmc = matches[q_idx][keep] 48 | 49 | sort_idx = order[keep] 50 | 51 | q_dist = distmat[q_idx] 52 | ind_pos = np.where(raw_cmc == 1)[0] 53 | pos.extend(q_dist[sort_idx[ind_pos]]) 54 | 55 | ind_neg = np.where(raw_cmc == 0)[0] 56 | neg.extend(q_dist[sort_idx[ind_neg]]) 57 | 58 | scores = np.hstack((pos, neg)) 59 | 60 | labels = np.hstack((np.zeros(len(pos)), np.ones(len(neg)))) 61 | return scores, labels 62 | 63 | 64 | def evaluate_roc( 65 | distmat, 66 | q_pids, 67 | g_pids, 68 | q_camids, 69 | g_camids, 70 | use_cython=True 71 | ): 72 | """Evaluates CMC rank. 73 | Args: 74 | distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery). 75 | q_pids (numpy.ndarray): 1-D array containing person identities 76 | of each query instance. 77 | g_pids (numpy.ndarray): 1-D array containing person identities 78 | of each gallery instance. 79 | q_camids (numpy.ndarray): 1-D array containing camera views under 80 | which each query instance is captured. 81 | g_camids (numpy.ndarray): 1-D array containing camera views under 82 | which each gallery instance is captured. 83 | use_cython (bool, optional): use cython code for evaluation. Default is True. 84 | This is highly recommended as the cython code can speed up the cmc computation 85 | by more than 10x. This requires Cython to be installed. 86 | """ 87 | if use_cython and IS_CYTHON_AVAI: 88 | return evaluate_roc_cy(distmat, q_pids, g_pids, q_camids, g_camids) 89 | else: 90 | return evaluate_roc_py(distmat, q_pids, g_pids, q_camids, g_camids) 91 | -------------------------------------------------------------------------------- /fastreid/evaluation/testing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import logging 3 | import pprint 4 | import sys 5 | from collections import Mapping, OrderedDict 6 | 7 | import numpy as np 8 | from tabulate import tabulate 9 | from termcolor import colored 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def print_csv_format(results): 15 | """ 16 | Print main metrics in a format similar to Detectron, 17 | so that they are easy to copypaste into a spreadsheet. 18 | Args: 19 | results (OrderedDict[dict]): task_name -> {metric -> score} 20 | """ 21 | assert isinstance(results, OrderedDict), results # unordered results cannot be properly printed 22 | task = list(results.keys())[0] 23 | metrics = ["Datasets"] + [k for k in results[task]] 24 | 25 | csv_results = [] 26 | for task, res in results.items(): 27 | csv_results.append((task, *list(res.values()))) 28 | 29 | # tabulate it 30 | table = tabulate( 31 | csv_results, 32 | tablefmt="pipe", 33 | floatfmt=".2%", 34 | headers=metrics, 35 | numalign="left", 36 | ) 37 | 38 | logger.info("Evaluation results in csv format: \n" + colored(table, "cyan")) 39 | 40 | 41 | def verify_results(cfg, results): 42 | """ 43 | Args: 44 | results (OrderedDict[dict]): task_name -> {metric -> score} 45 | Returns: 46 | bool: whether the verification succeeds or not 47 | """ 48 | expected_results = cfg.TEST.EXPECTED_RESULTS 49 | if not len(expected_results): 50 | return True 51 | 52 | ok = True 53 | for task, metric, expected, tolerance in expected_results: 54 | actual = results[task][metric] 55 | if not np.isfinite(actual): 56 | ok = False 57 | diff = abs(actual - expected) 58 | if diff > tolerance: 59 | ok = False 60 | 61 | logger = logging.getLogger(__name__) 62 | if not ok: 63 | logger.error("Result verification failed!") 64 | logger.error("Expected Results: " + str(expected_results)) 65 | logger.error("Actual Results: " + pprint.pformat(results)) 66 | 67 | sys.exit(1) 68 | else: 69 | logger.info("Results verification passed.") 70 | return ok 71 | 72 | 73 | def flatten_results_dict(results): 74 | """ 75 | Expand a hierarchical dict of scalars into a flat dict of scalars. 76 | If results[k1][k2][k3] = v, the returned dict will have the entry 77 | {"k1/k2/k3": v}. 78 | Args: 79 | results (dict): 80 | """ 81 | r = {} 82 | for k, v in results.items(): 83 | if isinstance(v, Mapping): 84 | v = flatten_results_dict(v) 85 | for kk, vv in v.items(): 86 | r[k + "/" + kk] = vv 87 | else: 88 | r[k] = v 89 | return r 90 | -------------------------------------------------------------------------------- /fastreid/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .batch_norm import * 8 | from .non_local import Non_local 9 | from .pooling import * 10 | from .gather_layer import GatherLayer 11 | from .stripe_layer import * 12 | from .mlp import MLP -------------------------------------------------------------------------------- /fastreid/layers/batch_norm.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import logging 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | 13 | __all__ = [ 14 | "BatchNorm", 15 | "IBN", 16 | "GhostBatchNorm", 17 | "FrozenBatchNorm", 18 | "SyncBatchNorm", 19 | "get_norm", 20 | ] 21 | 22 | 23 | class BatchNorm(nn.BatchNorm2d): 24 | def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0, 25 | bias_init=0.0, **kwargs): 26 | super().__init__(num_features, eps=eps, momentum=momentum) 27 | if weight_init is not None: nn.init.constant_(self.weight, weight_init) 28 | if bias_init is not None: nn.init.constant_(self.bias, bias_init) 29 | self.weight.requires_grad_(not weight_freeze) 30 | self.bias.requires_grad_(not bias_freeze) 31 | 32 | 33 | class SyncBatchNorm(nn.SyncBatchNorm): 34 | def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0, 35 | bias_init=0.0): 36 | super().__init__(num_features, eps=eps, momentum=momentum) 37 | if weight_init is not None: nn.init.constant_(self.weight, weight_init) 38 | if bias_init is not None: nn.init.constant_(self.bias, bias_init) 39 | self.weight.requires_grad_(not weight_freeze) 40 | self.bias.requires_grad_(not bias_freeze) 41 | 42 | 43 | class IBN(nn.Module): 44 | def __init__(self, planes, bn_norm='BN', **kwargs): 45 | super(IBN, self).__init__() 46 | half1 = int(planes / 2) 47 | self.half = half1 48 | half2 = planes - half1 49 | self.IN = nn.InstanceNorm2d(half1, affine=True) 50 | self.BN = get_norm(bn_norm, half2, **kwargs) 51 | 52 | def forward(self, x): 53 | split = torch.split(x, self.half, 1) 54 | out1 = self.IN(split[0].contiguous()) 55 | out2 = self.BN(split[1].contiguous()) 56 | out = torch.cat((out1, out2), 1) 57 | return out 58 | 59 | 60 | class GhostBatchNorm(BatchNorm): 61 | def __init__(self, num_features, num_splits=1, **kwargs): 62 | super().__init__(num_features, **kwargs) 63 | self.num_splits = num_splits 64 | self.register_buffer('running_mean', torch.zeros(num_features)) 65 | self.register_buffer('running_var', torch.ones(num_features)) 66 | 67 | def forward(self, input): 68 | N, C, H, W = input.shape 69 | if self.training or not self.track_running_stats: 70 | self.running_mean = self.running_mean.repeat(self.num_splits) 71 | self.running_var = self.running_var.repeat(self.num_splits) 72 | outputs = F.batch_norm( 73 | input.view(-1, C * self.num_splits, H, W), self.running_mean, self.running_var, 74 | self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits), 75 | True, self.momentum, self.eps).view(N, C, H, W) 76 | self.running_mean = torch.mean(self.running_mean.view(self.num_splits, self.num_features), dim=0) 77 | self.running_var = torch.mean(self.running_var.view(self.num_splits, self.num_features), dim=0) 78 | return outputs 79 | else: 80 | return F.batch_norm( 81 | input, self.running_mean, self.running_var, 82 | self.weight, self.bias, False, self.momentum, self.eps) 83 | 84 | 85 | class FrozenBatchNorm(BatchNorm): 86 | """ 87 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 88 | It contains non-trainable buffers called 89 | "weight" and "bias", "running_mean", "running_var", 90 | initialized to perform identity transformation. 91 | The pre-trained backbone models from Caffe2 only contain "weight" and "bias", 92 | which are computed from the original four parameters of BN. 93 | The affine transform `x * weight + bias` will perform the equivalent 94 | computation of `(x - running_mean) / sqrt(running_var) * weight + bias`. 95 | When loading a backbone model from Caffe2, "running_mean" and "running_var" 96 | will be left unchanged as identity transformation. 97 | Other pre-trained backbone models may contain all 4 parameters. 98 | The forward is implemented by `F.batch_norm(..., training=False)`. 99 | """ 100 | 101 | _version = 3 102 | 103 | def __init__(self, num_features, eps=1e-5, **kwargs): 104 | super().__init__(num_features, weight_freeze=True, bias_freeze=True, **kwargs) 105 | self.num_features = num_features 106 | self.eps = eps 107 | 108 | def forward(self, x): 109 | if x.requires_grad: 110 | # When gradients are needed, F.batch_norm will use extra memory 111 | # because its backward op computes gradients for weight/bias as well. 112 | scale = self.weight * (self.running_var + self.eps).rsqrt() 113 | bias = self.bias - self.running_mean * scale 114 | scale = scale.reshape(1, -1, 1, 1) 115 | bias = bias.reshape(1, -1, 1, 1) 116 | return x * scale + bias 117 | else: 118 | # When gradients are not needed, F.batch_norm is a single fused op 119 | # and provide more optimization opportunities. 120 | return F.batch_norm( 121 | x, 122 | self.running_mean, 123 | self.running_var, 124 | self.weight, 125 | self.bias, 126 | training=False, 127 | eps=self.eps, 128 | ) 129 | 130 | def _load_from_state_dict( 131 | self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs 132 | ): 133 | version = local_metadata.get("version", None) 134 | 135 | if version is None or version < 2: 136 | # No running_mean/var in early versions 137 | # This will silent the warnings 138 | if prefix + "running_mean" not in state_dict: 139 | state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean) 140 | if prefix + "running_var" not in state_dict: 141 | state_dict[prefix + "running_var"] = torch.ones_like(self.running_var) 142 | 143 | if version is not None and version < 3: 144 | logger = logging.getLogger(__name__) 145 | logger.info("FrozenBatchNorm {} is upgraded to version 3.".format(prefix.rstrip("."))) 146 | # In version < 3, running_var are used without +eps. 147 | state_dict[prefix + "running_var"] -= self.eps 148 | 149 | super()._load_from_state_dict( 150 | state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs 151 | ) 152 | 153 | def __repr__(self): 154 | return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps) 155 | 156 | @classmethod 157 | def convert_frozen_batchnorm(cls, module): 158 | """ 159 | Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm. 160 | Args: 161 | module (torch.nn.Module): 162 | Returns: 163 | If module is BatchNorm/SyncBatchNorm, returns a new module. 164 | Otherwise, in-place convert module and return it. 165 | Similar to convert_sync_batchnorm in 166 | https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py 167 | """ 168 | bn_module = nn.modules.batchnorm 169 | bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm) 170 | res = module 171 | if isinstance(module, bn_module): 172 | res = cls(module.num_features) 173 | if module.affine: 174 | res.weight.data = module.weight.data.clone().detach() 175 | res.bias.data = module.bias.data.clone().detach() 176 | res.running_mean.data = module.running_mean.data 177 | res.running_var.data = module.running_var.data 178 | res.eps = module.eps 179 | else: 180 | for name, child in module.named_children(): 181 | new_child = cls.convert_frozen_batchnorm(child) 182 | if new_child is not child: 183 | res.add_module(name, new_child) 184 | return res 185 | 186 | 187 | def get_norm(norm, out_channels, **kwargs): 188 | """ 189 | Args: 190 | norm (str or callable): either one of BN, GhostBN, FrozenBN, GN or SyncBN; 191 | or a callable that thakes a channel number and returns 192 | the normalization layer as a nn.Module 193 | out_channels: number of channels for normalization layer 194 | 195 | Returns: 196 | nn.Module or None: the normalization layer 197 | """ 198 | 199 | if isinstance(norm, str): 200 | if len(norm) == 0: 201 | return None 202 | norm = { 203 | "BN": BatchNorm, 204 | "GhostBN": GhostBatchNorm, 205 | "FrozenBN": FrozenBatchNorm, 206 | "GN": lambda channels, **args: nn.GroupNorm(32, channels), 207 | "syncBN": SyncBatchNorm, 208 | "IBN": IBN, 209 | }[norm] 210 | return norm(out_channels, **kwargs) 211 | -------------------------------------------------------------------------------- /fastreid/layers/gather_layer.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | # based on: https://github.com/open-mmlab/OpenSelfSup/blob/master/openselfsup/models/utils/gather_layer.py 8 | 9 | import torch 10 | import torch.distributed as dist 11 | 12 | 13 | class GatherLayer(torch.autograd.Function): 14 | """Gather tensors from all process, supporting backward propagation. 15 | """ 16 | 17 | @staticmethod 18 | def forward(ctx, input): 19 | ctx.save_for_backward(input) 20 | output = [torch.zeros_like(input) \ 21 | for _ in range(dist.get_world_size())] 22 | dist.all_gather(output, input) 23 | return tuple(output) 24 | 25 | @staticmethod 26 | def backward(ctx, *grads): 27 | input, = ctx.saved_tensors 28 | grad_out = torch.zeros_like(input) 29 | grad_out[:] = grads[dist.get_rank()] 30 | return grad_out 31 | -------------------------------------------------------------------------------- /fastreid/layers/mlp.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from copy import deepcopy 5 | 6 | class MLP(nn.Module): 7 | def __init__(self, channels: list, do_bn=True, xd='1'): 8 | super(MLP, self).__init__() 9 | n = len(channels) 10 | layers = [] 11 | conv = nn.Conv1d if xd == '1' else nn.Conv2d 12 | bn = nn.InstanceNorm1d if xd == '1' else nn.InstanceNorm2d 13 | for i in range(1, n): 14 | layers.append( 15 | conv(channels[i - 1], channels[i], kernel_size=1, bias=True) 16 | ) 17 | if i < n - 1: 18 | if do_bn: 19 | layers.append(bn(channels[i])) 20 | layers.append(nn.ReLU()) 21 | self.layers = nn.Sequential(*layers) 22 | 23 | def forward(self, x): 24 | x = self.layers(x) 25 | return x 26 | 27 | 28 | def MLP(channels: list, do_bn=True): 29 | n = len(channels) 30 | layers = [] 31 | for i in range(1, n): 32 | layers.append( 33 | nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True) 34 | ) 35 | if i < n-1: 36 | if do_bn: 37 | layers.append(nn.InstanceNorm1d(channels[i])) 38 | layers.append(nn.ReLU()) 39 | return nn.Sequential(*layers) 40 | -------------------------------------------------------------------------------- /fastreid/layers/non_local.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | 4 | import torch 5 | from torch import nn 6 | from .batch_norm import get_norm 7 | 8 | 9 | class Non_local(nn.Module): 10 | def __init__(self, in_channels, bn_norm, reduc_ratio=2): 11 | super(Non_local, self).__init__() 12 | 13 | self.in_channels = in_channels 14 | self.inter_channels = in_channels // reduc_ratio 15 | 16 | self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 17 | kernel_size=1, stride=1, padding=0) 18 | 19 | self.W = nn.Sequential( 20 | nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, 21 | kernel_size=1, stride=1, padding=0), 22 | get_norm(bn_norm, self.in_channels), 23 | ) 24 | nn.init.constant_(self.W[1].weight, 0.0) 25 | nn.init.constant_(self.W[1].bias, 0.0) 26 | 27 | self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 28 | kernel_size=1, stride=1, padding=0) 29 | 30 | self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 31 | kernel_size=1, stride=1, padding=0) 32 | 33 | def forward(self, x): 34 | """ 35 | :param x: (b, t, h, w) 36 | :return x: (b, t, h, w) 37 | """ 38 | batch_size = x.size(0) 39 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 40 | g_x = g_x.permute(0, 2, 1) 41 | 42 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 43 | theta_x = theta_x.permute(0, 2, 1) 44 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 45 | f = torch.matmul(theta_x, phi_x) 46 | N = f.size(-1) 47 | f_div_C = f / N 48 | 49 | y = torch.matmul(f_div_C, g_x) 50 | y = y.permute(0, 2, 1).contiguous() 51 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 52 | W_y = self.W(y) 53 | z = W_y + x 54 | return z 55 | -------------------------------------------------------------------------------- /fastreid/layers/pooling.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: l1aoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import nn 10 | 11 | __all__ = ["Flatten", 12 | "GeneralizedMeanPooling", 13 | "GeneralizedMeanPoolingP", 14 | "FastGlobalAvgPool2d", 15 | "AdaptiveAvgMaxPool2d", 16 | "ClipGlobalAvgPool2d", 17 | ] 18 | 19 | 20 | class Flatten(nn.Module): 21 | def forward(self, input): 22 | return input.view(input.size(0), -1) 23 | 24 | 25 | class GeneralizedMeanPooling(nn.Module): 26 | r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes. 27 | The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)` 28 | - At p = infinity, one gets Max Pooling 29 | - At p = 1, one gets Average Pooling 30 | The output is of size H x W, for any input size. 31 | The number of output features is equal to the number of input planes. 32 | Args: 33 | output_size: the target output size of the image of the form H x W. 34 | Can be a tuple (H, W) or a single H for a square image H x H 35 | H and W can be either a ``int``, or ``None`` which means the size will 36 | be the same as that of the input. 37 | """ 38 | 39 | def __init__(self, norm=3, output_size=1, eps=1e-6): 40 | super(GeneralizedMeanPooling, self).__init__() 41 | assert norm > 0 42 | self.p = float(norm) 43 | self.output_size = output_size 44 | self.eps = eps 45 | 46 | def forward(self, x): 47 | x = x.clamp(min=self.eps).pow(self.p) 48 | return torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p) 49 | 50 | def __repr__(self): 51 | return self.__class__.__name__ + '(' \ 52 | + str(self.p) + ', ' \ 53 | + 'output_size=' + str(self.output_size) + ')' 54 | 55 | 56 | class GeneralizedMeanPoolingP(GeneralizedMeanPooling): 57 | """ Same, but norm is trainable 58 | """ 59 | 60 | def __init__(self, norm=3, output_size=1, eps=1e-6): 61 | super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps) 62 | self.p = nn.Parameter(torch.ones(1) * norm) 63 | 64 | 65 | class AdaptiveAvgMaxPool2d(nn.Module): 66 | def __init__(self): 67 | super(AdaptiveAvgMaxPool2d, self).__init__() 68 | self.gap = FastGlobalAvgPool2d() 69 | self.gmp = nn.AdaptiveMaxPool2d(1) 70 | 71 | def forward(self, x): 72 | avg_feat = self.gap(x) 73 | max_feat = self.gmp(x) 74 | feat = avg_feat + max_feat 75 | return feat 76 | 77 | 78 | class FastGlobalAvgPool2d(nn.Module): 79 | def __init__(self, flatten=False): 80 | super(FastGlobalAvgPool2d, self).__init__() 81 | self.flatten = flatten 82 | 83 | def forward(self, x): 84 | if self.flatten: 85 | in_size = x.size() 86 | return x.view((in_size[0], in_size[1], -1)).mean(dim=2) 87 | else: 88 | return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1) 89 | 90 | 91 | class ClipGlobalAvgPool2d(nn.Module): 92 | def __init__(self): 93 | super().__init__() 94 | self.avgpool = FastGlobalAvgPool2d() 95 | 96 | def forward(self, x): 97 | x = self.avgpool(x) 98 | x = torch.clamp(x, min=0., max=1.) 99 | return x 100 | -------------------------------------------------------------------------------- /fastreid/layers/stripe_layer.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | from torch import nn, sigmoid, einsum 9 | from einops import rearrange 10 | from fastreid.layers.batch_norm import IBN 11 | 12 | class StripeAttention(nn.Module): 13 | 14 | def __init__(self, dim, heads=4, dimqk=128, dimv=128): 15 | super(StripeAttention, self).__init__() 16 | self.scale = dimqk ** -0.5 17 | self.heads = heads 18 | out_dimqk = heads * dimqk 19 | out_dimv = heads * dimv 20 | 21 | self.q = nn.Conv2d(dim, out_dimqk, 1, bias=False) 22 | self.k = nn.Conv2d(dim, out_dimqk, 1, bias=False) 23 | self.v = nn.Conv2d(dim, out_dimv, 1, bias=False) 24 | self.softmax = nn.Softmax(dim=-1) 25 | 26 | def forward(self, query_features, key_features, value_features, prob=False): 27 | heads = self.heads 28 | B, C, H, W = query_features.shape 29 | q = self.q(query_features) 30 | k = self.k(key_features) 31 | v = self.v(value_features) 32 | q, k, v = map(lambda x: rearrange(x, 'B (h d) H W -> B h (H W) d', h=heads), (q, k, v)) 33 | 34 | q *= self.scale 35 | 36 | logits = einsum('bhxd,bhyd->bhxy',q, k) 37 | 38 | weights = self.softmax(logits) 39 | out = einsum('bhxy,bhyd->bhxd', weights, v) 40 | out = rearrange(out, 'B h (H W) d -> B (h d) H W', H=H) 41 | if prob is True: 42 | return out, weights 43 | else: 44 | return out 45 | 46 | class StripeLayer(nn.Module): 47 | 48 | def __init__(self, inplanes, planes, num_stripe=8, p=0.3): 49 | super(StripeLayer, self).__init__() 50 | self.num_stripe = num_stripe 51 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=True) 52 | self.bn1 = IBN(planes, 'BN') 53 | self.dropout = nn.Dropout(p=p) 54 | 55 | self.sp = StripeAttention(planes, 4, planes // 4, planes // 4) 56 | self.bn2 = nn.BatchNorm2d(planes) 57 | 58 | self.conv3 = nn.Conv2d(planes, inplanes, kernel_size=1, bias=True) 59 | self.bn3 = nn.BatchNorm2d(inplanes) 60 | self.relu = nn.ReLU(inplace=True) 61 | 62 | def forward(self, x, mask = None): 63 | residual = x if mask is None else x * mask 64 | 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = self.relu(out) 68 | 69 | # stripe attention 70 | num = self.num_stripe 71 | res_out = out 72 | stripes = out.chunk(num, dim=2) 73 | outs = [self.sp(stripes[i], stripes[i], stripes[i]) for i in range(len(stripes))] 74 | out = torch.cat(outs, dim=2).contiguous() 75 | 76 | out = out + self.dropout(res_out) 77 | out = self.bn2(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv3(out) 81 | out = self.bn3(out) 82 | 83 | out += residual 84 | out = self.relu(out) 85 | 86 | return out -------------------------------------------------------------------------------- /fastreid/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .meta_arch import build_model 8 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .build import build_backbone, BACKBONE_REGISTRY 8 | 9 | from .resnet import build_resnet_backbone 10 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from ...utils.registry import Registry 8 | 9 | BACKBONE_REGISTRY = Registry("BACKBONE") 10 | BACKBONE_REGISTRY.__doc__ = """ 11 | Registry for backbones, which extract feature maps from images 12 | The registered object must be a callable that accepts two arguments: 13 | 1. A :class:`detectron2.config.CfgNode` 14 | 2. A :class:`detectron2.layers.ShapeSpec`, which contains the input shape specification. 15 | It must returns an instance of :class:`Backbone`. 16 | """ 17 | 18 | 19 | def build_backbone(cfg): 20 | """ 21 | Build a backbone from `cfg.MODEL.BACKBONE.NAME`. 22 | Returns: 23 | an instance of :class:`Backbone` 24 | """ 25 | 26 | backbone_name = cfg.MODEL.BACKBONE.NAME 27 | backbone = BACKBONE_REGISTRY.get(backbone_name)(cfg) 28 | return backbone 29 | -------------------------------------------------------------------------------- /fastreid/modeling/heads/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .build import REID_HEADS_REGISTRY, build_heads 8 | 9 | # import all the meta_arch, so they will be registered 10 | from .embedding_head import EmbeddingHead 11 | from .agg_head import AGGHead -------------------------------------------------------------------------------- /fastreid/modeling/heads/agg_head.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch.nn.functional as F 8 | from torch import nn 9 | import torch 10 | 11 | from fastreid.utils.comm import get_local_rank 12 | from fastreid.layers import * 13 | from fastreid.utils.weight_init import weights_init_kaiming, weights_init_classifier 14 | from .build import REID_HEADS_REGISTRY 15 | 16 | class BNC(nn.Module): 17 | def __init__(self, inplane, num_classes, num_heads): 18 | super().__init__() 19 | self.bnnecks = nn.ModuleList([ 20 | get_norm('BN', inplane, bias_freeze=True) for _ in range(num_heads) 21 | ]) 22 | self.classifiers = nn.ModuleList([ 23 | nn.Linear(inplane, num_classes, bias=False) for _ in range(num_heads) 24 | ]) 25 | self.bnnecks.apply(weights_init_kaiming) 26 | self.classifiers.apply(weights_init_classifier) 27 | 28 | def forward(self, x): 29 | # x [p, b, c] 30 | b, c, p = x.shape 31 | out_list = [] 32 | for i in range(p): 33 | out = x[:, :, i].reshape(b, c, 1, 1) 34 | out = self.bnnecks[i](out).reshape(b, c) 35 | out = self.classifiers[i](out) 36 | out_list.append(out) 37 | out = torch.stack(out_list, dim=-1).contiguous() 38 | return out 39 | 40 | # back up code for basic piguhead 41 | @REID_HEADS_REGISTRY.register() 42 | class AGGHead(nn.Module): 43 | def __init__(self, cfg): 44 | super().__init__() 45 | # fmt: off 46 | num_classes = cfg.MODEL.HEADS.NUM_CLASSES 47 | 48 | self.pool_layer = nn.AdaptiveAvgPool2d(1) 49 | 50 | self.count = 0 51 | self.bnc = BNC(2048, num_classes, 3) 52 | 53 | self.slassifier = nn.Linear(2048, num_classes, bias=False) 54 | self.sbn = get_norm('BN', 2048, bias_freeze=True) 55 | self.slassifier.apply(weights_init_classifier) 56 | self.sbn.apply(weights_init_kaiming) 57 | 58 | self.plassifier = nn.Linear(2048, num_classes, bias=False) 59 | self.pbn = get_norm('BN', 2048, bias_freeze=True) 60 | self.plassifier.apply(weights_init_classifier) 61 | self.pbn.apply(weights_init_kaiming) 62 | 63 | def forward(self, seatures, peatures, cls_feats, targets=None, confs= None): 64 | """ 65 | See :class:`ReIDHeads.forward`. 66 | """ 67 | b, c, h, w = seatures.shape 68 | self.count = (self.count + 1) % 100 69 | score = confs 70 | 71 | seatures = self.pool_layer(seatures) 72 | bn_seatures = self.sbn(seatures.reshape(b, c, 1, 1)).squeeze() 73 | peatures = self.pool_layer(peatures) 74 | bn_peatures = self.pbn(peatures.reshape(b, c, 1, 1)).squeeze() 75 | 76 | # Evaluation 77 | # fmt: off 78 | if not self.training: return bn_seatures, bn_peatures, cls_feats, score, 79 | # fmt: on 80 | 81 | # Training 82 | stp_outputs = self.slassifier(bn_seatures) 83 | pth_outputs = self.plassifier(bn_peatures) 84 | key_outputs = self.bnc(cls_feats) 85 | 86 | return { 87 | "stp_outputs": stp_outputs, 88 | "pth_outputs": pth_outputs, 89 | "key_outputs": key_outputs, 90 | "global_features": seatures.squeeze(), 91 | "key_feats": cls_feats, 92 | } 93 | -------------------------------------------------------------------------------- /fastreid/modeling/heads/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from ...utils.registry import Registry 8 | 9 | REID_HEADS_REGISTRY = Registry("HEADS") 10 | REID_HEADS_REGISTRY.__doc__ = """ 11 | Registry for ROI heads in a generalized R-CNN model. 12 | ROIHeads take feature maps and region proposals, and 13 | perform per-region computation. 14 | The registered object will be called with `obj(cfg, input_shape)`. 15 | The call is expected to return an :class:`ROIHeads`. 16 | """ 17 | 18 | 19 | def build_heads(cfg): 20 | """ 21 | Build REIDHeads defined by `cfg.MODEL.REID_HEADS.NAME`. 22 | """ 23 | head = cfg.MODEL.HEADS.NAME 24 | return REID_HEADS_REGISTRY.get(head)(cfg) 25 | -------------------------------------------------------------------------------- /fastreid/modeling/heads/embedding_head.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch.nn.functional as F 8 | from torch import nn 9 | 10 | from fastreid.layers import * 11 | from fastreid.utils.weight_init import weights_init_kaiming, weights_init_classifier 12 | from .build import REID_HEADS_REGISTRY 13 | 14 | 15 | @REID_HEADS_REGISTRY.register() 16 | class EmbeddingHead(nn.Module): 17 | def __init__(self, cfg): 18 | super().__init__() 19 | # fmt: off 20 | feat_dim = cfg.MODEL.BACKBONE.FEAT_DIM 21 | embedding_dim = cfg.MODEL.HEADS.EMBEDDING_DIM 22 | num_classes = cfg.MODEL.HEADS.NUM_CLASSES 23 | neck_feat = cfg.MODEL.HEADS.NECK_FEAT 24 | pool_type = cfg.MODEL.HEADS.POOL_LAYER 25 | cls_type = cfg.MODEL.HEADS.CLS_LAYER 26 | with_bnneck = cfg.MODEL.HEADS.WITH_BNNECK 27 | norm_type = cfg.MODEL.HEADS.NORM 28 | 29 | if pool_type == 'fastavgpool': self.pool_layer = FastGlobalAvgPool2d() 30 | elif pool_type == 'avgpool': self.pool_layer = nn.AdaptiveAvgPool2d(1) 31 | elif pool_type == 'maxpool': self.pool_layer = nn.AdaptiveMaxPool2d(1) 32 | elif pool_type == 'gempoolP': self.pool_layer = GeneralizedMeanPoolingP() 33 | elif pool_type == 'gempool': self.pool_layer = GeneralizedMeanPooling() 34 | elif pool_type == "avgmaxpool": self.pool_layer = AdaptiveAvgMaxPool2d() 35 | elif pool_type == 'clipavgpool': self.pool_layer = ClipGlobalAvgPool2d() 36 | elif pool_type == "identity": self.pool_layer = nn.Identity() 37 | elif pool_type == "flatten": self.pool_layer = Flatten() 38 | else: raise KeyError(f"{pool_type} is not supported!") 39 | # fmt: on 40 | 41 | self.neck_feat = neck_feat 42 | 43 | bottleneck = [] 44 | if embedding_dim > 0: 45 | bottleneck.append(nn.Conv2d(feat_dim, embedding_dim, 1, 1, bias=False)) 46 | feat_dim = embedding_dim 47 | 48 | if with_bnneck: 49 | bottleneck.append(get_norm(norm_type, feat_dim, bias_freeze=True)) 50 | 51 | self.bottleneck = nn.Sequential(*bottleneck) 52 | 53 | # identity classification layer 54 | # fmt: off 55 | if cls_type == 'linear': self.classifier = nn.Linear(feat_dim, num_classes, bias=False) 56 | elif cls_type == 'arcSoftmax': self.classifier = ArcSoftmax(cfg, feat_dim, num_classes) 57 | elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, feat_dim, num_classes) 58 | elif cls_type == 'amSoftmax': self.classifier = AMSoftmax(cfg, feat_dim, num_classes) 59 | else: raise KeyError(f"{cls_type} is not supported!") 60 | # fmt: on 61 | 62 | self.bottleneck.apply(weights_init_kaiming) 63 | self.classifier.apply(weights_init_classifier) 64 | 65 | def forward(self, features, targets=None): 66 | """ 67 | See :class:`ReIDHeads.forward`. 68 | """ 69 | global_feat = self.pool_layer(features) 70 | bn_feat = self.bottleneck(global_feat) 71 | bn_feat = bn_feat[..., 0, 0] 72 | 73 | # Evaluation 74 | # fmt: off 75 | if not self.training: return bn_feat 76 | # fmt: on 77 | 78 | # Training 79 | if self.classifier.__class__.__name__ == 'Linear': 80 | cls_outputs = self.classifier(bn_feat) 81 | pred_class_logits = F.linear(bn_feat, self.classifier.weight) 82 | else: 83 | cls_outputs = self.classifier(bn_feat, targets) 84 | pred_class_logits = self.classifier.s * F.linear(F.normalize(bn_feat), 85 | F.normalize(self.classifier.weight)) 86 | 87 | # fmt: off 88 | if self.neck_feat == "before": feat = global_feat[..., 0, 0] 89 | elif self.neck_feat == "after": feat = bn_feat 90 | else: raise KeyError(f"{self.neck_feat} is invalid for MODEL.HEADS.NECK_FEAT") 91 | # fmt: on 92 | 93 | return { 94 | "cls_outputs": cls_outputs, 95 | "pred_class_logits": pred_class_logits, 96 | "features": feat, 97 | } 98 | -------------------------------------------------------------------------------- /fastreid/modeling/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: l1aoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .cross_entroy_loss import cross_entropy_loss, log_accuracy 8 | from .triplet_loss import triplet_loss 9 | -------------------------------------------------------------------------------- /fastreid/modeling/losses/cross_entroy_loss.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: l1aoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | from fastreid.utils.comm import get_local_rank 10 | from fastreid.utils.events import get_event_storage 11 | 12 | 13 | def log_accuracy(pred_class_logits, gt_classes, topk=(1,)): 14 | """ 15 | Log the accuracy metrics to EventStorage. 16 | """ 17 | bsz = pred_class_logits.size(0) 18 | maxk = max(topk) 19 | _, pred_class = pred_class_logits.topk(maxk, 1, True, True) 20 | pred_class = pred_class.t() 21 | correct = pred_class.eq(gt_classes.view(1, -1).expand_as(pred_class)) 22 | 23 | ret = [] 24 | for k in topk: 25 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) 26 | ret.append(correct_k.mul_(1. / bsz)) 27 | 28 | storage = get_event_storage() 29 | storage.put_scalar("cls_accuracy", ret[0]) 30 | 31 | 32 | def cross_entropy_loss(pred_class_outputs, gt_classes, eps, alpha=0.2, conf=None): 33 | num_classes = pred_class_outputs.size(1) 34 | 35 | if eps >= 0: 36 | smooth_param = eps 37 | else: 38 | # Adaptive label smooth regularization 39 | soft_label = F.softmax(pred_class_outputs, dim=1) 40 | smooth_param = alpha * soft_label[torch.arange(soft_label.size(0)), gt_classes].unsqueeze(1) 41 | 42 | log_probs = F.log_softmax(pred_class_outputs, dim=1) 43 | with torch.no_grad(): 44 | targets = torch.ones_like(log_probs) 45 | targets *= smooth_param / (num_classes - 1) 46 | targets.scatter_(1, gt_classes.data.unsqueeze(1), (1 - smooth_param)) 47 | 48 | loss = (-targets * log_probs).sum(dim=1) 49 | 50 | """ 51 | # confidence penalty 52 | conf_penalty = 0.3 53 | probs = F.softmax(pred_class_logits, dim=1) 54 | entropy = torch.sum(-probs * log_probs, dim=1) 55 | loss = torch.clamp_min(loss - conf_penalty * entropy, min=0.) 56 | """ 57 | 58 | with torch.no_grad(): 59 | non_zero_cnt = max(loss.nonzero(as_tuple=False).size(0), 1) 60 | 61 | if conf is not None: 62 | loss *= conf 63 | loss = loss.sum() / non_zero_cnt 64 | 65 | return loss 66 | -------------------------------------------------------------------------------- /fastreid/modeling/losses/triplet_loss.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from fastreid.utils import comm 11 | from fastreid.layers import GatherLayer 12 | from .utils import concat_all_gather, euclidean_dist, normalize 13 | 14 | 15 | def softmax_weights(dist, mask): 16 | max_v = torch.max(dist * mask, dim=1, keepdim=True)[0] 17 | diff = dist - max_v 18 | Z = torch.sum(torch.exp(diff) * mask, dim=1, keepdim=True) + 1e-6 # avoid division by zero 19 | W = torch.exp(diff) * mask / Z 20 | return W 21 | 22 | 23 | def hard_example_mining(dist_mat, is_pos, is_neg): 24 | """For each anchor, find the hardest positive and negative sample. 25 | Args: 26 | dist_mat: pair wise distance between samples, shape [N, M] 27 | is_pos: positive index with shape [N, M] 28 | is_neg: negative index with shape [N, M] 29 | Returns: 30 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 31 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 32 | p_inds: pytorch LongTensor, with shape [N]; 33 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 34 | n_inds: pytorch LongTensor, with shape [N]; 35 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 36 | NOTE: Only consider the case in which all labels have same num of samples, 37 | thus we can cope with all anchors in parallel. 38 | """ 39 | 40 | assert len(dist_mat.size()) == 2 41 | N = dist_mat.size(0) 42 | 43 | # `dist_ap` means distance(anchor, positive) 44 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 45 | dist_ap, relative_p_inds = torch.max( 46 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) 47 | # `dist_an` means distance(anchor, negative) 48 | # both `dist_an` and `relative_n_inds` with shape [N, 1] 49 | dist_an, relative_n_inds = torch.min( 50 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) 51 | 52 | # shape [N] 53 | dist_ap = dist_ap.squeeze(1) 54 | dist_an = dist_an.squeeze(1) 55 | 56 | return dist_ap, dist_an 57 | 58 | 59 | def weighted_example_mining(dist_mat, is_pos, is_neg): 60 | """For each anchor, find the weighted positive and negative sample. 61 | Args: 62 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 63 | is_pos: 64 | is_neg: 65 | Returns: 66 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 67 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 68 | """ 69 | assert len(dist_mat.size()) == 2 70 | 71 | is_pos = is_pos.float() 72 | is_neg = is_neg.float() 73 | dist_ap = dist_mat * is_pos 74 | dist_an = dist_mat * is_neg 75 | 76 | weights_ap = softmax_weights(dist_ap, is_pos) 77 | weights_an = softmax_weights(-dist_an, is_neg) 78 | 79 | dist_ap = torch.sum(dist_ap * weights_ap, dim=1) 80 | dist_an = torch.sum(dist_an * weights_an, dim=1) 81 | 82 | return dist_ap, dist_an 83 | 84 | 85 | def triplet_loss(embedding, targets, margin, norm_feat, hard_mining): 86 | r"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid). 87 | Related Triplet Loss theory can be found in paper 'In Defense of the Triplet 88 | Loss for Person Re-Identification'.""" 89 | 90 | if norm_feat: embedding = normalize(embedding, axis=-1) 91 | 92 | # # For distributed training, gather all features from different process. 93 | if comm.get_world_size() > 1: 94 | all_embedding = torch.cat(GatherLayer.apply(embedding), dim=0) 95 | all_targets = concat_all_gather(targets) 96 | else: 97 | all_embedding = embedding 98 | all_targets = targets 99 | 100 | dist_mat = euclidean_dist(all_embedding, all_embedding) 101 | N, N = dist_mat.size() 102 | 103 | is_pos = all_targets.view(N, 1).expand(N, N).eq(all_targets.view(N, 1).expand(N, N).t()) 104 | is_neg = all_targets.view(N, 1).expand(N, N).ne(all_targets.view(N, 1).expand(N, N).t()) 105 | 106 | if hard_mining: 107 | dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg) 108 | else: 109 | dist_ap, dist_an = weighted_example_mining(dist_mat, is_pos, is_neg) 110 | 111 | y = dist_an.new().resize_as_(dist_an).fill_(1) 112 | 113 | if margin > 0: 114 | loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=margin) 115 | else: 116 | loss = F.soft_margin_loss(dist_an - dist_ap, y) 117 | # fmt: off 118 | if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3) 119 | # fmt: on 120 | 121 | return loss 122 | -------------------------------------------------------------------------------- /fastreid/modeling/losses/utils.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | 9 | 10 | def concat_all_gather(tensor): 11 | """ 12 | Performs all_gather operation on the provided tensors. 13 | *** Warning ***: torch.distributed.all_gather has no gradient. 14 | """ 15 | tensors_gather = [torch.ones_like(tensor) 16 | for _ in range(torch.distributed.get_world_size())] 17 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 18 | 19 | output = torch.cat(tensors_gather, dim=0) 20 | return output 21 | 22 | 23 | def normalize(x, axis=-1): 24 | """Normalizing to unit length along the specified dimension. 25 | Args: 26 | x: pytorch Variable 27 | Returns: 28 | x: pytorch Variable, same shape as input 29 | """ 30 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 31 | return x 32 | 33 | 34 | def euclidean_dist(x, y): 35 | m, n = x.size(0), y.size(0) 36 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 37 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 38 | dist = xx + yy - 2 * torch.matmul(x, y.t()) 39 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 40 | return dist 41 | 42 | 43 | def cosine_dist(x, y): 44 | bs1, bs2 = x.size(0), y.size(0) 45 | frac_up = torch.matmul(x, y.transpose(0, 1)) 46 | frac_down = (torch.sqrt(torch.sum(torch.pow(x, 2), 1))).view(bs1, 1).repeat(1, bs2) * \ 47 | (torch.sqrt(torch.sum(torch.pow(y, 2), 1))).view(1, bs2).repeat(bs1, 1) 48 | cosine = frac_up / frac_down 49 | return 1 - cosine 50 | -------------------------------------------------------------------------------- /fastreid/modeling/meta_arch/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .build import META_ARCH_REGISTRY, build_model 8 | 9 | 10 | # import all the meta_arch, so they will be registered 11 | from .baseline import Baseline 12 | from .pirt import Pirt 13 | -------------------------------------------------------------------------------- /fastreid/modeling/meta_arch/baseline.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | from torch import nn 9 | 10 | from fastreid.modeling.backbones import build_backbone 11 | from fastreid.modeling.heads import build_heads 12 | from fastreid.modeling.losses import * 13 | from .build import META_ARCH_REGISTRY 14 | 15 | 16 | @META_ARCH_REGISTRY.register() 17 | class Baseline(nn.Module): 18 | def __init__(self, cfg): 19 | super().__init__() 20 | self._cfg = cfg 21 | assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD) 22 | self.register_buffer("pixel_mean", torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1)) 23 | self.register_buffer("pixel_std", torch.tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1)) 24 | 25 | # backbone 26 | self.backbone = build_backbone(cfg) 27 | 28 | # head 29 | self.heads = build_heads(cfg) 30 | 31 | @property 32 | def device(self): 33 | return self.pixel_mean.device 34 | 35 | def forward(self, batched_inputs): 36 | images = self.preprocess_image(batched_inputs) 37 | features = self.backbone(images) 38 | 39 | if self.training: 40 | assert "targets" in batched_inputs, "Person ID annotation are missing in training!" 41 | targets = batched_inputs["targets"].to(self.device) 42 | if targets.sum() < 0: targets.zero_() 43 | 44 | outputs = self.heads(features, targets) 45 | return { 46 | "outputs": outputs, 47 | "targets": targets, 48 | } 49 | else: 50 | outputs = self.heads(features) 51 | return outputs 52 | 53 | def preprocess_image(self, batched_inputs): 54 | r""" 55 | Normalize and batch the input images. 56 | """ 57 | if isinstance(batched_inputs, dict): 58 | images = batched_inputs["images"].to(self.device) 59 | elif isinstance(batched_inputs, torch.Tensor): 60 | images = batched_inputs.to(self.device) 61 | else: 62 | raise TypeError("batched_inputs must be dict or torch.Tensor, but get {}".format(type(batched_inputs))) 63 | 64 | # images.sub_(self.pixel_mean).div_(self.pixel_std) 65 | images = images.sub(self.pixel_mean).div(self.pixel_std) 66 | return images 67 | 68 | def losses(self, outs): 69 | r""" 70 | Compute loss from modeling's outputs, the loss function input arguments 71 | must be the same as the outputs of the model forwarding. 72 | """ 73 | # fmt: off 74 | outputs = outs["outputs"] 75 | gt_labels = outs["targets"] 76 | # model predictions 77 | pred_class_logits = outputs['pred_class_logits'].detach() 78 | cls_outputs = outputs['cls_outputs'] 79 | 80 | loss_dict = {} 81 | loss_names = self._cfg.MODEL.LOSSES.NAME 82 | 83 | if "CrossEntropyLoss" in loss_names: 84 | loss_dict['loss_cls'] = cross_entropy_loss( 85 | cls_outputs, 86 | gt_labels, 87 | self._cfg.MODEL.LOSSES.CE.EPSILON, 88 | self._cfg.MODEL.LOSSES.CE.ALPHA, 89 | ) * self._cfg.MODEL.LOSSES.CE.SCALE 90 | 91 | return loss_dict 92 | -------------------------------------------------------------------------------- /fastreid/modeling/meta_arch/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | import torch 7 | 8 | from fastreid.utils.registry import Registry 9 | 10 | META_ARCH_REGISTRY = Registry("META_ARCH") # noqa F401 isort:skip 11 | META_ARCH_REGISTRY.__doc__ = """ 12 | Registry for meta-architectures, i.e. the whole model. 13 | The registered object will be called with `obj(cfg)` 14 | and expected to return a `nn.Module` object. 15 | """ 16 | 17 | 18 | def build_model(cfg): 19 | """ 20 | Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``. 21 | Note that it does not load any weights from ``cfg``. 22 | """ 23 | meta_arch = cfg.MODEL.META_ARCHITECTURE 24 | model = META_ARCH_REGISTRY.get(meta_arch)(cfg) 25 | model.to(torch.device(cfg.MODEL.DEVICE)) 26 | return model 27 | -------------------------------------------------------------------------------- /fastreid/modeling/posenets/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .build import build_posenet, POSENET_REGISTRY 8 | from .pose_hrnet import build_pose_hrnet -------------------------------------------------------------------------------- /fastreid/modeling/posenets/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from ...utils.registry import Registry 8 | 9 | POSENET_REGISTRY = Registry("POSENET") 10 | POSENET_REGISTRY.__doc__ = """ 11 | Registry for posenets, which extract feature maps from images 12 | The registered object must be a callable that accepts two arguments: 13 | 1. A :class:`detectron2.config.CfgNode` 14 | 2. A :class:`detectron2.layers.ShapeSpec`, which contains the input shape specification. 15 | It must returns an instance of :class:`Posenet`. 16 | """ 17 | 18 | def build_posenet(cfg): 19 | """ 20 | Build a posenet from `cfg.MODEL.POSENET.NAME`. 21 | Returns: 22 | an instance of :class:`posenet` 23 | """ 24 | 25 | posenet_name = cfg.MODEL.POSENET.NAME 26 | posenet = POSENET_REGISTRY.get(posenet_name)(cfg) 27 | return posenet 28 | -------------------------------------------------------------------------------- /fastreid/solver/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | 8 | from .build import build_lr_scheduler, build_optimizer -------------------------------------------------------------------------------- /fastreid/solver/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from . import lr_scheduler 8 | from . import optim 9 | 10 | 11 | def build_optimizer(cfg, model): 12 | params = [] 13 | for key, value in model.named_parameters(): 14 | if not value.requires_grad: continue 15 | 16 | lr = cfg.SOLVER.BASE_LR 17 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 18 | if "heads" in key: 19 | lr *= cfg.SOLVER.HEADS_LR_FACTOR 20 | if "bias" in key: 21 | lr *= cfg.SOLVER.BIAS_LR_FACTOR 22 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 23 | params += [{"name": key, "params": [value], "lr": lr, "weight_decay": weight_decay, "freeze": False}] 24 | 25 | solver_opt = cfg.SOLVER.OPT 26 | # fmt: off 27 | if solver_opt == "SGD": opt_fns = getattr(optim, solver_opt)(params, momentum=cfg.SOLVER.MOMENTUM) 28 | else: opt_fns = getattr(optim, solver_opt)(params) 29 | # fmt: on 30 | return opt_fns 31 | 32 | 33 | def build_lr_scheduler(cfg, optimizer): 34 | scheduler_args = { 35 | "optimizer": optimizer, 36 | 37 | # warmup options 38 | "warmup_factor": cfg.SOLVER.WARMUP_FACTOR, 39 | "warmup_iters": cfg.SOLVER.WARMUP_ITERS, 40 | "warmup_method": cfg.SOLVER.WARMUP_METHOD, 41 | 42 | # multi-step lr scheduler options 43 | "milestones": cfg.SOLVER.STEPS, 44 | "gamma": cfg.SOLVER.GAMMA, 45 | 46 | # cosine annealing lr scheduler options 47 | "max_iters": cfg.SOLVER.MAX_ITER, 48 | "delay_iters": cfg.SOLVER.DELAY_ITERS, 49 | "eta_min_lr": cfg.SOLVER.ETA_MIN_LR, 50 | 51 | } 52 | return getattr(lr_scheduler, cfg.SOLVER.SCHED)(**scheduler_args) 53 | -------------------------------------------------------------------------------- /fastreid/solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import math 8 | from bisect import bisect_right 9 | from typing import List 10 | 11 | import torch 12 | from torch.optim.lr_scheduler import _LRScheduler 13 | 14 | __all__ = ["WarmupMultiStepLR", "WarmupCosineAnnealingLR"] 15 | 16 | 17 | class WarmupMultiStepLR(_LRScheduler): 18 | def __init__( 19 | self, 20 | optimizer: torch.optim.Optimizer, 21 | milestones: List[int], 22 | gamma: float = 0.1, 23 | warmup_factor: float = 0.001, 24 | warmup_iters: int = 1000, 25 | warmup_method: str = "linear", 26 | last_epoch: int = -1, 27 | **kwargs, 28 | ): 29 | if not list(milestones) == sorted(milestones): 30 | raise ValueError( 31 | "Milestones should be a list of" " increasing integers. Got {}", milestones 32 | ) 33 | self.milestones = milestones 34 | self.gamma = gamma 35 | self.warmup_factor = warmup_factor 36 | self.warmup_iters = warmup_iters 37 | self.warmup_method = warmup_method 38 | super().__init__(optimizer, last_epoch) 39 | 40 | def get_lr(self) -> List[float]: 41 | warmup_factor = _get_warmup_factor_at_iter( 42 | self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor 43 | ) 44 | return [ 45 | base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch) 46 | for base_lr in self.base_lrs 47 | ] 48 | 49 | def _compute_values(self) -> List[float]: 50 | # The new interface 51 | return self.get_lr() 52 | 53 | 54 | class WarmupCosineAnnealingLR(_LRScheduler): 55 | r"""Set the learning rate of each parameter group using a cosine annealing 56 | schedule, where :math:`\eta_{max}` is set to the initial lr and 57 | :math:`T_{cur}` is the number of epochs since the last restart in SGDR: 58 | 59 | .. math:: 60 | \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 + 61 | \cos(\frac{T_{cur}}{T_{max}}\pi)) 62 | 63 | When last_epoch=-1, sets initial lr as lr. 64 | 65 | It has been proposed in 66 | `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only 67 | implements the cosine annealing part of SGDR, and not the restarts. 68 | 69 | Args: 70 | optimizer (Optimizer): Wrapped optimizer. 71 | T_max (int): Maximum number of iterations. 72 | eta_min (float): Minimum learning rate. Default: 0. 73 | last_epoch (int): The index of last epoch. Default: -1. 74 | 75 | .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: 76 | https://arxiv.org/abs/1608.03983 77 | """ 78 | 79 | def __init__( 80 | self, 81 | optimizer: torch.optim.Optimizer, 82 | max_iters: int, 83 | delay_iters: int = 0, 84 | eta_min_lr: int = 0, 85 | warmup_factor: float = 0.001, 86 | warmup_iters: int = 1000, 87 | warmup_method: str = "linear", 88 | last_epoch=-1, 89 | **kwargs 90 | ): 91 | self.max_iters = max_iters 92 | self.delay_iters = delay_iters 93 | self.eta_min_lr = eta_min_lr 94 | self.warmup_factor = warmup_factor 95 | self.warmup_iters = warmup_iters 96 | self.warmup_method = warmup_method 97 | assert self.delay_iters >= self.warmup_iters, "Scheduler delay iters must be larger than warmup iters" 98 | super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch) 99 | 100 | def get_lr(self) -> List[float]: 101 | if self.last_epoch <= self.warmup_iters: 102 | warmup_factor = _get_warmup_factor_at_iter( 103 | self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor, 104 | ) 105 | return [ 106 | base_lr * warmup_factor for base_lr in self.base_lrs 107 | ] 108 | elif self.last_epoch <= self.delay_iters: 109 | return self.base_lrs 110 | 111 | else: 112 | return [ 113 | self.eta_min_lr + (base_lr - self.eta_min_lr) * 114 | (1 + math.cos( 115 | math.pi * (self.last_epoch - self.delay_iters) / (self.max_iters - self.delay_iters))) / 2 116 | for base_lr in self.base_lrs] 117 | 118 | 119 | def _get_warmup_factor_at_iter( 120 | method: str, iter: int, warmup_iters: int, warmup_factor: float 121 | ) -> float: 122 | """ 123 | Return the learning rate warmup factor at a specific iteration. 124 | See https://arxiv.org/abs/1706.02677 for more details. 125 | Args: 126 | method (str): warmup method; either "constant" or "linear". 127 | iter (int): iteration at which to calculate the warmup factor. 128 | warmup_iters (int): the number of warmup iterations. 129 | warmup_factor (float): the base warmup factor (the meaning changes according 130 | to the method used). 131 | Returns: 132 | float: the effective warmup factor at the given iteration. 133 | """ 134 | if iter >= warmup_iters: 135 | return 1.0 136 | 137 | if method == "constant": 138 | return warmup_factor 139 | elif method == "linear": 140 | alpha = iter / warmup_iters 141 | return warmup_factor * (1 - alpha) + alpha 142 | else: 143 | raise ValueError("Unknown warmup method: {}".format(method)) 144 | -------------------------------------------------------------------------------- /fastreid/solver/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .adam import Adam 2 | from .sgd import SGD 3 | 4 | -------------------------------------------------------------------------------- /fastreid/solver/optim/adam.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.optim.optimizer import Optimizer 5 | 6 | 7 | class Adam(Optimizer): 8 | r"""Implements Adam algorithm. 9 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 10 | The implementation of the L2 penalty follows changes proposed in 11 | `Decoupled Weight Decay Regularization`_. 12 | Arguments: 13 | params (iterable): iterable of parameters to optimize or dicts defining 14 | parameter groups 15 | lr (float, optional): learning rate (default: 1e-3) 16 | betas (Tuple[float, float], optional): coefficients used for computing 17 | running averages of gradient and its square (default: (0.9, 0.999)) 18 | eps (float, optional): term added to the denominator to improve 19 | numerical stability (default: 1e-8) 20 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 21 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 22 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 23 | (default: False) 24 | .. _Adam\: A Method for Stochastic Optimization: 25 | https://arxiv.org/abs/1412.6980 26 | .. _Decoupled Weight Decay Regularization: 27 | https://arxiv.org/abs/1711.05101 28 | .. _On the Convergence of Adam and Beyond: 29 | https://openreview.net/forum?id=ryQu7f-RZ 30 | """ 31 | 32 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 33 | weight_decay=0, amsgrad=False): 34 | if not 0.0 <= lr: 35 | raise ValueError("Invalid learning rate: {}".format(lr)) 36 | if not 0.0 <= eps: 37 | raise ValueError("Invalid epsilon value: {}".format(eps)) 38 | if not 0.0 <= betas[0] < 1.0: 39 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 40 | if not 0.0 <= betas[1] < 1.0: 41 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 42 | if not 0.0 <= weight_decay: 43 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 44 | defaults = dict(lr=lr, betas=betas, eps=eps, 45 | weight_decay=weight_decay, amsgrad=amsgrad) 46 | super(Adam, self).__init__(params, defaults) 47 | 48 | def __setstate__(self, state): 49 | super(Adam, self).__setstate__(state) 50 | for group in self.param_groups: 51 | group.setdefault('amsgrad', False) 52 | 53 | @torch.no_grad() 54 | def step(self, closure=None): 55 | """Performs a single optimization step. 56 | Arguments: 57 | closure (callable, optional): A closure that reevaluates the model 58 | and returns the loss. 59 | """ 60 | loss = None 61 | if closure is not None: 62 | with torch.enable_grad(): 63 | loss = closure() 64 | 65 | for group in self.param_groups: 66 | if group['freeze']: continue 67 | 68 | for p in group['params']: 69 | if p.grad is None: 70 | continue 71 | grad = p.grad 72 | if grad.is_sparse: 73 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 74 | amsgrad = group['amsgrad'] 75 | 76 | state = self.state[p] 77 | 78 | # State initialization 79 | if len(state) == 0: 80 | state['step'] = 0 81 | # Exponential moving average of gradient values 82 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 83 | # Exponential moving average of squared gradient values 84 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 85 | if amsgrad: 86 | # Maintains max of all exp. moving avg. of sq. grad. values 87 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 88 | 89 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 90 | if amsgrad: 91 | max_exp_avg_sq = state['max_exp_avg_sq'] 92 | beta1, beta2 = group['betas'] 93 | 94 | state['step'] += 1 95 | bias_correction1 = 1 - beta1 ** state['step'] 96 | bias_correction2 = 1 - beta2 ** state['step'] 97 | 98 | if group['weight_decay'] != 0: 99 | grad = grad.add(p, alpha=group['weight_decay']) 100 | 101 | # Decay the first and second moment running average coefficient 102 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 103 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 104 | if amsgrad: 105 | # Maintains the maximum of all 2nd moment running avg. till now 106 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 107 | # Use the max. for normalizing running avg. of gradient 108 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 109 | else: 110 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 111 | 112 | step_size = group['lr'] / bias_correction1 113 | 114 | p.addcdiv_(exp_avg, denom, value=-step_size) 115 | 116 | return loss 117 | -------------------------------------------------------------------------------- /fastreid/solver/optim/sgd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer, required 3 | 4 | 5 | class SGD(Optimizer): 6 | r"""Implements stochastic gradient descent (optionally with momentum). 7 | Nesterov momentum is based on the formula from 8 | `On the importance of initialization and momentum in deep learning`__. 9 | Args: 10 | params (iterable): iterable of parameters to optimize or dicts defining 11 | parameter groups 12 | lr (float): learning rate 13 | momentum (float, optional): momentum factor (default: 0) 14 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 15 | dampening (float, optional): dampening for momentum (default: 0) 16 | nesterov (bool, optional): enables Nesterov momentum (default: False) 17 | Example: 18 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 19 | >>> optimizer.zero_grad() 20 | >>> loss_fn(model(input), target).backward() 21 | >>> optimizer.step() 22 | __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf 23 | .. note:: 24 | The implementation of SGD with Momentum/Nesterov subtly differs from 25 | Sutskever et. al. and implementations in some other frameworks. 26 | Considering the specific case of Momentum, the update can be written as 27 | .. math:: 28 | \begin{aligned} 29 | v_{t+1} & = \mu * v_{t} + g_{t+1}, \\ 30 | p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, 31 | \end{aligned} 32 | where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the 33 | parameters, gradient, velocity, and momentum respectively. 34 | This is in contrast to Sutskever et. al. and 35 | other frameworks which employ an update of the form 36 | .. math:: 37 | \begin{aligned} 38 | v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\ 39 | p_{t+1} & = p_{t} - v_{t+1}. 40 | \end{aligned} 41 | The Nesterov version is analogously modified. 42 | """ 43 | 44 | def __init__(self, params, lr=required, momentum=0, dampening=0, 45 | weight_decay=0, nesterov=False): 46 | if lr is not required and lr < 0.0: 47 | raise ValueError("Invalid learning rate: {}".format(lr)) 48 | if momentum < 0.0: 49 | raise ValueError("Invalid momentum value: {}".format(momentum)) 50 | if weight_decay < 0.0: 51 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 52 | 53 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 54 | weight_decay=weight_decay, nesterov=nesterov) 55 | if nesterov and (momentum <= 0 or dampening != 0): 56 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 57 | super(SGD, self).__init__(params, defaults) 58 | 59 | def __setstate__(self, state): 60 | super(SGD, self).__setstate__(state) 61 | for group in self.param_groups: 62 | group.setdefault('nesterov', False) 63 | 64 | @torch.no_grad() 65 | def step(self, closure=None): 66 | """Performs a single optimization step. 67 | Arguments: 68 | closure (callable, optional): A closure that reevaluates the model 69 | and returns the loss. 70 | """ 71 | loss = None 72 | if closure is not None: 73 | with torch.enable_grad(): 74 | loss = closure() 75 | 76 | for group in self.param_groups: 77 | if group['freeze']: continue 78 | 79 | weight_decay = group['weight_decay'] 80 | momentum = group['momentum'] 81 | dampening = group['dampening'] 82 | nesterov = group['nesterov'] 83 | 84 | for p in group['params']: 85 | if p.grad is None: 86 | continue 87 | d_p = p.grad 88 | if weight_decay != 0: 89 | d_p = d_p.add(p, alpha=weight_decay) 90 | if momentum != 0: 91 | param_state = self.state[p] 92 | if 'momentum_buffer' not in param_state: 93 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 94 | else: 95 | buf = param_state['momentum_buffer'] 96 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) 97 | if nesterov: 98 | d_p = d_p.add(buf, alpha=momentum) 99 | else: 100 | d_p = buf 101 | 102 | p.add_(d_p, alpha=-group['lr']) 103 | 104 | return loss 105 | -------------------------------------------------------------------------------- /fastreid/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | -------------------------------------------------------------------------------- /fastreid/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | # based on 8 | # https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/collect_env.py 9 | import importlib 10 | import os 11 | import re 12 | import subprocess 13 | import sys 14 | from collections import defaultdict 15 | 16 | import PIL 17 | import numpy as np 18 | import torch 19 | import torchvision 20 | from tabulate import tabulate 21 | 22 | __all__ = ["collect_env_info"] 23 | 24 | 25 | def collect_torch_env(): 26 | try: 27 | import torch.__config__ 28 | 29 | return torch.__config__.show() 30 | except ImportError: 31 | # compatible with older versions of pytorch 32 | from torch.utils.collect_env import get_pretty_env_info 33 | 34 | return get_pretty_env_info() 35 | 36 | 37 | def get_env_module(): 38 | var_name = "FASTREID_ENV_MODULE" 39 | return var_name, os.environ.get(var_name, "") 40 | 41 | 42 | def detect_compute_compatibility(CUDA_HOME, so_file): 43 | try: 44 | cuobjdump = os.path.join(CUDA_HOME, "bin", "cuobjdump") 45 | if os.path.isfile(cuobjdump): 46 | output = subprocess.check_output( 47 | "'{}' --list-elf '{}'".format(cuobjdump, so_file), shell=True 48 | ) 49 | output = output.decode("utf-8").strip().split("\n") 50 | sm = [] 51 | for line in output: 52 | line = re.findall(r"\.sm_[0-9]*\.", line)[0] 53 | sm.append(line.strip(".")) 54 | sm = sorted(set(sm)) 55 | return ", ".join(sm) 56 | else: 57 | return so_file + "; cannot find cuobjdump" 58 | except Exception: 59 | # unhandled failure 60 | return so_file 61 | 62 | 63 | def collect_env_info(): 64 | has_gpu = torch.cuda.is_available() # true for both CUDA & ROCM 65 | torch_version = torch.__version__ 66 | 67 | # NOTE: the use of CUDA_HOME and ROCM_HOME requires the CUDA/ROCM build deps, though in 68 | # theory detectron2 should be made runnable with only the corresponding runtimes 69 | from torch.utils.cpp_extension import CUDA_HOME 70 | 71 | has_rocm = False 72 | if tuple(map(int, torch_version.split(".")[:2])) >= (1, 5): 73 | from torch.utils.cpp_extension import ROCM_HOME 74 | 75 | if (getattr(torch.version, "hip", None) is not None) and (ROCM_HOME is not None): 76 | has_rocm = True 77 | has_cuda = has_gpu and (not has_rocm) 78 | 79 | data = [] 80 | data.append(("sys.platform", sys.platform)) 81 | data.append(("Python", sys.version.replace("\n", ""))) 82 | data.append(("numpy", np.__version__)) 83 | 84 | try: 85 | import fastreid # noqa 86 | 87 | data.append( 88 | ("fastreid", fastreid.__version__ + " @" + os.path.dirname(fastreid.__file__)) 89 | ) 90 | except ImportError: 91 | data.append(("fastreid", "failed to import")) 92 | 93 | data.append(get_env_module()) 94 | data.append(("PyTorch", torch_version + " @" + os.path.dirname(torch.__file__))) 95 | data.append(("PyTorch debug build", torch.version.debug)) 96 | 97 | data.append(("GPU available", has_gpu)) 98 | if has_gpu: 99 | devices = defaultdict(list) 100 | for k in range(torch.cuda.device_count()): 101 | devices[torch.cuda.get_device_name(k)].append(str(k)) 102 | for name, devids in devices.items(): 103 | data.append(("GPU " + ",".join(devids), name)) 104 | 105 | if has_rocm: 106 | data.append(("ROCM_HOME", str(ROCM_HOME))) 107 | else: 108 | data.append(("CUDA_HOME", str(CUDA_HOME))) 109 | 110 | cuda_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None) 111 | if cuda_arch_list: 112 | data.append(("TORCH_CUDA_ARCH_LIST", cuda_arch_list)) 113 | data.append(("Pillow", PIL.__version__)) 114 | 115 | try: 116 | data.append( 117 | ( 118 | "torchvision", 119 | str(torchvision.__version__) + " @" + os.path.dirname(torchvision.__file__), 120 | ) 121 | ) 122 | if has_cuda: 123 | try: 124 | torchvision_C = importlib.util.find_spec("torchvision._C").origin 125 | msg = detect_compute_compatibility(CUDA_HOME, torchvision_C) 126 | data.append(("torchvision arch flags", msg)) 127 | except ImportError: 128 | data.append(("torchvision._C", "failed to find")) 129 | except AttributeError: 130 | data.append(("torchvision", "unknown")) 131 | 132 | try: 133 | import fvcore 134 | 135 | data.append(("fvcore", fvcore.__version__)) 136 | except ImportError: 137 | pass 138 | 139 | try: 140 | import cv2 141 | 142 | data.append(("cv2", cv2.__version__)) 143 | except ImportError: 144 | pass 145 | env_str = tabulate(data) + "\n" 146 | env_str += collect_torch_env() 147 | return env_str 148 | 149 | 150 | if __name__ == "__main__": 151 | try: 152 | import detectron2 # noqa 153 | except ImportError: 154 | print(collect_env_info()) 155 | else: 156 | from fastreid.utils.collect_env import collect_env_info 157 | 158 | print(collect_env_info()) 159 | -------------------------------------------------------------------------------- /fastreid/utils/comm.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains primitives for multi-gpu communication. 3 | This is useful when doing distributed training. 4 | """ 5 | 6 | import functools 7 | import logging 8 | import numpy as np 9 | import pickle 10 | import torch 11 | import torch.distributed as dist 12 | 13 | _LOCAL_PROCESS_GROUP = None 14 | """ 15 | A torch process group which only includes processes that on the same machine as the current process. 16 | This variable is set when processes are spawned by `launch()` in "engine/launch.py". 17 | """ 18 | 19 | 20 | def get_world_size() -> int: 21 | if not dist.is_available(): 22 | return 1 23 | if not dist.is_initialized(): 24 | return 1 25 | return dist.get_world_size() 26 | 27 | 28 | def get_rank() -> int: 29 | if not dist.is_available(): 30 | return 0 31 | if not dist.is_initialized(): 32 | return 0 33 | return dist.get_rank() 34 | 35 | 36 | def get_local_rank() -> int: 37 | """ 38 | Returns: 39 | The rank of the current process within the local (per-machine) process group. 40 | """ 41 | if not dist.is_available(): 42 | return 0 43 | if not dist.is_initialized(): 44 | return 0 45 | assert _LOCAL_PROCESS_GROUP is not None 46 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 47 | 48 | 49 | def get_local_size() -> int: 50 | """ 51 | Returns: 52 | The size of the per-machine process group, 53 | i.e. the number of processes per machine. 54 | """ 55 | if not dist.is_available(): 56 | return 1 57 | if not dist.is_initialized(): 58 | return 1 59 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 60 | 61 | 62 | def is_main_process() -> bool: 63 | return get_rank() == 0 64 | 65 | 66 | def synchronize(): 67 | """ 68 | Helper function to synchronize (barrier) among all processes when 69 | using distributed training 70 | """ 71 | if not dist.is_available(): 72 | return 73 | if not dist.is_initialized(): 74 | return 75 | world_size = dist.get_world_size() 76 | if world_size == 1: 77 | return 78 | dist.barrier() 79 | 80 | 81 | @functools.lru_cache() 82 | def _get_global_gloo_group(): 83 | """ 84 | Return a process group based on gloo backend, containing all the ranks 85 | The result is cached. 86 | """ 87 | if dist.get_backend() == "nccl": 88 | return dist.new_group(backend="gloo") 89 | else: 90 | return dist.group.WORLD 91 | 92 | 93 | def _serialize_to_tensor(data, group): 94 | backend = dist.get_backend(group) 95 | assert backend in ["gloo", "nccl"] 96 | device = torch.device("cpu" if backend == "gloo" else "cuda") 97 | 98 | buffer = pickle.dumps(data) 99 | if len(buffer) > 1024 ** 3: 100 | logger = logging.getLogger(__name__) 101 | logger.warning( 102 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( 103 | get_rank(), len(buffer) / (1024 ** 3), device 104 | ) 105 | ) 106 | storage = torch.ByteStorage.from_buffer(buffer) 107 | tensor = torch.ByteTensor(storage).to(device=device) 108 | return tensor 109 | 110 | 111 | def _pad_to_largest_tensor(tensor, group): 112 | """ 113 | Returns: 114 | list[int]: size of the tensor, on each rank 115 | Tensor: padded tensor that has the max size 116 | """ 117 | world_size = dist.get_world_size(group=group) 118 | assert ( 119 | world_size >= 1 120 | ), "comm.gather/all_gather must be called from ranks within the given group!" 121 | local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) 122 | size_list = [ 123 | torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) 124 | ] 125 | dist.all_gather(size_list, local_size, group=group) 126 | size_list = [int(size.item()) for size in size_list] 127 | 128 | max_size = max(size_list) 129 | 130 | # we pad the tensor because torch all_gather does not support 131 | # gathering tensors of different shapes 132 | if local_size != max_size: 133 | padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) 134 | tensor = torch.cat((tensor, padding), dim=0) 135 | return size_list, tensor 136 | 137 | 138 | def all_gather(data, group=None): 139 | """ 140 | Run all_gather on arbitrary picklable data (not necessarily tensors). 141 | Args: 142 | data: any picklable object 143 | group: a torch process group. By default, will use a group which 144 | contains all ranks on gloo backend. 145 | Returns: 146 | list[data]: list of data gathered from each rank 147 | """ 148 | if get_world_size() == 1: 149 | return [data] 150 | if group is None: 151 | group = _get_global_gloo_group() 152 | if dist.get_world_size(group) == 1: 153 | return [data] 154 | 155 | tensor = _serialize_to_tensor(data, group) 156 | 157 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 158 | max_size = max(size_list) 159 | 160 | # receiving Tensor from all ranks 161 | tensor_list = [ 162 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list 163 | ] 164 | dist.all_gather(tensor_list, tensor, group=group) 165 | 166 | data_list = [] 167 | for size, tensor in zip(size_list, tensor_list): 168 | buffer = tensor.cpu().numpy().tobytes()[:size] 169 | data_list.append(pickle.loads(buffer)) 170 | 171 | return data_list 172 | 173 | 174 | def gather(data, dst=0, group=None): 175 | """ 176 | Run gather on arbitrary picklable data (not necessarily tensors). 177 | Args: 178 | data: any picklable object 179 | dst (int): destination rank 180 | group: a torch process group. By default, will use a group which 181 | contains all ranks on gloo backend. 182 | Returns: 183 | list[data]: on dst, a list of data gathered from each rank. Otherwise, 184 | an empty list. 185 | """ 186 | if get_world_size() == 1: 187 | return [data] 188 | if group is None: 189 | group = _get_global_gloo_group() 190 | if dist.get_world_size(group=group) == 1: 191 | return [data] 192 | rank = dist.get_rank(group=group) 193 | 194 | tensor = _serialize_to_tensor(data, group) 195 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 196 | 197 | # receiving Tensor from all ranks 198 | if rank == dst: 199 | max_size = max(size_list) 200 | tensor_list = [ 201 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list 202 | ] 203 | dist.gather(tensor, tensor_list, dst=dst, group=group) 204 | 205 | data_list = [] 206 | for size, tensor in zip(size_list, tensor_list): 207 | buffer = tensor.cpu().numpy().tobytes()[:size] 208 | data_list.append(pickle.loads(buffer)) 209 | return data_list 210 | else: 211 | dist.gather(tensor, [], dst=dst, group=group) 212 | return [] 213 | 214 | 215 | def shared_random_seed(): 216 | """ 217 | Returns: 218 | int: a random number that is the same across all workers. 219 | If workers need a shared RNG, they can use this shared seed to 220 | create one. 221 | All workers must call this function, otherwise it will deadlock. 222 | """ 223 | ints = np.random.randint(2 ** 31) 224 | all_ints = all_gather(ints) 225 | return all_ints[0] 226 | 227 | 228 | def reduce_dict(input_dict, average=True): 229 | """ 230 | Reduce the values in the dictionary from all processes so that process with rank 231 | 0 has the reduced results. 232 | Args: 233 | input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. 234 | average (bool): whether to do average or sum 235 | Returns: 236 | a dict with the same keys as input_dict, after reduction. 237 | """ 238 | world_size = get_world_size() 239 | if world_size < 2: 240 | return input_dict 241 | with torch.no_grad(): 242 | names = [] 243 | values = [] 244 | # sort the keys so that they are consistent across processes 245 | for k in sorted(input_dict.keys()): 246 | names.append(k) 247 | values.append(input_dict[k]) 248 | values = torch.stack(values, dim=0) 249 | dist.reduce(values, dst=0) 250 | if dist.get_rank() == 0 and average: 251 | # only main process gets accumulated, so only divide by 252 | # world_size in this case 253 | values /= world_size 254 | reduced_dict = {k: v for k, v in zip(names, values)} 255 | return reduced_dict 256 | -------------------------------------------------------------------------------- /fastreid/utils/compute_dist.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | # Modified from: https://github.com/open-mmlab/OpenUnReID/blob/66bb2ae0b00575b80fbe8915f4d4f4739cc21206/openunreid/core/utils/compute_dist.py 8 | 9 | 10 | import faiss 11 | import numpy as np 12 | import torch 13 | import torch.nn.functional as F 14 | 15 | from .faiss_utils import ( 16 | index_init_cpu, 17 | index_init_gpu, 18 | search_index_pytorch, 19 | search_raw_array_pytorch, 20 | ) 21 | 22 | __all__ = [ 23 | "build_dist", 24 | "compute_jaccard_distance", 25 | "compute_euclidean_distance", 26 | "compute_cosine_distance", 27 | "compute_pose_distance" 28 | ] 29 | 30 | 31 | @torch.no_grad() 32 | def build_dist(feat_1: torch.Tensor, feat_2: torch.Tensor, metric: str = "euclidean", **kwargs) -> np.ndarray: 33 | r"""Compute distance between two feature embeddings. 34 | 35 | Args: 36 | feat_1 (torch.Tensor): 2-D feature with batch dimension. 37 | feat_2 (torch.Tensor): 2-D feature with batch dimension. 38 | metric: 39 | 40 | Returns: 41 | numpy.ndarray: distance matrix. 42 | """ 43 | assert metric in ["cosine", "euclidean", "jaccard", "pose"], "Expected metrics are cosine, euclidean and jaccard, " \ 44 | "but got {}".format(metric) 45 | 46 | if metric == "euclidean": 47 | return compute_euclidean_distance(feat_1, feat_2) 48 | 49 | elif metric == "cosine": 50 | return compute_cosine_distance(feat_1, feat_2) 51 | 52 | elif metric == "jaccard": 53 | feat = torch.cat((feat_1, feat_2), dim=0) 54 | dist = compute_jaccard_distance(feat, k1=kwargs["k1"], k2=kwargs["k2"], search_option=0) 55 | return dist[: feat_1.size(0), feat_1.size(0):] 56 | 57 | elif metric == 'pose': 58 | return compute_pose_distance(feat_1, feat_2, qconf=kwargs["qconf"], gconf=kwargs["gconf"]) 59 | 60 | 61 | def k_reciprocal_neigh(initial_rank, i, k1): 62 | forward_k_neigh_index = initial_rank[i, : k1 + 1] 63 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, : k1 + 1] 64 | fi = np.where(backward_k_neigh_index == i)[0] 65 | return forward_k_neigh_index[fi] 66 | 67 | 68 | @torch.no_grad() 69 | def compute_jaccard_distance(features, k1=20, k2=6, search_option=0, fp16=False): 70 | if search_option < 3: 71 | # torch.cuda.empty_cache() 72 | features = features.cuda() 73 | 74 | ngpus = faiss.get_num_gpus() 75 | N = features.size(0) 76 | mat_type = np.float16 if fp16 else np.float32 77 | 78 | if search_option == 0: 79 | # GPU + PyTorch CUDA Tensors (1) 80 | res = faiss.StandardGpuResources() 81 | res.setDefaultNullStreamAllDevices() 82 | _, initial_rank = search_raw_array_pytorch(res, features, features, k1) 83 | initial_rank = initial_rank.cpu().numpy() 84 | elif search_option == 1: 85 | # GPU + PyTorch CUDA Tensors (2) 86 | res = faiss.StandardGpuResources() 87 | index = faiss.GpuIndexFlatL2(res, features.size(-1)) 88 | index.add(features.cpu().numpy()) 89 | _, initial_rank = search_index_pytorch(index, features, k1) 90 | res.syncDefaultStreamCurrentDevice() 91 | initial_rank = initial_rank.cpu().numpy() 92 | elif search_option == 2: 93 | # GPU 94 | index = index_init_gpu(ngpus, features.size(-1)) 95 | index.add(features.cpu().numpy()) 96 | _, initial_rank = index.search(features.cpu().numpy(), k1) 97 | else: 98 | # CPU 99 | index = index_init_cpu(features.size(-1)) 100 | index.add(features.cpu().numpy()) 101 | _, initial_rank = index.search(features.cpu().numpy(), k1) 102 | 103 | nn_k1 = [] 104 | nn_k1_half = [] 105 | for i in range(N): 106 | nn_k1.append(k_reciprocal_neigh(initial_rank, i, k1)) 107 | nn_k1_half.append(k_reciprocal_neigh(initial_rank, i, int(np.around(k1 / 2)))) 108 | 109 | V = np.zeros((N, N), dtype=mat_type) 110 | for i in range(N): 111 | k_reciprocal_index = nn_k1[i] 112 | k_reciprocal_expansion_index = k_reciprocal_index 113 | for candidate in k_reciprocal_index: 114 | candidate_k_reciprocal_index = nn_k1_half[candidate] 115 | if len( 116 | np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index) 117 | ) > 2 / 3 * len(candidate_k_reciprocal_index): 118 | k_reciprocal_expansion_index = np.append( 119 | k_reciprocal_expansion_index, candidate_k_reciprocal_index 120 | ) 121 | 122 | k_reciprocal_expansion_index = np.unique( 123 | k_reciprocal_expansion_index 124 | ) # element-wise unique 125 | 126 | x = features[i].unsqueeze(0).contiguous() 127 | y = features[k_reciprocal_expansion_index] 128 | m, n = x.size(0), y.size(0) 129 | dist = ( 130 | torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) 131 | + torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 132 | ) 133 | dist.addmm_(x, y.t(), beta=1, alpha=-2) 134 | 135 | if fp16: 136 | V[i, k_reciprocal_expansion_index] = ( 137 | F.softmax(-dist, dim=1).view(-1).cpu().numpy().astype(mat_type) 138 | ) 139 | else: 140 | V[i, k_reciprocal_expansion_index] = ( 141 | F.softmax(-dist, dim=1).view(-1).cpu().numpy() 142 | ) 143 | 144 | del nn_k1, nn_k1_half, x, y 145 | features = features.cpu() 146 | 147 | if k2 != 1: 148 | V_qe = np.zeros_like(V, dtype=mat_type) 149 | for i in range(N): 150 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 151 | V = V_qe 152 | del V_qe 153 | 154 | del initial_rank 155 | 156 | invIndex = [] 157 | for i in range(N): 158 | invIndex.append(np.where(V[:, i] != 0)[0]) # len(invIndex)=all_num 159 | 160 | jaccard_dist = np.zeros((N, N), dtype=mat_type) 161 | for i in range(N): 162 | temp_min = np.zeros((1, N), dtype=mat_type) 163 | indNonZero = np.where(V[i, :] != 0)[0] 164 | indImages = [invIndex[ind] for ind in indNonZero] 165 | for j in range(len(indNonZero)): 166 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum( 167 | V[i, indNonZero[j]], V[indImages[j], indNonZero[j]] 168 | ) 169 | 170 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 171 | 172 | del invIndex, V 173 | 174 | pos_bool = jaccard_dist < 0 175 | jaccard_dist[pos_bool] = 0.0 176 | 177 | return jaccard_dist 178 | 179 | 180 | @torch.no_grad() 181 | def compute_euclidean_distance(features, others): 182 | m, n = features.size(0), others.size(0) 183 | dist_m = ( 184 | torch.pow(features, 2).sum(dim=1, keepdim=True).expand(m, n) 185 | + torch.pow(others, 2).sum(dim=1, keepdim=True).expand(n, m).t() 186 | ) 187 | dist_m.addmm_(1, -2, features, others.t()) 188 | 189 | return dist_m.cpu().numpy() 190 | 191 | 192 | @torch.no_grad() 193 | def compute_cosine_distance(features, others): 194 | """Computes cosine distance. 195 | Args: 196 | features (torch.Tensor): 2-D feature matrix. 197 | others (torch.Tensor): 2-D feature matrix. 198 | Returns: 199 | torch.Tensor: distance matrix. 200 | """ 201 | features = F.normalize(features, p=2, dim=1) 202 | others = F.normalize(others, p=2, dim=1) 203 | dist_m = 1 - torch.mm(features, others.t()) 204 | return dist_m.cpu().numpy() 205 | 206 | 207 | @torch.no_grad() 208 | def compute_pose_distance(qfeats, gfeats, qconf, gconf): 209 | # qfeats.shape = [n, 2048, p] 210 | # gfeats.shape = [m, 2048, p] 211 | # qconf.shape = [n, p] 212 | n, m = qfeats.shape[0], gfeats.shape[0] 213 | dist = torch.zeros(n, m, dtype=torch.float32) 214 | for i in range(qfeats.shape[-1]): 215 | 216 | qfeat = qfeats[:, :, i] 217 | gfeat = gfeats[:, :, i] 218 | 219 | dist_m = ( 220 | torch.pow(qfeat, 2).sum(dim=1, keepdim=True).expand(n, m) 221 | + torch.pow(gfeat, 2).sum(dim=1, keepdim=True).expand(m, n).t() 222 | ) 223 | dist_m.addmm_(1, -2, qfeat, gfeat.t()) 224 | # gconf[:, i] 225 | dist += dist_m * (1. / qfeats.shape[-1]) # * torch.sqrt(torch.einsum('n,m->nm', qconf[:, i], qconf[:, i])) 226 | 227 | return dist.cpu().numpy() -------------------------------------------------------------------------------- /fastreid/utils/env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import importlib 3 | import importlib.util 4 | import logging 5 | import numpy as np 6 | import os 7 | import random 8 | import sys 9 | from datetime import datetime 10 | import torch 11 | from fastreid.utils.comm import get_local_rank 12 | 13 | __all__ = ["seed_all_rng"] 14 | 15 | 16 | TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2]) 17 | """ 18 | PyTorch version as a tuple of 2 ints. Useful for comparison. 19 | """ 20 | 21 | 22 | def seed_all_rng(seed=None): 23 | """ 24 | Set the random seed for the RNG in torch, numpy and python. 25 | Args: 26 | seed (int): if None, will use a strong random seed. 27 | """ 28 | if seed is None: 29 | seed = ( 30 | # os.getpid() 31 | # + int(datetime.now().strftime("%S%f")) 32 | # + int.from_bytes(os.urandom(2), "big") 33 | (get_local_rank() * 22123 - get_local_rank() + 541) * 343 34 | ) 35 | # 343 342 36 | logger = logging.getLogger(__name__) 37 | logger.info("Using a generated random seed {}".format(seed)) 38 | np.random.seed(seed) 39 | torch.set_rng_state(torch.manual_seed(seed).get_state()) 40 | random.seed(seed) 41 | 42 | 43 | # from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path 44 | def _import_file(module_name, file_path, make_importable=False): 45 | spec = importlib.util.spec_from_file_location(module_name, file_path) 46 | module = importlib.util.module_from_spec(spec) 47 | spec.loader.exec_module(module) 48 | if make_importable: 49 | sys.modules[module_name] = module 50 | return module 51 | 52 | 53 | def _configure_libraries(): 54 | """ 55 | Configurations for some libraries. 56 | """ 57 | # An environment option to disable `import cv2` globally, 58 | # in case it leads to negative performance impact 59 | disable_cv2 = int(os.environ.get("DETECTRON2_DISABLE_CV2", False)) 60 | if disable_cv2: 61 | sys.modules["cv2"] = None 62 | else: 63 | # Disable opencl in opencv since its interaction with cuda often has negative effects 64 | # This envvar is supported after OpenCV 3.4.0 65 | os.environ["OPENCV_OPENCL_RUNTIME"] = "disabled" 66 | try: 67 | import cv2 68 | 69 | if int(cv2.__version__.split(".")[0]) >= 3: 70 | cv2.ocl.setUseOpenCL(False) 71 | except ImportError: 72 | pass 73 | 74 | def get_version(module, digit=2): 75 | return tuple(map(int, module.__version__.split(".")[:digit])) 76 | 77 | # fmt: off 78 | assert get_version(torch) >= (1, 4), "Requires torch>=1.4" 79 | import yaml 80 | assert get_version(yaml) >= (5, 1), "Requires pyyaml>=5.1" 81 | # fmt: on 82 | 83 | 84 | _ENV_SETUP_DONE = False 85 | 86 | 87 | def setup_environment(): 88 | """Perform environment setup work. The default setup is a no-op, but this 89 | function allows the user to specify a Python source file or a module in 90 | the $FASTREID_ENV_MODULE environment variable, that performs 91 | custom setup work that may be necessary to their computing environment. 92 | """ 93 | global _ENV_SETUP_DONE 94 | if _ENV_SETUP_DONE: 95 | return 96 | _ENV_SETUP_DONE = True 97 | 98 | _configure_libraries() 99 | 100 | custom_module_path = os.environ.get("FASTREID_ENV_MODULE") 101 | 102 | if custom_module_path: 103 | setup_custom_environment(custom_module_path) 104 | else: 105 | # The default setup is a no-op 106 | pass 107 | 108 | 109 | def setup_custom_environment(custom_module): 110 | """ 111 | Load custom environment setup by importing a Python source file or a 112 | module, and run the setup function. 113 | """ 114 | if custom_module.endswith(".py"): 115 | module = _import_file("fastreid.utils.env.custom_module", custom_module) 116 | else: 117 | module = importlib.import_module(custom_module) 118 | assert hasattr(module, "setup_environment") and callable(module.setup_environment), ( 119 | "Custom environment module defined in {} does not have the " 120 | "required callable attribute 'setup_environment'." 121 | ).format(custom_module) 122 | module.setup_environment() -------------------------------------------------------------------------------- /fastreid/utils/faiss_utils.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # copy from: https://github.com/open-mmlab/OpenUnReID/blob/66bb2ae0b00575b80fbe8915f4d4f4739cc21206/openunreid/core/utils/faiss_utils.py 3 | 4 | import faiss 5 | import torch 6 | 7 | 8 | def swig_ptr_from_FloatTensor(x): 9 | assert x.is_contiguous() 10 | assert x.dtype == torch.float32 11 | return faiss.cast_integer_to_float_ptr( 12 | x.storage().data_ptr() + x.storage_offset() * 4 13 | ) 14 | 15 | 16 | def swig_ptr_from_LongTensor(x): 17 | assert x.is_contiguous() 18 | assert x.dtype == torch.int64, "dtype=%s" % x.dtype 19 | return faiss.cast_integer_to_long_ptr( 20 | x.storage().data_ptr() + x.storage_offset() * 8 21 | ) 22 | 23 | 24 | def search_index_pytorch(index, x, k, D=None, I=None): 25 | """call the search function of an index with pytorch tensor I/O (CPU 26 | and GPU supported)""" 27 | assert x.is_contiguous() 28 | n, d = x.size() 29 | assert d == index.d 30 | 31 | if D is None: 32 | D = torch.empty((n, k), dtype=torch.float32, device=x.device) 33 | else: 34 | assert D.size() == (n, k) 35 | 36 | if I is None: 37 | I = torch.empty((n, k), dtype=torch.int64, device=x.device) 38 | else: 39 | assert I.size() == (n, k) 40 | torch.cuda.synchronize() 41 | xptr = swig_ptr_from_FloatTensor(x) 42 | Iptr = swig_ptr_from_LongTensor(I) 43 | Dptr = swig_ptr_from_FloatTensor(D) 44 | index.search_c(n, xptr, k, Dptr, Iptr) 45 | torch.cuda.synchronize() 46 | return D, I 47 | 48 | 49 | def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None, metric=faiss.METRIC_L2): 50 | assert xb.device == xq.device 51 | 52 | nq, d = xq.size() 53 | if xq.is_contiguous(): 54 | xq_row_major = True 55 | elif xq.t().is_contiguous(): 56 | xq = xq.t() # I initially wrote xq:t(), Lua is still haunting me :-) 57 | xq_row_major = False 58 | else: 59 | raise TypeError("matrix should be row or column-major") 60 | 61 | xq_ptr = swig_ptr_from_FloatTensor(xq) 62 | 63 | nb, d2 = xb.size() 64 | assert d2 == d 65 | if xb.is_contiguous(): 66 | xb_row_major = True 67 | elif xb.t().is_contiguous(): 68 | xb = xb.t() 69 | xb_row_major = False 70 | else: 71 | raise TypeError("matrix should be row or column-major") 72 | xb_ptr = swig_ptr_from_FloatTensor(xb) 73 | 74 | if D is None: 75 | D = torch.empty(nq, k, device=xb.device, dtype=torch.float32) 76 | else: 77 | assert D.shape == (nq, k) 78 | assert D.device == xb.device 79 | 80 | if I is None: 81 | I = torch.empty(nq, k, device=xb.device, dtype=torch.int64) 82 | else: 83 | assert I.shape == (nq, k) 84 | assert I.device == xb.device 85 | 86 | D_ptr = swig_ptr_from_FloatTensor(D) 87 | I_ptr = swig_ptr_from_LongTensor(I) 88 | 89 | faiss.bruteForceKnn( 90 | res, 91 | metric, 92 | xb_ptr, 93 | xb_row_major, 94 | nb, 95 | xq_ptr, 96 | xq_row_major, 97 | nq, 98 | d, 99 | k, 100 | D_ptr, 101 | I_ptr, 102 | ) 103 | 104 | return D, I 105 | 106 | 107 | def index_init_gpu(ngpus, feat_dim): 108 | flat_config = [] 109 | for i in range(ngpus): 110 | cfg = faiss.GpuIndexFlatConfig() 111 | cfg.useFloat16 = False 112 | cfg.device = i 113 | flat_config.append(cfg) 114 | 115 | res = [faiss.StandardGpuResources() for i in range(ngpus)] 116 | indexes = [ 117 | faiss.GpuIndexFlatL2(res[i], feat_dim, flat_config[i]) for i in range(ngpus) 118 | ] 119 | index = faiss.IndexShards(feat_dim) 120 | for sub_index in indexes: 121 | index.add_shard(sub_index) 122 | index.reset() 123 | return index 124 | 125 | 126 | def index_init_cpu(feat_dim): 127 | return faiss.IndexFlatL2(feat_dim) 128 | -------------------------------------------------------------------------------- /fastreid/utils/history_buffer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | import numpy as np 5 | from typing import List, Tuple 6 | 7 | 8 | class HistoryBuffer: 9 | """ 10 | Track a series of scalar values and provide access to smoothed values over a 11 | window or the global average of the series. 12 | """ 13 | 14 | def __init__(self, max_length: int = 1000000): 15 | """ 16 | Args: 17 | max_length: maximal number of values that can be stored in the 18 | buffer. When the capacity of the buffer is exhausted, old 19 | values will be removed. 20 | """ 21 | self._max_length: int = max_length 22 | self._data: List[Tuple[float, float]] = [] # (value, iteration) pairs 23 | self._count: int = 0 24 | self._global_avg: float = 0 25 | 26 | def update(self, value: float, iteration: float = None): 27 | """ 28 | Add a new scalar value produced at certain iteration. If the length 29 | of the buffer exceeds self._max_length, the oldest element will be 30 | removed from the buffer. 31 | """ 32 | if iteration is None: 33 | iteration = self._count 34 | if len(self._data) == self._max_length: 35 | self._data.pop(0) 36 | self._data.append((value, iteration)) 37 | 38 | self._count += 1 39 | self._global_avg += (value - self._global_avg) / self._count 40 | 41 | def latest(self): 42 | """ 43 | Return the latest scalar value added to the buffer. 44 | """ 45 | return self._data[-1][0] 46 | 47 | def median(self, window_size: int): 48 | """ 49 | Return the median of the latest `window_size` values in the buffer. 50 | """ 51 | return np.median([x[0] for x in self._data[-window_size:]]) 52 | 53 | def avg(self, window_size: int): 54 | """ 55 | Return the mean of the latest `window_size` values in the buffer. 56 | """ 57 | return np.mean([x[0] for x in self._data[-window_size:]]) 58 | 59 | def global_avg(self): 60 | """ 61 | Return the mean of all the elements in the buffer. Note that this 62 | includes those getting removed due to limited buffer storage. 63 | """ 64 | return self._global_avg 65 | 66 | def values(self): 67 | """ 68 | Returns: 69 | list[(number, iteration)]: content of the current buffer. 70 | """ 71 | return self._data 72 | -------------------------------------------------------------------------------- /fastreid/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import functools 3 | import logging 4 | import os 5 | import sys 6 | import time 7 | from collections import Counter 8 | from .file_io import PathManager 9 | from termcolor import colored 10 | 11 | 12 | class _ColorfulFormatter(logging.Formatter): 13 | def __init__(self, *args, **kwargs): 14 | self._root_name = kwargs.pop("root_name") + "." 15 | self._abbrev_name = kwargs.pop("abbrev_name", "") 16 | if len(self._abbrev_name): 17 | self._abbrev_name = self._abbrev_name + "." 18 | super(_ColorfulFormatter, self).__init__(*args, **kwargs) 19 | 20 | def formatMessage(self, record): 21 | record.name = record.name.replace(self._root_name, self._abbrev_name) 22 | log = super(_ColorfulFormatter, self).formatMessage(record) 23 | if record.levelno == logging.WARNING: 24 | prefix = colored("WARNING", "red", attrs=["blink"]) 25 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 26 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 27 | else: 28 | return log 29 | return prefix + " " + log 30 | 31 | 32 | @functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers 33 | def setup_logger( 34 | output=None, distributed_rank=0, *, color=True, name="fastreid", abbrev_name=None 35 | ): 36 | """ 37 | Args: 38 | output (str): a file name or a directory to save log. If None, will not save log file. 39 | If ends with ".txt" or ".log", assumed to be a file name. 40 | Otherwise, logs will be saved to `output/log.txt`. 41 | name (str): the root module name of this logger 42 | abbrev_name (str): an abbreviation of the module, to avoid long names in logs. 43 | Set to "" to not log the root module in logs. 44 | By default, will abbreviate "detectron2" to "d2" and leave other 45 | modules unchanged. 46 | """ 47 | logger = logging.getLogger(name) 48 | logger.setLevel(logging.DEBUG) 49 | logger.propagate = False 50 | 51 | if abbrev_name is None: 52 | abbrev_name = "d2" if name == "detectron2" else name 53 | 54 | plain_formatter = logging.Formatter( 55 | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" 56 | ) 57 | # stdout logging: master only 58 | if distributed_rank == 0: 59 | ch = logging.StreamHandler(stream=sys.stdout) 60 | ch.setLevel(logging.DEBUG) 61 | if color: 62 | formatter = _ColorfulFormatter( 63 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 64 | datefmt="%m/%d %H:%M:%S", 65 | root_name=name, 66 | abbrev_name=str(abbrev_name), 67 | ) 68 | else: 69 | formatter = plain_formatter 70 | ch.setFormatter(formatter) 71 | logger.addHandler(ch) 72 | 73 | # file logging: all workers 74 | if output is not None: 75 | if output.endswith(".txt") or output.endswith(".log"): 76 | filename = output 77 | else: 78 | filename = os.path.join(output, "log.txt") 79 | if distributed_rank > 0: 80 | filename = filename + ".rank{}".format(distributed_rank) 81 | PathManager.mkdirs(os.path.dirname(filename)) 82 | 83 | fh = logging.StreamHandler(_cached_log_stream(filename)) 84 | fh.setLevel(logging.DEBUG) 85 | fh.setFormatter(plain_formatter) 86 | logger.addHandler(fh) 87 | 88 | return logger 89 | 90 | 91 | # cache the opened file object, so that different calls to `setup_logger` 92 | # with the same file name can safely write to the same file. 93 | @functools.lru_cache(maxsize=None) 94 | def _cached_log_stream(filename): 95 | return PathManager.open(filename, "a") 96 | 97 | 98 | """ 99 | Below are some other convenient logging methods. 100 | They are mainly adopted from 101 | https://github.com/abseil/abseil-py/blob/master/absl/logging/__init__.py 102 | """ 103 | 104 | 105 | def _find_caller(): 106 | """ 107 | Returns: 108 | str: module name of the caller 109 | tuple: a hashable key to be used to identify different callers 110 | """ 111 | frame = sys._getframe(2) 112 | while frame: 113 | code = frame.f_code 114 | if os.path.join("utils", "logger.") not in code.co_filename: 115 | mod_name = frame.f_globals["__name__"] 116 | if mod_name == "__main__": 117 | mod_name = "detectron2" 118 | return mod_name, (code.co_filename, frame.f_lineno, code.co_name) 119 | frame = frame.f_back 120 | 121 | 122 | _LOG_COUNTER = Counter() 123 | _LOG_TIMER = {} 124 | 125 | 126 | def log_first_n(lvl, msg, n=1, *, name=None, key="caller"): 127 | """ 128 | Log only for the first n times. 129 | Args: 130 | lvl (int): the logging level 131 | msg (str): 132 | n (int): 133 | name (str): name of the logger to use. Will use the caller's module by default. 134 | key (str or tuple[str]): the string(s) can be one of "caller" or 135 | "message", which defines how to identify duplicated logs. 136 | For example, if called with `n=1, key="caller"`, this function 137 | will only log the first call from the same caller, regardless of 138 | the message content. 139 | If called with `n=1, key="message"`, this function will log the 140 | same content only once, even if they are called from different places. 141 | If called with `n=1, key=("caller", "message")`, this function 142 | will not log only if the same caller has logged the same message before. 143 | """ 144 | if isinstance(key, str): 145 | key = (key,) 146 | assert len(key) > 0 147 | 148 | caller_module, caller_key = _find_caller() 149 | hash_key = () 150 | if "caller" in key: 151 | hash_key = hash_key + caller_key 152 | if "message" in key: 153 | hash_key = hash_key + (msg,) 154 | 155 | _LOG_COUNTER[hash_key] += 1 156 | if _LOG_COUNTER[hash_key] <= n: 157 | logging.getLogger(name or caller_module).log(lvl, msg) 158 | 159 | 160 | def log_every_n(lvl, msg, n=1, *, name=None): 161 | """ 162 | Log once per n times. 163 | Args: 164 | lvl (int): the logging level 165 | msg (str): 166 | n (int): 167 | name (str): name of the logger to use. Will use the caller's module by default. 168 | """ 169 | caller_module, key = _find_caller() 170 | _LOG_COUNTER[key] += 1 171 | if n == 1 or _LOG_COUNTER[key] % n == 1: 172 | logging.getLogger(name or caller_module).log(lvl, msg) 173 | 174 | 175 | def log_every_n_seconds(lvl, msg, n=1, *, name=None): 176 | """ 177 | Log no more than once per n seconds. 178 | Args: 179 | lvl (int): the logging level 180 | msg (str): 181 | n (int): 182 | name (str): name of the logger to use. Will use the caller's module by default. 183 | """ 184 | caller_module, key = _find_caller() 185 | last_logged = _LOG_TIMER.get(key, None) 186 | current_time = time.time() 187 | if last_logged is None or current_time - last_logged >= n: 188 | logging.getLogger(name or caller_module).log(lvl, msg) 189 | _LOG_TIMER[key] = current_time 190 | 191 | # def create_small_table(small_dict): 192 | # """ 193 | # Create a small table using the keys of small_dict as headers. This is only 194 | # suitable for small dictionaries. 195 | # Args: 196 | # small_dict (dict): a result dictionary of only a few items. 197 | # Returns: 198 | # str: the table as a string. 199 | # """ 200 | # keys, values = tuple(zip(*small_dict.items())) 201 | # table = tabulate( 202 | # [values], 203 | # headers=keys, 204 | # tablefmt="pipe", 205 | # floatfmt=".3f", 206 | # stralign="center", 207 | # numalign="center", 208 | # ) 209 | # return table 210 | -------------------------------------------------------------------------------- /fastreid/utils/precision_bn.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import itertools 8 | 9 | import torch 10 | 11 | BN_MODULE_TYPES = ( 12 | torch.nn.BatchNorm1d, 13 | torch.nn.BatchNorm2d, 14 | torch.nn.BatchNorm3d, 15 | torch.nn.SyncBatchNorm, 16 | ) 17 | 18 | 19 | @torch.no_grad() 20 | def update_bn_stats(model, data_loader, num_iters: int = 200): 21 | """ 22 | Recompute and update the batch norm stats to make them more precise. During 23 | training both BN stats and the weight are changing after every iteration, so 24 | the running average can not precisely reflect the actual stats of the 25 | current model. 26 | In this function, the BN stats are recomputed with fixed weights, to make 27 | the running average more precise. Specifically, it computes the true average 28 | of per-batch mean/variance instead of the running average. 29 | Args: 30 | model (nn.Module): the model whose bn stats will be recomputed. 31 | Note that: 32 | 1. This function will not alter the training mode of the given model. 33 | Users are responsible for setting the layers that needs 34 | precise-BN to training mode, prior to calling this function. 35 | 2. Be careful if your models contain other stateful layers in 36 | addition to BN, i.e. layers whose state can change in forward 37 | iterations. This function will alter their state. If you wish 38 | them unchanged, you need to either pass in a submodule without 39 | those layers, or backup the states. 40 | data_loader (iterator): an iterator. Produce data as inputs to the model. 41 | num_iters (int): number of iterations to compute the stats. 42 | """ 43 | bn_layers = get_bn_modules(model) 44 | if len(bn_layers) == 0: 45 | return 46 | 47 | # In order to make the running stats only reflect the current batch, the 48 | # momentum is disabled. 49 | # bn.running_mean = (1 - momentum) * bn.running_mean + momentum * batch_mean 50 | # Setting the momentum to 1.0 to compute the stats without momentum. 51 | momentum_actual = [bn.momentum for bn in bn_layers] 52 | for bn in bn_layers: 53 | bn.momentum = 1.0 54 | 55 | # Note that running_var actually means "running average of variance" 56 | running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers] 57 | running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers] 58 | 59 | for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)): 60 | inputs['targets'].fill_(-1) 61 | with torch.no_grad(): # No need to backward 62 | model(inputs) 63 | for i, bn in enumerate(bn_layers): 64 | # Accumulates the bn stats. 65 | running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1) 66 | running_var[i] += (bn.running_var - running_var[i]) / (ind + 1) 67 | # We compute the "average of variance" across iterations. 68 | assert ind == num_iters - 1, ( 69 | "update_bn_stats is meant to run for {} iterations, " 70 | "but the dataloader stops at {} iterations.".format(num_iters, ind) 71 | ) 72 | 73 | for i, bn in enumerate(bn_layers): 74 | # Sets the precise bn stats. 75 | bn.running_mean = running_mean[i] 76 | bn.running_var = running_var[i] 77 | bn.momentum = momentum_actual[i] 78 | 79 | 80 | def get_bn_modules(model): 81 | """ 82 | Find all BatchNorm (BN) modules that are in training mode. See 83 | fvcore.precise_bn.BN_MODULE_TYPES for a list of all modules that are 84 | included in this search. 85 | Args: 86 | model (nn.Module): a model possibly containing BN modules. 87 | Returns: 88 | list[nn.Module]: all BN modules in the model. 89 | """ 90 | # Finds all the bn layers. 91 | bn_layers = [ 92 | m for m in model.modules() if m.training and isinstance(m, BN_MODULE_TYPES) 93 | ] 94 | return bn_layers 95 | -------------------------------------------------------------------------------- /fastreid/utils/registry.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from typing import Dict, Optional 5 | 6 | 7 | class Registry(object): 8 | """ 9 | The registry that provides name -> object mapping, to support third-party 10 | users' custom modules. 11 | To create a registry (e.g. a backbone registry): 12 | .. code-block:: python 13 | BACKBONE_REGISTRY = Registry('BACKBONE') 14 | To register an object: 15 | .. code-block:: python 16 | @BACKBONE_REGISTRY.register() 17 | class MyBackbone(): 18 | ... 19 | Or: 20 | .. code-block:: python 21 | BACKBONE_REGISTRY.register(MyBackbone) 22 | """ 23 | 24 | def __init__(self, name: str) -> None: 25 | """ 26 | Args: 27 | name (str): the name of this registry 28 | """ 29 | self._name: str = name 30 | self._obj_map: Dict[str, object] = {} 31 | 32 | def _do_register(self, name: str, obj: object) -> None: 33 | assert ( 34 | name not in self._obj_map 35 | ), "An object named '{}' was already registered in '{}' registry!".format( 36 | name, self._name 37 | ) 38 | self._obj_map[name] = obj 39 | 40 | def register(self, obj: object = None) -> Optional[object]: 41 | """ 42 | Register the given object under the the name `obj.__name__`. 43 | Can be used as either a decorator or not. See docstring of this class for usage. 44 | """ 45 | if obj is None: 46 | # used as a decorator 47 | def deco(func_or_class: object) -> object: 48 | name = func_or_class.__name__ # pyre-ignore 49 | self._do_register(name, func_or_class) 50 | return func_or_class 51 | 52 | return deco 53 | 54 | # used as a function call 55 | name = obj.__name__ # pyre-ignore 56 | self._do_register(name, obj) 57 | 58 | def get(self, name: str) -> object: 59 | ret = self._obj_map.get(name) 60 | if ret is None: 61 | raise KeyError( 62 | "No object named '{}' found in '{}' registry!".format( 63 | name, self._name 64 | ) 65 | ) 66 | return ret 67 | -------------------------------------------------------------------------------- /fastreid/utils/summary.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.autograd import Variable 10 | 11 | from collections import OrderedDict 12 | import numpy as np 13 | 14 | 15 | def summary(model, input_size, batch_size=-1, device="cuda"): 16 | def register_hook(module): 17 | 18 | def hook(module, input, output): 19 | class_name = str(module.__class__).split(".")[-1].split("'")[0] 20 | module_idx = len(summary) 21 | 22 | m_key = "%s-%i" % (class_name, module_idx + 1) 23 | summary[m_key] = OrderedDict() 24 | summary[m_key]["input_shape"] = list(input[0].size()) 25 | summary[m_key]["input_shape"][0] = batch_size 26 | if isinstance(output, (list, tuple)): 27 | summary[m_key]["output_shape"] = [ 28 | [-1] + list(o.size())[1:] for o in output 29 | ] 30 | else: 31 | summary[m_key]["output_shape"] = list(output.size()) 32 | summary[m_key]["output_shape"][0] = batch_size 33 | 34 | params = 0 35 | if hasattr(module, "weight") and hasattr(module.weight, "size"): 36 | params += torch.prod(torch.LongTensor(list(module.weight.size()))) 37 | summary[m_key]["trainable"] = module.weight.requires_grad 38 | if hasattr(module, "bias") and hasattr(module.bias, "size"): 39 | params += torch.prod(torch.LongTensor(list(module.bias.size()))) 40 | summary[m_key]["nb_params"] = params 41 | 42 | if ( 43 | not isinstance(module, nn.Sequential) 44 | and not isinstance(module, nn.ModuleList) 45 | and not (module == model) 46 | ): 47 | hooks.append(module.register_forward_hook(hook)) 48 | 49 | device = device.lower() 50 | assert device in [ 51 | "cuda", 52 | "cpu", 53 | ], "Input device is not valid, please specify 'cuda' or 'cpu'" 54 | 55 | if device == "cuda" and torch.cuda.is_available(): 56 | dtype = torch.cuda.FloatTensor 57 | else: 58 | dtype = torch.FloatTensor 59 | 60 | # multiple inputs to the network 61 | if isinstance(input_size, tuple): 62 | input_size = [input_size] 63 | 64 | # batch_size of 2 for batchnorm 65 | x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size] 66 | # print(type(x[0])) 67 | 68 | # create properties 69 | summary = OrderedDict() 70 | hooks = [] 71 | 72 | # register hook 73 | model.apply(register_hook) 74 | 75 | # make a forward pass 76 | # print(x.shape) 77 | model(*x) 78 | 79 | # remove these hooks 80 | for h in hooks: 81 | h.remove() 82 | 83 | print("----------------------------------------------------------------") 84 | line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #") 85 | print(line_new) 86 | print("================================================================") 87 | total_params = 0 88 | total_output = 0 89 | trainable_params = 0 90 | for layer in summary: 91 | # input_shape, output_shape, trainable, nb_params 92 | line_new = "{:>20} {:>25} {:>15}".format( 93 | layer, 94 | str(summary[layer]["output_shape"]), 95 | "{0:,}".format(summary[layer]["nb_params"]), 96 | ) 97 | total_params += summary[layer]["nb_params"] 98 | total_output += np.prod(summary[layer]["output_shape"]) 99 | if "trainable" in summary[layer]: 100 | if summary[layer]["trainable"] == True: 101 | trainable_params += summary[layer]["nb_params"] 102 | print(line_new) 103 | 104 | # assume 4 bytes/number (float on cuda). 105 | total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.)) 106 | total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients 107 | total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.)) 108 | total_size = total_params_size + total_output_size + total_input_size 109 | 110 | print("================================================================") 111 | print("Total params: {0:,}".format(total_params)) 112 | print("Trainable params: {0:,}".format(trainable_params)) 113 | print("Non-trainable params: {0:,}".format(total_params - trainable_params)) 114 | print("----------------------------------------------------------------") 115 | print("Input size (MB): %0.2f" % total_input_size) 116 | print("Forward/backward pass size (MB): %0.2f" % total_output_size) 117 | print("Params size (MB): %0.2f" % total_params_size) 118 | print("Estimated Total Size (MB): %0.2f" % total_size) 119 | print("----------------------------------------------------------------") 120 | # return summary 121 | -------------------------------------------------------------------------------- /fastreid/utils/timer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # -*- coding: utf-8 -*- 3 | 4 | from time import perf_counter 5 | from typing import Optional 6 | 7 | 8 | class Timer: 9 | """ 10 | A timer which computes the time elapsed since the start/reset of the timer. 11 | """ 12 | 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | """ 18 | Reset the timer. 19 | """ 20 | self._start = perf_counter() 21 | self._paused: Optional[float] = None 22 | self._total_paused = 0 23 | self._count_start = 1 24 | 25 | def pause(self): 26 | """ 27 | Pause the timer. 28 | """ 29 | if self._paused is not None: 30 | raise ValueError("Trying to pause a Timer that is already paused!") 31 | self._paused = perf_counter() 32 | 33 | def is_paused(self) -> bool: 34 | """ 35 | Returns: 36 | bool: whether the timer is currently paused 37 | """ 38 | return self._paused is not None 39 | 40 | def resume(self): 41 | """ 42 | Resume the timer. 43 | """ 44 | if self._paused is None: 45 | raise ValueError("Trying to resume a Timer that is not paused!") 46 | self._total_paused += perf_counter() - self._paused 47 | self._paused = None 48 | self._count_start += 1 49 | 50 | def seconds(self) -> float: 51 | """ 52 | Returns: 53 | (float): the total number of seconds since the start/reset of the 54 | timer, excluding the time when the timer is paused. 55 | """ 56 | if self._paused is not None: 57 | end_time: float = self._paused # type: ignore 58 | else: 59 | end_time = perf_counter() 60 | return end_time - self._start - self._total_paused 61 | 62 | def avg_seconds(self) -> float: 63 | """ 64 | Returns: 65 | (float): the average number of seconds between every start/reset and 66 | pause. 67 | """ 68 | return self.seconds() / self._count_start 69 | -------------------------------------------------------------------------------- /fastreid/utils/weight_init.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import math 8 | from torch import nn 9 | 10 | __all__ = [ 11 | 'weights_init_classifier', 12 | 'weights_init_kaiming', 13 | ] 14 | 15 | 16 | def weights_init_kaiming(m): 17 | classname = m.__class__.__name__ 18 | if classname.find('Linear') != -1: 19 | nn.init.normal_(m.weight, 0, 0.01) 20 | if m.bias is not None: 21 | nn.init.constant_(m.bias, 0.0) 22 | elif classname.find('Conv') != -1: 23 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 24 | if m.bias is not None: 25 | nn.init.constant_(m.bias, 0.0) 26 | elif classname.find('BatchNorm') != -1: 27 | if m.affine: 28 | nn.init.normal_(m.weight, 1.0, 0.02) 29 | nn.init.constant_(m.bias, 0.0) 30 | 31 | 32 | def weights_init_classifier(m): 33 | classname = m.__class__.__name__ 34 | if classname.find('Linear') != -1: 35 | nn.init.normal_(m.weight, std=0.001) 36 | if m.bias is not None: 37 | nn.init.constant_(m.bias, 0.0) 38 | -------------------------------------------------------------------------------- /tools/train_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | """ 4 | @author: sherlock 5 | @contact: sherlockliao01@gmail.com 6 | """ 7 | 8 | import sys 9 | 10 | sys.path.append('.') 11 | 12 | from fastreid.config import get_cfg 13 | from fastreid.engine import DefaultTrainer, default_argument_parser, default_setup, launch 14 | from fastreid.utils.checkpoint import Checkpointer 15 | 16 | 17 | def setup(args): 18 | """ 19 | Create configs and perform basic setups. 20 | """ 21 | cfg = get_cfg() 22 | cfg.merge_from_file(args.config_file) 23 | cfg.merge_from_list(args.opts) 24 | cfg.freeze() 25 | default_setup(cfg, args) 26 | return cfg 27 | 28 | 29 | def main(args): 30 | cfg = setup(args) 31 | 32 | if args.eval_only: 33 | cfg.defrost() 34 | cfg.MODEL.BACKBONE.PRETRAIN = False 35 | model = DefaultTrainer.build_model(cfg) 36 | 37 | Checkpointer(model).load(cfg.MODEL.WEIGHTS) # load trained model 38 | 39 | res = DefaultTrainer.test(cfg, model) 40 | return res 41 | 42 | trainer = DefaultTrainer(cfg) 43 | 44 | trainer.resume_or_load(resume=args.resume) 45 | return trainer.train() 46 | 47 | 48 | if __name__ == "__main__": 49 | args = default_argument_parser().parse_args() 50 | print("Command Line Args:", args) 51 | launch( 52 | main, 53 | args.num_gpus, 54 | num_machines=args.num_machines, 55 | machine_rank=args.machine_rank, 56 | dist_url=args.dist_url, 57 | args=(args,), 58 | ) 59 | --------------------------------------------------------------------------------