├── COAT_pt171.yml ├── README.md ├── __pycache__ ├── defaults.cpython-38.pyc ├── engine.cpython-38.pyc └── eval_func.cpython-38.pyc ├── configs ├── cuhk_sysu.yaml └── prw.yaml ├── datasets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── base.cpython-38.pyc │ ├── build.cpython-38.pyc │ ├── cuhk_sysu.cpython-38.pyc │ └── prw.cpython-38.pyc ├── base.py ├── build.py ├── cuhk_sysu.py └── prw.py ├── defaults.py ├── doc └── framework.png ├── engine.py ├── eval_func.py ├── loss ├── __pycache__ │ ├── oim.cpython-38.pyc │ └── softmax_loss.cpython-38.pyc ├── oim.py └── softmax_loss.py ├── models ├── __pycache__ │ ├── coat.cpython-38.pyc │ ├── resnet.cpython-38.pyc │ └── transformer.cpython-38.pyc ├── coat.py ├── resnet.py └── transformer.py ├── train.py └── utils ├── __pycache__ ├── km.cpython-38.pyc ├── mask.cpython-38.pyc ├── transforms.cpython-38.pyc └── utils.cpython-38.pyc ├── km.py ├── mask.py ├── transforms.py └── utils.py /COAT_pt171.yml: -------------------------------------------------------------------------------- 1 | name: coat 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - cudatoolkit=11.0 8 | - numpy=1.19.2 9 | - pillow=8.2.0 10 | - pip=21.0.1 11 | - python=3.8.8 12 | - pytorch=1.7.1 13 | - scipy=1.6.2 14 | - torchvision=0.8.2 15 | - tqdm=4.60.0 16 | - scikit-learn=0.24.1 17 | - black=21.5b0 18 | - flake8=3.9.0 19 | - isort=5.8.0 20 | - tabulate=0.8.9 21 | - future=0.18.2 22 | - tensorboard=2.4.1 23 | - tensorboardx=2.2 24 | - pip: 25 | - ipython==7.5.0 26 | - yacs==0.1.8 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository hosts the source code of our paper: [[CVPR 2022] Cascade Transformers for End-to-End Person Search](https://arxiv.org/abs/2203.09642). In this work, we developed a novel Cascaded Occlusion-Aware Transformer (COAT) model for end-to-end person search. The COAT model outperforms **state-of-the-art** methods on the PRW benchmark dataset by a large margin and achieves state-of-the-art performance on the CUHK-SYSU dataset. 2 | 3 | | Dataset | mAP | Top-1 | Model | 4 | | --------- | ---- | ----- | ------------------------------------------------------------ | 5 | | CUHK-SYSU | 94.2 | 94.7 | [model](https://drive.google.com/file/d/1LkEwXYaJg93yk4Kfhyk3m6j8v3i9s1B7/view?usp=sharing) | 6 | | PRW | 53.3 | 87.4 | [model](https://drive.google.com/file/d/1vEd_zzFN88RgxbRMG5-WfJZgD3vmP0Xg/view?usp=sharing) | 7 | 8 | **Abstract**: The goal of person search is to localize a target person from a gallery set of scene images, which is extremely challenging due to large scale variations, pose/viewpoint changes, and occlusions. In this paper, we propose the Cascade Occluded Attention Transformer (COAT) for end-to-end person search. Specifically, our three-stage cascade design focuses on detecting people at the first stage, then progressively refines the representation for person detection and re-identification simultaneously at the following stages. The occluded attention transformer at each stage applies tighter intersection over union thresholds, forcing the network to learn coarse-to-fine pose/scale invariant features. Meanwhile, we calculate the occluded attention across instances in a mini-batch to differentiate tokens from other people or the background. In this way, we simulate the effect of other objects occluding a person of interest at the token-level. Through comprehensive experiments, we demonstrate the benefits of our method by achieving state-of-the-art performance on two benchmark datasets. 9 | 10 | ![COAT](doc/framework.png) 11 | 12 | 13 | ## Installation 14 | 1. Download the datasets in your path `$DATA_DIR`. Change the dataset paths in L4 in [cuhk_sysu.yaml](configs/cuhk_sysu.yaml) and [prw.yaml](configs/prw.yaml). 15 | 16 | **PRW**: 17 | 18 | ``` 19 | cd $DATA_DIR 20 | pip install gdown 21 | gdown https://drive.google.com/uc?id=0B6tjyrV1YrHeYnlhNnhEYTh5MUU 22 | unzip PRW-v16.04.20.zip 23 | mv PRW-v16.04.20 PRW 24 | ``` 25 | 26 | **CUHK-SYSU**: 27 | 28 | ``` 29 | cd $DATA_DIR 30 | gdown https://drive.google.com/uc?id=1z3LsFrJTUeEX3-XjSEJMOBrslxD2T5af 31 | tar -xzvf cuhk_sysu.tar.gz 32 | mv cuhk_sysu CUHK-SYSU 33 | ``` 34 | 35 | 2. Our method is tested with PyTorch 1.7.1. You can install the required packages by anaconda/miniconda with the following commands: 36 | 37 | ``` 38 | cd COAT 39 | conda env create -f COAT_pt171.yml 40 | conda activate coat 41 | ``` 42 | 43 | If you want to install another version of PyTorch, you can modify the versions in `coat_pt171.yml`. Just make sure the dependencies have the appropriate version. 44 | 45 | 46 | ## Experiments on CUHK-SYSU 47 | **Training**: The code currently only supports single GPU. The default training script for CUHK-SYSU is as follows: 48 | 49 | ``` 50 | cd COAT 51 | python train.py --cfg configs/cuhk_sysu.yaml INPUT.BATCH_SIZE_TRAIN 3 SOLVER.BASE_LR 0.003 SOLVER.MAX_EPOCHS 14 SOLVER.LR_DECAY_MILESTONES [11] MODEL.LOSS.USE_SOFTMAX True SOLVER.LW_RCNN_SOFTMAX_2ND 0.1 SOLVER.LW_RCNN_SOFTMAX_3RD 0.1 OUTPUT_DIR ./logs/cuhk-sysu 52 | ``` 53 | 54 | Note that the dataset-specific parameters are defined in `configs/cuhk_sysu.yaml`. When the batch size (`INPUT.BATCH_SIZE_TRAIN`) is 3, the training will take about 23GB GPU memory, being suitable for GPUs like RTX6000. When the batch size is 5, the training will take about 38GB GPU memory, being able to run on A100 GPU. The larger batch size usually results in better performance on CUHK-SYSU. 55 | 56 | For the CUHK-SYSU dataset, we use a relative low weight for softmax loss (`SOLVER.LW_RCNN_SOFTMAX_2ND` 0.1 and `SOLVER.LW_RCNN_SOFTMAX_3RD` 0.1). The trained models and TF logs will be saved in the folder `OUTPUT_DIR`. Other important training parameters can be found in the file `COAT/defaults.py`. For example, `CKPT_PERIOD` is the frequency of saving a checkpoint model. 57 | 58 | **Testing**: The test script is very simple. You just need to add the flag `--eval` and provide the folder `--ckpt` where the [model](https://drive.google.com/file/d/1LkEwXYaJg93yk4Kfhyk3m6j8v3i9s1B7/view?usp=sharing) was saved. 59 | 60 | ``` 61 | python train.py --cfg ./configs/cuhk-sysu/config.yaml --eval --ckpt ./logs/cuhk-sysu/cuhk_COAT.pth 62 | ``` 63 | 64 | **Testing with CBGM**: Context Bipartite Graph Matching ([CBGM](https://github.com/serend1p1ty/SeqNet)) is an optimized matching algorithm in test phase. The detail can be found in the paper [[AAAI 2021] Sequential End-to-end Network for Efficient Person Search](https://arxiv.org/abs/2103.10148). We can use CBGM to further improve the person search accuracy. In test script, we just set the flag `EVAL_USE_CBGM` to True (default is False). 65 | 66 | ``` 67 | python train.py --cfg ./configs/cuhk-sysu/config.yaml --eval --ckpt ./logs/cuhk-sysu/cuhk_COAT.pth EVAL_USE_CB GM True 68 | ``` 69 | 70 | **Testing with different gallery sizes on CUHK-SYSU**: The default gallery size for evaluating CUHK-SYSU is 100. If you want to test with other pre-defined gallery sizes (50, 100, 500, 1000, 2000, 4000) for drawing the CUHK-SYSU gallery size curve, please set the parameter `EVAL_GALLERY_SIZE` with a gallery size. 71 | 72 | ``` 73 | python train.py --cfg ./configs/cuhk-sysu/config.yaml --eval --ckpt ./logs/cuhk-sysu/cuhk_COAT.pth EVAL_GALLER Y_SIZE 500 74 | ``` 75 | 76 | ## Experiments on PRW 77 | **Training**: The script is similar to CUHK-SYSU. The code currently only supports single GPU. The default training script for PRW is as follows: 78 | 79 | ``` 80 | cd COAT 81 | python train.py --cfg ./configs/prw.yaml INPUT.BATCH_SIZE_TRAIN 3 SOLVER.BASE_LR 0.003 SOLVER.MAX_EPOCHS 13 MODEL.LOSS.USE_SOFTMAX True OUTPUT_DIR ./logs/prw 82 | ``` 83 | 84 | The dataset-specific parameters are defined in `configs/prw.yaml`. When the batch size (`INPUT.BATCH_SIZE_TRAIN`) is 3, the training will take about 19GB GPU memory, being suitable for GPUs like RTX6000. The larger batch size does not necessarily result in better accuracy on the PRW dataset. 85 | Softmax loss is effective on PRW. The default weights of softmax loss at Stage 2 and Stage 3 (`SOLVER.LW_RCNN_SOFTMAX_2ND` and `SOLVER.LW_RCNN_SOFTMAX_3RD`) are 0.5, which can be found in the file `COAT/defaults.py`. If you want to run a model without Softmax loss for comparison, just set `MODEL.LOSS.USE_SOFTMAX` to False in the script. 86 | 87 | 88 | **Testing**: The test script is similar to CUHK-SYSU. Make sure the path of pre-trained model [model](https://drive.google.com/file/d/1vEd_zzFN88RgxbRMG5-WfJZgD3vmP0Xg/view?usp=sharing) is correct. 89 | 90 | ``` 91 | python train.py --cfg ./logs/prw/config.yaml --eval --ckpt ./logs/prw/prw_COAT.pth 92 | 93 | ``` 94 | 95 | **Testing with CBGM**: Similar to CUHK-SYSU, set the flag `EVAL_USE_CBGM` to True (default is False). 96 | 97 | ``` 98 | python train.py --cfg ./logs/prw/config.yaml --eval --ckpt ./logs/prw/prw_COAT.pth EVAL_USE_CBGM True 99 | ``` 100 | 101 | 102 | ## Acknowledgement 103 | This code borrows from [SeqNet](https://github.com/serend1p1ty/SeqNet), [TransReID](https://github.com/damo-cv/TransReID), and [DSTT](https://github.com/ruiliu-ai/DSTT). 104 | 105 | ## Citation 106 | If you use this code in your research, please cite this project as follows: 107 | 108 | ``` 109 | @inproceedings{yu2022coat, 110 | title = {Cascade Transformers for End-to-End Person Search}, 111 | author = {Rui Yu and 112 | Dawei Du and 113 | Rodney LaLonde and 114 | Daniel Davila and 115 | Christopher Funk and 116 | Anthony Hoogs and 117 | Brian Clipp}, 118 | booktitle = {{IEEE} Conference on Computer Vision and Pattern Recognition}, 119 | year = {2022} 120 | } 121 | ``` 122 | 123 | ## License 124 | This work is distributed under the OSI-approved BSD 3-Clause [License](https://github.com/Kitware/COAT/blob/master/LICENSE). 125 | -------------------------------------------------------------------------------- /__pycache__/defaults.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kitware/COAT/369d93418716a4a28a96203ebd012020877db550/__pycache__/defaults.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/engine.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kitware/COAT/369d93418716a4a28a96203ebd012020877db550/__pycache__/engine.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/eval_func.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kitware/COAT/369d93418716a4a28a96203ebd012020877db550/__pycache__/eval_func.cpython-38.pyc -------------------------------------------------------------------------------- /configs/cuhk_sysu.yaml: -------------------------------------------------------------------------------- 1 | OUTPUT_DIR: "./logs/cuhk_coat" 2 | INPUT: 3 | DATASET: "CUHK-SYSU" 4 | DATA_ROOT: "../../datasets/CUHK-SYSU" 5 | BATCH_SIZE_TRAIN: 4 6 | SOLVER: 7 | MAX_EPOCHS: 14 8 | BASE_LR: 0.003 9 | LW_RCNN_SOFTMAX_2ND: 0.1 10 | LW_RCNN_SOFTMAX_3RD: 0.1 11 | MODEL: 12 | LOSS: 13 | LUT_SIZE: 5532 14 | CQ_SIZE: 5000 15 | DISP_PERIOD: 100 16 | -------------------------------------------------------------------------------- /configs/prw.yaml: -------------------------------------------------------------------------------- 1 | OUTPUT_DIR: "./logs/prw_coat" 2 | INPUT: 3 | DATASET: "PRW" 4 | DATA_ROOT: "../../datasets/PRW" 5 | BATCH_SIZE_TRAIN: 3 6 | SOLVER: 7 | MAX_EPOCHS: 13 8 | BASE_LR: 0.003 9 | MODEL: 10 | LOSS: 11 | LUT_SIZE: 482 12 | CQ_SIZE: 500 13 | DISP_PERIOD: 100 14 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # This file is part of COAT, and is distributed under the 2 | # OSI-approved BSD 3-Clause License. See top-level LICENSE file or 3 | # https://github.com/Kitware/COAT/blob/master/LICENSE for details. 4 | 5 | from .build import build_test_loader, build_train_loader 6 | -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kitware/COAT/369d93418716a4a28a96203ebd012020877db550/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kitware/COAT/369d93418716a4a28a96203ebd012020877db550/datasets/__pycache__/base.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/build.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kitware/COAT/369d93418716a4a28a96203ebd012020877db550/datasets/__pycache__/build.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/cuhk_sysu.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kitware/COAT/369d93418716a4a28a96203ebd012020877db550/datasets/__pycache__/cuhk_sysu.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/prw.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kitware/COAT/369d93418716a4a28a96203ebd012020877db550/datasets/__pycache__/prw.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/base.py: -------------------------------------------------------------------------------- 1 | # This file is part of COAT, and is distributed under the 2 | # OSI-approved BSD 3-Clause License. See top-level LICENSE file or 3 | # https://github.com/Kitware/COAT/blob/master/LICENSE for details. 4 | 5 | import torch 6 | from PIL import Image 7 | 8 | class BaseDataset: 9 | """ 10 | Base class of person search dataset. 11 | """ 12 | 13 | def __init__(self, root, transforms, split): 14 | self.root = root 15 | self.transforms = transforms 16 | self.split = split 17 | assert self.split in ("train", "gallery", "query") 18 | self.annotations = self._load_annotations() 19 | 20 | def _load_annotations(self): 21 | """ 22 | For each image, load its annotation that is a dictionary with the following keys: 23 | img_name (str): image name 24 | img_path (str): image path 25 | boxes (np.array[N, 4]): ground-truth boxes in (x1, y1, x2, y2) format 26 | pids (np.array[N]): person IDs corresponding to these boxes 27 | cam_id (int): camera ID (only for PRW dataset) 28 | """ 29 | raise NotImplementedError 30 | 31 | def __getitem__(self, index): 32 | anno = self.annotations[index] 33 | img = Image.open(anno["img_path"]).convert("RGB") 34 | boxes = torch.as_tensor(anno["boxes"], dtype=torch.float32) 35 | labels = torch.as_tensor(anno["pids"], dtype=torch.int64) 36 | target = {"img_name": anno["img_name"], "boxes": boxes, "labels": labels} 37 | if self.transforms is not None: 38 | img, target = self.transforms(img, target) 39 | return img, target 40 | 41 | def __len__(self): 42 | return len(self.annotations) 43 | -------------------------------------------------------------------------------- /datasets/build.py: -------------------------------------------------------------------------------- 1 | # This file is part of COAT, and is distributed under the 2 | # OSI-approved BSD 3-Clause License. See top-level LICENSE file or 3 | # https://github.com/Kitware/COAT/blob/master/LICENSE for details. 4 | 5 | import torch 6 | from utils.transforms import build_transforms 7 | from utils.utils import create_small_table 8 | from .cuhk_sysu import CUHKSYSU 9 | from .prw import PRW 10 | 11 | def print_statistics(dataset): 12 | """ 13 | Print dataset statistics. 14 | """ 15 | num_imgs = len(dataset.annotations) 16 | num_boxes = 0 17 | pid_set = set() 18 | for anno in dataset.annotations: 19 | num_boxes += anno["boxes"].shape[0] 20 | for pid in anno["pids"]: 21 | pid_set.add(pid) 22 | statistics = { 23 | "dataset": dataset.name, 24 | "split": dataset.split, 25 | "num_images": num_imgs, 26 | "num_boxes": num_boxes, 27 | } 28 | if dataset.name != "CUHK-SYSU" or dataset.split != "query": 29 | pid_list = sorted(list(pid_set)) 30 | if dataset.split == "query": 31 | num_pids, min_pid, max_pid = len(pid_list), min(pid_list), max(pid_list) 32 | statistics.update( 33 | { 34 | "num_labeled_pids": num_pids, 35 | "min_labeled_pid": int(min_pid), 36 | "max_labeled_pid": int(max_pid), 37 | } 38 | ) 39 | else: 40 | unlabeled_pid = pid_list[-1] 41 | pid_list = pid_list[:-1] # remove unlabeled pid 42 | num_pids, min_pid, max_pid = len(pid_list), min(pid_list), max(pid_list) 43 | statistics.update( 44 | { 45 | "num_labeled_pids": num_pids, 46 | "min_labeled_pid": int(min_pid), 47 | "max_labeled_pid": int(max_pid), 48 | "unlabeled_pid": int(unlabeled_pid), 49 | } 50 | ) 51 | print(f"=> {dataset.name}-{dataset.split} loaded:\n" + create_small_table(statistics)) 52 | 53 | 54 | def build_dataset(dataset_name, root, transforms, split, verbose=True): 55 | if dataset_name == "CUHK-SYSU": 56 | dataset = CUHKSYSU(root, transforms, split) 57 | elif dataset_name == "PRW": 58 | dataset = PRW(root, transforms, split) 59 | else: 60 | raise NotImplementedError(f"Unknow dataset: {dataset_name}") 61 | if verbose: 62 | print_statistics(dataset) 63 | return dataset 64 | 65 | 66 | def collate_fn(batch): 67 | return tuple(zip(*batch)) 68 | 69 | 70 | def build_train_loader(cfg): 71 | transforms = build_transforms(cfg, is_train=True) 72 | dataset = build_dataset(cfg.INPUT.DATASET, cfg.INPUT.DATA_ROOT, transforms, "train") 73 | return torch.utils.data.DataLoader( 74 | dataset, 75 | batch_size=cfg.INPUT.BATCH_SIZE_TRAIN, 76 | shuffle=True, 77 | num_workers=cfg.INPUT.NUM_WORKERS_TRAIN, 78 | pin_memory=True, 79 | drop_last=True, 80 | collate_fn=collate_fn, 81 | ) 82 | 83 | 84 | def build_test_loader(cfg): 85 | transforms = build_transforms(cfg, is_train=False) 86 | gallery_set = build_dataset(cfg.INPUT.DATASET, cfg.INPUT.DATA_ROOT, transforms, "gallery") 87 | query_set = build_dataset(cfg.INPUT.DATASET, cfg.INPUT.DATA_ROOT, transforms, "query") 88 | gallery_loader = torch.utils.data.DataLoader( 89 | gallery_set, 90 | batch_size=cfg.INPUT.BATCH_SIZE_TEST, 91 | shuffle=False, 92 | num_workers=cfg.INPUT.NUM_WORKERS_TEST, 93 | pin_memory=True, 94 | collate_fn=collate_fn, 95 | ) 96 | query_loader = torch.utils.data.DataLoader( 97 | query_set, 98 | batch_size=cfg.INPUT.BATCH_SIZE_TEST, 99 | shuffle=False, 100 | num_workers=cfg.INPUT.NUM_WORKERS_TEST, 101 | pin_memory=True, 102 | collate_fn=collate_fn, 103 | ) 104 | return gallery_loader, query_loader 105 | -------------------------------------------------------------------------------- /datasets/cuhk_sysu.py: -------------------------------------------------------------------------------- 1 | # This file is part of COAT, and is distributed under the 2 | # OSI-approved BSD 3-Clause License. See top-level LICENSE file or 3 | # https://github.com/Kitware/COAT/blob/master/LICENSE for details. 4 | 5 | import os.path as osp 6 | import numpy as np 7 | from scipy.io import loadmat 8 | from .base import BaseDataset 9 | 10 | class CUHKSYSU(BaseDataset): 11 | def __init__(self, root, transforms, split): 12 | self.name = "CUHK-SYSU" 13 | self.img_prefix = osp.join(root, "Image", "SSM") 14 | super(CUHKSYSU, self).__init__(root, transforms, split) 15 | 16 | def _load_queries(self): 17 | # TestG50: a test protocol, 50 gallery images per query 18 | protoc = loadmat(osp.join(self.root, "annotation/test/train_test/TestG50.mat")) 19 | protoc = protoc["TestG50"].squeeze() 20 | queries = [] 21 | for item in protoc["Query"]: 22 | img_name = str(item["imname"][0, 0][0]) 23 | roi = item["idlocate"][0, 0][0].astype(np.int32) 24 | roi[2:] += roi[:2] 25 | queries.append( 26 | { 27 | "img_name": img_name, 28 | "img_path": osp.join(self.img_prefix, img_name), 29 | "boxes": roi[np.newaxis, :], 30 | "pids": np.array([-100]), # dummy pid 31 | } 32 | ) 33 | return queries 34 | 35 | def _load_split_img_names(self): 36 | """ 37 | Load the image names for the specific split. 38 | """ 39 | assert self.split in ("train", "gallery") 40 | # gallery images 41 | gallery_imgs = loadmat(osp.join(self.root, "annotation", "pool.mat")) 42 | gallery_imgs = gallery_imgs["pool"].squeeze() 43 | gallery_imgs = [str(a[0]) for a in gallery_imgs] 44 | if self.split == "gallery": 45 | return gallery_imgs 46 | # all images 47 | all_imgs = loadmat(osp.join(self.root, "annotation", "Images.mat")) 48 | all_imgs = all_imgs["Img"].squeeze() 49 | all_imgs = [str(a[0][0]) for a in all_imgs] 50 | # training images = all images - gallery images 51 | training_imgs = sorted(list(set(all_imgs) - set(gallery_imgs))) 52 | return training_imgs 53 | 54 | def _load_annotations(self): 55 | if self.split == "query": 56 | return self._load_queries() 57 | 58 | # load all images and build a dict from image to boxes 59 | all_imgs = loadmat(osp.join(self.root, "annotation", "Images.mat")) 60 | all_imgs = all_imgs["Img"].squeeze() 61 | name_to_boxes = {} 62 | name_to_pids = {} 63 | unlabeled_pid = 5555 # default pid for unlabeled people 64 | for img_name, _, boxes in all_imgs: 65 | img_name = str(img_name[0]) 66 | boxes = np.asarray([b[0] for b in boxes[0]]) 67 | boxes = boxes.reshape(boxes.shape[0], 4) # (x1, y1, w, h) 68 | valid_index = np.where((boxes[:, 2] > 0) & (boxes[:, 3] > 0))[0] 69 | assert valid_index.size > 0, "Warning: {} has no valid boxes.".format(img_name) 70 | boxes = boxes[valid_index] 71 | name_to_boxes[img_name] = boxes.astype(np.int32) 72 | name_to_pids[img_name] = unlabeled_pid * np.ones(boxes.shape[0], dtype=np.int32) 73 | 74 | def set_box_pid(boxes, box, pids, pid): 75 | for i in range(boxes.shape[0]): 76 | if np.all(boxes[i] == box): 77 | pids[i] = pid 78 | return 79 | 80 | # assign a unique pid from 1 to N for each identity 81 | if self.split == "train": 82 | train = loadmat(osp.join(self.root, "annotation/test/train_test/Train.mat")) 83 | train = train["Train"].squeeze() 84 | for index, item in enumerate(train): 85 | scenes = item[0, 0][2].squeeze() 86 | for img_name, box, _ in scenes: 87 | img_name = str(img_name[0]) 88 | box = box.squeeze().astype(np.int32) 89 | set_box_pid(name_to_boxes[img_name], box, name_to_pids[img_name], index + 1) 90 | else: 91 | protoc = loadmat(osp.join(self.root, "annotation/test/train_test/TestG50.mat")) 92 | protoc = protoc["TestG50"].squeeze() 93 | for index, item in enumerate(protoc): 94 | # query 95 | im_name = str(item["Query"][0, 0][0][0]) 96 | box = item["Query"][0, 0][1].squeeze().astype(np.int32) 97 | set_box_pid(name_to_boxes[im_name], box, name_to_pids[im_name], index + 1) 98 | # gallery 99 | gallery = item["Gallery"].squeeze() 100 | for im_name, box, _ in gallery: 101 | im_name = str(im_name[0]) 102 | if box.size == 0: 103 | break 104 | box = box.squeeze().astype(np.int32) 105 | set_box_pid(name_to_boxes[im_name], box, name_to_pids[im_name], index + 1) 106 | 107 | annotations = [] 108 | imgs = self._load_split_img_names() 109 | for img_name in imgs: 110 | boxes = name_to_boxes[img_name] 111 | boxes[:, 2:] += boxes[:, :2] # (x1, y1, w, h) -> (x1, y1, x2, y2) 112 | pids = name_to_pids[img_name] 113 | annotations.append( 114 | { 115 | "img_name": img_name, 116 | "img_path": osp.join(self.img_prefix, img_name), 117 | "boxes": boxes, 118 | "pids": pids, 119 | } 120 | ) 121 | return annotations 122 | -------------------------------------------------------------------------------- /datasets/prw.py: -------------------------------------------------------------------------------- 1 | # This file is part of COAT, and is distributed under the 2 | # OSI-approved BSD 3-Clause License. See top-level LICENSE file or 3 | # https://github.com/Kitware/COAT/blob/master/LICENSE for details. 4 | 5 | import os.path as osp 6 | import re 7 | 8 | import numpy as np 9 | from scipy.io import loadmat 10 | 11 | from .base import BaseDataset 12 | 13 | 14 | class PRW(BaseDataset): 15 | def __init__(self, root, transforms, split): 16 | self.name = "PRW" 17 | self.img_prefix = osp.join(root, "frames") 18 | super(PRW, self).__init__(root, transforms, split) 19 | 20 | def _get_cam_id(self, img_name): 21 | match = re.search(r"c\d", img_name).group().replace("c", "") 22 | return int(match) 23 | 24 | def _load_queries(self): 25 | query_info = osp.join(self.root, "query_info.txt") 26 | with open(query_info, "rb") as f: 27 | raw = f.readlines() 28 | 29 | queries = [] 30 | for line in raw: 31 | linelist = str(line, "utf-8").split(" ") 32 | pid = int(linelist[0]) 33 | x, y, w, h = ( 34 | float(linelist[1]), 35 | float(linelist[2]), 36 | float(linelist[3]), 37 | float(linelist[4]), 38 | ) 39 | roi = np.array([x, y, x + w, y + h]).astype(np.int32) 40 | roi = np.clip(roi, 0, None) # several coordinates are negative 41 | img_name = linelist[5][:-2] + ".jpg" 42 | queries.append( 43 | { 44 | "img_name": img_name, 45 | "img_path": osp.join(self.img_prefix, img_name), 46 | "boxes": roi[np.newaxis, :], 47 | "pids": np.array([pid]), 48 | "cam_id": self._get_cam_id(img_name), 49 | } 50 | ) 51 | return queries 52 | 53 | def _load_split_img_names(self): 54 | """ 55 | Load the image names for the specific split. 56 | """ 57 | assert self.split in ("train", "gallery") 58 | if self.split == "train": 59 | imgs = loadmat(osp.join(self.root, "frame_train.mat"))["img_index_train"] 60 | else: 61 | imgs = loadmat(osp.join(self.root, "frame_test.mat"))["img_index_test"] 62 | return [img[0][0] + ".jpg" for img in imgs] 63 | 64 | def _load_annotations(self): 65 | if self.split == "query": 66 | return self._load_queries() 67 | 68 | annotations = [] 69 | imgs = self._load_split_img_names() 70 | for img_name in imgs: 71 | anno_path = osp.join(self.root, "annotations", img_name) 72 | anno = loadmat(anno_path) 73 | box_key = "box_new" 74 | if box_key not in anno.keys(): 75 | box_key = "anno_file" 76 | if box_key not in anno.keys(): 77 | box_key = "anno_previous" 78 | 79 | rois = anno[box_key][:, 1:] 80 | ids = anno[box_key][:, 0] 81 | rois = np.clip(rois, 0, None) # several coordinates are negative 82 | 83 | assert len(rois) == len(ids) 84 | 85 | rois[:, 2:] += rois[:, :2] 86 | ids[ids == -2] = 5555 # assign pid = 5555 for unlabeled people 87 | annotations.append( 88 | { 89 | "img_name": img_name, 90 | "img_path": osp.join(self.img_prefix, img_name), 91 | "boxes": rois.astype(np.int32), 92 | # (training pids) 1, 2,..., 478, 480, 481, 482, 483, 932, 5555 93 | "pids": ids.astype(np.int32), 94 | "cam_id": self._get_cam_id(img_name), 95 | } 96 | ) 97 | return annotations 98 | -------------------------------------------------------------------------------- /defaults.py: -------------------------------------------------------------------------------- 1 | # This file is part of COAT, and is distributed under the 2 | # OSI-approved BSD 3-Clause License. See top-level LICENSE file or 3 | # https://github.com/Kitware/COAT/blob/master/LICENSE for details. 4 | 5 | from yacs.config import CfgNode as CN 6 | 7 | _C = CN() 8 | 9 | # -------------------------------------------------------- # 10 | # Input # 11 | # -------------------------------------------------------- # 12 | _C.INPUT = CN() 13 | _C.INPUT.DATASET = "CUHK-SYSU" 14 | _C.INPUT.DATA_ROOT = "data/CUHK-SYSU" 15 | 16 | # Size of the smallest side of the image 17 | _C.INPUT.MIN_SIZE = 900 18 | # Maximum size of the side of the image 19 | _C.INPUT.MAX_SIZE = 1500 20 | 21 | # Number of images per batch 22 | _C.INPUT.BATCH_SIZE_TRAIN = 5 23 | _C.INPUT.BATCH_SIZE_TEST = 1 24 | 25 | # Number of data loading threads 26 | _C.INPUT.NUM_WORKERS_TRAIN = 5 27 | _C.INPUT.NUM_WORKERS_TEST = 1 28 | 29 | # Image augmentation 30 | _C.INPUT.IMAGE_CUTOUT = False 31 | _C.INPUT.IMAGE_ERASE = False 32 | _C.INPUT.IMAGE_MIXUP = False 33 | 34 | # -------------------------------------------------------- # 35 | # GRID # 36 | # -------------------------------------------------------- # 37 | _C.INPUT.IMAGE_GRID = False 38 | _C.GRID = CN() 39 | _C.GRID.ROTATE = 1 40 | _C.GRID.OFFSET = 0 41 | _C.GRID.RATIO = 0.5 42 | _C.GRID.MODE = 1 43 | _C.GRID.PROB = 0.5 44 | 45 | # -------------------------------------------------------- # 46 | # Solver # 47 | # -------------------------------------------------------- # 48 | _C.SOLVER = CN() 49 | _C.SOLVER.MAX_EPOCHS = 13 50 | 51 | # Learning rate settings 52 | _C.SOLVER.BASE_LR = 0.003 53 | 54 | # The epoch milestones to decrease the learning rate by GAMMA 55 | _C.SOLVER.LR_DECAY_MILESTONES = [10, 14] 56 | _C.SOLVER.GAMMA = 0.1 57 | 58 | _C.SOLVER.WEIGHT_DECAY = 0.0005 59 | _C.SOLVER.SGD_MOMENTUM = 0.9 60 | 61 | # Loss weight of RPN regression 62 | _C.SOLVER.LW_RPN_REG = 1 63 | # Loss weight of RPN classification 64 | _C.SOLVER.LW_RPN_CLS = 1 65 | 66 | # Loss weight of Cascade R-CNN and Re-ID (OIM) 67 | _C.SOLVER.LW_RCNN_REG_1ST = 10 68 | _C.SOLVER.LW_RCNN_CLS_1ST = 1 69 | _C.SOLVER.LW_RCNN_REG_2ND = 10 70 | _C.SOLVER.LW_RCNN_CLS_2ND = 1 71 | _C.SOLVER.LW_RCNN_REG_3RD = 10 72 | _C.SOLVER.LW_RCNN_CLS_3RD = 1 73 | _C.SOLVER.LW_RCNN_REID_2ND = 0.5 74 | _C.SOLVER.LW_RCNN_REID_3RD = 0.5 75 | # Loss weight of box reid, softmax loss 76 | _C.SOLVER.LW_RCNN_SOFTMAX_2ND = 0.5 77 | _C.SOLVER.LW_RCNN_SOFTMAX_3RD = 0.5 78 | 79 | # Set to negative value to disable gradient clipping 80 | _C.SOLVER.CLIP_GRADIENTS = 10.0 81 | 82 | # -------------------------------------------------------- # 83 | # RPN # 84 | # -------------------------------------------------------- # 85 | _C.MODEL = CN() 86 | _C.MODEL.RPN = CN() 87 | # NMS threshold used on RoIs 88 | _C.MODEL.RPN.NMS_THRESH = 0.7 89 | # Number of anchors per image used to train RPN 90 | _C.MODEL.RPN.BATCH_SIZE_TRAIN = 256 91 | # Target fraction of foreground examples per RPN minibatch 92 | _C.MODEL.RPN.POS_FRAC_TRAIN = 0.5 93 | # Overlap threshold for an anchor to be considered foreground (if >= POS_THRESH_TRAIN) 94 | _C.MODEL.RPN.POS_THRESH_TRAIN = 0.7 95 | # Overlap threshold for an anchor to be considered background (if < NEG_THRESH_TRAIN) 96 | _C.MODEL.RPN.NEG_THRESH_TRAIN = 0.3 97 | # Number of top scoring RPN RoIs to keep before applying NMS 98 | _C.MODEL.RPN.PRE_NMS_TOPN_TRAIN = 12000 99 | _C.MODEL.RPN.PRE_NMS_TOPN_TEST = 6000 100 | # Number of top scoring RPN RoIs to keep after applying NMS 101 | _C.MODEL.RPN.POST_NMS_TOPN_TRAIN = 2000 102 | _C.MODEL.RPN.POST_NMS_TOPN_TEST = 300 103 | 104 | # -------------------------------------------------------- # 105 | # RoI head # 106 | # -------------------------------------------------------- # 107 | _C.MODEL.ROI_HEAD = CN() 108 | # Whether to use bn neck (i.e. batch normalization after linear) 109 | _C.MODEL.ROI_HEAD.BN_NECK = True 110 | # Number of RoIs per image used to train RoI head 111 | _C.MODEL.ROI_HEAD.BATCH_SIZE_TRAIN = 128 112 | # Target fraction of foreground examples per RoI minibatch 113 | _C.MODEL.ROI_HEAD.POS_FRAC_TRAIN = 0.25 # 0.5 114 | 115 | _C.MODEL.ROI_HEAD.USE_DIFF_THRESH = True 116 | # Overlap threshold for an RoI to be considered foreground (if >= POS_THRESH_TRAIN) 117 | _C.MODEL.ROI_HEAD.POS_THRESH_TRAIN = 0.5 118 | _C.MODEL.ROI_HEAD.POS_THRESH_TRAIN_2ND = 0.6 119 | _C.MODEL.ROI_HEAD.POS_THRESH_TRAIN_3RD = 0.7 120 | # Overlap threshold for an RoI to be considered background (if < NEG_THRESH_TRAIN) 121 | _C.MODEL.ROI_HEAD.NEG_THRESH_TRAIN = 0.5 122 | _C.MODEL.ROI_HEAD.NEG_THRESH_TRAIN_2ND = 0.6 123 | _C.MODEL.ROI_HEAD.NEG_THRESH_TRAIN_3RD = 0.7 124 | # Minimum score threshold 125 | _C.MODEL.ROI_HEAD.SCORE_THRESH_TEST = 0.5 126 | # NMS threshold used on boxes 127 | _C.MODEL.ROI_HEAD.NMS_THRESH_TEST = 0.4 128 | _C.MODEL.ROI_HEAD.NMS_THRESH_TEST_1ST = 0.4 129 | _C.MODEL.ROI_HEAD.NMS_THRESH_TEST_2ND = 0.4 130 | _C.MODEL.ROI_HEAD.NMS_THRESH_TEST_3RD = 0.5 131 | # Maximum number of detected objects 132 | _C.MODEL.ROI_HEAD.DETECTIONS_PER_IMAGE_TEST = 300 133 | 134 | # -------------------------------------------------------- # 135 | # Transformer head # 136 | # -------------------------------------------------------- # 137 | _C.MODEL.TRANSFORMER = CN() 138 | _C.MODEL.TRANSFORMER.DIM_MODEL = 512 139 | _C.MODEL.TRANSFORMER.ENCODER_LAYERS = 1 140 | _C.MODEL.TRANSFORMER.N_HEAD = 8 141 | _C.MODEL.TRANSFORMER.USE_OUTPUT_LAYER = False 142 | _C.MODEL.TRANSFORMER.DROPOUT = 0. 143 | _C.MODEL.TRANSFORMER.USE_LOCAL_SHORTCUT = True 144 | _C.MODEL.TRANSFORMER.USE_GLOBAL_SHORTCUT = True 145 | 146 | _C.MODEL.TRANSFORMER.USE_DIFF_SCALE = True 147 | _C.MODEL.TRANSFORMER.NAMES_1ST = ['scale1','scale2'] 148 | _C.MODEL.TRANSFORMER.NAMES_2ND = ['scale1','scale2'] 149 | _C.MODEL.TRANSFORMER.NAMES_3RD = ['scale1','scale2'] 150 | _C.MODEL.TRANSFORMER.KERNEL_SIZE_1ST = [(1,1),(3,3)] 151 | _C.MODEL.TRANSFORMER.KERNEL_SIZE_2ND = [(1,1),(3,3)] 152 | _C.MODEL.TRANSFORMER.KERNEL_SIZE_3RD = [(1,1),(3,3)] 153 | _C.MODEL.TRANSFORMER.USE_MASK_1ST = False 154 | _C.MODEL.TRANSFORMER.USE_MASK_2ND = True 155 | _C.MODEL.TRANSFORMER.USE_MASK_3RD = True 156 | _C.MODEL.TRANSFORMER.USE_PATCH2VEC = True 157 | 158 | #### 159 | _C.MODEL.USE_FEATURE_MASK = True 160 | _C.MODEL.FEATURE_AUG_TYPE = 'exchange_token' # 'exchange_token', 'jigsaw_token', 'cutout_patch', 'erase_patch', 'mixup_patch', 'jigsaw_patch' 161 | _C.MODEL.FEATURE_MASK_SIZE = 4 162 | _C.MODEL.MASK_SHAPE = 'stripe' # 'square', 'random' 163 | _C.MODEL.MASK_SIZE = 1 164 | _C.MODEL.MASK_MODE = 'random_direction' # 'horizontal', 'vertical' for stripe; 'random_size' for square 165 | _C.MODEL.MASK_PERCENT = 0.1 166 | #### 167 | _C.MODEL.EMBEDDING_DIM = 256 168 | 169 | # -------------------------------------------------------- # 170 | # Loss # 171 | # -------------------------------------------------------- # 172 | _C.MODEL.LOSS = CN() 173 | # Size of the lookup table in OIM 174 | _C.MODEL.LOSS.LUT_SIZE = 5532 175 | # Size of the circular queue in OIM 176 | _C.MODEL.LOSS.CQ_SIZE = 5000 177 | _C.MODEL.LOSS.OIM_MOMENTUM = 0.5 178 | _C.MODEL.LOSS.OIM_SCALAR = 30.0 179 | 180 | _C.MODEL.LOSS.USE_SOFTMAX = True 181 | 182 | # -------------------------------------------------------- # 183 | # Evaluation # 184 | # -------------------------------------------------------- # 185 | # The period to evaluate the model during training 186 | _C.EVAL_PERIOD = 1 187 | # Evaluation with GT boxes to verify the upper bound of person search performance 188 | _C.EVAL_USE_GT = False 189 | # Fast evaluation with cached features 190 | _C.EVAL_USE_CACHE = False 191 | # Evaluation with Context Bipartite Graph Matching (CBGM) algorithm 192 | _C.EVAL_USE_CBGM = False 193 | # Gallery size in evaluation, only for CUHK-SYSU 194 | _C.EVAL_GALLERY_SIZE = 100 195 | # Feature used for evaluation 196 | _C.EVAL_FEATURE = 'concat' # 'stage2', 'stage3' 197 | 198 | # -------------------------------------------------------- # 199 | # Miscs # 200 | # -------------------------------------------------------- # 201 | # Save a checkpoint after every this number of epochs 202 | _C.CKPT_PERIOD = 1 203 | # The period (in terms of iterations) to display training losses 204 | _C.DISP_PERIOD = 10 205 | # Whether to use tensorboard for visualization 206 | _C.TF_BOARD = True 207 | # The device loading the model 208 | _C.DEVICE = "cuda" 209 | # Set seed to negative to fully randomize everything 210 | _C.SEED = 1 211 | # Directory where output files are written 212 | _C.OUTPUT_DIR = "./output" 213 | 214 | 215 | def get_default_cfg(): 216 | """ 217 | Get a copy of the default config. 218 | """ 219 | return _C.clone() 220 | -------------------------------------------------------------------------------- /doc/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kitware/COAT/369d93418716a4a28a96203ebd012020877db550/doc/framework.png -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | # This file is part of COAT, and is distributed under the 2 | # OSI-approved BSD 3-Clause License. See top-level LICENSE file or 3 | # https://github.com/Kitware/COAT/blob/master/LICENSE for details. 4 | 5 | import math 6 | import sys 7 | from copy import deepcopy 8 | 9 | import torch 10 | from torch.nn.utils import clip_grad_norm_ 11 | from tqdm import tqdm 12 | 13 | from eval_func import eval_detection, eval_search_cuhk, eval_search_prw 14 | from utils.utils import MetricLogger, SmoothedValue, mkdir, reduce_dict, warmup_lr_scheduler 15 | from utils.transforms import mixup_data 16 | 17 | 18 | def to_device(images, targets, device): 19 | images = [image.to(device) for image in images] 20 | for t in targets: 21 | t["boxes"] = t["boxes"].to(device) 22 | t["labels"] = t["labels"].to(device) 23 | return images, targets 24 | 25 | 26 | def train_one_epoch(cfg, model, optimizer, data_loader, device, epoch, tfboard, softmax_criterion_s2, softmax_criterion_s3): 27 | model.train() 28 | metric_logger = MetricLogger(delimiter=" ") 29 | metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}")) 30 | header = "Epoch: [{}]".format(epoch) 31 | 32 | # warmup learning rate in the first epoch 33 | if epoch == 0: 34 | warmup_factor = 1.0 / 1000 35 | warmup_iters = len(data_loader) - 1 36 | warmup_scheduler = warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor) 37 | 38 | for i, (images, targets) in enumerate( 39 | metric_logger.log_every(data_loader, cfg.DISP_PERIOD, header) 40 | ): 41 | images, targets = to_device(images, targets, device) 42 | 43 | # if using image based data augmentation 44 | if cfg.INPUT.IMAGE_MIXUP: 45 | images = mixup_data(images, alpha=0.8) 46 | 47 | loss_dict, feats_reid_2nd, targets_reid_2nd, feats_reid_3rd, targets_reid_3rd = model(images, targets) 48 | 49 | if cfg.MODEL.LOSS.USE_SOFTMAX: 50 | softmax_loss_2nd = cfg.SOLVER.LW_RCNN_SOFTMAX_2ND * softmax_criterion_s2(feats_reid_2nd, targets_reid_2nd) 51 | softmax_loss_3rd = cfg.SOLVER.LW_RCNN_SOFTMAX_3RD * softmax_criterion_s3(feats_reid_3rd, targets_reid_3rd) 52 | loss_dict.update(loss_box_softmax_2nd=softmax_loss_2nd) 53 | loss_dict.update(loss_box_softmax_3rd=softmax_loss_3rd) 54 | 55 | losses = sum(loss for loss in loss_dict.values()) 56 | 57 | # reduce losses over all GPUs for logging purposes 58 | loss_dict_reduced = reduce_dict(loss_dict) 59 | losses_reduced = sum(loss for loss in loss_dict_reduced.values()) 60 | loss_value = losses_reduced.item() 61 | 62 | if not math.isfinite(loss_value): 63 | print(f"Loss is {loss_value}, stopping training") 64 | print(loss_dict_reduced) 65 | sys.exit(1) 66 | 67 | optimizer.zero_grad() 68 | losses.backward() 69 | if cfg.SOLVER.CLIP_GRADIENTS > 0: 70 | clip_grad_norm_(model.parameters(), cfg.SOLVER.CLIP_GRADIENTS) 71 | optimizer.step() 72 | 73 | if epoch == 0: 74 | warmup_scheduler.step() 75 | 76 | metric_logger.update(loss=loss_value, **loss_dict_reduced) 77 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 78 | if tfboard: 79 | iter = epoch * len(data_loader) + i 80 | for k, v in loss_dict_reduced.items(): 81 | tfboard.add_scalars("train", {k: v}, iter) 82 | 83 | 84 | @torch.no_grad() 85 | def evaluate_performance( 86 | model, gallery_loader, query_loader, device, use_gt=False, use_cache=False, use_cbgm=False, gallery_size=100): 87 | """ 88 | Args: 89 | use_gt (bool, optional): Whether to use GT as detection results to verify the upper 90 | bound of person search performance. Defaults to False. 91 | use_cache (bool, optional): Whether to use the cached features. Defaults to False. 92 | use_cbgm (bool, optional): Whether to use Context Bipartite Graph Matching algorithm. 93 | Defaults to False. 94 | """ 95 | model.eval() 96 | if use_cache: 97 | eval_cache = torch.load("data/eval_cache/eval_cache.pth") 98 | gallery_dets = eval_cache["gallery_dets"] 99 | gallery_feats = eval_cache["gallery_feats"] 100 | query_dets = eval_cache["query_dets"] 101 | query_feats = eval_cache["query_feats"] 102 | query_box_feats = eval_cache["query_box_feats"] 103 | else: 104 | gallery_dets, gallery_feats = [], [] 105 | for images, targets in tqdm(gallery_loader, ncols=0): 106 | images, targets = to_device(images, targets, device) 107 | if not use_gt: 108 | outputs = model(images) 109 | else: 110 | boxes = targets[0]["boxes"] 111 | n_boxes = boxes.size(0) 112 | embeddings = model(images, targets) 113 | outputs = [ 114 | { 115 | "boxes": boxes, 116 | "embeddings": torch.cat(embeddings), 117 | "labels": torch.ones(n_boxes).to(device), 118 | "scores": torch.ones(n_boxes).to(device), 119 | } 120 | ] 121 | 122 | for output in outputs: 123 | box_w_scores = torch.cat([output["boxes"], output["scores"].unsqueeze(1)], dim=1) 124 | gallery_dets.append(box_w_scores.cpu().numpy()) 125 | gallery_feats.append(output["embeddings"].cpu().numpy()) 126 | 127 | # regarding query image as gallery to detect all people 128 | # i.e. query person + surrounding people (context information) 129 | query_dets, query_feats = [], [] 130 | for images, targets in tqdm(query_loader, ncols=0): 131 | images, targets = to_device(images, targets, device) 132 | # targets will be modified in the model, so deepcopy it 133 | outputs = model(images, deepcopy(targets), query_img_as_gallery=True) 134 | 135 | # consistency check 136 | gt_box = targets[0]["boxes"].squeeze() 137 | 138 | assert ( 139 | gt_box - outputs[0]["boxes"][0] 140 | ).sum() <= 0.001, "GT box must be the first one in the detected boxes of query image" 141 | 142 | for output in outputs: 143 | box_w_scores = torch.cat([output["boxes"], output["scores"].unsqueeze(1)], dim=1) 144 | query_dets.append(box_w_scores.cpu().numpy()) 145 | query_feats.append(output["embeddings"].cpu().numpy()) 146 | 147 | # extract the features of query boxes 148 | query_box_feats = [] 149 | for images, targets in tqdm(query_loader, ncols=0): 150 | images, targets = to_device(images, targets, device) 151 | embeddings = model(images, targets) 152 | assert len(embeddings) == 1, "batch size in test phase should be 1" 153 | query_box_feats.append(embeddings[0].cpu().numpy()) 154 | 155 | mkdir("data/eval_cache") 156 | save_dict = { 157 | "gallery_dets": gallery_dets, 158 | "gallery_feats": gallery_feats, 159 | "query_dets": query_dets, 160 | "query_feats": query_feats, 161 | "query_box_feats": query_box_feats, 162 | } 163 | torch.save(save_dict, "data/eval_cache/eval_cache.pth") 164 | 165 | eval_detection(gallery_loader.dataset, gallery_dets, det_thresh=0.01) 166 | eval_search_func = ( 167 | eval_search_cuhk if gallery_loader.dataset.name == "CUHK-SYSU" else eval_search_prw 168 | ) 169 | eval_search_func( 170 | gallery_loader.dataset, 171 | query_loader.dataset, 172 | gallery_dets, 173 | gallery_feats, 174 | query_box_feats, 175 | query_dets, 176 | query_feats, 177 | cbgm=use_cbgm, 178 | gallery_size=gallery_size, 179 | ) 180 | -------------------------------------------------------------------------------- /eval_func.py: -------------------------------------------------------------------------------- 1 | # This file is part of COAT, and is distributed under the 2 | # OSI-approved BSD 3-Clause License. See top-level LICENSE file or 3 | # https://github.com/Kitware/COAT/blob/master/LICENSE for details. 4 | 5 | import os.path as osp 6 | import numpy as np 7 | from scipy.io import loadmat 8 | from sklearn.metrics import average_precision_score 9 | 10 | from utils.km import run_kuhn_munkres 11 | from utils.utils import write_json 12 | 13 | 14 | def _compute_iou(a, b): 15 | x1 = max(a[0], b[0]) 16 | y1 = max(a[1], b[1]) 17 | x2 = min(a[2], b[2]) 18 | y2 = min(a[3], b[3]) 19 | inter = max(0, x2 - x1) * max(0, y2 - y1) 20 | union = (a[2] - a[0]) * (a[3] - a[1]) + (b[2] - b[0]) * (b[3] - b[1]) - inter 21 | return inter * 1.0 / union 22 | 23 | 24 | def eval_detection( 25 | gallery_dataset, gallery_dets, det_thresh=0.5, iou_thresh=0.5, labeled_only=False 26 | ): 27 | """ 28 | gallery_det (list of ndarray): n_det x [x1, y1, x2, y2, score] per image 29 | det_thresh (float): filter out gallery detections whose scores below this 30 | iou_thresh (float): treat as true positive if IoU is above this threshold 31 | labeled_only (bool): filter out unlabeled background people 32 | """ 33 | assert len(gallery_dataset) == len(gallery_dets) 34 | annos = gallery_dataset.annotations 35 | 36 | y_true, y_score = [], [] 37 | count_gt, count_tp = 0, 0 38 | for anno, det in zip(annos, gallery_dets): 39 | gt_boxes = anno["boxes"] 40 | if labeled_only: 41 | # exclude the unlabeled people (pid == 5555) 42 | inds = np.where(anno["pids"].ravel() != 5555)[0] 43 | if len(inds) == 0: 44 | continue 45 | gt_boxes = gt_boxes[inds] 46 | num_gt = gt_boxes.shape[0] 47 | 48 | if det != []: 49 | det = np.asarray(det) 50 | inds = np.where(det[:, 4].ravel() >= det_thresh)[0] 51 | det = det[inds] 52 | num_det = det.shape[0] 53 | else: 54 | num_det = 0 55 | if num_det == 0: 56 | count_gt += num_gt 57 | continue 58 | 59 | ious = np.zeros((num_gt, num_det), dtype=np.float32) 60 | for i in range(num_gt): 61 | for j in range(num_det): 62 | ious[i, j] = _compute_iou(gt_boxes[i], det[j, :4]) 63 | tfmat = ious >= iou_thresh 64 | # for each det, keep only the largest iou of all the gt 65 | for j in range(num_det): 66 | largest_ind = np.argmax(ious[:, j]) 67 | for i in range(num_gt): 68 | if i != largest_ind: 69 | tfmat[i, j] = False 70 | # for each gt, keep only the largest iou of all the det 71 | for i in range(num_gt): 72 | largest_ind = np.argmax(ious[i, :]) 73 | for j in range(num_det): 74 | if j != largest_ind: 75 | tfmat[i, j] = False 76 | for j in range(num_det): 77 | y_score.append(det[j, -1]) 78 | y_true.append(tfmat[:, j].any()) 79 | count_tp += tfmat.sum() 80 | count_gt += num_gt 81 | 82 | det_rate = count_tp * 1.0 / count_gt 83 | ap = average_precision_score(y_true, y_score) * det_rate 84 | 85 | print("{} detection:".format("labeled only" if labeled_only else "all")) 86 | print(" recall = {:.2%}".format(det_rate)) 87 | if not labeled_only: 88 | print(" ap = {:.2%}".format(ap)) 89 | return det_rate, ap 90 | 91 | 92 | def eval_search_cuhk( 93 | gallery_dataset, 94 | query_dataset, 95 | gallery_dets, 96 | gallery_feats, 97 | query_box_feats, 98 | query_dets, 99 | query_feats, 100 | k1=10, 101 | k2=3, 102 | det_thresh=0.5, 103 | cbgm=False, 104 | gallery_size=100, 105 | ): 106 | """ 107 | gallery_dataset/query_dataset: an instance of BaseDataset 108 | gallery_det (list of ndarray): n_det x [x1, x2, y1, y2, score] per image 109 | gallery_feat (list of ndarray): n_det x D features per image 110 | query_feat (list of ndarray): D dimensional features per query image 111 | det_thresh (float): filter out gallery detections whose scores below this 112 | gallery_size (int): gallery size [-1, 50, 100, 500, 1000, 2000, 4000] 113 | -1 for using full set 114 | """ 115 | assert len(gallery_dataset) == len(gallery_dets) 116 | assert len(gallery_dataset) == len(gallery_feats) 117 | assert len(query_dataset) == len(query_box_feats) 118 | 119 | use_full_set = gallery_size == -1 120 | fname = "TestG{}".format(gallery_size if not use_full_set else 50) 121 | protoc = loadmat(osp.join(gallery_dataset.root, "annotation/test/train_test", fname + ".mat")) 122 | protoc = protoc[fname].squeeze() 123 | 124 | # mapping from gallery image to (det, feat) 125 | annos = gallery_dataset.annotations 126 | name_to_det_feat = {} 127 | for anno, det, feat in zip(annos, gallery_dets, gallery_feats): 128 | name = anno["img_name"] 129 | if det != []: 130 | scores = det[:, 4].ravel() 131 | inds = np.where(scores >= det_thresh)[0] 132 | if len(inds) > 0: 133 | name_to_det_feat[name] = (det[inds], feat[inds]) 134 | 135 | aps = [] 136 | accs = [] 137 | topk = [1, 5, 10] 138 | ret = {"image_root": gallery_dataset.img_prefix, "results": []} 139 | for i in range(len(query_dataset)): 140 | y_true, y_score = [], [] 141 | imgs, rois = [], [] 142 | count_gt, count_tp = 0, 0 143 | # get L2-normalized feature vector 144 | feat_q = query_box_feats[i].ravel() 145 | # ignore the query image 146 | query_imname = str(protoc["Query"][i]["imname"][0, 0][0]) 147 | query_roi = protoc["Query"][i]["idlocate"][0, 0][0].astype(np.int32) 148 | query_roi[2:] += query_roi[:2] 149 | query_gt = [] 150 | tested = set([query_imname]) 151 | 152 | name2sim = {} 153 | name2gt = {} 154 | sims = [] 155 | imgs_cbgm = [] 156 | # 1. Go through the gallery samples defined by the protocol 157 | for item in protoc["Gallery"][i].squeeze(): 158 | gallery_imname = str(item[0][0]) 159 | # some contain the query (gt not empty), some not 160 | gt = item[1][0].astype(np.int32) 161 | count_gt += gt.size > 0 162 | # compute distance between query and gallery dets 163 | if gallery_imname not in name_to_det_feat: 164 | continue 165 | det, feat_g = name_to_det_feat[gallery_imname] 166 | # no detection in this gallery, skip it 167 | if det.shape[0] == 0: 168 | continue 169 | # get L2-normalized feature matrix NxD 170 | assert feat_g.size == np.prod(feat_g.shape[:2]) 171 | feat_g = feat_g.reshape(feat_g.shape[:2]) 172 | # compute cosine similarities 173 | sim = feat_g.dot(feat_q).ravel() 174 | 175 | if gallery_imname in name2sim: 176 | continue 177 | name2sim[gallery_imname] = sim 178 | name2gt[gallery_imname] = gt 179 | sims.extend(list(sim)) 180 | imgs_cbgm.extend([gallery_imname] * len(sim)) 181 | # 2. Go through the remaining gallery images if using full set 182 | if use_full_set: 183 | for gallery_imname in gallery_dataset.imgs: 184 | if gallery_imname in tested: 185 | continue 186 | if gallery_imname not in name_to_det_feat: 187 | continue 188 | det, feat_g = name_to_det_feat[gallery_imname] 189 | # get L2-normalized feature matrix NxD 190 | assert feat_g.size == np.prod(feat_g.shape[:2]) 191 | feat_g = feat_g.reshape(feat_g.shape[:2]) 192 | # compute cosine similarities 193 | sim = feat_g.dot(feat_q).ravel() 194 | # guaranteed no target query in these gallery images 195 | label = np.zeros(len(sim), dtype=np.int32) 196 | y_true.extend(list(label)) 197 | y_score.extend(list(sim)) 198 | imgs.extend([gallery_imname] * len(sim)) 199 | rois.extend(list(det)) 200 | 201 | if cbgm: 202 | # -------- Context Bipartite Graph Matching (CBGM) ------- # 203 | sims = np.array(sims) 204 | imgs_cbgm = np.array(imgs_cbgm) 205 | # only process the top-k1 gallery images for efficiency 206 | inds = np.argsort(sims)[-k1:] 207 | imgs_cbgm = set(imgs_cbgm[inds]) 208 | for img in imgs_cbgm: 209 | sim = name2sim[img] 210 | det, feat_g = name_to_det_feat[img] 211 | # only regard the people with top-k2 detection confidence 212 | # in the query image as context information 213 | qboxes = query_dets[i][:k2] 214 | qfeats = query_feats[i][:k2] 215 | assert ( 216 | query_roi - qboxes[0][:4] 217 | ).sum() <= 0.001, "query_roi must be the first one in pboxes" 218 | 219 | # build the bipartite graph and run Kuhn-Munkres (K-M) algorithm 220 | # to find the best match 221 | graph = [] 222 | for indx_i, pfeat in enumerate(qfeats): 223 | for indx_j, gfeat in enumerate(feat_g): 224 | graph.append((indx_i, indx_j, (pfeat * gfeat).sum())) 225 | km_res, max_val = run_kuhn_munkres(graph) 226 | 227 | # revise the similarity between query person and its matching 228 | for indx_i, indx_j, _ in km_res: 229 | # 0 denotes the query roi 230 | if indx_i == 0: 231 | sim[indx_j] = max_val 232 | break 233 | for gallery_imname, sim in name2sim.items(): 234 | gt = name2gt[gallery_imname] 235 | det, feat_g = name_to_det_feat[gallery_imname] 236 | # assign label for each det 237 | label = np.zeros(len(sim), dtype=np.int32) 238 | if gt.size > 0: 239 | w, h = gt[2], gt[3] 240 | gt[2:] += gt[:2] 241 | query_gt.append({"img": str(gallery_imname), "roi": list(map(float, list(gt)))}) 242 | iou_thresh = min(0.5, (w * h * 1.0) / ((w + 10) * (h + 10))) 243 | inds = np.argsort(sim)[::-1] 244 | sim = sim[inds] 245 | det = det[inds] 246 | # only set the first matched det as true positive 247 | for j, roi in enumerate(det[:, :4]): 248 | if _compute_iou(roi, gt) >= iou_thresh: 249 | label[j] = 1 250 | count_tp += 1 251 | break 252 | y_true.extend(list(label)) 253 | y_score.extend(list(sim)) 254 | imgs.extend([gallery_imname] * len(sim)) 255 | rois.extend(list(det)) 256 | tested.add(gallery_imname) 257 | # 3. Compute AP for this query (need to scale by recall rate) 258 | y_score = np.asarray(y_score) 259 | y_true = np.asarray(y_true) 260 | assert count_tp <= count_gt 261 | recall_rate = count_tp * 1.0 / count_gt 262 | ap = 0 if count_tp == 0 else average_precision_score(y_true, y_score) * recall_rate 263 | aps.append(ap) 264 | inds = np.argsort(y_score)[::-1] 265 | y_score = y_score[inds] 266 | y_true = y_true[inds] 267 | accs.append([min(1, sum(y_true[:k])) for k in topk]) 268 | # 4. Save result for JSON dump 269 | new_entry = { 270 | "query_img": str(query_imname), 271 | "query_roi": list(map(float, list(query_roi))), 272 | "query_gt": query_gt, 273 | "gallery": [], 274 | } 275 | # only record wrong results 276 | if int(y_true[0]): 277 | continue 278 | # only save top-10 predictions 279 | for k in range(10): 280 | new_entry["gallery"].append( 281 | { 282 | "img": str(imgs[inds[k]]), 283 | "roi": list(map(float, list(rois[inds[k]]))), 284 | "score": float(y_score[k]), 285 | "correct": int(y_true[k]), 286 | } 287 | ) 288 | ret["results"].append(new_entry) 289 | 290 | print("search ranking:") 291 | print(" mAP = {:.2%}".format(np.mean(aps))) 292 | accs = np.mean(accs, axis=0) 293 | for i, k in enumerate(topk): 294 | print(" top-{:2d} = {:.2%}".format(k, accs[i])) 295 | 296 | write_json(ret, "vis/results.json") 297 | 298 | ret["mAP"] = np.mean(aps) 299 | ret["accs"] = accs 300 | return ret 301 | 302 | 303 | def eval_search_prw( 304 | gallery_dataset, 305 | query_dataset, 306 | gallery_dets, 307 | gallery_feats, 308 | query_box_feats, 309 | query_dets, 310 | query_feats, 311 | k1=30, 312 | k2=4, 313 | det_thresh=0.5, 314 | cbgm=False, 315 | gallery_size=None, # not used in PRW 316 | ignore_cam_id=True, 317 | ): 318 | """ 319 | gallery_det (list of ndarray): n_det x [x1, x2, y1, y2, score] per image 320 | gallery_feat (list of ndarray): n_det x D features per image 321 | query_feat (list of ndarray): D dimensional features per query image 322 | det_thresh (float): filter out gallery detections whose scores below this 323 | gallery_size (int): -1 for using full set 324 | ignore_cam_id (bool): Set to True acoording to CUHK-SYSU, 325 | although it's a common practice to focus on cross-cam match only. 326 | """ 327 | assert len(gallery_dataset) == len(gallery_dets) 328 | assert len(gallery_dataset) == len(gallery_feats) 329 | assert len(query_dataset) == len(query_box_feats) 330 | 331 | annos = gallery_dataset.annotations 332 | name_to_det_feat = {} 333 | for anno, det, feat in zip(annos, gallery_dets, gallery_feats): 334 | name = anno["img_name"] 335 | scores = det[:, 4].ravel() 336 | inds = np.where(scores >= det_thresh)[0] 337 | if len(inds) > 0: 338 | name_to_det_feat[name] = (det[inds], feat[inds]) 339 | 340 | aps = [] 341 | accs = [] 342 | topk = [1, 5, 10] 343 | ret = {"image_root": gallery_dataset.img_prefix, "results": []} 344 | for i in range(len(query_dataset)): 345 | y_true, y_score = [], [] 346 | imgs, rois = [], [] 347 | count_gt, count_tp = 0, 0 348 | 349 | feat_p = query_box_feats[i].ravel() 350 | 351 | query_imname = query_dataset.annotations[i]["img_name"] 352 | query_roi = query_dataset.annotations[i]["boxes"] 353 | query_pid = query_dataset.annotations[i]["pids"] 354 | query_cam = query_dataset.annotations[i]["cam_id"] 355 | 356 | # Find all occurence of this query 357 | gallery_imgs = [] 358 | for x in annos: 359 | if query_pid in x["pids"] and x["img_name"] != query_imname: 360 | gallery_imgs.append(x) 361 | query_gts = {} 362 | for item in gallery_imgs: 363 | query_gts[item["img_name"]] = item["boxes"][item["pids"] == query_pid] 364 | 365 | # Construct gallery set for this query 366 | if ignore_cam_id: 367 | gallery_imgs = [] 368 | for x in annos: 369 | if x["img_name"] != query_imname: 370 | gallery_imgs.append(x) 371 | else: 372 | gallery_imgs = [] 373 | for x in annos: 374 | if x["img_name"] != query_imname and x["cam_id"] != query_cam: 375 | gallery_imgs.append(x) 376 | 377 | name2sim = {} 378 | sims = [] 379 | imgs_cbgm = [] 380 | # 1. Go through all gallery samples 381 | for item in gallery_imgs: 382 | gallery_imname = item["img_name"] 383 | # some contain the query (gt not empty), some not 384 | count_gt += gallery_imname in query_gts 385 | # compute distance between query and gallery dets 386 | if gallery_imname not in name_to_det_feat: 387 | continue 388 | det, feat_g = name_to_det_feat[gallery_imname] 389 | # get L2-normalized feature matrix NxD 390 | assert feat_g.size == np.prod(feat_g.shape[:2]) 391 | feat_g = feat_g.reshape(feat_g.shape[:2]) 392 | # compute cosine similarities 393 | sim = feat_g.dot(feat_p).ravel() 394 | 395 | if gallery_imname in name2sim: 396 | continue 397 | name2sim[gallery_imname] = sim 398 | sims.extend(list(sim)) 399 | imgs_cbgm.extend([gallery_imname] * len(sim)) 400 | 401 | if cbgm: 402 | sims = np.array(sims) 403 | imgs_cbgm = np.array(imgs_cbgm) 404 | inds = np.argsort(sims)[-k1:] 405 | imgs_cbgm = set(imgs_cbgm[inds]) 406 | for img in imgs_cbgm: 407 | sim = name2sim[img] 408 | det, feat_g = name_to_det_feat[img] 409 | qboxes = query_dets[i][:k2] 410 | qfeats = query_feats[i][:k2] 411 | # assert ( 412 | # query_roi - qboxes[0][:4] 413 | # ).sum() <= 0.001, "query_roi must be the first one in pboxes" 414 | 415 | graph = [] 416 | for indx_i, pfeat in enumerate(qfeats): 417 | for indx_j, gfeat in enumerate(feat_g): 418 | graph.append((indx_i, indx_j, (pfeat * gfeat).sum())) 419 | km_res, max_val = run_kuhn_munkres(graph) 420 | 421 | for indx_i, indx_j, _ in km_res: 422 | if indx_i == 0: 423 | sim[indx_j] = max_val 424 | break 425 | for gallery_imname, sim in name2sim.items(): 426 | det, feat_g = name_to_det_feat[gallery_imname] 427 | # assign label for each det 428 | label = np.zeros(len(sim), dtype=np.int32) 429 | if gallery_imname in query_gts: 430 | gt = query_gts[gallery_imname].ravel() 431 | w, h = gt[2] - gt[0], gt[3] - gt[1] 432 | iou_thresh = min(0.5, (w * h * 1.0) / ((w + 10) * (h + 10))) 433 | inds = np.argsort(sim)[::-1] 434 | sim = sim[inds] 435 | det = det[inds] 436 | # only set the first matched det as true positive 437 | for j, roi in enumerate(det[:, :4]): 438 | if _compute_iou(roi, gt) >= iou_thresh: 439 | label[j] = 1 440 | count_tp += 1 441 | break 442 | y_true.extend(list(label)) 443 | y_score.extend(list(sim)) 444 | imgs.extend([gallery_imname] * len(sim)) 445 | rois.extend(list(det)) 446 | 447 | # 2. Compute AP for this query (need to scale by recall rate) 448 | y_score = np.asarray(y_score) 449 | y_true = np.asarray(y_true) 450 | assert count_tp <= count_gt 451 | recall_rate = count_tp * 1.0 / count_gt 452 | ap = 0 if count_tp == 0 else average_precision_score(y_true, y_score) * recall_rate 453 | aps.append(ap) 454 | inds = np.argsort(y_score)[::-1] 455 | y_score = y_score[inds] 456 | y_true = y_true[inds] 457 | accs.append([min(1, sum(y_true[:k])) for k in topk]) 458 | # 4. Save result for JSON dump 459 | new_entry = { 460 | "query_img": str(query_imname), 461 | "query_roi": list(map(float, list(query_roi.squeeze()))), 462 | "query_gt": query_gts, 463 | "gallery": [], 464 | } 465 | # only save top-10 predictions 466 | for k in range(10): 467 | new_entry["gallery"].append( 468 | { 469 | "img": str(imgs[inds[k]]), 470 | "roi": list(map(float, list(rois[inds[k]]))), 471 | "score": float(y_score[k]), 472 | "correct": int(y_true[k]), 473 | } 474 | ) 475 | ret["results"].append(new_entry) 476 | 477 | print("search ranking:") 478 | mAP = np.mean(aps) 479 | print(" mAP = {:.2%}".format(mAP)) 480 | accs = np.mean(accs, axis=0) 481 | for i, k in enumerate(topk): 482 | print(" top-{:2d} = {:.2%}".format(k, accs[i])) 483 | 484 | # write_json(ret, "vis/results.json") 485 | 486 | ret["mAP"] = np.mean(aps) 487 | ret["accs"] = accs 488 | return ret 489 | -------------------------------------------------------------------------------- /loss/__pycache__/oim.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kitware/COAT/369d93418716a4a28a96203ebd012020877db550/loss/__pycache__/oim.cpython-38.pyc -------------------------------------------------------------------------------- /loss/__pycache__/softmax_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kitware/COAT/369d93418716a4a28a96203ebd012020877db550/loss/__pycache__/softmax_loss.cpython-38.pyc -------------------------------------------------------------------------------- /loss/oim.py: -------------------------------------------------------------------------------- 1 | # This file is part of COAT, and is distributed under the 2 | # OSI-approved BSD 3-Clause License. See top-level LICENSE file or 3 | # https://github.com/Kitware/COAT/blob/master/LICENSE for details. 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import autograd, nn 8 | 9 | class OIM(autograd.Function): 10 | @staticmethod 11 | def forward(ctx, inputs, targets, lut, cq, header, momentum): 12 | ctx.save_for_backward(inputs, targets, lut, cq, header, momentum) 13 | outputs_labeled = inputs.mm(lut.t()) 14 | outputs_unlabeled = inputs.mm(cq.t()) 15 | return torch.cat([outputs_labeled, outputs_unlabeled], dim=1) 16 | 17 | @staticmethod 18 | def backward(ctx, grad_outputs): 19 | inputs, targets, lut, cq, header, momentum = ctx.saved_tensors 20 | 21 | grad_inputs = None 22 | if ctx.needs_input_grad[0]: 23 | grad_inputs = grad_outputs.mm(torch.cat([lut, cq], dim=0)) 24 | if grad_inputs.dtype == torch.float16: 25 | grad_inputs = grad_inputs.to(torch.float32) 26 | 27 | for x, y in zip(inputs, targets): 28 | if y < len(lut): 29 | lut[y] = momentum * lut[y] + (1.0 - momentum) * x 30 | lut[y] /= lut[y].norm() 31 | else: 32 | cq[header] = x 33 | header = (header + 1) % cq.size(0) 34 | return grad_inputs, None, None, None, None, None 35 | 36 | 37 | def oim(inputs, targets, lut, cq, header, momentum=0.5): 38 | return OIM.apply(inputs, targets, lut, cq, torch.tensor(header), torch.tensor(momentum)) 39 | 40 | 41 | class OIMLoss(nn.Module): 42 | def __init__(self, num_features, num_pids, num_cq_size, oim_momentum, oim_scalar): 43 | super(OIMLoss, self).__init__() 44 | self.num_features = num_features 45 | self.num_pids = num_pids 46 | self.num_unlabeled = num_cq_size 47 | self.momentum = oim_momentum 48 | self.oim_scalar = oim_scalar 49 | 50 | self.register_buffer("lut", torch.zeros(self.num_pids, self.num_features)) 51 | self.register_buffer("cq", torch.zeros(self.num_unlabeled, self.num_features)) 52 | 53 | self.header_cq = 0 54 | 55 | def forward(self, inputs, roi_label): 56 | # merge into one batch, background label = 0 57 | targets = torch.cat(roi_label) 58 | label = targets - 1 # background label = -1 59 | 60 | inds = label >= 0 61 | 62 | label = label[inds] 63 | inputs = inputs[inds.unsqueeze(1).expand_as(inputs)].view(-1, self.num_features) 64 | 65 | projected = oim(inputs, label, self.lut, self.cq, self.header_cq, momentum=self.momentum) 66 | # projected - Tensor [M, lut+cq], e.g., [M, 482+500]=[M, 982] 67 | 68 | projected *= self.oim_scalar 69 | 70 | self.header_cq = ( 71 | self.header_cq + (label >= self.num_pids).long().sum().item() 72 | ) % self.num_unlabeled 73 | 74 | loss_oim = F.cross_entropy(projected, label, ignore_index=5554) 75 | 76 | return loss_oim, inputs, label 77 | -------------------------------------------------------------------------------- /loss/softmax_loss.py: -------------------------------------------------------------------------------- 1 | # This file is part of COAT, and is distributed under the 2 | # OSI-approved BSD 3-Clause License. See top-level LICENSE file or 3 | # https://github.com/Kitware/COAT/blob/master/LICENSE for details. 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | class SoftmaxLoss(nn.Module): 10 | def __init__(self, cfg): 11 | super(SoftmaxLoss, self).__init__() 12 | 13 | self.feat_dim = cfg.MODEL.EMBEDDING_DIM 14 | self.num_classes = cfg.MODEL.LOSS.LUT_SIZE 15 | 16 | self.bottleneck = nn.BatchNorm1d(self.feat_dim) 17 | self.bottleneck.bias.requires_grad_(False) # no shift 18 | self.classifier = nn.Linear(self.feat_dim, self.num_classes, bias=False) 19 | 20 | self.bottleneck.apply(weights_init_kaiming) 21 | self.classifier.apply(weights_init_classifier) 22 | 23 | def forward(self, inputs, labels): 24 | """ 25 | Args: 26 | inputs: feature matrix with shape (batch_size, feat_dim). 27 | labels: ground truth labels with shape (num_classes). 28 | """ 29 | assert inputs.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)" 30 | 31 | target = labels.clone() 32 | target[target >= self.num_classes] = 5554 33 | 34 | feat = self.bottleneck(inputs) 35 | score = self.classifier(feat) 36 | loss = F.cross_entropy(score, target, ignore_index=5554) 37 | 38 | return loss 39 | 40 | 41 | def weights_init_kaiming(m): 42 | classname = m.__class__.__name__ 43 | if classname.find('Linear') != -1: 44 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 45 | nn.init.constant_(m.bias, 0.0) 46 | elif classname.find('Conv') != -1: 47 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 48 | if m.bias is not None: 49 | nn.init.constant_(m.bias, 0.0) 50 | elif classname.find('BatchNorm') != -1: 51 | if m.affine: 52 | nn.init.constant_(m.weight, 1.0) 53 | nn.init.constant_(m.bias, 0.0) 54 | 55 | 56 | def weights_init_classifier(m): 57 | classname = m.__class__.__name__ 58 | if classname.find('Linear') != -1: 59 | nn.init.normal_(m.weight, std=0.001) 60 | if m.bias: 61 | nn.init.constant_(m.bias, 0.0) 62 | 63 | -------------------------------------------------------------------------------- /models/__pycache__/coat.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kitware/COAT/369d93418716a4a28a96203ebd012020877db550/models/__pycache__/coat.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kitware/COAT/369d93418716a4a28a96203ebd012020877db550/models/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kitware/COAT/369d93418716a4a28a96203ebd012020877db550/models/__pycache__/transformer.cpython-38.pyc -------------------------------------------------------------------------------- /models/coat.py: -------------------------------------------------------------------------------- 1 | # This file is part of COAT, and is distributed under the 2 | # OSI-approved BSD 3-Clause License. See top-level LICENSE file or 3 | # https://github.com/Kitware/COAT/blob/master/LICENSE for details. 4 | 5 | from copy import deepcopy 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.nn import init 10 | from torchvision.models.detection.faster_rcnn import FastRCNNPredictor 11 | from torchvision.models.detection.roi_heads import RoIHeads 12 | from torchvision.models.detection.rpn import AnchorGenerator, RegionProposalNetwork, RPNHead 13 | from torchvision.models.detection.transform import GeneralizedRCNNTransform 14 | from torchvision.ops import MultiScaleRoIAlign 15 | from torchvision.ops import boxes as box_ops 16 | from torchvision.models.detection import _utils as det_utils 17 | 18 | from loss.oim import OIMLoss 19 | from models.resnet import build_resnet 20 | from models.transformer import TransformerHead 21 | 22 | 23 | class COAT(nn.Module): 24 | def __init__(self, cfg): 25 | super(COAT, self).__init__() 26 | 27 | backbone, _ = build_resnet(name="resnet50", pretrained=True) 28 | anchor_generator = AnchorGenerator( 29 | sizes=((32, 64, 128, 256, 512),), aspect_ratios=((0.5, 1.0, 2.0),) 30 | ) 31 | head = RPNHead( 32 | in_channels=backbone.out_channels, 33 | num_anchors=anchor_generator.num_anchors_per_location()[0], 34 | ) 35 | pre_nms_top_n = dict( 36 | training=cfg.MODEL.RPN.PRE_NMS_TOPN_TRAIN, testing=cfg.MODEL.RPN.PRE_NMS_TOPN_TEST 37 | ) 38 | post_nms_top_n = dict( 39 | training=cfg.MODEL.RPN.POST_NMS_TOPN_TRAIN, testing=cfg.MODEL.RPN.POST_NMS_TOPN_TEST 40 | ) 41 | rpn = RegionProposalNetwork( 42 | anchor_generator=anchor_generator, 43 | head=head, 44 | fg_iou_thresh=cfg.MODEL.RPN.POS_THRESH_TRAIN, 45 | bg_iou_thresh=cfg.MODEL.RPN.NEG_THRESH_TRAIN, 46 | batch_size_per_image=cfg.MODEL.RPN.BATCH_SIZE_TRAIN, 47 | positive_fraction=cfg.MODEL.RPN.POS_FRAC_TRAIN, 48 | pre_nms_top_n=pre_nms_top_n, 49 | post_nms_top_n=post_nms_top_n, 50 | nms_thresh=cfg.MODEL.RPN.NMS_THRESH, 51 | ) 52 | 53 | box_head = TransformerHead( 54 | cfg=cfg, 55 | trans_names=cfg.MODEL.TRANSFORMER.NAMES_1ST, 56 | kernel_size=cfg.MODEL.TRANSFORMER.KERNEL_SIZE_1ST, 57 | use_feature_mask=cfg.MODEL.TRANSFORMER.USE_MASK_1ST, 58 | ) 59 | box_head_2nd = TransformerHead( 60 | cfg=cfg, 61 | trans_names=cfg.MODEL.TRANSFORMER.NAMES_2ND, 62 | kernel_size=cfg.MODEL.TRANSFORMER.KERNEL_SIZE_2ND, 63 | use_feature_mask=cfg.MODEL.TRANSFORMER.USE_MASK_2ND, 64 | ) 65 | box_head_3rd = TransformerHead( 66 | cfg=cfg, 67 | trans_names=cfg.MODEL.TRANSFORMER.NAMES_3RD, 68 | kernel_size=cfg.MODEL.TRANSFORMER.KERNEL_SIZE_3RD, 69 | use_feature_mask=cfg.MODEL.TRANSFORMER.USE_MASK_3RD, 70 | ) 71 | 72 | faster_rcnn_predictor = FastRCNNPredictor(2048, 2) 73 | box_roi_pool = MultiScaleRoIAlign( 74 | featmap_names=["feat_res4"], output_size=14, sampling_ratio=2 75 | ) 76 | box_predictor = BBoxRegressor(2048, num_classes=2, bn_neck=cfg.MODEL.ROI_HEAD.BN_NECK) 77 | roi_heads = CascadedROIHeads( 78 | cfg=cfg, 79 | # Cascade Transformer Head 80 | faster_rcnn_predictor=faster_rcnn_predictor, 81 | box_head_2nd=box_head_2nd, 82 | box_head_3rd=box_head_3rd, 83 | # parent class 84 | box_roi_pool=box_roi_pool, 85 | box_head=box_head, 86 | box_predictor=box_predictor, 87 | fg_iou_thresh=cfg.MODEL.ROI_HEAD.POS_THRESH_TRAIN, 88 | bg_iou_thresh=cfg.MODEL.ROI_HEAD.NEG_THRESH_TRAIN, 89 | batch_size_per_image=cfg.MODEL.ROI_HEAD.BATCH_SIZE_TRAIN, 90 | positive_fraction=cfg.MODEL.ROI_HEAD.POS_FRAC_TRAIN, 91 | bbox_reg_weights=None, 92 | score_thresh=cfg.MODEL.ROI_HEAD.SCORE_THRESH_TEST, 93 | nms_thresh=cfg.MODEL.ROI_HEAD.NMS_THRESH_TEST, 94 | detections_per_img=cfg.MODEL.ROI_HEAD.DETECTIONS_PER_IMAGE_TEST, 95 | ) 96 | 97 | transform = GeneralizedRCNNTransform( 98 | min_size=cfg.INPUT.MIN_SIZE, 99 | max_size=cfg.INPUT.MAX_SIZE, 100 | image_mean=[0.485, 0.456, 0.406], 101 | image_std=[0.229, 0.224, 0.225], 102 | ) 103 | 104 | self.backbone = backbone 105 | self.rpn = rpn 106 | self.roi_heads = roi_heads 107 | self.transform = transform 108 | self.eval_feat = cfg.EVAL_FEATURE 109 | 110 | # loss weights 111 | self.lw_rpn_reg = cfg.SOLVER.LW_RPN_REG 112 | self.lw_rpn_cls = cfg.SOLVER.LW_RPN_CLS 113 | self.lw_rcnn_reg_1st = cfg.SOLVER.LW_RCNN_REG_1ST 114 | self.lw_rcnn_cls_1st = cfg.SOLVER.LW_RCNN_CLS_1ST 115 | self.lw_rcnn_reg_2nd = cfg.SOLVER.LW_RCNN_REG_2ND 116 | self.lw_rcnn_cls_2nd = cfg.SOLVER.LW_RCNN_CLS_2ND 117 | self.lw_rcnn_reg_3rd = cfg.SOLVER.LW_RCNN_REG_3RD 118 | self.lw_rcnn_cls_3rd = cfg.SOLVER.LW_RCNN_CLS_3RD 119 | self.lw_rcnn_reid_2nd = cfg.SOLVER.LW_RCNN_REID_2ND 120 | self.lw_rcnn_reid_3rd = cfg.SOLVER.LW_RCNN_REID_3RD 121 | 122 | def inference(self, images, targets=None, query_img_as_gallery=False): 123 | original_image_sizes = [img.shape[-2:] for img in images] 124 | images, targets = self.transform(images, targets) 125 | features = self.backbone(images.tensors) 126 | 127 | if query_img_as_gallery: 128 | assert targets is not None 129 | 130 | if targets is not None and not query_img_as_gallery: 131 | # query 132 | boxes = [t["boxes"] for t in targets] 133 | box_features = self.roi_heads.box_roi_pool(features, boxes, images.image_sizes) 134 | box_features_2nd = self.roi_heads.box_head_2nd(box_features) 135 | embeddings_2nd, _ = self.roi_heads.embedding_head_2nd(box_features_2nd) 136 | box_features_3rd = self.roi_heads.box_head_3rd(box_features) 137 | embeddings_3rd, _ = self.roi_heads.embedding_head_3rd(box_features_3rd) 138 | if self.eval_feat == 'concat': 139 | embeddings = torch.cat((embeddings_2nd, embeddings_3rd), dim=1) 140 | elif self.eval_feat == 'stage2': 141 | embeddings = embeddings_2nd 142 | elif self.eval_feat == 'stage3': 143 | embeddings = embeddings_3rd 144 | else: 145 | raise Exception("Unknown evaluation feature name") 146 | return embeddings.split(1, 0) 147 | else: 148 | # gallery 149 | boxes, _ = self.rpn(images, features, targets) 150 | detections = self.roi_heads(features, boxes, images.image_sizes, targets, query_img_as_gallery)[0] 151 | detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) 152 | return detections 153 | 154 | def forward(self, images, targets=None, query_img_as_gallery=False): 155 | if not self.training: 156 | return self.inference(images, targets, query_img_as_gallery) 157 | images, targets = self.transform(images, targets) 158 | features = self.backbone(images.tensors) 159 | boxes, rpn_losses = self.rpn(images, features, targets) 160 | 161 | _, rcnn_losses, feats_reid_2nd, targets_reid_2nd, feats_reid_3rd, targets_reid_3rd = self.roi_heads(features, boxes, images.image_sizes, targets) 162 | 163 | # rename rpn losses to be consistent with detection losses 164 | rpn_losses["loss_rpn_reg"] = rpn_losses.pop("loss_rpn_box_reg") 165 | rpn_losses["loss_rpn_cls"] = rpn_losses.pop("loss_objectness") 166 | 167 | losses = {} 168 | losses.update(rcnn_losses) 169 | losses.update(rpn_losses) 170 | 171 | # apply loss weights 172 | losses["loss_rpn_reg"] *= self.lw_rpn_reg 173 | losses["loss_rpn_cls"] *= self.lw_rpn_cls 174 | losses["loss_rcnn_reg_1st"] *= self.lw_rcnn_reg_1st 175 | losses["loss_rcnn_cls_1st"] *= self.lw_rcnn_cls_1st 176 | losses["loss_rcnn_reg_2nd"] *= self.lw_rcnn_reg_2nd 177 | losses["loss_rcnn_cls_2nd"] *= self.lw_rcnn_cls_2nd 178 | losses["loss_rcnn_reg_3rd"] *= self.lw_rcnn_reg_3rd 179 | losses["loss_rcnn_cls_3rd"] *= self.lw_rcnn_cls_3rd 180 | losses["loss_rcnn_reid_2nd"] *= self.lw_rcnn_reid_2nd 181 | losses["loss_rcnn_reid_3rd"] *= self.lw_rcnn_reid_3rd 182 | 183 | return losses, feats_reid_2nd, targets_reid_2nd, feats_reid_3rd, targets_reid_3rd 184 | 185 | class CascadedROIHeads(RoIHeads): 186 | ''' 187 | https://github.com/pytorch/vision/blob/master/torchvision/models/detection/roi_heads.py 188 | ''' 189 | def __init__( 190 | self, 191 | cfg, 192 | faster_rcnn_predictor, 193 | box_head_2nd, 194 | box_head_3rd, 195 | *args, 196 | **kwargs 197 | ): 198 | super(CascadedROIHeads, self).__init__(*args, **kwargs) 199 | 200 | # ROI head 201 | self.use_diff_thresh=cfg.MODEL.ROI_HEAD.USE_DIFF_THRESH 202 | self.nms_thresh_1st = cfg.MODEL.ROI_HEAD.NMS_THRESH_TEST_1ST 203 | self.nms_thresh_2nd = cfg.MODEL.ROI_HEAD.NMS_THRESH_TEST_2ND 204 | self.nms_thresh_3rd = cfg.MODEL.ROI_HEAD.NMS_THRESH_TEST_3RD 205 | self.fg_iou_thresh_1st = cfg.MODEL.ROI_HEAD.POS_THRESH_TRAIN 206 | self.bg_iou_thresh_1st = cfg.MODEL.ROI_HEAD.NEG_THRESH_TRAIN 207 | self.fg_iou_thresh_2nd = cfg.MODEL.ROI_HEAD.POS_THRESH_TRAIN_2ND 208 | self.bg_iou_thresh_2nd = cfg.MODEL.ROI_HEAD.NEG_THRESH_TRAIN_2ND 209 | self.fg_iou_thresh_3rd = cfg.MODEL.ROI_HEAD.POS_THRESH_TRAIN_3RD 210 | self.bg_iou_thresh_3rd = cfg.MODEL.ROI_HEAD.NEG_THRESH_TRAIN_3RD 211 | 212 | # Regression head 213 | self.box_predictor_1st = faster_rcnn_predictor 214 | self.box_predictor_2nd = self.box_predictor 215 | self.box_predictor_3rd = deepcopy(self.box_predictor) 216 | 217 | # Transformer head 218 | self.box_head_1st = self.box_head 219 | self.box_head_2nd = box_head_2nd 220 | self.box_head_3rd = box_head_3rd 221 | 222 | # feature mask 223 | self.use_feature_mask = cfg.MODEL.USE_FEATURE_MASK 224 | self.feature_mask_size = cfg.MODEL.FEATURE_MASK_SIZE 225 | 226 | # Feature embedding 227 | embedding_dim = cfg.MODEL.EMBEDDING_DIM 228 | self.embedding_head_2nd = NormAwareEmbedding(featmap_names=["before_trans", "after_trans"], in_channels=[1024, 2048], dim=embedding_dim) 229 | self.embedding_head_3rd = deepcopy(self.embedding_head_2nd) 230 | 231 | # OIM 232 | num_pids = cfg.MODEL.LOSS.LUT_SIZE 233 | num_cq_size = cfg.MODEL.LOSS.CQ_SIZE 234 | oim_momentum = cfg.MODEL.LOSS.OIM_MOMENTUM 235 | oim_scalar = cfg.MODEL.LOSS.OIM_SCALAR 236 | self.reid_loss_2nd = OIMLoss(embedding_dim, num_pids, num_cq_size, oim_momentum, oim_scalar) 237 | self.reid_loss_3rd = deepcopy(self.reid_loss_2nd) 238 | 239 | # rename the method inherited from parent class 240 | self.postprocess_proposals = self.postprocess_detections 241 | 242 | # evaluation 243 | self.eval_feat = cfg.EVAL_FEATURE 244 | 245 | def forward(self, features, boxes, image_shapes, targets=None, query_img_as_gallery=False): 246 | """ 247 | Arguments: 248 | features (List[Tensor]) 249 | boxes (List[Tensor[N, 4]]) 250 | image_shapes (List[Tuple[H, W]]) 251 | targets (List[Dict]) 252 | """ 253 | cws = True 254 | gt_det_2nd = None 255 | gt_det_3rd = None 256 | feats_reid_2nd = None 257 | feats_reid_3rd = None 258 | targets_reid_2nd = None 259 | targets_reid_3rd = None 260 | 261 | if self.training: 262 | if self.use_diff_thresh: 263 | self.proposal_matcher = det_utils.Matcher( 264 | self.fg_iou_thresh_1st, 265 | self.bg_iou_thresh_1st, 266 | allow_low_quality_matches=False) 267 | boxes, _, box_pid_labels_1st, box_reg_targets_1st = self.select_training_samples( 268 | boxes, targets 269 | ) 270 | 271 | # ------------------- The first stage ------------------ # 272 | box_features_1st = self.box_roi_pool(features, boxes, image_shapes) 273 | box_features_1st = self.box_head_1st(box_features_1st) 274 | box_cls_scores_1st, box_regs_1st = self.box_predictor_1st(box_features_1st["after_trans"]) 275 | 276 | if self.training: 277 | boxes = self.get_boxes(box_regs_1st, boxes, image_shapes) 278 | boxes = [boxes_per_image.detach() for boxes_per_image in boxes] 279 | if self.use_diff_thresh: 280 | self.proposal_matcher = det_utils.Matcher( 281 | self.fg_iou_thresh_2nd, 282 | self.bg_iou_thresh_2nd, 283 | allow_low_quality_matches=False) 284 | boxes, _, box_pid_labels_2nd, box_reg_targets_2nd = self.select_training_samples(boxes, targets) 285 | else: 286 | orig_thresh = self.nms_thresh # 0.4 287 | self.nms_thresh = self.nms_thresh_1st 288 | boxes, scores, _ = self.postprocess_proposals( 289 | box_cls_scores_1st, box_regs_1st, boxes, image_shapes 290 | ) 291 | 292 | if not self.training and query_img_as_gallery: 293 | # When regarding the query image as gallery, GT boxes may be excluded 294 | # from detected boxes. To avoid this, we compulsorily include GT in the 295 | # detection results. Additionally, CWS should be disabled as the 296 | # confidences of these people in query image are 1 297 | cws = False 298 | gt_box = [targets[0]["boxes"]] 299 | gt_box_features = self.box_roi_pool(features, gt_box, image_shapes) 300 | gt_box_features = self.box_head_2nd(gt_box_features) 301 | embeddings, _ = self.embedding_head_2nd(gt_box_features) 302 | gt_det_2nd = {"boxes": targets[0]["boxes"], "embeddings": embeddings} 303 | 304 | # no detection predicted by Faster R-CNN head in test phase 305 | if boxes[0].shape[0] == 0: 306 | assert not self.training 307 | boxes = gt_det_2nd["boxes"] if gt_det_2nd else torch.zeros(0, 4) 308 | labels = torch.ones(1).type_as(boxes) if gt_det_2nd else torch.zeros(0) 309 | scores = torch.ones(1).type_as(boxes) if gt_det_2nd else torch.zeros(0) 310 | if self.eval_feat == 'concat': 311 | embeddings = torch.cat((gt_det_2nd["embeddings"], gt_det_2nd["embeddings"]), dim=1) if gt_det_2nd else torch.zeros(0, 512) 312 | elif self.eval_feat == 'stage2' or self.eval_feat == 'stage3': 313 | embeddings = gt_det_2nd["embeddings"] if gt_det_2nd else torch.zeros(0, 256) 314 | else: 315 | raise Exception("Unknown evaluation feature name") 316 | return [dict(boxes=boxes, labels=labels, scores=scores, embeddings=embeddings)], [] 317 | 318 | # --------------------- The second stage -------------------- # 319 | box_features = self.box_roi_pool(features, boxes, image_shapes) 320 | box_features = self.box_head_2nd(box_features) 321 | box_regs_2nd = self.box_predictor_2nd(box_features["after_trans"]) 322 | box_embeddings_2nd, box_cls_scores_2nd = self.embedding_head_2nd(box_features) 323 | if box_cls_scores_2nd.dim() == 0: 324 | box_cls_scores_2nd = box_cls_scores_2nd.unsqueeze(0) 325 | 326 | if self.training: 327 | boxes = self.get_boxes(box_regs_2nd, boxes, image_shapes) 328 | boxes = [boxes_per_image.detach() for boxes_per_image in boxes] 329 | if self.use_diff_thresh: 330 | self.proposal_matcher = det_utils.Matcher( 331 | self.fg_iou_thresh_3rd, 332 | self.bg_iou_thresh_3rd, 333 | allow_low_quality_matches=False) 334 | boxes, _, box_pid_labels_3rd, box_reg_targets_3rd = self.select_training_samples(boxes, targets) 335 | else: 336 | self.nms_thresh = self.nms_thresh_2nd 337 | if self.eval_feat != 'stage2': 338 | boxes, scores, _, _ = self.postprocess_boxes( 339 | box_cls_scores_2nd, 340 | box_regs_2nd, 341 | box_embeddings_2nd, 342 | boxes, 343 | image_shapes, 344 | fcs=scores, 345 | gt_det=None, 346 | cws=cws, 347 | ) 348 | 349 | if not self.training and query_img_as_gallery and self.eval_feat != 'stage2': 350 | cws = False 351 | gt_box = [targets[0]["boxes"]] 352 | gt_box_features = self.box_roi_pool(features, gt_box, image_shapes) 353 | gt_box_features = self.box_head_3rd(gt_box_features) 354 | embeddings, _ = self.embedding_head_3rd(gt_box_features) 355 | gt_det_3rd = {"boxes": targets[0]["boxes"], "embeddings": embeddings} 356 | 357 | # no detection predicted by Faster R-CNN head in test phase 358 | if boxes[0].shape[0] == 0 and self.eval_feat != 'stage2': 359 | assert not self.training 360 | boxes = gt_det_3rd["boxes"] if gt_det_3rd else torch.zeros(0, 4) 361 | labels = torch.ones(1).type_as(boxes) if gt_det_3rd else torch.zeros(0) 362 | scores = torch.ones(1).type_as(boxes) if gt_det_3rd else torch.zeros(0) 363 | if self.eval_feat == 'concat': 364 | embeddings = torch.cat((gt_det_2nd["embeddings"], gt_det_3rd["embeddings"]), dim=1) if gt_det_3rd else torch.zeros(0, 512) 365 | elif self.eval_feat == 'stage3': 366 | embeddings = gt_det_2nd["embeddings"] if gt_det_3rd else torch.zeros(0, 256) 367 | else: 368 | raise Exception("Unknown evaluation feature name") 369 | return [dict(boxes=boxes, labels=labels, scores=scores, embeddings=embeddings)], [] 370 | 371 | # --------------------- The third stage -------------------- # 372 | box_features = self.box_roi_pool(features, boxes, image_shapes) 373 | 374 | if not self.training: 375 | box_features_2nd = self.box_head_2nd(box_features) 376 | box_embeddings_2nd, _ = self.embedding_head_2nd(box_features_2nd) 377 | 378 | box_features = self.box_head_3rd(box_features) 379 | box_regs_3rd = self.box_predictor_3rd(box_features["after_trans"]) 380 | box_embeddings_3rd, box_cls_scores_3rd = self.embedding_head_3rd(box_features) 381 | if box_cls_scores_3rd.dim() == 0: 382 | box_cls_scores_3rd = box_cls_scores_3rd.unsqueeze(0) 383 | 384 | result, losses = [], {} 385 | if self.training: 386 | box_labels_1st = [y.clamp(0, 1) for y in box_pid_labels_1st] 387 | box_labels_2nd = [y.clamp(0, 1) for y in box_pid_labels_2nd] 388 | box_labels_3rd = [y.clamp(0, 1) for y in box_pid_labels_3rd] 389 | losses = detection_losses( 390 | box_cls_scores_1st, 391 | box_regs_1st, 392 | box_labels_1st, 393 | box_reg_targets_1st, 394 | box_cls_scores_2nd, 395 | box_regs_2nd, 396 | box_labels_2nd, 397 | box_reg_targets_2nd, 398 | box_cls_scores_3rd, 399 | box_regs_3rd, 400 | box_labels_3rd, 401 | box_reg_targets_3rd, 402 | ) 403 | 404 | loss_rcnn_reid_2nd, feats_reid_2nd, targets_reid_2nd = self.reid_loss_2nd(box_embeddings_2nd, box_pid_labels_2nd) 405 | loss_rcnn_reid_3rd, feats_reid_3rd, targets_reid_3rd = self.reid_loss_3rd(box_embeddings_3rd, box_pid_labels_3rd) 406 | losses.update(loss_rcnn_reid_2nd=loss_rcnn_reid_2nd) 407 | losses.update(loss_rcnn_reid_3rd=loss_rcnn_reid_3rd) 408 | else: 409 | if self.eval_feat == 'stage2': 410 | boxes, scores, embeddings_2nd, labels = self.postprocess_boxes( 411 | box_cls_scores_2nd, 412 | box_regs_2nd, 413 | box_embeddings_2nd, 414 | boxes, 415 | image_shapes, 416 | fcs=scores, 417 | gt_det=gt_det_2nd, 418 | cws=cws, 419 | ) 420 | else: 421 | self.nms_thresh = self.nms_thresh_3rd 422 | _, _, embeddings_2nd, _ = self.postprocess_boxes( 423 | box_cls_scores_3rd, 424 | box_regs_3rd, 425 | box_embeddings_2nd, 426 | boxes, 427 | image_shapes, 428 | fcs=scores, 429 | gt_det=gt_det_2nd, 430 | cws=cws, 431 | ) 432 | boxes, scores, embeddings_3rd, labels = self.postprocess_boxes( 433 | box_cls_scores_3rd, 434 | box_regs_3rd, 435 | box_embeddings_3rd, 436 | boxes, 437 | image_shapes, 438 | fcs=scores, 439 | gt_det=gt_det_3rd, 440 | cws=cws, 441 | ) 442 | # set to original thresh after finishing postprocess 443 | self.nms_thresh = orig_thresh 444 | 445 | num_images = len(boxes) 446 | for i in range(num_images): 447 | if self.eval_feat == 'concat': 448 | embeddings = torch.cat((embeddings_2nd[i],embeddings_3rd[i]), dim=1) 449 | elif self.eval_feat == 'stage2': 450 | embeddings = embeddings_2nd[i] 451 | elif self.eval_feat == 'stage3': 452 | embeddings = embeddings_3rd[i] 453 | else: 454 | raise Exception("Unknown evaluation feature name") 455 | result.append( 456 | dict( 457 | boxes=boxes[i], 458 | labels=labels[i], 459 | scores=scores[i], 460 | embeddings=embeddings 461 | ) 462 | ) 463 | 464 | return result, losses, feats_reid_2nd, targets_reid_2nd, feats_reid_3rd, targets_reid_3rd 465 | 466 | def get_boxes(self, box_regression, proposals, image_shapes): 467 | """ 468 | Get boxes from proposals. 469 | """ 470 | boxes_per_image = [len(boxes_in_image) for boxes_in_image in proposals] 471 | pred_boxes = self.box_coder.decode(box_regression, proposals) 472 | pred_boxes = pred_boxes.split(boxes_per_image, 0) 473 | 474 | all_boxes = [] 475 | for boxes, image_shape in zip(pred_boxes, image_shapes): 476 | boxes = box_ops.clip_boxes_to_image(boxes, image_shape) 477 | # remove predictions with the background label 478 | boxes = boxes[:, 1:].reshape(-1, 4) 479 | all_boxes.append(boxes) 480 | 481 | return all_boxes 482 | 483 | def postprocess_boxes( 484 | self, 485 | class_logits, 486 | box_regression, 487 | embeddings, 488 | proposals, 489 | image_shapes, 490 | fcs=None, 491 | gt_det=None, 492 | cws=True, 493 | ): 494 | """ 495 | Similar to RoIHeads.postprocess_detections, but can handle embeddings and implement 496 | First Classification Score (FCS). 497 | """ 498 | device = class_logits.device 499 | 500 | boxes_per_image = [len(boxes_in_image) for boxes_in_image in proposals] 501 | pred_boxes = self.box_coder.decode(box_regression, proposals) 502 | 503 | if fcs is not None: 504 | # Fist Classification Score (FCS) 505 | pred_scores = fcs[0] 506 | else: 507 | pred_scores = torch.sigmoid(class_logits) 508 | if cws: 509 | # Confidence Weighted Similarity (CWS) 510 | embeddings = embeddings * pred_scores.view(-1, 1) 511 | 512 | # split boxes and scores per image 513 | pred_boxes = pred_boxes.split(boxes_per_image, 0) 514 | pred_scores = pred_scores.split(boxes_per_image, 0) 515 | pred_embeddings = embeddings.split(boxes_per_image, 0) 516 | 517 | all_boxes = [] 518 | all_scores = [] 519 | all_labels = [] 520 | all_embeddings = [] 521 | for boxes, scores, embeddings, image_shape in zip( 522 | pred_boxes, pred_scores, pred_embeddings, image_shapes 523 | ): 524 | boxes = box_ops.clip_boxes_to_image(boxes, image_shape) 525 | 526 | # create labels for each prediction 527 | labels = torch.ones(scores.size(0), device=device) 528 | 529 | # remove predictions with the background label 530 | boxes = boxes[:, 1:] 531 | scores = scores.unsqueeze(1) 532 | labels = labels.unsqueeze(1) 533 | 534 | # batch everything, by making every class prediction be a separate instance 535 | boxes = boxes.reshape(-1, 4) 536 | scores = scores.flatten() 537 | labels = labels.flatten() 538 | embeddings = embeddings.reshape(-1, self.embedding_head_2nd.dim) 539 | 540 | # remove low scoring boxes 541 | inds = torch.nonzero(scores > self.score_thresh).squeeze(1) 542 | boxes, scores, labels, embeddings = ( 543 | boxes[inds], 544 | scores[inds], 545 | labels[inds], 546 | embeddings[inds], 547 | ) 548 | 549 | # remove empty boxes 550 | keep = box_ops.remove_small_boxes(boxes, min_size=1e-2) 551 | boxes, scores, labels, embeddings = ( 552 | boxes[keep], 553 | scores[keep], 554 | labels[keep], 555 | embeddings[keep], 556 | ) 557 | 558 | if gt_det is not None: 559 | # include GT into the detection results 560 | boxes = torch.cat((boxes, gt_det["boxes"]), dim=0) 561 | labels = torch.cat((labels, torch.tensor([1.0]).to(device)), dim=0) 562 | scores = torch.cat((scores, torch.tensor([1.0]).to(device)), dim=0) 563 | embeddings = torch.cat((embeddings, gt_det["embeddings"]), dim=0) 564 | 565 | # non-maximum suppression, independently done per class 566 | keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh) 567 | # keep only topk scoring predictions 568 | keep = keep[: self.detections_per_img] 569 | boxes, scores, labels, embeddings = ( 570 | boxes[keep], 571 | scores[keep], 572 | labels[keep], 573 | embeddings[keep], 574 | ) 575 | 576 | all_boxes.append(boxes) 577 | all_scores.append(scores) 578 | all_labels.append(labels) 579 | all_embeddings.append(embeddings) 580 | 581 | return all_boxes, all_scores, all_embeddings, all_labels 582 | 583 | 584 | class NormAwareEmbedding(nn.Module): 585 | """ 586 | Implements the Norm-Aware Embedding proposed in 587 | Chen, Di, et al. "Norm-aware embedding for efficient person search." CVPR 2020. 588 | """ 589 | 590 | def __init__(self, featmap_names=["feat_res4", "feat_res5"], in_channels=[1024, 2048], dim=256): 591 | super(NormAwareEmbedding, self).__init__() 592 | self.featmap_names = featmap_names 593 | self.in_channels = in_channels 594 | self.dim = dim 595 | 596 | self.projectors = nn.ModuleDict() 597 | indv_dims = self._split_embedding_dim() 598 | for ftname, in_channel, indv_dim in zip(self.featmap_names, self.in_channels, indv_dims): 599 | proj = nn.Sequential(nn.Linear(in_channel, indv_dim), nn.BatchNorm1d(indv_dim)) 600 | init.normal_(proj[0].weight, std=0.01) 601 | init.normal_(proj[1].weight, std=0.01) 602 | init.constant_(proj[0].bias, 0) 603 | init.constant_(proj[1].bias, 0) 604 | self.projectors[ftname] = proj 605 | 606 | self.rescaler = nn.BatchNorm1d(1, affine=True) 607 | 608 | def forward(self, featmaps): 609 | """ 610 | Arguments: 611 | featmaps: OrderedDict[Tensor], and in featmap_names you can choose which 612 | featmaps to use 613 | Returns: 614 | tensor of size (BatchSize, dim), L2 normalized embeddings. 615 | tensor of size (BatchSize, ) rescaled norm of embeddings, as class_logits. 616 | """ 617 | assert len(featmaps) == len(self.featmap_names) 618 | if len(featmaps) == 1: 619 | k, v = featmaps.items()[0] 620 | v = self._flatten_fc_input(v) 621 | embeddings = self.projectors[k](v) 622 | norms = embeddings.norm(2, 1, keepdim=True) 623 | embeddings = embeddings / norms.expand_as(embeddings).clamp(min=1e-12) 624 | norms = self.rescaler(norms).squeeze() 625 | return embeddings, norms 626 | else: 627 | outputs = [] 628 | for k, v in featmaps.items(): 629 | v = self._flatten_fc_input(v) 630 | outputs.append(self.projectors[k](v)) 631 | embeddings = torch.cat(outputs, dim=1) 632 | norms = embeddings.norm(2, 1, keepdim=True) 633 | embeddings = embeddings / norms.expand_as(embeddings).clamp(min=1e-12) 634 | norms = self.rescaler(norms).squeeze() 635 | return embeddings, norms 636 | 637 | def _flatten_fc_input(self, x): 638 | if x.ndimension() == 4: 639 | assert list(x.shape[2:]) == [1, 1] 640 | return x.flatten(start_dim=1) 641 | return x 642 | 643 | def _split_embedding_dim(self): 644 | parts = len(self.in_channels) 645 | tmp = [self.dim // parts] * parts 646 | if sum(tmp) == self.dim: 647 | return tmp 648 | else: 649 | res = self.dim % parts 650 | for i in range(1, res + 1): 651 | tmp[-i] += 1 652 | assert sum(tmp) == self.dim 653 | return tmp 654 | 655 | 656 | class BBoxRegressor(nn.Module): 657 | """ 658 | Bounding box regression layer. 659 | """ 660 | 661 | def __init__(self, in_channels, num_classes=2, bn_neck=True): 662 | """ 663 | Args: 664 | in_channels (int): Input channels. 665 | num_classes (int, optional): Defaults to 2 (background and pedestrian). 666 | bn_neck (bool, optional): Whether to use BN after Linear. Defaults to True. 667 | """ 668 | super(BBoxRegressor, self).__init__() 669 | if bn_neck: 670 | self.bbox_pred = nn.Sequential( 671 | nn.Linear(in_channels, 4 * num_classes), nn.BatchNorm1d(4 * num_classes) 672 | ) 673 | init.normal_(self.bbox_pred[0].weight, std=0.01) 674 | init.normal_(self.bbox_pred[1].weight, std=0.01) 675 | init.constant_(self.bbox_pred[0].bias, 0) 676 | init.constant_(self.bbox_pred[1].bias, 0) 677 | else: 678 | self.bbox_pred = nn.Linear(in_channels, 4 * num_classes) 679 | init.normal_(self.bbox_pred.weight, std=0.01) 680 | init.constant_(self.bbox_pred.bias, 0) 681 | 682 | def forward(self, x): 683 | if x.ndimension() == 4: 684 | if list(x.shape[2:]) != [1, 1]: 685 | x = F.adaptive_avg_pool2d(x, output_size=1) 686 | x = x.flatten(start_dim=1) 687 | bbox_deltas = self.bbox_pred(x) 688 | return bbox_deltas 689 | 690 | 691 | def detection_losses( 692 | box_cls_scores_1st, 693 | box_regs_1st, 694 | box_labels_1st, 695 | box_reg_targets_1st, 696 | box_cls_scores_2nd, 697 | box_regs_2nd, 698 | box_labels_2nd, 699 | box_reg_targets_2nd, 700 | box_cls_scores_3rd, 701 | box_regs_3rd, 702 | box_labels_3rd, 703 | box_reg_targets_3rd, 704 | ): 705 | # --------------------- The first stage -------------------- # 706 | box_labels_1st = torch.cat(box_labels_1st, dim=0) 707 | box_reg_targets_1st = torch.cat(box_reg_targets_1st, dim=0) 708 | loss_rcnn_cls_1st = F.cross_entropy(box_cls_scores_1st, box_labels_1st) 709 | 710 | # get indices that correspond to the regression targets for the 711 | # corresponding ground truth labels, to be used with advanced indexing 712 | sampled_pos_inds_subset = torch.nonzero(box_labels_1st > 0).squeeze(1) 713 | labels_pos = box_labels_1st[sampled_pos_inds_subset] 714 | N = box_cls_scores_1st.size(0) 715 | box_regs_1st = box_regs_1st.reshape(N, -1, 4) 716 | 717 | loss_rcnn_reg_1st = F.smooth_l1_loss( 718 | box_regs_1st[sampled_pos_inds_subset, labels_pos], 719 | box_reg_targets_1st[sampled_pos_inds_subset], 720 | reduction="sum", 721 | ) 722 | loss_rcnn_reg_1st = loss_rcnn_reg_1st / box_labels_1st.numel() 723 | 724 | # --------------------- The second stage -------------------- # 725 | box_labels_2nd = torch.cat(box_labels_2nd, dim=0) 726 | box_reg_targets_2nd = torch.cat(box_reg_targets_2nd, dim=0) 727 | loss_rcnn_cls_2nd = F.binary_cross_entropy_with_logits(box_cls_scores_2nd, box_labels_2nd.float()) 728 | 729 | sampled_pos_inds_subset = torch.nonzero(box_labels_2nd > 0).squeeze(1) 730 | labels_pos = box_labels_2nd[sampled_pos_inds_subset] 731 | N = box_cls_scores_2nd.size(0) 732 | box_regs_2nd = box_regs_2nd.reshape(N, -1, 4) 733 | 734 | loss_rcnn_reg_2nd = F.smooth_l1_loss( 735 | box_regs_2nd[sampled_pos_inds_subset, labels_pos], 736 | box_reg_targets_2nd[sampled_pos_inds_subset], 737 | reduction="sum", 738 | ) 739 | loss_rcnn_reg_2nd = loss_rcnn_reg_2nd / box_labels_2nd.numel() 740 | 741 | # --------------------- The third stage -------------------- # 742 | box_labels_3rd = torch.cat(box_labels_3rd, dim=0) 743 | box_reg_targets_3rd = torch.cat(box_reg_targets_3rd, dim=0) 744 | loss_rcnn_cls_3rd = F.binary_cross_entropy_with_logits(box_cls_scores_3rd, box_labels_3rd.float()) 745 | 746 | sampled_pos_inds_subset = torch.nonzero(box_labels_3rd > 0).squeeze(1) 747 | labels_pos = box_labels_3rd[sampled_pos_inds_subset] 748 | N = box_cls_scores_3rd.size(0) 749 | box_regs_3rd = box_regs_3rd.reshape(N, -1, 4) 750 | 751 | loss_rcnn_reg_3rd = F.smooth_l1_loss( 752 | box_regs_3rd[sampled_pos_inds_subset, labels_pos], 753 | box_reg_targets_3rd[sampled_pos_inds_subset], 754 | reduction="sum", 755 | ) 756 | loss_rcnn_reg_3rd = loss_rcnn_reg_3rd / box_labels_3rd.numel() 757 | 758 | return dict( 759 | loss_rcnn_cls_1st=loss_rcnn_cls_1st, 760 | loss_rcnn_reg_1st=loss_rcnn_reg_1st, 761 | loss_rcnn_cls_2nd=loss_rcnn_cls_2nd, 762 | loss_rcnn_reg_2nd=loss_rcnn_reg_2nd, 763 | loss_rcnn_cls_3rd=loss_rcnn_cls_3rd, 764 | loss_rcnn_reg_3rd=loss_rcnn_reg_3rd, 765 | ) 766 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | # This file is part of COAT, and is distributed under the 2 | # OSI-approved BSD 3-Clause License. See top-level LICENSE file or 3 | # https://github.com/Kitware/COAT/blob/master/LICENSE for details. 4 | 5 | from collections import OrderedDict 6 | import torch.nn.functional as F 7 | import torchvision 8 | from torch import nn 9 | 10 | class Backbone(nn.Sequential): 11 | def __init__(self, resnet): 12 | super(Backbone, self).__init__( 13 | OrderedDict( 14 | [ 15 | ["conv1", resnet.conv1], 16 | ["bn1", resnet.bn1], 17 | ["relu", resnet.relu], 18 | ["maxpool", resnet.maxpool], 19 | ["layer1", resnet.layer1], # res2 20 | ["layer2", resnet.layer2], # res3 21 | ["layer3", resnet.layer3], # res4 22 | ] 23 | ) 24 | ) 25 | self.out_channels = 1024 26 | 27 | def forward(self, x): 28 | # using the forward method from nn.Sequential 29 | feat = super(Backbone, self).forward(x) 30 | return OrderedDict([["feat_res4", feat]]) 31 | 32 | 33 | class Res5Head(nn.Sequential): 34 | def __init__(self, resnet): 35 | super(Res5Head, self).__init__(OrderedDict([["layer4", resnet.layer4]])) # res5 36 | self.out_channels = [1024, 2048] 37 | 38 | def forward(self, x): 39 | feat = super(Res5Head, self).forward(x) 40 | x = F.adaptive_max_pool2d(x, 1) 41 | feat = F.adaptive_max_pool2d(feat, 1) 42 | return OrderedDict([["feat_res4", x], ["feat_res5", feat]]) 43 | 44 | 45 | def build_resnet(name="resnet50", pretrained=True): 46 | resnet = torchvision.models.resnet.__dict__[name](pretrained=pretrained) 47 | 48 | # freeze layers 49 | resnet.conv1.weight.requires_grad_(False) 50 | resnet.bn1.weight.requires_grad_(False) 51 | resnet.bn1.bias.requires_grad_(False) 52 | 53 | return Backbone(resnet), Res5Head(resnet) 54 | -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | # This file is part of COAT, and is distributed under the 2 | # OSI-approved BSD 3-Clause License. See top-level LICENSE file or 3 | # https://github.com/Kitware/COAT/blob/master/LICENSE for details. 4 | 5 | import math 6 | import random 7 | from functools import reduce 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from utils.mask import exchange_token, exchange_patch, get_mask_box, jigsaw_token, cutout_patch, erase_patch, mixup_patch, jigsaw_patch 12 | 13 | 14 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 15 | """1x1 convolution""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 17 | 18 | 19 | class TransformerHead(nn.Module): 20 | def __init__( 21 | self, 22 | cfg, 23 | trans_names, 24 | kernel_size, 25 | use_feature_mask, 26 | ): 27 | super(TransformerHead, self).__init__() 28 | d_model = cfg.MODEL.TRANSFORMER.DIM_MODEL 29 | 30 | # Mask parameters 31 | self.use_feature_mask = use_feature_mask 32 | mask_shape = cfg.MODEL.MASK_SHAPE 33 | mask_size = cfg.MODEL.MASK_SIZE 34 | mask_mode = cfg.MODEL.MASK_MODE 35 | 36 | self.bypass_mask = exchange_patch(mask_shape, mask_size, mask_mode) 37 | self.get_mask_box = get_mask_box(mask_shape, mask_size, mask_mode) 38 | 39 | self.transformer_encoder = Transformers( 40 | cfg=cfg, 41 | trans_names=trans_names, 42 | kernel_size=kernel_size, 43 | use_feature_mask=use_feature_mask, 44 | ) 45 | self.conv0 = conv1x1(1024, 1024) 46 | self.conv1 = conv1x1(1024, d_model) 47 | self.conv2 = conv1x1(d_model, 2048) 48 | 49 | def forward(self, box_features): 50 | mask_box = self.get_mask_box(box_features) 51 | 52 | if self.use_feature_mask: 53 | skip_features = self.conv0(box_features) 54 | if self.training: 55 | skip_features = self.bypass_mask(skip_features) 56 | else: 57 | skip_features = box_features 58 | 59 | trans_features = {} 60 | trans_features["before_trans"] = F.adaptive_max_pool2d(skip_features, 1) 61 | box_features = self.conv1(box_features) 62 | box_features = self.transformer_encoder((box_features,mask_box)) 63 | box_features = self.conv2(box_features) 64 | trans_features["after_trans"] = F.adaptive_max_pool2d(box_features, 1) 65 | 66 | return trans_features 67 | 68 | 69 | class Transformers(nn.Module): 70 | def __init__( 71 | self, 72 | cfg, 73 | trans_names, 74 | kernel_size, 75 | use_feature_mask, 76 | ): 77 | super(Transformers, self).__init__() 78 | d_model = cfg.MODEL.TRANSFORMER.DIM_MODEL 79 | self.feature_aug_type = cfg.MODEL.FEATURE_AUG_TYPE 80 | self.use_feature_mask = use_feature_mask 81 | 82 | # If no conv before transformer, we do not use scales 83 | if not cfg.MODEL.TRANSFORMER.USE_PATCH2VEC: 84 | trans_names = ['scale1'] 85 | kernel_size = [(1,1)] 86 | 87 | self.trans_names = trans_names 88 | self.scale_size = len(self.trans_names) 89 | hidden = d_model//(2*self.scale_size) 90 | 91 | # kernel_size: (padding, stride) 92 | kernels = { 93 | (1,1): [(0,0),(1,1)], 94 | (3,3): [(1,1),(1,1)] 95 | } 96 | 97 | padding = [] 98 | stride = [] 99 | for ksize in kernel_size: 100 | if ksize not in [(1,1),(3,3)]: 101 | raise ValueError('Undefined kernel size.') 102 | padding.append(kernels[ksize][0]) 103 | stride.append(kernels[ksize][1]) 104 | 105 | self.use_output_layer = cfg.MODEL.TRANSFORMER.USE_OUTPUT_LAYER 106 | self.use_global_shortcut = cfg.MODEL.TRANSFORMER.USE_GLOBAL_SHORTCUT 107 | 108 | self.blocks = nn.ModuleDict() 109 | for tname, ksize, psize, ssize in zip(self.trans_names, kernel_size, padding, stride): 110 | transblock = Transformer( 111 | cfg, d_model//self.scale_size, ksize, psize, ssize, hidden, use_feature_mask 112 | ) 113 | self.blocks[tname] = nn.Sequential(transblock) 114 | 115 | self.output_linear = nn.Sequential( 116 | nn.Conv2d(d_model, d_model, kernel_size=3, padding=1), 117 | nn.LeakyReLU(0.2, inplace=True) 118 | ) 119 | self.mask_para = [cfg.MODEL.MASK_SHAPE, cfg.MODEL.MASK_SIZE, cfg.MODEL.MASK_MODE] 120 | 121 | def forward(self, inputs): 122 | trans_feat = [] 123 | enc_feat, mask_box = inputs 124 | 125 | if self.training and self.use_feature_mask and self.feature_aug_type == 'exchange_patch': 126 | feature_mask = exchange_patch(self.mask_para[0], self.mask_para[1], self.mask_para[2]) 127 | enc_feat = feature_mask(enc_feat) 128 | 129 | for tname, feat in zip(self.trans_names, torch.chunk(enc_feat, len(self.trans_names), dim=1)): 130 | feat = self.blocks[tname]((feat, mask_box)) 131 | trans_feat.append(feat) 132 | 133 | trans_feat = torch.cat(trans_feat, 1) 134 | if self.use_output_layer: 135 | trans_feat = self.output_linear(trans_feat) 136 | if self.use_global_shortcut: 137 | trans_feat = enc_feat + trans_feat 138 | return trans_feat 139 | 140 | 141 | class Transformer(nn.Module): 142 | def __init__(self, cfg, channel, kernel_size, padding, stride, hidden, use_feature_mask 143 | ): 144 | super(Transformer, self).__init__() 145 | self.k = kernel_size[0] 146 | stack_num = cfg.MODEL.TRANSFORMER.ENCODER_LAYERS 147 | num_head = cfg.MODEL.TRANSFORMER.N_HEAD 148 | dropout = cfg.MODEL.TRANSFORMER.DROPOUT 149 | output_size = (14,14) 150 | token_size = tuple(map(lambda x,y:x//y, output_size, stride)) 151 | blocks = [] 152 | self.transblock = TransformerBlock(token_size, hidden=hidden, num_head=num_head, dropout=dropout) 153 | for _ in range(stack_num): 154 | blocks.append(self.transblock) 155 | self.transformer = nn.Sequential(*blocks) 156 | self.patch2vec = nn.Conv2d(channel, hidden, kernel_size=kernel_size, stride=stride, padding=padding) 157 | self.vec2patch = Vec2Patch(channel, hidden, output_size, kernel_size, stride, padding) 158 | self.use_local_shortcut = cfg.MODEL.TRANSFORMER.USE_LOCAL_SHORTCUT 159 | self.use_feature_mask = use_feature_mask 160 | self.feature_aug_type = cfg.MODEL.FEATURE_AUG_TYPE 161 | self.use_patch2vec = cfg.MODEL.TRANSFORMER.USE_PATCH2VEC 162 | 163 | def forward(self, inputs): 164 | enc_feat, mask_box = inputs 165 | b, c, h, w = enc_feat.size() 166 | 167 | trans_feat = self.patch2vec(enc_feat) 168 | 169 | _, c, h, w = trans_feat.size() 170 | trans_feat = trans_feat.view(b, c, -1).permute(0, 2, 1) 171 | 172 | # For 1x1 & 3x3 kernels, exchange tokens 173 | if self.training and self.use_feature_mask: 174 | if self.feature_aug_type == 'exchange_token': 175 | feature_mask = exchange_token() 176 | trans_feat = feature_mask(trans_feat, mask_box) 177 | elif self.feature_aug_type == 'cutout_patch': 178 | feature_mask = cutout_patch() 179 | trans_feat = feature_mask(trans_feat) 180 | elif self.feature_aug_type == 'erase_patch': 181 | feature_mask = erase_patch() 182 | trans_feat = feature_mask(trans_feat) 183 | elif self.feature_aug_type == 'mixup_patch': 184 | feature_mask = mixup_patch() 185 | trans_feat = feature_mask(trans_feat) 186 | 187 | if self.use_feature_mask: 188 | if self.feature_aug_type == 'jigsaw_patch': 189 | feature_mask = jigsaw_patch() 190 | trans_feat = feature_mask(trans_feat) 191 | elif self.feature_aug_type == 'jigsaw_token': 192 | feature_mask = jigsaw_token() 193 | trans_feat = feature_mask(trans_feat) 194 | 195 | trans_feat = self.transformer(trans_feat) 196 | trans_feat = self.vec2patch(trans_feat) 197 | if self.use_local_shortcut: 198 | trans_feat = enc_feat + trans_feat 199 | 200 | return trans_feat 201 | 202 | 203 | class TransformerBlock(nn.Module): 204 | """ 205 | Transformer = MultiHead_Attention + Feed_Forward with sublayer connection 206 | """ 207 | def __init__(self, tokensize, hidden=128, num_head=4, dropout=0.1): 208 | super().__init__() 209 | self.attention = MultiHeadedAttention(tokensize, d_model=hidden, head=num_head, p=dropout) 210 | self.ffn = FeedForward(hidden, p=dropout) 211 | self.norm1 = nn.LayerNorm(hidden) 212 | self.norm2 = nn.LayerNorm(hidden) 213 | self.dropout = nn.Dropout(p=dropout) 214 | 215 | def forward(self, x): 216 | x = self.norm1(x) 217 | x = x + self.dropout(self.attention(x)) 218 | y = self.norm2(x) 219 | x = x + self.ffn(y) 220 | 221 | return x 222 | 223 | 224 | class Attention(nn.Module): 225 | """ 226 | Compute 'Scaled Dot Product Attention 227 | """ 228 | def __init__(self, p=0.1): 229 | super(Attention, self).__init__() 230 | self.dropout = nn.Dropout(p=p) 231 | 232 | def forward(self, query, key, value): 233 | scores = torch.matmul(query, key.transpose(-2, -1) 234 | ) / math.sqrt(query.size(-1)) 235 | p_attn = F.softmax(scores, dim=-1) 236 | p_attn = self.dropout(p_attn) 237 | p_val = torch.matmul(p_attn, value) 238 | return p_val, p_attn 239 | 240 | 241 | class Vec2Patch(nn.Module): 242 | def __init__(self, channel, hidden, output_size, kernel_size, stride, padding): 243 | super(Vec2Patch, self).__init__() 244 | self.relu = nn.LeakyReLU(0.2, inplace=True) 245 | c_out = reduce((lambda x, y: x * y), kernel_size) * channel 246 | self.embedding = nn.Linear(hidden, c_out) 247 | self.to_patch = torch.nn.Fold(output_size=output_size, kernel_size=kernel_size, stride=stride, padding=padding) 248 | h, w = output_size 249 | 250 | def forward(self, x): 251 | feat = self.embedding(x) 252 | b, n, c = feat.size() 253 | feat = feat.permute(0, 2, 1) 254 | feat = self.to_patch(feat) 255 | 256 | return feat 257 | 258 | class MultiHeadedAttention(nn.Module): 259 | """ 260 | Take in model size and number of heads. 261 | """ 262 | def __init__(self, tokensize, d_model, head, p=0.1): 263 | super().__init__() 264 | self.query_embedding = nn.Linear(d_model, d_model) 265 | self.value_embedding = nn.Linear(d_model, d_model) 266 | self.key_embedding = nn.Linear(d_model, d_model) 267 | self.output_linear = nn.Linear(d_model, d_model) 268 | self.attention = Attention(p=p) 269 | self.head = head 270 | self.h, self.w = tokensize 271 | 272 | def forward(self, x): 273 | b, n, c = x.size() 274 | c_h = c // self.head 275 | key = self.key_embedding(x) 276 | query = self.query_embedding(x) 277 | value = self.value_embedding(x) 278 | key = key.view(b, n, self.head, c_h).permute(0, 2, 1, 3) 279 | query = query.view(b, n, self.head, c_h).permute(0, 2, 1, 3) 280 | value = value.view(b, n, self.head, c_h).permute(0, 2, 1, 3) 281 | att, _ = self.attention(query, key, value) 282 | att = att.permute(0, 2, 1, 3).contiguous().view(b, n, c) 283 | output = self.output_linear(att) 284 | 285 | return output 286 | 287 | 288 | class FeedForward(nn.Module): 289 | def __init__(self, d_model, p=0.1): 290 | super(FeedForward, self).__init__() 291 | self.conv = nn.Sequential( 292 | nn.Linear(d_model, d_model * 4), 293 | nn.ReLU(inplace=True), 294 | nn.Dropout(p=p), 295 | nn.Linear(d_model * 4, d_model), 296 | nn.Dropout(p=p)) 297 | 298 | def forward(self, x): 299 | x = self.conv(x) 300 | return x 301 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # This file is part of COAT, and is distributed under the 2 | # OSI-approved BSD 3-Clause License. See top-level LICENSE file or 3 | # https://github.com/Kitware/COAT/blob/master/LICENSE for details. 4 | 5 | import argparse 6 | import datetime 7 | import os.path as osp 8 | import time 9 | 10 | import torch 11 | import torch.utils.data 12 | 13 | from datasets import build_test_loader, build_train_loader 14 | from defaults import get_default_cfg 15 | from engine import evaluate_performance, train_one_epoch 16 | from models.coat import COAT 17 | from utils.utils import mkdir, resume_from_ckpt, save_on_master, set_random_seed 18 | 19 | from loss.softmax_loss import SoftmaxLoss 20 | 21 | 22 | def main(args): 23 | cfg = get_default_cfg() 24 | if args.cfg_file: 25 | cfg.merge_from_file(args.cfg_file) 26 | cfg.merge_from_list(args.opts) 27 | cfg.freeze() 28 | 29 | device = torch.device(cfg.DEVICE) 30 | if cfg.SEED >= 0: 31 | set_random_seed(cfg.SEED) 32 | 33 | print("Creating model...") 34 | model = COAT(cfg) 35 | model.to(device) 36 | 37 | print("Loading data...") 38 | train_loader = build_train_loader(cfg) 39 | gallery_loader, query_loader = build_test_loader(cfg) 40 | 41 | softmax_criterion_s2 = None 42 | softmax_criterion_s3 = None 43 | if cfg.MODEL.LOSS.USE_SOFTMAX: 44 | softmax_criterion_s2 = SoftmaxLoss(cfg) 45 | softmax_criterion_s3 = SoftmaxLoss(cfg) 46 | softmax_criterion_s2.to(device) 47 | softmax_criterion_s3.to(device) 48 | 49 | if args.eval: 50 | assert args.ckpt, "--ckpt must be specified when --eval enabled" 51 | resume_from_ckpt(args.ckpt, model) 52 | evaluate_performance( 53 | model, 54 | gallery_loader, 55 | query_loader, 56 | device, 57 | use_gt=cfg.EVAL_USE_GT, 58 | use_cache=cfg.EVAL_USE_CACHE, 59 | use_cbgm=cfg.EVAL_USE_CBGM, 60 | gallery_size=cfg.EVAL_GALLERY_SIZE, 61 | ) 62 | exit(0) 63 | 64 | params = [p for p in model.parameters() if p.requires_grad] 65 | if cfg.MODEL.LOSS.USE_SOFTMAX: 66 | params_softmax_s2 = [p for p in softmax_criterion_s2.parameters() if p.requires_grad] 67 | params_softmax_s3 = [p for p in softmax_criterion_s3.parameters() if p.requires_grad] 68 | params.extend(params_softmax_s2) 69 | params.extend(params_softmax_s3) 70 | 71 | optimizer = torch.optim.SGD( 72 | params, 73 | lr=cfg.SOLVER.BASE_LR, 74 | momentum=cfg.SOLVER.SGD_MOMENTUM, 75 | weight_decay=cfg.SOLVER.WEIGHT_DECAY, 76 | ) 77 | 78 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( 79 | optimizer, milestones=cfg.SOLVER.LR_DECAY_MILESTONES, gamma=cfg.SOLVER.GAMMA 80 | ) 81 | 82 | start_epoch = 0 83 | if args.resume: 84 | assert args.ckpt, "--ckpt must be specified when --resume enabled" 85 | start_epoch = resume_from_ckpt(args.ckpt, model, optimizer, lr_scheduler) + 1 86 | 87 | print("Creating output folder...") 88 | output_dir = cfg.OUTPUT_DIR 89 | mkdir(output_dir) 90 | path = osp.join(output_dir, "config.yaml") 91 | with open(path, "w") as f: 92 | f.write(cfg.dump()) 93 | print(f"Full config is saved to {path}") 94 | tfboard = None 95 | if cfg.TF_BOARD: 96 | from torch.utils.tensorboard import SummaryWriter 97 | 98 | tf_log_path = osp.join(output_dir, "tf_log") 99 | mkdir(tf_log_path) 100 | tfboard = SummaryWriter(log_dir=tf_log_path) 101 | print(f"TensorBoard files are saved to {tf_log_path}") 102 | 103 | print("Start training...") 104 | start_time = time.time() 105 | for epoch in range(start_epoch, cfg.SOLVER.MAX_EPOCHS): 106 | train_one_epoch(cfg, model, optimizer, train_loader, device, epoch, tfboard, softmax_criterion_s2, softmax_criterion_s3) 107 | lr_scheduler.step() 108 | 109 | # only save the last three checkpoints 110 | if epoch >= cfg.SOLVER.MAX_EPOCHS - 3: 111 | save_on_master( 112 | { 113 | "model": model.state_dict(), 114 | "optimizer": optimizer.state_dict(), 115 | "lr_scheduler": lr_scheduler.state_dict(), 116 | "epoch": epoch, 117 | }, 118 | osp.join(output_dir, f"epoch_{epoch}.pth"), 119 | ) 120 | 121 | # evaluate the current checkpoint 122 | evaluate_performance( 123 | model, 124 | gallery_loader, 125 | query_loader, 126 | device, 127 | use_gt=cfg.EVAL_USE_GT, 128 | use_cache=cfg.EVAL_USE_CACHE, 129 | use_cbgm=cfg.EVAL_USE_CBGM, 130 | gallery_size=cfg.EVAL_GALLERY_SIZE, 131 | ) 132 | 133 | if tfboard: 134 | tfboard.close() 135 | total_time = time.time() - start_time 136 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 137 | print(f"Total training time {total_time_str}") 138 | 139 | 140 | if __name__ == "__main__": 141 | parser = argparse.ArgumentParser(description="Train a person search network.") 142 | parser.add_argument("--cfg", dest="cfg_file", help="Path to configuration file.") 143 | parser.add_argument( 144 | "--eval", action="store_true", help="Evaluate the performance of a given checkpoint." 145 | ) 146 | parser.add_argument( 147 | "--resume", action="store_true", help="Resume from the specified checkpoint." 148 | ) 149 | parser.add_argument("--ckpt", help="Path to checkpoint to resume or evaluate.") 150 | parser.add_argument( 151 | "opts", nargs=argparse.REMAINDER, help="Modify config options using the command-line" 152 | ) 153 | args = parser.parse_args() 154 | main(args) 155 | -------------------------------------------------------------------------------- /utils/__pycache__/km.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kitware/COAT/369d93418716a4a28a96203ebd012020877db550/utils/__pycache__/km.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/mask.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kitware/COAT/369d93418716a4a28a96203ebd012020877db550/utils/__pycache__/mask.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kitware/COAT/369d93418716a4a28a96203ebd012020877db550/utils/__pycache__/transforms.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kitware/COAT/369d93418716a4a28a96203ebd012020877db550/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/km.py: -------------------------------------------------------------------------------- 1 | # This file is part of COAT, and is distributed under the 2 | # OSI-approved BSD 3-Clause License. See top-level LICENSE file or 3 | # https://github.com/Kitware/COAT/blob/master/LICENSE for details. 4 | 5 | import random 6 | import numpy as np 7 | 8 | zero_threshold = 0.00000001 9 | 10 | class KMNode(object): 11 | def __init__(self, id, exception=0, match=None, visit=False): 12 | self.id = id 13 | self.exception = exception 14 | self.match = match 15 | self.visit = visit 16 | 17 | 18 | class KuhnMunkres(object): 19 | def __init__(self): 20 | self.matrix = None 21 | self.x_nodes = [] 22 | self.y_nodes = [] 23 | self.minz = float("inf") 24 | self.x_length = 0 25 | self.y_length = 0 26 | self.index_x = 0 27 | self.index_y = 1 28 | 29 | def __del__(self): 30 | pass 31 | 32 | def set_matrix(self, x_y_values): 33 | xs = set() 34 | ys = set() 35 | for x, y, value in x_y_values: 36 | xs.add(x) 37 | ys.add(y) 38 | 39 | if len(xs) < len(ys): 40 | self.index_x = 0 41 | self.index_y = 1 42 | else: 43 | self.index_x = 1 44 | self.index_y = 0 45 | xs, ys = ys, xs 46 | 47 | x_dic = {x: i for i, x in enumerate(xs)} 48 | y_dic = {y: j for j, y in enumerate(ys)} 49 | self.x_nodes = [KMNode(x) for x in xs] 50 | self.y_nodes = [KMNode(y) for y in ys] 51 | self.x_length = len(xs) 52 | self.y_length = len(ys) 53 | 54 | self.matrix = np.zeros((self.x_length, self.y_length)) 55 | for row in x_y_values: 56 | x = row[self.index_x] 57 | y = row[self.index_y] 58 | value = row[2] 59 | x_index = x_dic[x] 60 | y_index = y_dic[y] 61 | self.matrix[x_index, y_index] = value 62 | 63 | for i in range(self.x_length): 64 | self.x_nodes[i].exception = max(self.matrix[i, :]) 65 | 66 | def km(self): 67 | for i in range(self.x_length): 68 | while True: 69 | self.minz = float("inf") 70 | self.set_false(self.x_nodes) 71 | self.set_false(self.y_nodes) 72 | 73 | if self.dfs(i): 74 | break 75 | 76 | self.change_exception(self.x_nodes, -self.minz) 77 | self.change_exception(self.y_nodes, self.minz) 78 | 79 | def dfs(self, i): 80 | x_node = self.x_nodes[i] 81 | x_node.visit = True 82 | for j in range(self.y_length): 83 | y_node = self.y_nodes[j] 84 | if not y_node.visit: 85 | t = x_node.exception + y_node.exception - self.matrix[i][j] 86 | if abs(t) < zero_threshold: 87 | y_node.visit = True 88 | if y_node.match is None or self.dfs(y_node.match): 89 | x_node.match = j 90 | y_node.match = i 91 | return True 92 | else: 93 | if t >= zero_threshold: 94 | self.minz = min(self.minz, t) 95 | return False 96 | 97 | def set_false(self, nodes): 98 | for node in nodes: 99 | node.visit = False 100 | 101 | def change_exception(self, nodes, change): 102 | for node in nodes: 103 | if node.visit: 104 | node.exception += change 105 | 106 | def get_connect_result(self): 107 | ret = [] 108 | for i in range(self.x_length): 109 | x_node = self.x_nodes[i] 110 | j = x_node.match 111 | y_node = self.y_nodes[j] 112 | x_id = x_node.id 113 | y_id = y_node.id 114 | value = self.matrix[i][j] 115 | 116 | if self.index_x == 1 and self.index_y == 0: 117 | x_id, y_id = y_id, x_id 118 | ret.append((x_id, y_id, value)) 119 | 120 | return ret 121 | 122 | def get_max_value_result(self): 123 | ret = -100 124 | for i in range(self.x_length): 125 | j = self.x_nodes[i].match 126 | ret = max(ret, self.matrix[i][j]) 127 | return ret 128 | 129 | 130 | def run_kuhn_munkres(x_y_values): 131 | process = KuhnMunkres() 132 | process.set_matrix(x_y_values) 133 | process.km() 134 | return process.get_connect_result(), process.get_max_value_result() 135 | 136 | 137 | def test(): 138 | values = [] 139 | random.seed(0) 140 | for i in range(500): 141 | for j in range(1000): 142 | value = random.random() 143 | values.append((i, j, value)) 144 | 145 | return run_kuhn_munkres(values) 146 | 147 | 148 | if __name__ == "__main__": 149 | values = [(1, 1, 3), (1, 3, 4), (2, 1, 2), (2, 2, 1), (2, 3, 3), (3, 2, 4), (3, 3, 5)] 150 | print(run_kuhn_munkres(values)) 151 | -------------------------------------------------------------------------------- /utils/mask.py: -------------------------------------------------------------------------------- 1 | # This file is part of COAT, and is distributed under the 2 | # OSI-approved BSD 3-Clause License. See top-level LICENSE file or 3 | # https://github.com/Kitware/COAT/blob/master/LICENSE for details. 4 | 5 | import random 6 | import torch 7 | 8 | class exchange_token: 9 | def __init__(self): 10 | pass 11 | 12 | def __call__(self, features, mask_box): 13 | b, hw, c = features.size() 14 | assert hw == 14*14 15 | new_idx, mask_x1, mask_x2, mask_y1, mask_y2 = mask_box 16 | features = features.view(b, 14, 14, c) 17 | features[:, mask_x1 : mask_x2, mask_y1 : mask_y2, :] = features[new_idx, mask_x1 : mask_x2, mask_y1 : mask_y2, :] 18 | features = features.view(b, hw, c) 19 | return features 20 | 21 | class jigsaw_token: 22 | def __init__(self, shift=5, group=2, begin=1): 23 | self.shift = shift 24 | self.group = group 25 | self.begin = begin 26 | 27 | def __call__(self, features): 28 | batchsize = features.size(0) 29 | dim = features.size(2) 30 | 31 | num_tokens = features.size(1) 32 | if num_tokens == 196: 33 | self.group = 2 34 | elif num_tokens == 25: 35 | self.group = 5 36 | else: 37 | raise Exception("Jigsaw - Unwanted number of tokens") 38 | 39 | # Shift Operation 40 | feature_random = torch.cat([features[:, self.begin-1+self.shift:, :], features[:, self.begin-1:self.begin-1+self.shift, :]], dim=1) 41 | x = feature_random 42 | 43 | # Patch Shuffle Operation 44 | try: 45 | x = x.view(batchsize, self.group, -1, dim) 46 | except: 47 | raise Exception("Jigsaw - Unwanted number of groups") 48 | 49 | x = torch.transpose(x, 1, 2).contiguous() 50 | x = x.view(batchsize, -1, dim) 51 | 52 | return x 53 | 54 | class get_mask_box: 55 | def __init__(self, shape='stripe', mask_size=2, mode='random_direct'): 56 | self.shape = shape 57 | self.mask_size = mask_size 58 | self.mode = mode 59 | 60 | def __call__(self, features): 61 | # Stripe mask 62 | if self.shape == 'stripe': 63 | if self.mode == 'horizontal': 64 | mask_box = self.hstripe(features, self.mask_size) 65 | elif self.mode == 'vertical': 66 | mask_box = self.vstripe(features, self.mask_size) 67 | elif self.mode == 'random_direction': 68 | if random.random() < 0.5: 69 | mask_box = self.hstripe(features, self.mask_size) 70 | else: 71 | mask_box = self.vstripe(features, self.mask_size) 72 | else: 73 | raise Exception("Unknown stripe mask mode name") 74 | # Square mask 75 | elif self.shape == 'square': 76 | if self.mode == 'random_size': 77 | self.mask_size = 4 if random.random() < 0.5 else 5 78 | mask_box = self.square(features, self.mask_size) 79 | # Random stripe/square mask 80 | elif self.shape == 'random': 81 | random_num = random.random() 82 | if random_num < 0.25: 83 | mask_box = self.hstripe(features, 2) 84 | elif random_num < 0.5 and random_num >= 0.25: 85 | mask_box = self.vstripe(features, 2) 86 | elif random_num < 0.75 and random_num >= 0.5: 87 | mask_box = self.square(features, 4) 88 | else: 89 | mask_box = self.square(features, 5) 90 | else: 91 | raise Exception("Unknown mask shape name") 92 | return mask_box 93 | 94 | def hstripe(self, features, mask_size): 95 | """ 96 | """ 97 | # horizontal stripe 98 | mask_x1 = 0 99 | mask_x2 = features.shape[2] 100 | y1_max = features.shape[3] - mask_size 101 | mask_y1 = torch.randint(y1_max, (1,)) 102 | mask_y2 = mask_y1 + mask_size 103 | new_idx = torch.randperm(features.shape[0]) 104 | mask_box = (new_idx, mask_x1, mask_x2, mask_y1, mask_y2) 105 | return mask_box 106 | 107 | def vstripe(self, features, mask_size): 108 | """ 109 | """ 110 | # vertical stripe 111 | mask_y1 = 0 112 | mask_y2 = features.shape[3] 113 | x1_max = features.shape[2] - mask_size 114 | mask_x1 = torch.randint(x1_max, (1,)) 115 | mask_x2 = mask_x1 + mask_size 116 | new_idx = torch.randperm(features.shape[0]) 117 | mask_box = (new_idx, mask_x1, mask_x2, mask_y1, mask_y2) 118 | return mask_box 119 | 120 | def square(self, features, mask_size): 121 | """ 122 | """ 123 | # square 124 | x1_max = features.shape[2] - mask_size 125 | y1_max = features.shape[3] - mask_size 126 | mask_x1 = torch.randint(x1_max, (1,)) 127 | mask_y1 = torch.randint(y1_max, (1,)) 128 | mask_x2 = mask_x1 + mask_size 129 | mask_y2 = mask_y1 + mask_size 130 | new_idx = torch.randperm(features.shape[0]) 131 | mask_box = (new_idx, mask_x1, mask_x2, mask_y1, mask_y2) 132 | return mask_box 133 | 134 | 135 | class exchange_patch: 136 | def __init__(self, shape='stripe', mask_size=2, mode='random_direct'): 137 | self.shape = shape 138 | self.mask_size = mask_size 139 | self.mode = mode 140 | 141 | def __call__(self, features): 142 | # Stripe mask 143 | if self.shape == 'stripe': 144 | if self.mode == 'horizontal': 145 | features = self.xpatch_hstripe(features, self.mask_size) 146 | elif self.mode == 'vertical': 147 | features = self.xpatch_vstripe(features, self.mask_size) 148 | elif self.mode == 'random_direction': 149 | if random.random() < 0.5: 150 | features = self.xpatch_hstripe(features, self.mask_size) 151 | else: 152 | features = self.xpatch_vstripe(features, self.mask_size) 153 | else: 154 | raise Exception("Unknown stripe mask mode name") 155 | # Square mask 156 | elif self.shape == 'square': 157 | if self.mode == 'random_size': 158 | self.mask_size = 4 if random.random() < 0.5 else 5 159 | features = self.xpatch_square(features, self.mask_size) 160 | # Random stripe/square mask 161 | elif self.shape == 'random': 162 | random_num = random.random() 163 | if random_num < 0.25: 164 | features = self.xpatch_hstripe(features, 2) 165 | elif random_num < 0.5 and random_num >= 0.25: 166 | features = self.xpatch_vstripe(features, 2) 167 | elif random_num < 0.75 and random_num >= 0.5: 168 | features = self.xpatch_square(features, 4) 169 | else: 170 | features = self.xpatch_square(features, 5) 171 | else: 172 | raise Exception("Unknown mask shape name") 173 | 174 | return features 175 | 176 | def xpatch_hstripe(self, features, mask_size): 177 | """ 178 | """ 179 | # horizontal stripe 180 | y1_max = features.shape[3] - mask_size 181 | num_masks = 1 182 | for i in range(num_masks): 183 | mask_y1 = torch.randint(y1_max, (1,)) 184 | mask_y2 = mask_y1 + mask_size 185 | new_idx = torch.randperm(features.shape[0]) 186 | features[:, :, :, mask_y1 : mask_y2] = features[new_idx, :, :, mask_y1 : mask_y2] 187 | return features 188 | 189 | 190 | def xpatch_vstripe(self, features, mask_size): 191 | """ 192 | """ 193 | # vertical stripe 194 | x1_max = features.shape[2] - mask_size 195 | num_masks = 1 196 | for i in range(num_masks): 197 | mask_x1 = torch.randint(x1_max, (1,)) 198 | mask_x2 = mask_x1 + mask_size 199 | new_idx = torch.randperm(features.shape[0]) 200 | features[:, :, mask_x1 : mask_x2, :] = features[new_idx, :, mask_x1 : mask_x2, :] 201 | return features 202 | 203 | 204 | def xpatch_square(self, features, mask_size): 205 | """ 206 | """ 207 | # square 208 | x1_max = features.shape[2] - mask_size 209 | y1_max = features.shape[3] - mask_size 210 | num_masks = 1 211 | for i in range(num_masks): 212 | mask_x1 = torch.randint(x1_max, (1,)) 213 | mask_y1 = torch.randint(y1_max, (1,)) 214 | mask_x2 = mask_x1 + mask_size 215 | mask_y2 = mask_y1 + mask_size 216 | new_idx = torch.randperm(features.shape[0]) 217 | features[:, :, mask_x1 : mask_x2, mask_y1 : mask_y2] = features[new_idx, :, mask_x1 : mask_x2, mask_y1 : mask_y2] 218 | return features 219 | 220 | 221 | class cutout_patch: 222 | def __init__(self, mask_size=2): 223 | self.mask_size = mask_size 224 | 225 | def __call__(self, features): 226 | if random.random() < 0.5: 227 | y1_max = features.shape[3] - self.mask_size 228 | num_masks = 1 229 | for i in range(num_masks): 230 | mask_y1 = torch.randint(y1_max, (features.shape[0],)) 231 | mask_y2 = mask_y1 + self.mask_size 232 | for k in range(features.shape[0]): 233 | features[k, :, :, mask_y1[k] : mask_y2[k]] = 0 234 | else: 235 | x1_max = features.shape[3] - self.mask_size 236 | num_masks = 1 237 | for i in range(num_masks): 238 | mask_x1 = torch.randint(x1_max, (features.shape[0],)) 239 | mask_x2 = mask_x1 + self.mask_size 240 | for k in range(features.shape[0]): 241 | features[k, :, mask_x1[k] : mask_x2[k], :] = 0 242 | 243 | return features 244 | 245 | 246 | class erase_patch: 247 | def __init__(self, mask_size=2): 248 | self.mask_size = mask_size 249 | 250 | def __call__(self, features): 251 | std, mean = torch.std_mean(features.detach()) 252 | dim = features.shape[1] 253 | if random.random() < 0.5: 254 | y1_max = features.shape[3] - self.mask_size 255 | num_masks = 1 256 | for i in range(num_masks): 257 | mask_y1 = torch.randint(y1_max, (features.shape[0],)) 258 | mask_y2 = mask_y1 + self.mask_size 259 | for k in range(features.shape[0]): 260 | features[k, :, :, mask_y1[k] : mask_y2[k]] = torch.normal(mean.repeat(dim,14,2), std.repeat(dim,14,2)) 261 | else: 262 | x1_max = features.shape[3] - self.mask_size 263 | num_masks = 1 264 | for i in range(num_masks): 265 | mask_x1 = torch.randint(x1_max, (features.shape[0],)) 266 | mask_x2 = mask_x1 + self.mask_size 267 | for k in range(features.shape[0]): 268 | features[k, :, mask_x1[k] : mask_x2[k], :] = torch.normal(mean.repeat(dim,2,14), std.repeat(dim,2,14)) 269 | 270 | return features 271 | 272 | class mixup_patch: 273 | def __init__(self, mask_size=2): 274 | self.mask_size = mask_size 275 | 276 | def __call__(self, features): 277 | lam = random.uniform(0, 1) 278 | if random.random() < 0.5: 279 | y1_max = features.shape[3] - self.mask_size 280 | num_masks = 1 281 | for i in range(num_masks): 282 | mask_y1 = torch.randint(y1_max, (1,)) 283 | mask_y2 = mask_y1 + self.mask_size 284 | new_idx = torch.randperm(features.shape[0]) 285 | features[:, :, :, mask_y1 : mask_y2] = lam*features[:, :, :, mask_y1 : mask_y2] + (1-lam)*features[new_idx, :, :, mask_y1 : mask_y2] 286 | else: 287 | x1_max = features.shape[2] - self.mask_size 288 | num_masks = 1 289 | for i in range(num_masks): 290 | mask_x1 = torch.randint(x1_max, (1,)) 291 | mask_x2 = mask_x1 + self.mask_size 292 | new_idx = torch.randperm(features.shape[0]) 293 | features[:, :, mask_x1 : mask_x2, :] = lam*features[:, :, mask_x1 : mask_x2, :] + (1-lam)*features[new_idx, :, mask_x1 : mask_x2, :] 294 | 295 | return features 296 | 297 | 298 | class jigsaw_patch: 299 | def __init__(self, shift=5, group=2): 300 | self.shift = shift 301 | self.group = group 302 | 303 | def __call__(self, features): 304 | batchsize = features.size(0) 305 | dim = features.size(1) 306 | features = features.view(batchsize, dim, -1) 307 | 308 | # Shift Operation 309 | feature_random = torch.cat([features[:, :, self.shift:], features[:, :, :self.shift]], dim=2) 310 | x = feature_random 311 | 312 | # Patch Shuffle Operation 313 | try: 314 | x = x.view(batchsize, dim, self.group, -1) 315 | except: 316 | x = torch.cat([x, x[:, -2:-1, :]], dim=1) 317 | x = x.view(batchsize, self.group, -1, dim) 318 | 319 | x = torch.transpose(x, 2, 3).contiguous() 320 | 321 | x = x.view(batchsize, dim, -1) 322 | x = x.view(batchsize, dim, 14, 14) 323 | 324 | return x 325 | 326 | -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | # This file is part of COAT, and is distributed under the 2 | # OSI-approved BSD 3-Clause License. See top-level LICENSE file or 3 | # https://github.com/Kitware/COAT/blob/master/LICENSE for details. 4 | 5 | import random 6 | import math 7 | import torch 8 | import numpy as np 9 | from copy import deepcopy 10 | from torchvision.transforms import functional as F 11 | 12 | def mixup_data(images, alpha=0.8): 13 | if alpha > 0. and alpha < 1.: 14 | lam = random.uniform(alpha, 1) 15 | else: 16 | lam = 1. 17 | 18 | batch_size = len(images) 19 | min_x = 9999 20 | min_y = 9999 21 | for i in range(batch_size): 22 | min_x = min(min_x, images[i].shape[1]) 23 | min_y = min(min_y, images[i].shape[2]) 24 | 25 | shuffle_images = deepcopy(images) 26 | random.shuffle(shuffle_images) 27 | mixed_images = deepcopy(images) 28 | for i in range(batch_size): 29 | mixed_images[i][:, :min_x, :min_y] = lam * images[i][:, :min_x, :min_y] + (1 - lam) * shuffle_images[i][:, :min_x, :min_y] 30 | 31 | return mixed_images 32 | 33 | class Compose: 34 | def __init__(self, transforms): 35 | self.transforms = transforms 36 | 37 | def __call__(self, image, target): 38 | for t in self.transforms: 39 | image, target = t(image, target) 40 | return image, target 41 | 42 | 43 | class RandomHorizontalFlip: 44 | def __init__(self, prob=0.5): 45 | self.prob = prob 46 | 47 | def __call__(self, image, target): 48 | if random.random() < self.prob: 49 | height, width = image.shape[-2:] 50 | image = image.flip(-1) 51 | bbox = target["boxes"] 52 | bbox[:, [0, 2]] = width - bbox[:, [2, 0]] 53 | target["boxes"] = bbox 54 | return image, target 55 | 56 | class Cutout(object): 57 | """Randomly mask out one or more patches from an image. 58 | https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py 59 | Args: 60 | n_holes (int): Number of patches to cut out of each image. 61 | length (int): The length (in pixels) of each square patch. 62 | """ 63 | def __init__(self, n_holes=2, length=100): 64 | self.n_holes = n_holes 65 | self.length = length 66 | 67 | def __call__(self, img, target): 68 | """ 69 | Args: 70 | img (Tensor): Tensor image of size (C, H, W). 71 | Returns: 72 | Tensor: Image with n_holes of dimension length x length cut out of it. 73 | """ 74 | h = img.size(1) 75 | w = img.size(2) 76 | mask = np.ones((h, w), np.float32) 77 | 78 | for n in range(self.n_holes): 79 | y = np.random.randint(h) 80 | x = np.random.randint(w) 81 | y1 = np.clip(y - self.length // 2, 0, h) 82 | y2 = np.clip(y + self.length // 2, 0, h) 83 | x1 = np.clip(x - self.length // 2, 0, w) 84 | x2 = np.clip(x + self.length // 2, 0, w) 85 | mask[y1: y2, x1: x2] = 0. 86 | 87 | mask = torch.from_numpy(mask) 88 | mask = mask.expand_as(img) 89 | img = img * mask 90 | 91 | return img, target 92 | 93 | 94 | class RandomErasing(object): 95 | ''' 96 | https://github.com/zhunzhong07/CamStyle/blob/master/reid/utils/data/transforms.py 97 | ''' 98 | def __init__(self, EPSILON=0.5, mean=[0.485, 0.456, 0.406]): 99 | self.EPSILON = EPSILON 100 | self.mean = mean 101 | 102 | def __call__(self, img, target): 103 | if random.uniform(0, 1) > self.EPSILON: 104 | return img, target 105 | 106 | for attempt in range(100): 107 | area = img.size()[1] * img.size()[2] 108 | 109 | target_area = random.uniform(0.02, 0.2) * area 110 | aspect_ratio = random.uniform(0.3, 3) 111 | 112 | h = int(round(math.sqrt(target_area * aspect_ratio))) 113 | w = int(round(math.sqrt(target_area / aspect_ratio))) 114 | 115 | if w <= img.size()[2] and h <= img.size()[1]: 116 | x1 = random.randint(0, img.size()[1] - h) 117 | y1 = random.randint(0, img.size()[2] - w) 118 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 119 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 120 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 121 | 122 | return img, target 123 | 124 | return img, target 125 | 126 | 127 | class ToTensor: 128 | def __call__(self, image, target): 129 | # convert [0, 255] to [0, 1] 130 | image = F.to_tensor(image) 131 | return image, target 132 | 133 | 134 | def build_transforms(cfg, is_train): 135 | transforms = [] 136 | transforms.append(ToTensor()) 137 | if is_train: 138 | transforms.append(RandomHorizontalFlip()) 139 | if cfg.INPUT.IMAGE_CUTOUT: 140 | transforms.append(Cutout()) 141 | if cfg.INPUT.IMAGE_ERASE: 142 | transforms.append(RandomErasing()) 143 | 144 | return Compose(transforms) 145 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # This file is part of COAT, and is distributed under the 2 | # OSI-approved BSD 3-Clause License. See top-level LICENSE file or 3 | # https://github.com/Kitware/COAT/blob/master/LICENSE for details. 4 | 5 | import datetime 6 | import errno 7 | import json 8 | import os 9 | import os.path as osp 10 | import pickle 11 | import random 12 | import time 13 | from collections import defaultdict, deque 14 | 15 | import numpy as np 16 | import torch 17 | import torch.distributed as dist 18 | from tabulate import tabulate 19 | 20 | 21 | # -------------------------------------------------------- # 22 | # Logger # 23 | # -------------------------------------------------------- # 24 | class SmoothedValue(object): 25 | """ 26 | Track a series of values and provide access to smoothed values over a 27 | window or the global series average. 28 | """ 29 | 30 | def __init__(self, window_size=20, fmt=None): 31 | if fmt is None: 32 | fmt = "{median:.4f} ({global_avg:.4f})" 33 | self.deque = deque(maxlen=window_size) 34 | self.total = 0.0 35 | self.count = 0 36 | self.fmt = fmt 37 | 38 | def update(self, value, n=1): 39 | self.deque.append(value) 40 | self.count += n 41 | self.total += value * n 42 | 43 | def synchronize_between_processes(self): 44 | """ 45 | Warning: does not synchronize the deque! 46 | """ 47 | if not is_dist_avail_and_initialized(): 48 | return 49 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 50 | dist.barrier() 51 | dist.all_reduce(t) 52 | t = t.tolist() 53 | self.count = int(t[0]) 54 | self.total = t[1] 55 | 56 | @property 57 | def median(self): 58 | d = torch.tensor(list(self.deque)) 59 | return d.median().item() 60 | 61 | @property 62 | def avg(self): 63 | d = torch.tensor(list(self.deque), dtype=torch.float32) 64 | return d.mean().item() 65 | 66 | @property 67 | def global_avg(self): 68 | return self.total / self.count 69 | 70 | @property 71 | def max(self): 72 | return max(self.deque) 73 | 74 | @property 75 | def value(self): 76 | return self.deque[-1] 77 | 78 | def __str__(self): 79 | return self.fmt.format( 80 | median=self.median, 81 | avg=self.avg, 82 | global_avg=self.global_avg, 83 | max=self.max, 84 | value=self.value, 85 | ) 86 | 87 | 88 | class MetricLogger(object): 89 | def __init__(self, delimiter="\t"): 90 | self.meters = defaultdict(SmoothedValue) 91 | self.delimiter = delimiter 92 | 93 | def update(self, **kwargs): 94 | for k, v in kwargs.items(): 95 | if isinstance(v, torch.Tensor): 96 | v = v.item() 97 | assert isinstance(v, (float, int)) 98 | self.meters[k].update(v) 99 | 100 | def __getattr__(self, attr): 101 | if attr in self.meters: 102 | return self.meters[attr] 103 | if attr in self.__dict__: 104 | return self.__dict__[attr] 105 | raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) 106 | 107 | def __str__(self): 108 | loss_str = [] 109 | for name, meter in self.meters.items(): 110 | loss_str.append("{}: {}".format(name, str(meter))) 111 | return self.delimiter.join(loss_str) 112 | 113 | def synchronize_between_processes(self): 114 | for meter in self.meters.values(): 115 | meter.synchronize_between_processes() 116 | 117 | def add_meter(self, name, meter): 118 | self.meters[name] = meter 119 | 120 | def log_every(self, iterable, print_freq, header=None): 121 | i = 0 122 | if not header: 123 | header = "" 124 | start_time = time.time() 125 | end = time.time() 126 | iter_time = SmoothedValue(fmt="{avg:.4f}") 127 | data_time = SmoothedValue(fmt="{avg:.4f}") 128 | space_fmt = ":" + str(len(str(len(iterable)))) + "d" 129 | if torch.cuda.is_available(): 130 | log_msg = self.delimiter.join( 131 | [ 132 | header, 133 | "[{0" + space_fmt + "}/{1}]", 134 | "eta: {eta}", 135 | "{meters}", 136 | "time: {time}", 137 | "data: {data}", 138 | "max mem: {memory:.0f}", 139 | ] 140 | ) 141 | else: 142 | log_msg = self.delimiter.join( 143 | [ 144 | header, 145 | "[{0" + space_fmt + "}/{1}]", 146 | "eta: {eta}", 147 | "{meters}", 148 | "time: {time}", 149 | "data: {data}", 150 | ] 151 | ) 152 | MB = 1024.0 * 1024.0 153 | for obj in iterable: 154 | data_time.update(time.time() - end) 155 | yield obj 156 | iter_time.update(time.time() - end) 157 | if i % print_freq == 0 or i == len(iterable) - 1: 158 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 159 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 160 | if torch.cuda.is_available(): 161 | print( 162 | log_msg.format( 163 | i, 164 | len(iterable), 165 | eta=eta_string, 166 | meters=str(self), 167 | time=str(iter_time), 168 | data=str(data_time), 169 | memory=torch.cuda.max_memory_allocated() / MB, 170 | ) 171 | ) 172 | else: 173 | print( 174 | log_msg.format( 175 | i, 176 | len(iterable), 177 | eta=eta_string, 178 | meters=str(self), 179 | time=str(iter_time), 180 | data=str(data_time), 181 | ) 182 | ) 183 | i += 1 184 | end = time.time() 185 | total_time = time.time() - start_time 186 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 187 | print( 188 | "{} Total time: {} ({:.4f} s / it)".format( 189 | header, total_time_str, total_time / len(iterable) 190 | ) 191 | ) 192 | 193 | 194 | # -------------------------------------------------------- # 195 | # Distributed training # 196 | # -------------------------------------------------------- # 197 | def all_gather(data): 198 | """ 199 | Run all_gather on arbitrary picklable data (not necessarily tensors) 200 | 201 | Args: 202 | data: any picklable object 203 | 204 | Returns: 205 | list[data]: list of data gathered from each rank 206 | """ 207 | world_size = get_world_size() 208 | if world_size == 1: 209 | return [data] 210 | 211 | # serialized to a Tensor 212 | buffer = pickle.dumps(data) 213 | storage = torch.ByteStorage.from_buffer(buffer) 214 | tensor = torch.ByteTensor(storage).to("cuda") 215 | 216 | # obtain Tensor size of each rank 217 | local_size = torch.tensor([tensor.numel()], device="cuda") 218 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] 219 | dist.all_gather(size_list, local_size) 220 | size_list = [int(size.item()) for size in size_list] 221 | max_size = max(size_list) 222 | 223 | # receiving Tensor from all ranks 224 | # we pad the tensor because torch all_gather does not support 225 | # gathering tensors of different shapes 226 | tensor_list = [] 227 | for _ in size_list: 228 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) 229 | if local_size != max_size: 230 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") 231 | tensor = torch.cat((tensor, padding), dim=0) 232 | dist.all_gather(tensor_list, tensor) 233 | 234 | data_list = [] 235 | for size, tensor in zip(size_list, tensor_list): 236 | buffer = tensor.cpu().numpy().tobytes()[:size] 237 | data_list.append(pickle.loads(buffer)) 238 | 239 | return data_list 240 | 241 | 242 | def reduce_dict(input_dict, average=True): 243 | """ 244 | Reduce the values in the dictionary from all processes so that all processes 245 | have the averaged results. Returns a dict with the same fields as 246 | input_dict, after reduction. 247 | 248 | Args: 249 | input_dict (dict): all the values will be reduced 250 | average (bool): whether to do average or sum 251 | """ 252 | world_size = get_world_size() 253 | if world_size < 2: 254 | return input_dict 255 | with torch.no_grad(): 256 | names = [] 257 | values = [] 258 | # sort the keys so that they are consistent across processes 259 | for k in sorted(input_dict.keys()): 260 | names.append(k) 261 | values.append(input_dict[k]) 262 | values = torch.stack(values, dim=0) 263 | dist.all_reduce(values) 264 | if average: 265 | values /= world_size 266 | reduced_dict = {k: v for k, v in zip(names, values)} 267 | return reduced_dict 268 | 269 | 270 | def setup_for_distributed(is_master): 271 | """ 272 | This function disables printing when not in master process 273 | """ 274 | import builtins as __builtin__ 275 | 276 | builtin_print = __builtin__.print 277 | 278 | def print(*args, **kwargs): 279 | force = kwargs.pop("force", False) 280 | if is_master or force: 281 | builtin_print(*args, **kwargs) 282 | 283 | __builtin__.print = print 284 | 285 | 286 | def is_dist_avail_and_initialized(): 287 | if not dist.is_available(): 288 | return False 289 | if not dist.is_initialized(): 290 | return False 291 | return True 292 | 293 | 294 | def get_world_size(): 295 | if not is_dist_avail_and_initialized(): 296 | return 1 297 | return dist.get_world_size() 298 | 299 | 300 | def get_rank(): 301 | if not is_dist_avail_and_initialized(): 302 | return 0 303 | return dist.get_rank() 304 | 305 | 306 | def is_main_process(): 307 | return get_rank() == 0 308 | 309 | 310 | def save_on_master(*args, **kwargs): 311 | if is_main_process(): 312 | torch.save(*args, **kwargs) 313 | 314 | 315 | def init_distributed_mode(args): 316 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: 317 | args.rank = int(os.environ["RANK"]) 318 | args.world_size = int(os.environ["WORLD_SIZE"]) 319 | args.gpu = int(os.environ["LOCAL_RANK"]) 320 | elif "SLURM_PROCID" in os.environ: 321 | args.rank = int(os.environ["SLURM_PROCID"]) 322 | args.gpu = args.rank % torch.cuda.device_count() 323 | else: 324 | print("Not using distributed mode") 325 | args.distributed = False 326 | return 327 | 328 | args.distributed = True 329 | 330 | torch.cuda.set_device(args.gpu) 331 | args.dist_backend = "nccl" 332 | print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) 333 | torch.distributed.init_process_group( 334 | backend=args.dist_backend, 335 | init_method=args.dist_url, 336 | world_size=args.world_size, 337 | rank=args.rank, 338 | ) 339 | torch.distributed.barrier() 340 | setup_for_distributed(args.rank == 0) 341 | 342 | 343 | # -------------------------------------------------------- # 344 | # File operation # 345 | # -------------------------------------------------------- # 346 | def filename(path): 347 | return osp.splitext(osp.basename(path))[0] 348 | 349 | 350 | def mkdir(path): 351 | try: 352 | os.makedirs(path) 353 | except OSError as e: 354 | if e.errno != errno.EEXIST: 355 | raise 356 | 357 | 358 | def read_json(fpath): 359 | with open(fpath, "r") as f: 360 | obj = json.load(f) 361 | return obj 362 | 363 | 364 | def write_json(obj, fpath): 365 | mkdir(osp.dirname(fpath)) 366 | _obj = obj.copy() 367 | for k, v in _obj.items(): 368 | if isinstance(v, np.ndarray): 369 | _obj.pop(k) 370 | with open(fpath, "w") as f: 371 | json.dump(_obj, f, indent=4, separators=(",", ": ")) 372 | 373 | 374 | def symlink(src, dst, overwrite=True, **kwargs): 375 | if os.path.lexists(dst) and overwrite: 376 | os.remove(dst) 377 | os.symlink(src, dst, **kwargs) 378 | 379 | 380 | # -------------------------------------------------------- # 381 | # Misc # 382 | # -------------------------------------------------------- # 383 | def create_small_table(small_dict): 384 | """ 385 | Create a small table using the keys of small_dict as headers. This is only 386 | suitable for small dictionaries. 387 | 388 | Args: 389 | small_dict (dict): a result dictionary of only a few items. 390 | 391 | Returns: 392 | str: the table as a string. 393 | """ 394 | keys, values = tuple(zip(*small_dict.items())) 395 | table = tabulate( 396 | [values], 397 | headers=keys, 398 | tablefmt="pipe", 399 | floatfmt=".3f", 400 | stralign="center", 401 | numalign="center", 402 | ) 403 | return table 404 | 405 | 406 | def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor): 407 | def f(x): 408 | if x >= warmup_iters: 409 | return 1 410 | alpha = float(x) / warmup_iters 411 | return warmup_factor * (1 - alpha) + alpha 412 | 413 | return torch.optim.lr_scheduler.LambdaLR(optimizer, f) 414 | 415 | 416 | def resume_from_ckpt(ckpt_path, model, optimizer=None, lr_scheduler=None): 417 | ckpt = torch.load(ckpt_path) 418 | model.load_state_dict(ckpt["model"], strict=False) 419 | if optimizer is not None: 420 | optimizer.load_state_dict(ckpt["optimizer"]) 421 | if lr_scheduler is not None: 422 | lr_scheduler.load_state_dict(ckpt["lr_scheduler"]) 423 | print(f"loaded checkpoint {ckpt_path}") 424 | print(f"model was trained for {ckpt['epoch']} epochs") 425 | return ckpt["epoch"] 426 | 427 | 428 | def set_random_seed(seed): 429 | torch.manual_seed(seed) 430 | torch.cuda.manual_seed(seed) 431 | torch.cuda.manual_seed_all(seed) 432 | torch.backends.cudnn.benchmark = False 433 | torch.backends.cudnn.deterministic = True 434 | random.seed(seed) 435 | np.random.seed(seed) 436 | os.environ["PYTHONHASHSEED"] = str(seed) 437 | --------------------------------------------------------------------------------