├── datasets ├── __init__.py ├── base.py ├── prw.py ├── build.py └── cuhk_sysu.py ├── doc ├── title.jpg └── net_arch.jpg ├── demo_imgs ├── demo.jpg ├── query.jpg ├── gallery-1.jpg ├── gallery-2.jpg ├── gallery-3.jpg ├── gallery-4.jpg └── gallery-5.jpg ├── configs ├── prw.yaml └── cuhk_sysu.yaml ├── requirements.txt ├── convert_model.py ├── dev └── linter.sh ├── run.sh ├── utils ├── transforms.py ├── km.py └── utils.py ├── models ├── resnet.py ├── oim.py ├── swin.py ├── backbone.py ├── seqnet.py └── swin_transformer.py ├── README.md ├── demo.py ├── train.py ├── defaults.py ├── engine.py └── eval_func.py /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_test_loader, build_train_loader 2 | -------------------------------------------------------------------------------- /doc/title.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinyvision/SOLIDER-PersonSearch/HEAD/doc/title.jpg -------------------------------------------------------------------------------- /demo_imgs/demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinyvision/SOLIDER-PersonSearch/HEAD/demo_imgs/demo.jpg -------------------------------------------------------------------------------- /doc/net_arch.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinyvision/SOLIDER-PersonSearch/HEAD/doc/net_arch.jpg -------------------------------------------------------------------------------- /demo_imgs/query.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinyvision/SOLIDER-PersonSearch/HEAD/demo_imgs/query.jpg -------------------------------------------------------------------------------- /demo_imgs/gallery-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinyvision/SOLIDER-PersonSearch/HEAD/demo_imgs/gallery-1.jpg -------------------------------------------------------------------------------- /demo_imgs/gallery-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinyvision/SOLIDER-PersonSearch/HEAD/demo_imgs/gallery-2.jpg -------------------------------------------------------------------------------- /demo_imgs/gallery-3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinyvision/SOLIDER-PersonSearch/HEAD/demo_imgs/gallery-3.jpg -------------------------------------------------------------------------------- /demo_imgs/gallery-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinyvision/SOLIDER-PersonSearch/HEAD/demo_imgs/gallery-4.jpg -------------------------------------------------------------------------------- /demo_imgs/gallery-5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinyvision/SOLIDER-PersonSearch/HEAD/demo_imgs/gallery-5.jpg -------------------------------------------------------------------------------- /configs/prw.yaml: -------------------------------------------------------------------------------- 1 | OUTPUT_DIR: "./exp_prw" 2 | INPUT: 3 | DATASET: "PRW" 4 | DATA_ROOT: "data/PRW" 5 | SOLVER: 6 | MAX_EPOCHS: 18 7 | MODEL: 8 | LOSS: 9 | LUT_SIZE: 482 10 | CQ_SIZE: 500 -------------------------------------------------------------------------------- /configs/cuhk_sysu.yaml: -------------------------------------------------------------------------------- 1 | OUTPUT_DIR: "./exp_cuhk" 2 | INPUT: 3 | DATASET: "CUHK-SYSU" 4 | DATA_ROOT: "data/CUHK-SYSU" 5 | SOLVER: 6 | MAX_EPOCHS: 20 7 | MODEL: 8 | LOSS: 9 | LUT_SIZE: 5532 10 | CQ_SIZE: 5000 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | black==20.8b1 2 | flake8==3.9.0 3 | isort==5.8.0 4 | numpy==1.16.4 5 | Pillow==6.1.0 6 | scikit-learn==0.23.1 7 | scipy==1.5.1 8 | tabulate==0.8.7 9 | torch==1.7.1 10 | torchvision==0.8.2 11 | tqdm==4.48.2 12 | yacs==0.1.8 13 | future==0.18.2 14 | tensorboard==2.4.1 15 | -------------------------------------------------------------------------------- /convert_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import pickle as pkl 5 | import sys 6 | import torch 7 | 8 | if __name__ == "__main__": 9 | input = sys.argv[1] 10 | obj = torch.load(input, map_location="cpu") 11 | obj = obj["teacher"] 12 | torch.save(obj,sys.argv[2]) 13 | -------------------------------------------------------------------------------- /dev/linter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | 3 | # Run this script at project root by "./dev/linter.sh" before you commit 4 | 5 | { 6 | black --version | grep -E "20.8b1" > /dev/null 7 | } || { 8 | echo "Linter requires 'black==20.8b1' !" 9 | exit 1 10 | } 11 | 12 | ISORT_VERSION=$(isort --version-number) 13 | if [[ "$ISORT_VERSION" != 5.8.0 ]]; then 14 | echo "Linter requires isort==5.8.0 !" 15 | exit 1 16 | fi 17 | 18 | echo "Running isort ..." 19 | isort --line-length=100 --profile=black . 20 | 21 | echo "Running black ..." 22 | black --line-length=100 . 23 | 24 | echo "Running flake8 ..." 25 | if [ -x "$(command -v flake8-3)" ]; then 26 | flake8-3 . 27 | else 28 | python3 -m flake8 . 29 | fi 30 | 31 | command -v arc > /dev/null && arc lint -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | # Swin Base 2 | #CUDA_VISIBLE_DEVICES=0 python train.py --cfg configs/cuhk_sysu.yaml --resume --ckpt path/to/SOLIDER/log/lup/swin_base/checkpoint_tea.pth OUTPUT_DIR './results/cuhk_sysu/swin_base' SOLVER.BASE_LR 0.0003 EVAL_PERIOD 5 MODEL.BONE 'swin_base' INPUT.BATCH_SIZE_TRAIN 2 MODEL.SEMANTIC_WEIGHT 0.6 3 | 4 | # Swin Small 5 | #CUDA_VISIBLE_DEVICES=0 python train.py --cfg configs/cuhk_sysu.yaml --resume --ckpt path/to/SOLIDER/log/lup/swin_small/checkpoint_tea.pth OUTPUT_DIR './results/cuhk_sysu/swin_small' SOLVER.BASE_LR 0.0003 EVAL_PERIOD 5 MODEL.BONE 'swin_small' INPUT.BATCH_SIZE_TRAIN 3 MODEL.SEMANTIC_WEIGHT 0.6 6 | 7 | # Swin Tiny 8 | CUDA_VISIBLE_DEVICES=0 python train.py --cfg configs/cuhk_sysu.yaml --resume --ckpt path/to/SOLIDER/log/lup/swin_tiny/checkpoint_tea.pth OUTPUT_DIR './results/cuhk_sysu/swin_tiny' SOLVER.BASE_LR 0.0003 EVAL_PERIOD 5 MODEL.BONE 'swin_tiny' INPUT.BATCH_SIZE_TRAIN 4 MODEL.SEMANTIC_WEIGHT 0.6 9 | -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from torchvision.transforms import functional as F 4 | 5 | 6 | class Compose: 7 | def __init__(self, transforms): 8 | self.transforms = transforms 9 | 10 | def __call__(self, image, target): 11 | for t in self.transforms: 12 | image, target = t(image, target) 13 | return image, target 14 | 15 | 16 | class RandomHorizontalFlip: 17 | def __init__(self, prob=0.5): 18 | self.prob = prob 19 | 20 | def __call__(self, image, target): 21 | if random.random() < self.prob: 22 | height, width = image.shape[-2:] 23 | image = image.flip(-1) 24 | bbox = target["boxes"] 25 | bbox[:, [0, 2]] = width - bbox[:, [2, 0]] 26 | target["boxes"] = bbox 27 | return image, target 28 | 29 | 30 | class ToTensor: 31 | def __call__(self, image, target): 32 | # convert [0, 255] to [0, 1] 33 | image = F.to_tensor(image) 34 | return image, target 35 | 36 | 37 | def build_transforms(is_train): 38 | transforms = [] 39 | transforms.append(ToTensor()) 40 | if is_train: 41 | transforms.append(RandomHorizontalFlip()) 42 | return Compose(transforms) 43 | -------------------------------------------------------------------------------- /datasets/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | 4 | 5 | class BaseDataset: 6 | """ 7 | Base class of person search dataset. 8 | """ 9 | 10 | def __init__(self, root, transforms, split): 11 | self.root = root 12 | self.transforms = transforms 13 | self.split = split 14 | assert self.split in ("train", "gallery", "query") 15 | self.annotations = self._load_annotations() 16 | 17 | def _load_annotations(self): 18 | """ 19 | For each image, load its annotation that is a dictionary with the following keys: 20 | img_name (str): image name 21 | img_path (str): image path 22 | boxes (np.array[N, 4]): ground-truth boxes in (x1, y1, x2, y2) format 23 | pids (np.array[N]): person IDs corresponding to these boxes 24 | cam_id (int): camera ID (only for PRW dataset) 25 | """ 26 | raise NotImplementedError 27 | 28 | def __getitem__(self, index): 29 | anno = self.annotations[index] 30 | img = Image.open(anno["img_path"]).convert("RGB") 31 | boxes = torch.as_tensor(anno["boxes"], dtype=torch.float32) 32 | labels = torch.as_tensor(anno["pids"], dtype=torch.int64) 33 | target = {"img_name": anno["img_name"], "boxes": boxes, "labels": labels} 34 | if self.transforms is not None: 35 | img, target = self.transforms(img, target) 36 | return img, target 37 | 38 | def __len__(self): 39 | return len(self.annotations) 40 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch.nn.functional as F 4 | import torchvision 5 | from torch import nn 6 | 7 | from .backbone import resnet50 8 | 9 | class Backbone(nn.Sequential): 10 | def __init__(self, resnet): 11 | super(Backbone, self).__init__( 12 | OrderedDict( 13 | [ 14 | ["conv1", resnet.conv1], 15 | ["bn1", resnet.bn1], 16 | ["relu", resnet.relu], 17 | ["maxpool", resnet.maxpool], 18 | ["layer1", resnet.layer1], # res2 19 | ["layer2", resnet.layer2], # res3 20 | ["layer3", resnet.layer3], # res4 21 | ] 22 | ) 23 | ) 24 | self.out_channels = 1024 25 | 26 | def forward(self, x): 27 | # using the forward method from nn.Sequential 28 | feat = super(Backbone, self).forward(x) 29 | return OrderedDict([["feat_res4", feat]]) 30 | 31 | 32 | class Res5Head(nn.Sequential): 33 | def __init__(self, resnet): 34 | super(Res5Head, self).__init__(OrderedDict([["layer4", resnet.layer4]])) # res5 35 | self.out_channels = [1024, 2048] 36 | 37 | def forward(self, x): 38 | feat = super(Res5Head, self).forward(x) 39 | x = F.adaptive_max_pool2d(x, 1) 40 | feat = F.adaptive_max_pool2d(feat, 1) 41 | print(x.shape,feat.shape) 42 | return OrderedDict([["feat_res4", x], ["feat_res5", feat]]) 43 | 44 | 45 | def build_resnet(name="resnet50", pretrained=True): 46 | #resnet = torchvision.models.resnet.__dict__[name](pretrained=pretrained) 47 | resnet = resnet50(pretrained=True) 48 | 49 | # freeze layers 50 | resnet.conv1.weight.requires_grad_(False) 51 | resnet.bn1.weight.requires_grad_(False) 52 | resnet.bn1.bias.requires_grad_(False) 53 | 54 | return Backbone(resnet), Res5Head(resnet) 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SOLIDER on [Person Search] 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/beyond-appearance-a-semantic-controllable/person-search-on-cuhk-sysu)](https://paperswithcode.com/sota/person-search-on-cuhk-sysu?p=beyond-appearance-a-semantic-controllable) 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/beyond-appearance-a-semantic-controllable/person-search-on-prw)](https://paperswithcode.com/sota/person-search-on-prw?p=beyond-appearance-a-semantic-controllable) 5 | 6 | This repo provides details about how to use [SOLIDER](https://github.com/tinyvision/SOLIDER) pretrained representation on person search task. 7 | We modify the code from [SeqNet](https://github.com/serend1p1ty/SeqNet), and you can refer to the original repo for more details. 8 | 9 | ## Installation and Datasets 10 | 11 | Details of installation and dataset preparation can be found in [SeqNet](https://github.com/serend1p1ty/SeqNet). 12 | 13 | ## Prepare Pre-trained Models 14 | You can download models from [SOLIDER](https://github.com/tinyvision/SOLIDER), or use [SOLIDER](https://github.com/tinyvision/SOLIDER) to train your own models. 15 | Before training, you should convert the models first. 16 | 17 | ```bash 18 | python convert_model.py path/to/SOLIDER/log/lup/swin_tiny/checkpoint.pth path/to/SOLIDER/log/lup/swin_tiny/checkpoint_tea.pth 19 | ``` 20 | 21 | ## Training 22 | 23 | We utilize 1 GPU for training. Please modify the `ckpt` and `OUTPUT_DIR` in the bash file. 24 | 25 | ```bash 26 | sh run.sh 27 | ``` 28 | 29 | ## Performance 30 | 31 | | Method | Model | CUHK-SYSU
(mAP/R1) | PRW
(mAP/R1) | 32 | | ------ | :---: | :---: | :---: | 33 | | SOLIDER | Swin Tiny | 94.91/95.72 | 56.84/86.78 | 34 | | SOLIDER | Swin Small | 95.46/95.79 | 59.84/86.73 | 35 | | SOLIDER | Swin Base | 94.93/95.52 | 59.72/86.83 | 36 | 37 | - We use the pretrained models from [SOLIDER](https://github.com/tinyvision/SOLIDER). 38 | - The semantic weight is set to 0.6 in these experiments. 39 | 40 | ## Citation 41 | 42 | If you find this code useful for your research, please cite our paper 43 | 44 | ``` 45 | @inproceedings{chen2023beyond, 46 | title={Beyond Appearance: a Semantic Controllable Self-Supervised Learning Framework for Human-Centric Visual Tasks}, 47 | author={Weihua Chen and Xianzhe Xu and Jian Jia and Hao Luo and Yaohua Wang and Fan Wang and Rong Jin and Xiuyu Sun}, 48 | booktitle={The IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 49 | year={2023}, 50 | } 51 | ``` 52 | -------------------------------------------------------------------------------- /models/oim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import autograd, nn 4 | 5 | # from utils.distributed import tensor_gather 6 | 7 | 8 | class OIM(autograd.Function): 9 | @staticmethod 10 | def forward(ctx, inputs, targets, lut, cq, header, momentum): 11 | ctx.save_for_backward(inputs, targets, lut, cq, header, momentum) 12 | outputs_labeled = inputs.mm(lut.t()) 13 | outputs_unlabeled = inputs.mm(cq.t()) 14 | return torch.cat([outputs_labeled, outputs_unlabeled], dim=1) 15 | 16 | @staticmethod 17 | def backward(ctx, grad_outputs): 18 | inputs, targets, lut, cq, header, momentum = ctx.saved_tensors 19 | 20 | # inputs, targets = tensor_gather((inputs, targets)) 21 | 22 | grad_inputs = None 23 | if ctx.needs_input_grad[0]: 24 | grad_inputs = grad_outputs.mm(torch.cat([lut, cq], dim=0)) 25 | if grad_inputs.dtype == torch.float16: 26 | grad_inputs = grad_inputs.to(torch.float32) 27 | 28 | for x, y in zip(inputs, targets): 29 | if y < len(lut): 30 | lut[y] = momentum * lut[y] + (1.0 - momentum) * x 31 | lut[y] /= lut[y].norm() 32 | else: 33 | cq[header] = x 34 | header = (header + 1) % cq.size(0) 35 | return grad_inputs, None, None, None, None, None 36 | 37 | 38 | def oim(inputs, targets, lut, cq, header, momentum=0.5): 39 | return OIM.apply(inputs, targets, lut, cq, torch.tensor(header), torch.tensor(momentum)) 40 | 41 | 42 | class OIMLoss(nn.Module): 43 | def __init__(self, num_features, num_pids, num_cq_size, oim_momentum, oim_scalar): 44 | super(OIMLoss, self).__init__() 45 | self.num_features = num_features 46 | self.num_pids = num_pids 47 | self.num_unlabeled = num_cq_size 48 | self.momentum = oim_momentum 49 | self.oim_scalar = oim_scalar 50 | 51 | self.register_buffer("lut", torch.zeros(self.num_pids, self.num_features)) 52 | self.register_buffer("cq", torch.zeros(self.num_unlabeled, self.num_features)) 53 | 54 | self.header_cq = 0 55 | 56 | def forward(self, inputs, roi_label): 57 | # merge into one batch, background label = 0 58 | targets = torch.cat(roi_label) 59 | label = targets - 1 # background label = -1 60 | 61 | inds = label >= 0 62 | label = label[inds] 63 | inputs = inputs[inds.unsqueeze(1).expand_as(inputs)].view(-1, self.num_features) 64 | 65 | projected = oim(inputs, label, self.lut, self.cq, self.header_cq, momentum=self.momentum) 66 | projected *= self.oim_scalar 67 | 68 | self.header_cq = ( 69 | self.header_cq + (label >= self.num_pids).long().sum().item() 70 | ) % self.num_unlabeled 71 | loss_oim = F.cross_entropy(projected, label, ignore_index=5554) 72 | return loss_oim 73 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from glob import glob 3 | 4 | import matplotlib.pyplot as plt 5 | import torch 6 | import torch.utils.data 7 | from PIL import Image 8 | from torchvision.transforms import functional as F 9 | 10 | from defaults import get_default_cfg 11 | from models.seqnet import SeqNet 12 | from utils.utils import resume_from_ckpt 13 | 14 | 15 | def visualize_result(img_path, detections, similarities): 16 | fig, ax = plt.subplots(figsize=(16, 9)) 17 | ax.imshow(plt.imread(img_path)) 18 | plt.axis("off") 19 | for detection, sim in zip(detections, similarities): 20 | x1, y1, x2, y2 = detection 21 | ax.add_patch( 22 | plt.Rectangle( 23 | (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="#4CAF50", linewidth=3.5 24 | ) 25 | ) 26 | ax.add_patch( 27 | plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="white", linewidth=1) 28 | ) 29 | ax.text( 30 | x1 + 5, 31 | y1 - 18, 32 | "{:.2f}".format(sim), 33 | bbox=dict(facecolor="#4CAF50", linewidth=0), 34 | fontsize=20, 35 | color="white", 36 | ) 37 | plt.tight_layout() 38 | fig.savefig(img_path.replace("gallery", "result")) 39 | plt.show() 40 | plt.close(fig) 41 | 42 | 43 | def main(args): 44 | cfg = get_default_cfg() 45 | if args.cfg_file: 46 | cfg.merge_from_file(args.cfg_file) 47 | cfg.merge_from_list(args.opts) 48 | cfg.freeze() 49 | 50 | device = torch.device(cfg.DEVICE) 51 | 52 | print("Creating model") 53 | model = SeqNet(cfg) 54 | model.to(device) 55 | model.eval() 56 | 57 | resume_from_ckpt(args.ckpt, model) 58 | 59 | query_img = [F.to_tensor(Image.open("demo_imgs/query.jpg").convert("RGB")).to(device)] 60 | query_target = [{"boxes": torch.tensor([[0, 0, 466, 943]]).to(device)}] 61 | query_feat = model(query_img, query_target)[0] 62 | 63 | gallery_img_paths = sorted(glob("demo_imgs/gallery-*.jpg")) 64 | for gallery_img_path in gallery_img_paths: 65 | print(f"Processing {gallery_img_path}") 66 | gallery_img = [F.to_tensor(Image.open(gallery_img_path).convert("RGB")).to(device)] 67 | gallery_output = model(gallery_img)[0] 68 | detections = gallery_output["boxes"] 69 | gallery_feats = gallery_output["embeddings"] 70 | 71 | # Compute pairwise cosine similarities, 72 | # which equals to inner-products, as features are already L2-normed 73 | similarities = gallery_feats.mm(query_feat.view(-1, 1)).squeeze() 74 | 75 | visualize_result(gallery_img_path, detections, similarities) 76 | 77 | 78 | if __name__ == "__main__": 79 | parser = argparse.ArgumentParser(description="Train a person search network.") 80 | parser.add_argument("--cfg", dest="cfg_file", help="Path to configuration file.") 81 | parser.add_argument("--ckpt", required=True, help="Path to checkpoint to resume or evaluate.") 82 | parser.add_argument( 83 | "opts", nargs=argparse.REMAINDER, help="Modify config options using the command-line" 84 | ) 85 | args = parser.parse_args() 86 | main(args) 87 | -------------------------------------------------------------------------------- /datasets/prw.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import re 3 | 4 | import numpy as np 5 | from scipy.io import loadmat 6 | 7 | from .base import BaseDataset 8 | 9 | 10 | class PRW(BaseDataset): 11 | def __init__(self, root, transforms, split): 12 | self.name = "PRW" 13 | self.img_prefix = osp.join(root, "frames") 14 | super(PRW, self).__init__(root, transforms, split) 15 | 16 | def _get_cam_id(self, img_name): 17 | match = re.search(r"c\d", img_name).group().replace("c", "") 18 | return int(match) 19 | 20 | def _load_queries(self): 21 | query_info = osp.join(self.root, "query_info.txt") 22 | with open(query_info, "rb") as f: 23 | raw = f.readlines() 24 | 25 | queries = [] 26 | for line in raw: 27 | linelist = str(line, "utf-8").split(" ") 28 | pid = int(linelist[0]) 29 | x, y, w, h = ( 30 | float(linelist[1]), 31 | float(linelist[2]), 32 | float(linelist[3]), 33 | float(linelist[4]), 34 | ) 35 | roi = np.array([x, y, x + w, y + h]).astype(np.int32) 36 | roi = np.clip(roi, 0, None) # several coordinates are negative 37 | img_name = linelist[5][:-2] + ".jpg" 38 | queries.append( 39 | { 40 | "img_name": img_name, 41 | "img_path": osp.join(self.img_prefix, img_name), 42 | "boxes": roi[np.newaxis, :], 43 | "pids": np.array([pid]), 44 | "cam_id": self._get_cam_id(img_name), 45 | } 46 | ) 47 | return queries 48 | 49 | def _load_split_img_names(self): 50 | """ 51 | Load the image names for the specific split. 52 | """ 53 | assert self.split in ("train", "gallery") 54 | if self.split == "train": 55 | imgs = loadmat(osp.join(self.root, "frame_train.mat"))["img_index_train"] 56 | else: 57 | imgs = loadmat(osp.join(self.root, "frame_test.mat"))["img_index_test"] 58 | return [img[0][0] + ".jpg" for img in imgs] 59 | 60 | def _load_annotations(self): 61 | if self.split == "query": 62 | return self._load_queries() 63 | 64 | annotations = [] 65 | imgs = self._load_split_img_names() 66 | for img_name in imgs: 67 | anno_path = osp.join(self.root, "annotations", img_name) 68 | anno = loadmat(anno_path) 69 | box_key = "box_new" 70 | if box_key not in anno.keys(): 71 | box_key = "anno_file" 72 | if box_key not in anno.keys(): 73 | box_key = "anno_previous" 74 | 75 | rois = anno[box_key][:, 1:] 76 | ids = anno[box_key][:, 0] 77 | rois = np.clip(rois, 0, None) # several coordinates are negative 78 | 79 | assert len(rois) == len(ids) 80 | 81 | rois[:, 2:] += rois[:, :2] 82 | ids[ids == -2] = 5555 # assign pid = 5555 for unlabeled people 83 | annotations.append( 84 | { 85 | "img_name": img_name, 86 | "img_path": osp.join(self.img_prefix, img_name), 87 | "boxes": rois.astype(np.int32), 88 | # FIXME: (training pids) 1, 2,..., 478, 480, 481, 482, 483, 932, 5555 89 | "pids": ids.astype(np.int32), 90 | "cam_id": self._get_cam_id(img_name), 91 | } 92 | ) 93 | return annotations 94 | -------------------------------------------------------------------------------- /datasets/build.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from utils.transforms import build_transforms 4 | from utils.utils import create_small_table 5 | 6 | from .cuhk_sysu import CUHKSYSU 7 | from .prw import PRW 8 | 9 | 10 | def print_statistics(dataset): 11 | """ 12 | Print dataset statistics. 13 | """ 14 | num_imgs = len(dataset.annotations) 15 | num_boxes = 0 16 | pid_set = set() 17 | for anno in dataset.annotations: 18 | num_boxes += anno["boxes"].shape[0] 19 | for pid in anno["pids"]: 20 | pid_set.add(pid) 21 | statistics = { 22 | "dataset": dataset.name, 23 | "split": dataset.split, 24 | "num_images": num_imgs, 25 | "num_boxes": num_boxes, 26 | } 27 | if dataset.name != "CUHK-SYSU" or dataset.split != "query": 28 | pid_list = sorted(list(pid_set)) 29 | if dataset.split == "query": 30 | num_pids, min_pid, max_pid = len(pid_list), min(pid_list), max(pid_list) 31 | statistics.update( 32 | { 33 | "num_labeled_pids": num_pids, 34 | "min_labeled_pid": int(min_pid), 35 | "max_labeled_pid": int(max_pid), 36 | } 37 | ) 38 | else: 39 | unlabeled_pid = pid_list[-1] 40 | pid_list = pid_list[:-1] # remove unlabeled pid 41 | num_pids, min_pid, max_pid = len(pid_list), min(pid_list), max(pid_list) 42 | statistics.update( 43 | { 44 | "num_labeled_pids": num_pids, 45 | "min_labeled_pid": int(min_pid), 46 | "max_labeled_pid": int(max_pid), 47 | "unlabeled_pid": int(unlabeled_pid), 48 | } 49 | ) 50 | print(f"=> {dataset.name}-{dataset.split} loaded:\n" + create_small_table(statistics)) 51 | 52 | 53 | def build_dataset(dataset_name, root, transforms, split, verbose=True): 54 | if dataset_name == "CUHK-SYSU": 55 | dataset = CUHKSYSU(root, transforms, split) 56 | elif dataset_name == "PRW": 57 | dataset = PRW(root, transforms, split) 58 | else: 59 | raise NotImplementedError(f"Unknow dataset: {dataset_name}") 60 | if verbose: 61 | print_statistics(dataset) 62 | return dataset 63 | 64 | 65 | def collate_fn(batch): 66 | return tuple(zip(*batch)) 67 | 68 | 69 | def build_train_loader(cfg): 70 | transforms = build_transforms(is_train=True) 71 | dataset = build_dataset(cfg.INPUT.DATASET, cfg.INPUT.DATA_ROOT, transforms, "train") 72 | return torch.utils.data.DataLoader( 73 | dataset, 74 | batch_size=cfg.INPUT.BATCH_SIZE_TRAIN, 75 | shuffle=True, 76 | num_workers=cfg.INPUT.NUM_WORKERS_TRAIN, 77 | pin_memory=True, 78 | drop_last=True, 79 | collate_fn=collate_fn, 80 | ) 81 | 82 | 83 | def build_test_loader(cfg): 84 | transforms = build_transforms(is_train=False) 85 | gallery_set = build_dataset(cfg.INPUT.DATASET, cfg.INPUT.DATA_ROOT, transforms, "gallery") 86 | query_set = build_dataset(cfg.INPUT.DATASET, cfg.INPUT.DATA_ROOT, transforms, "query") 87 | gallery_loader = torch.utils.data.DataLoader( 88 | gallery_set, 89 | batch_size=cfg.INPUT.BATCH_SIZE_TEST, 90 | shuffle=False, 91 | num_workers=cfg.INPUT.NUM_WORKERS_TEST, 92 | pin_memory=True, 93 | collate_fn=collate_fn, 94 | ) 95 | query_loader = torch.utils.data.DataLoader( 96 | query_set, 97 | batch_size=cfg.INPUT.BATCH_SIZE_TEST, 98 | shuffle=False, 99 | num_workers=cfg.INPUT.NUM_WORKERS_TEST, 100 | pin_memory=True, 101 | collate_fn=collate_fn, 102 | ) 103 | return gallery_loader, query_loader 104 | -------------------------------------------------------------------------------- /models/swin.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import torchvision 6 | from torch import nn 7 | 8 | from .swin_transformer import swin_tiny_patch4_window7_224,swin_small_patch4_window7_224,swin_base_patch4_window7_224 9 | 10 | bonenum = 3 11 | 12 | class Backbone(nn.Sequential): 13 | def __init__(self, swin, out_channels=384): 14 | super().__init__() 15 | self.swin = swin 16 | self.out_channels = out_channels 17 | 18 | def forward(self, x): 19 | if self.swin.semantic_weight >= 0: 20 | w = torch.ones(x.shape[0],1) * self.swin.semantic_weight 21 | w = torch.cat([w, 1-w], axis=-1) 22 | semantic_weight = w.cuda() 23 | 24 | x, hw_shape = self.swin.patch_embed(x) 25 | 26 | if self.swin.use_abs_pos_embed: 27 | x = x + self.swin.absolute_pos_embed 28 | x = self.swin.drop_after_pos(x) 29 | 30 | outs = [] 31 | for i, stage in enumerate(self.swin.stages[:bonenum]): 32 | x, hw_shape, out, out_hw_shape = stage(x, hw_shape) 33 | if self.swin.semantic_weight >= 0: 34 | sw = self.swin.semantic_embed_w[i](semantic_weight).unsqueeze(1) 35 | sb = self.swin.semantic_embed_b[i](semantic_weight).unsqueeze(1) 36 | x = x * self.swin.softplus(sw) + sb 37 | if i == bonenum-1: 38 | norm_layer = getattr(self.swin, f'norm{i}') 39 | out = norm_layer(out) 40 | out = out.view(-1, *out_hw_shape, 41 | self.swin.num_features[i]).permute(0, 3, 1, 42 | 2).contiguous() 43 | outs.append(out) 44 | return OrderedDict([["feat_res4", outs[-1]]]) 45 | 46 | class Res5Head(nn.Sequential): 47 | def __init__(self, swin, out_channels=384): 48 | super().__init__() # last block 49 | self.swin = swin 50 | self.out_channels = [out_channels, out_channels*2] 51 | 52 | def forward(self, x): 53 | if self.swin.semantic_weight >= 0: 54 | w = torch.ones(x.shape[0],1) * self.swin.semantic_weight 55 | w = torch.cat([w, 1-w], axis=-1) 56 | semantic_weight = w.cuda() 57 | 58 | feat = x 59 | hw_shape = x.shape[-2:] 60 | x = torch.flatten(x, 2) 61 | x = x.permute(0, 2, 1) 62 | x,hw_shape = self.swin.stages[bonenum-1].downsample(x,hw_shape) 63 | if self.swin.semantic_weight >= 0: 64 | sw = self.swin.semantic_embed_w[bonenum-1](semantic_weight).unsqueeze(1) 65 | sb = self.swin.semantic_embed_b[bonenum-1](semantic_weight).unsqueeze(1) 66 | x = x * self.swin.softplus(sw) + sb 67 | for i, stage in enumerate(self.swin.stages[bonenum:]): 68 | x, hw_shape, out, out_hw_shape = stage(x, hw_shape) 69 | if self.swin.semantic_weight >= 0: 70 | sw = self.swin.semantic_embed_w[bonenum+i](semantic_weight).unsqueeze(1) 71 | sb = self.swin.semantic_embed_b[bonenum+i](semantic_weight).unsqueeze(1) 72 | x = x * self.swin.softplus(sw) + sb 73 | if i == len(self.swin.stages) - bonenum - 1: 74 | norm_layer = getattr(self.swin, f'norm{bonenum+i}') 75 | out = norm_layer(out) 76 | out = out.view(-1, *out_hw_shape, 77 | self.swin.num_features[bonenum+i]).permute(0, 3, 1, 78 | 2).contiguous() 79 | feat = self.swin.avgpool(feat) 80 | out = self.swin.avgpool(out) 81 | return OrderedDict([["feat_res4", feat], ["feat_res5", out]]) 82 | 83 | def build_swin(name="swin_tiny", semantic_weight=1.0): 84 | if 'tiny' in name: 85 | swin = swin_tiny_patch4_window7_224(drop_path_rate=0.1,semantic_weight=semantic_weight) 86 | out_channels = 384 87 | elif 'small' in name: 88 | swin = swin_small_patch4_window7_224(drop_path_rate=0.1,semantic_weight=semantic_weight) 89 | out_channels = 384 90 | elif 'base' in name: 91 | swin = swin_base_patch4_window7_224(drop_path_rate=0.1,semantic_weight=semantic_weight) 92 | out_channels = 512 93 | 94 | return Backbone(swin,out_channels), Res5Head(swin,out_channels), out_channels*2 95 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os.path as osp 4 | import time 5 | 6 | import torch 7 | import torch.utils.data 8 | 9 | from datasets import build_test_loader, build_train_loader 10 | from defaults import get_default_cfg 11 | from engine import evaluate_performance, train_one_epoch 12 | from models.seqnet import SeqNet 13 | from utils.utils import mkdir, resume_from_ckpt, save_on_master, set_random_seed 14 | 15 | 16 | def main(args): 17 | cfg = get_default_cfg() 18 | if args.cfg_file: 19 | cfg.merge_from_file(args.cfg_file) 20 | cfg.merge_from_list(args.opts) 21 | cfg.freeze() 22 | 23 | device = torch.device(cfg.DEVICE) 24 | if cfg.SEED >= 0: 25 | set_random_seed(cfg.SEED) 26 | 27 | print("Creating model") 28 | model = SeqNet(cfg) 29 | model.to(device) 30 | 31 | print("Loading data") 32 | train_loader = build_train_loader(cfg) 33 | gallery_loader, query_loader = build_test_loader(cfg) 34 | 35 | if args.eval: 36 | assert args.ckpt, "--ckpt must be specified when --eval enabled" 37 | resume_from_ckpt(args.ckpt, model) 38 | evaluate_performance( 39 | model, 40 | gallery_loader, 41 | query_loader, 42 | device, 43 | use_gt=cfg.EVAL_USE_GT, 44 | use_cache=cfg.EVAL_USE_CACHE, 45 | use_cbgm=cfg.EVAL_USE_CBGM, 46 | ) 47 | exit(0) 48 | 49 | params = [p for p in model.parameters() if p.requires_grad] 50 | optimizer = torch.optim.SGD( 51 | params, 52 | lr=cfg.SOLVER.BASE_LR, 53 | momentum=cfg.SOLVER.SGD_MOMENTUM, 54 | weight_decay=cfg.SOLVER.WEIGHT_DECAY, 55 | ) 56 | 57 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( 58 | optimizer, milestones=cfg.SOLVER.LR_DECAY_MILESTONES, gamma=0.1 59 | ) 60 | 61 | start_epoch = 0 62 | if args.resume: 63 | assert args.ckpt, "--ckpt must be specified when --resume enabled" 64 | start_epoch = resume_from_ckpt(args.ckpt, model, optimizer, lr_scheduler) 65 | 66 | print("Creating output folder") 67 | output_dir = cfg.OUTPUT_DIR 68 | mkdir(output_dir) 69 | path = osp.join(output_dir, "config.yaml") 70 | with open(path, "w") as f: 71 | f.write(cfg.dump()) 72 | print(f"Full config is saved to {path}") 73 | tfboard = None 74 | if cfg.TF_BOARD: 75 | from torch.utils.tensorboard import SummaryWriter 76 | 77 | tf_log_path = osp.join(output_dir, "tf_log") 78 | mkdir(tf_log_path) 79 | tfboard = SummaryWriter(log_dir=tf_log_path) 80 | print(f"TensorBoard files are saved to {tf_log_path}") 81 | 82 | print("Start training") 83 | start_time = time.time() 84 | for epoch in range(start_epoch, cfg.SOLVER.MAX_EPOCHS): 85 | train_one_epoch(cfg, model, optimizer, train_loader, device, epoch, tfboard) 86 | lr_scheduler.step() 87 | 88 | if (epoch + 1) % cfg.EVAL_PERIOD == 0 or epoch == cfg.SOLVER.MAX_EPOCHS - 1: 89 | evaluate_performance( 90 | model, 91 | gallery_loader, 92 | query_loader, 93 | device, 94 | use_gt=cfg.EVAL_USE_GT, 95 | use_cache=cfg.EVAL_USE_CACHE, 96 | use_cbgm=cfg.EVAL_USE_CBGM, 97 | ) 98 | 99 | if (epoch + 1) % cfg.CKPT_PERIOD == 0 or epoch == cfg.SOLVER.MAX_EPOCHS - 1: 100 | save_on_master( 101 | { 102 | "model": model.state_dict(), 103 | "optimizer": optimizer.state_dict(), 104 | "lr_scheduler": lr_scheduler.state_dict(), 105 | "epoch": epoch, 106 | }, 107 | osp.join(output_dir, f"epoch_{epoch}.pth"), 108 | ) 109 | 110 | if tfboard: 111 | tfboard.close() 112 | total_time = time.time() - start_time 113 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 114 | print(f"Total training time {total_time_str}") 115 | 116 | 117 | if __name__ == "__main__": 118 | parser = argparse.ArgumentParser(description="Train a person search network.") 119 | parser.add_argument("--cfg", dest="cfg_file", help="Path to configuration file.") 120 | parser.add_argument( 121 | "--eval", action="store_true", help="Evaluate the performance of a given checkpoint." 122 | ) 123 | parser.add_argument( 124 | "--resume", action="store_true", help="Resume from the specified checkpoint." 125 | ) 126 | parser.add_argument("--ckpt", help="Path to checkpoint to resume or evaluate.") 127 | parser.add_argument( 128 | "opts", nargs=argparse.REMAINDER, help="Modify config options using the command-line" 129 | ) 130 | args = parser.parse_args() 131 | main(args) 132 | -------------------------------------------------------------------------------- /utils/km.py: -------------------------------------------------------------------------------- 1 | # encoding=utf-8 2 | 3 | import random 4 | 5 | import numpy as np 6 | 7 | zero_threshold = 0.00000001 8 | 9 | 10 | class KMNode(object): 11 | def __init__(self, id, exception=0, match=None, visit=False): 12 | self.id = id 13 | self.exception = exception 14 | self.match = match 15 | self.visit = visit 16 | 17 | 18 | class KuhnMunkres(object): 19 | def __init__(self): 20 | self.matrix = None 21 | self.x_nodes = [] 22 | self.y_nodes = [] 23 | self.minz = float("inf") 24 | self.x_length = 0 25 | self.y_length = 0 26 | self.index_x = 0 27 | self.index_y = 1 28 | 29 | def __del__(self): 30 | pass 31 | 32 | def set_matrix(self, x_y_values): 33 | xs = set() 34 | ys = set() 35 | for x, y, value in x_y_values: 36 | xs.add(x) 37 | ys.add(y) 38 | 39 | # 选取较小的作为x 40 | if len(xs) < len(ys): 41 | self.index_x = 0 42 | self.index_y = 1 43 | else: 44 | self.index_x = 1 45 | self.index_y = 0 46 | xs, ys = ys, xs 47 | 48 | x_dic = {x: i for i, x in enumerate(xs)} 49 | y_dic = {y: j for j, y in enumerate(ys)} 50 | self.x_nodes = [KMNode(x) for x in xs] 51 | self.y_nodes = [KMNode(y) for y in ys] 52 | self.x_length = len(xs) 53 | self.y_length = len(ys) 54 | 55 | self.matrix = np.zeros((self.x_length, self.y_length)) 56 | for row in x_y_values: 57 | x = row[self.index_x] 58 | y = row[self.index_y] 59 | value = row[2] 60 | x_index = x_dic[x] 61 | y_index = y_dic[y] 62 | self.matrix[x_index, y_index] = value 63 | 64 | for i in range(self.x_length): 65 | self.x_nodes[i].exception = max(self.matrix[i, :]) 66 | 67 | def km(self): 68 | for i in range(self.x_length): 69 | while True: 70 | self.minz = float("inf") 71 | self.set_false(self.x_nodes) 72 | self.set_false(self.y_nodes) 73 | 74 | if self.dfs(i): 75 | break 76 | 77 | self.change_exception(self.x_nodes, -self.minz) 78 | self.change_exception(self.y_nodes, self.minz) 79 | 80 | def dfs(self, i): 81 | x_node = self.x_nodes[i] 82 | x_node.visit = True 83 | for j in range(self.y_length): 84 | y_node = self.y_nodes[j] 85 | if not y_node.visit: 86 | t = x_node.exception + y_node.exception - self.matrix[i][j] 87 | if abs(t) < zero_threshold: 88 | y_node.visit = True 89 | if y_node.match is None or self.dfs(y_node.match): 90 | x_node.match = j 91 | y_node.match = i 92 | return True 93 | else: 94 | if t >= zero_threshold: 95 | self.minz = min(self.minz, t) 96 | return False 97 | 98 | def set_false(self, nodes): 99 | for node in nodes: 100 | node.visit = False 101 | 102 | def change_exception(self, nodes, change): 103 | for node in nodes: 104 | if node.visit: 105 | node.exception += change 106 | 107 | def get_connect_result(self): 108 | ret = [] 109 | for i in range(self.x_length): 110 | x_node = self.x_nodes[i] 111 | j = x_node.match 112 | y_node = self.y_nodes[j] 113 | x_id = x_node.id 114 | y_id = y_node.id 115 | value = self.matrix[i][j] 116 | 117 | if self.index_x == 1 and self.index_y == 0: 118 | x_id, y_id = y_id, x_id 119 | ret.append((x_id, y_id, value)) 120 | 121 | return ret 122 | 123 | def get_max_value_result(self): 124 | # ret = 0 125 | ret = -100 126 | # ret = [] 127 | for i in range(self.x_length): 128 | j = self.x_nodes[i].match 129 | # ret += self.matrix[i][j] 130 | ret = max(ret, self.matrix[i][j]) 131 | # ret.append(self.matrix[i][j]) 132 | # ret.sort() 133 | # ret = ret[-1:] 134 | # ret = np.array(ret).mean() 135 | return ret 136 | 137 | 138 | def run_kuhn_munkres(x_y_values): 139 | process = KuhnMunkres() 140 | process.set_matrix(x_y_values) 141 | process.km() 142 | return process.get_connect_result(), process.get_max_value_result() 143 | 144 | 145 | def test(): 146 | values = [] 147 | random.seed(0) 148 | for i in range(500): 149 | for j in range(1000): 150 | value = random.random() 151 | values.append((i, j, value)) 152 | 153 | return run_kuhn_munkres(values) 154 | 155 | 156 | if __name__ == "__main__": 157 | # s_time = time.time() 158 | # ret = test() 159 | # print "time usage: %s " % str(time.time() - s_time) 160 | values = [(1, 1, 3), (1, 3, 4), (2, 1, 2), (2, 2, 1), (2, 3, 3), (3, 2, 4), (3, 3, 5)] 161 | print(run_kuhn_munkres(values)) 162 | -------------------------------------------------------------------------------- /datasets/cuhk_sysu.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import numpy as np 4 | from scipy.io import loadmat 5 | 6 | from .base import BaseDataset 7 | 8 | 9 | class CUHKSYSU(BaseDataset): 10 | def __init__(self, root, transforms, split): 11 | self.name = "CUHK-SYSU" 12 | self.img_prefix = osp.join(root, "Image", "SSM") 13 | super(CUHKSYSU, self).__init__(root, transforms, split) 14 | 15 | def _load_queries(self): 16 | # TestG50: a test protocol, 50 gallery images per query 17 | protoc = loadmat(osp.join(self.root, "annotation/test/train_test/TestG50.mat")) 18 | protoc = protoc["TestG50"].squeeze() 19 | queries = [] 20 | for item in protoc["Query"]: 21 | img_name = str(item["imname"][0, 0][0]) 22 | roi = item["idlocate"][0, 0][0].astype(np.int32) 23 | roi[2:] += roi[:2] 24 | queries.append( 25 | { 26 | "img_name": img_name, 27 | "img_path": osp.join(self.img_prefix, img_name), 28 | "boxes": roi[np.newaxis, :], 29 | "pids": np.array([-100]), # dummy pid 30 | } 31 | ) 32 | return queries 33 | 34 | def _load_split_img_names(self): 35 | """ 36 | Load the image names for the specific split. 37 | """ 38 | assert self.split in ("train", "gallery") 39 | # gallery images 40 | gallery_imgs = loadmat(osp.join(self.root, "annotation", "pool.mat")) 41 | gallery_imgs = gallery_imgs["pool"].squeeze() 42 | gallery_imgs = [str(a[0]) for a in gallery_imgs] 43 | if self.split == "gallery": 44 | return gallery_imgs 45 | # all images 46 | all_imgs = loadmat(osp.join(self.root, "annotation", "Images.mat")) 47 | all_imgs = all_imgs["Img"].squeeze() 48 | all_imgs = [str(a[0][0]) for a in all_imgs] 49 | # training images = all images - gallery images 50 | training_imgs = sorted(list(set(all_imgs) - set(gallery_imgs))) 51 | return training_imgs 52 | 53 | def _load_annotations(self): 54 | if self.split == "query": 55 | return self._load_queries() 56 | 57 | # load all images and build a dict from image to boxes 58 | all_imgs = loadmat(osp.join(self.root, "annotation", "Images.mat")) 59 | all_imgs = all_imgs["Img"].squeeze() 60 | name_to_boxes = {} 61 | name_to_pids = {} 62 | unlabeled_pid = 5555 # default pid for unlabeled people 63 | for img_name, _, boxes in all_imgs: 64 | img_name = str(img_name[0]) 65 | boxes = np.asarray([b[0] for b in boxes[0]]) 66 | boxes = boxes.reshape(boxes.shape[0], 4) # (x1, y1, w, h) 67 | valid_index = np.where((boxes[:, 2] > 0) & (boxes[:, 3] > 0))[0] 68 | assert valid_index.size > 0, "Warning: {} has no valid boxes.".format(img_name) 69 | boxes = boxes[valid_index] 70 | name_to_boxes[img_name] = boxes.astype(np.int32) 71 | name_to_pids[img_name] = unlabeled_pid * np.ones(boxes.shape[0], dtype=np.int32) 72 | 73 | def set_box_pid(boxes, box, pids, pid): 74 | for i in range(boxes.shape[0]): 75 | if np.all(boxes[i] == box): 76 | pids[i] = pid 77 | return 78 | 79 | # assign a unique pid from 1 to N for each identity 80 | if self.split == "train": 81 | train = loadmat(osp.join(self.root, "annotation/test/train_test/Train.mat")) 82 | train = train["Train"].squeeze() 83 | for index, item in enumerate(train): 84 | scenes = item[0, 0][2].squeeze() 85 | for img_name, box, _ in scenes: 86 | img_name = str(img_name[0]) 87 | box = box.squeeze().astype(np.int32) 88 | set_box_pid(name_to_boxes[img_name], box, name_to_pids[img_name], index + 1) 89 | else: 90 | protoc = loadmat(osp.join(self.root, "annotation/test/train_test/TestG50.mat")) 91 | protoc = protoc["TestG50"].squeeze() 92 | for index, item in enumerate(protoc): 93 | # query 94 | im_name = str(item["Query"][0, 0][0][0]) 95 | box = item["Query"][0, 0][1].squeeze().astype(np.int32) 96 | set_box_pid(name_to_boxes[im_name], box, name_to_pids[im_name], index + 1) 97 | # gallery 98 | gallery = item["Gallery"].squeeze() 99 | for im_name, box, _ in gallery: 100 | im_name = str(im_name[0]) 101 | if box.size == 0: 102 | break 103 | box = box.squeeze().astype(np.int32) 104 | set_box_pid(name_to_boxes[im_name], box, name_to_pids[im_name], index + 1) 105 | 106 | annotations = [] 107 | imgs = self._load_split_img_names() 108 | for img_name in imgs: 109 | boxes = name_to_boxes[img_name] 110 | boxes[:, 2:] += boxes[:, :2] # (x1, y1, w, h) -> (x1, y1, x2, y2) 111 | pids = name_to_pids[img_name] 112 | annotations.append( 113 | { 114 | "img_name": img_name, 115 | "img_path": osp.join(self.img_prefix, img_name), 116 | "boxes": boxes, 117 | "pids": pids, 118 | } 119 | ) 120 | return annotations 121 | -------------------------------------------------------------------------------- /defaults.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | _C = CN() 4 | 5 | # -------------------------------------------------------- # 6 | # Input # 7 | # -------------------------------------------------------- # 8 | _C.INPUT = CN() 9 | _C.INPUT.DATASET = "CUHK-SYSU" 10 | _C.INPUT.DATA_ROOT = "data/CUHK-SYSU" 11 | 12 | # Size of the smallest side of the image 13 | _C.INPUT.MIN_SIZE = 900 14 | # Maximum size of the side of the image 15 | _C.INPUT.MAX_SIZE = 1500 16 | 17 | # TODO: support aspect ratio grouping 18 | # Whether to use aspect ratio grouping for saving GPU memory 19 | # _C.INPUT.ASPECT_RATIO_GROUPING_TRAIN = False 20 | 21 | # Number of images per batch 22 | _C.INPUT.BATCH_SIZE_TRAIN = 2 23 | _C.INPUT.BATCH_SIZE_TEST = 1 24 | 25 | # Number of data loading threads 26 | _C.INPUT.NUM_WORKERS_TRAIN = 5 27 | _C.INPUT.NUM_WORKERS_TEST = 1 28 | 29 | # -------------------------------------------------------- # 30 | # Solver # 31 | # -------------------------------------------------------- # 32 | _C.SOLVER = CN() 33 | _C.SOLVER.MAX_EPOCHS = 20 34 | 35 | # Learning rate settings 36 | _C.SOLVER.BASE_LR = 0.003 37 | 38 | # TODO: add config option WARMUP_EPOCHS 39 | _C.SOLVER.WARMUP_FACTOR = 1.0 / 1000 40 | # _C.SOLVER.WARMUP_EPOCHS = 1 41 | 42 | # The epoch milestones to decrease the learning rate by GAMMA 43 | _C.SOLVER.LR_DECAY_MILESTONES = [16] 44 | _C.SOLVER.GAMMA = 0.1 45 | 46 | _C.SOLVER.WEIGHT_DECAY = 0.0005 47 | _C.SOLVER.SGD_MOMENTUM = 0.9 48 | 49 | # Loss weight of RPN regression 50 | _C.SOLVER.LW_RPN_REG = 1 51 | # Loss weight of RPN classification 52 | _C.SOLVER.LW_RPN_CLS = 1 53 | # Loss weight of proposal regression 54 | _C.SOLVER.LW_PROPOSAL_REG = 10 55 | # Loss weight of proposal classification 56 | _C.SOLVER.LW_PROPOSAL_CLS = 1 57 | # Loss weight of box regression 58 | _C.SOLVER.LW_BOX_REG = 1 59 | # Loss weight of box classification 60 | _C.SOLVER.LW_BOX_CLS = 1 61 | # Loss weight of box OIM (i.e. Online Instance Matching) 62 | _C.SOLVER.LW_BOX_REID = 1 63 | 64 | # Set to negative value to disable gradient clipping 65 | _C.SOLVER.CLIP_GRADIENTS = 10.0 66 | 67 | # -------------------------------------------------------- # 68 | # RPN # 69 | # -------------------------------------------------------- # 70 | _C.MODEL = CN() 71 | _C.MODEL.BONE = "swin_tiny" 72 | _C.MODEL.SEMANTIC_WEIGHT = 1.0 73 | 74 | _C.MODEL.RPN = CN() 75 | # NMS threshold used on RoIs 76 | _C.MODEL.RPN.NMS_THRESH = 0.7 77 | # Number of anchors per image used to train RPN 78 | _C.MODEL.RPN.BATCH_SIZE_TRAIN = 256 79 | # Target fraction of foreground examples per RPN minibatch 80 | _C.MODEL.RPN.POS_FRAC_TRAIN = 0.5 81 | # Overlap threshold for an anchor to be considered foreground (if >= POS_THRESH_TRAIN) 82 | _C.MODEL.RPN.POS_THRESH_TRAIN = 0.7 83 | # Overlap threshold for an anchor to be considered background (if < NEG_THRESH_TRAIN) 84 | _C.MODEL.RPN.NEG_THRESH_TRAIN = 0.3 85 | # Number of top scoring RPN RoIs to keep before applying NMS 86 | _C.MODEL.RPN.PRE_NMS_TOPN_TRAIN = 12000 87 | _C.MODEL.RPN.PRE_NMS_TOPN_TEST = 6000 88 | # Number of top scoring RPN RoIs to keep after applying NMS 89 | _C.MODEL.RPN.POST_NMS_TOPN_TRAIN = 2000 90 | _C.MODEL.RPN.POST_NMS_TOPN_TEST = 300 91 | 92 | # -------------------------------------------------------- # 93 | # RoI head # 94 | # -------------------------------------------------------- # 95 | _C.MODEL.ROI_HEAD = CN() 96 | # Whether to use bn neck (i.e. batch normalization after linear) 97 | _C.MODEL.ROI_HEAD.BN_NECK = True 98 | # Number of RoIs per image used to train RoI head 99 | _C.MODEL.ROI_HEAD.BATCH_SIZE_TRAIN = 128 100 | # Target fraction of foreground examples per RoI minibatch 101 | _C.MODEL.ROI_HEAD.POS_FRAC_TRAIN = 0.5 102 | # Overlap threshold for an RoI to be considered foreground (if >= POS_THRESH_TRAIN) 103 | _C.MODEL.ROI_HEAD.POS_THRESH_TRAIN = 0.5 104 | # Overlap threshold for an RoI to be considered background (if < NEG_THRESH_TRAIN) 105 | _C.MODEL.ROI_HEAD.NEG_THRESH_TRAIN = 0.5 106 | # Minimum score threshold 107 | _C.MODEL.ROI_HEAD.SCORE_THRESH_TEST = 0.5 108 | # NMS threshold used on boxes 109 | _C.MODEL.ROI_HEAD.NMS_THRESH_TEST = 0.4 110 | # Maximum number of detected objects 111 | _C.MODEL.ROI_HEAD.DETECTIONS_PER_IMAGE_TEST = 300 112 | 113 | # -------------------------------------------------------- # 114 | # Loss # 115 | # -------------------------------------------------------- # 116 | _C.MODEL.LOSS = CN() 117 | # Size of the lookup table in OIM 118 | _C.MODEL.LOSS.LUT_SIZE = 5532 119 | # Size of the circular queue in OIM 120 | _C.MODEL.LOSS.CQ_SIZE = 5000 121 | _C.MODEL.LOSS.OIM_MOMENTUM = 0.5 122 | _C.MODEL.LOSS.OIM_SCALAR = 30.0 123 | 124 | # -------------------------------------------------------- # 125 | # Evaluation # 126 | # -------------------------------------------------------- # 127 | # The period to evaluate the model during training 128 | _C.EVAL_PERIOD = 1 129 | # Evaluation with GT boxes to verify the upper bound of person search performance 130 | _C.EVAL_USE_GT = False 131 | # Fast evaluation with cached features 132 | _C.EVAL_USE_CACHE = False 133 | # Evaluation with Context Bipartite Graph Matching (CBGM) algorithm 134 | _C.EVAL_USE_CBGM = False 135 | 136 | # -------------------------------------------------------- # 137 | # Miscs # 138 | # -------------------------------------------------------- # 139 | # Save a checkpoint after every this number of epochs 140 | _C.CKPT_PERIOD = 1 141 | # The period (in terms of iterations) to display training losses 142 | _C.DISP_PERIOD = 10 143 | # Whether to use tensorboard for visualization 144 | _C.TF_BOARD = True 145 | # The device loading the model 146 | _C.DEVICE = "cuda" 147 | # Set seed to negative to fully randomize everything 148 | _C.SEED = 1 149 | # Directory where output files are written 150 | _C.OUTPUT_DIR = "./output" 151 | 152 | 153 | def get_default_cfg(): 154 | """ 155 | Get a copy of the default config. 156 | """ 157 | return _C.clone() 158 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | from copy import deepcopy 4 | 5 | import torch 6 | from torch.nn.utils import clip_grad_norm_ 7 | from tqdm import tqdm 8 | 9 | from eval_func import eval_detection, eval_search_cuhk, eval_search_prw 10 | from utils.utils import MetricLogger, SmoothedValue, mkdir, reduce_dict, warmup_lr_scheduler 11 | 12 | 13 | def to_device(images, targets, device): 14 | images = [image.to(device) for image in images] 15 | for t in targets: 16 | t["boxes"] = t["boxes"].to(device) 17 | t["labels"] = t["labels"].to(device) 18 | return images, targets 19 | 20 | 21 | def train_one_epoch(cfg, model, optimizer, data_loader, device, epoch, tfboard=None): 22 | model.train() 23 | metric_logger = MetricLogger(delimiter=" ") 24 | metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}")) 25 | header = "Epoch: [{}]".format(epoch) 26 | 27 | # warmup learning rate in the first epoch 28 | if epoch == 0: 29 | warmup_factor = 1.0 / 1000 30 | # FIXME: min(1000, len(data_loader) - 1) 31 | warmup_iters = len(data_loader) - 1 32 | warmup_scheduler = warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor) 33 | 34 | for i, (images, targets) in enumerate( 35 | metric_logger.log_every(data_loader, cfg.DISP_PERIOD, header) 36 | ): 37 | images, targets = to_device(images, targets, device) 38 | 39 | loss_dict = model(images, targets) 40 | losses = sum(loss for loss in loss_dict.values()) 41 | 42 | # reduce losses over all GPUs for logging purposes 43 | loss_dict_reduced = reduce_dict(loss_dict) 44 | losses_reduced = sum(loss for loss in loss_dict_reduced.values()) 45 | loss_value = losses_reduced.item() 46 | 47 | if not math.isfinite(loss_value): 48 | print(f"Loss is {loss_value}, stopping training") 49 | print(loss_dict_reduced) 50 | sys.exit(1) 51 | 52 | optimizer.zero_grad() 53 | losses.backward() 54 | if cfg.SOLVER.CLIP_GRADIENTS > 0: 55 | clip_grad_norm_(model.parameters(), cfg.SOLVER.CLIP_GRADIENTS) 56 | optimizer.step() 57 | 58 | if epoch == 0: 59 | warmup_scheduler.step() 60 | 61 | metric_logger.update(loss=loss_value, **loss_dict_reduced) 62 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 63 | if tfboard: 64 | iter = epoch * len(data_loader) + i 65 | for k, v in loss_dict_reduced.items(): 66 | tfboard.add_scalars("train", {k: v}, iter) 67 | 68 | 69 | @torch.no_grad() 70 | def evaluate_performance( 71 | model, gallery_loader, query_loader, device, use_gt=False, use_cache=False, use_cbgm=False 72 | ): 73 | """ 74 | Args: 75 | use_gt (bool, optional): Whether to use GT as detection results to verify the upper 76 | bound of person search performance. Defaults to False. 77 | use_cache (bool, optional): Whether to use the cached features. Defaults to False. 78 | use_cbgm (bool, optional): Whether to use Context Bipartite Graph Matching algorithm. 79 | Defaults to False. 80 | """ 81 | model.eval() 82 | if use_cache: 83 | eval_cache = torch.load("data/eval_cache/eval_cache.pth") 84 | gallery_dets = eval_cache["gallery_dets"] 85 | gallery_feats = eval_cache["gallery_feats"] 86 | query_dets = eval_cache["query_dets"] 87 | query_feats = eval_cache["query_feats"] 88 | query_box_feats = eval_cache["query_box_feats"] 89 | else: 90 | gallery_dets, gallery_feats = [], [] 91 | for images, targets in tqdm(gallery_loader, ncols=0): 92 | images, targets = to_device(images, targets, device) 93 | if not use_gt: 94 | outputs = model(images) 95 | else: 96 | boxes = targets[0]["boxes"] 97 | n_boxes = boxes.size(0) 98 | embeddings = model(images, targets) 99 | outputs = [ 100 | { 101 | "boxes": boxes, 102 | "embeddings": torch.cat(embeddings), 103 | "labels": torch.ones(n_boxes).to(device), 104 | "scores": torch.ones(n_boxes).to(device), 105 | } 106 | ] 107 | 108 | for output in outputs: 109 | box_w_scores = torch.cat([output["boxes"], output["scores"].unsqueeze(1)], dim=1) 110 | gallery_dets.append(box_w_scores.cpu().numpy()) 111 | gallery_feats.append(output["embeddings"].cpu().numpy()) 112 | 113 | # regarding query image as gallery to detect all people 114 | # i.e. query person + surrounding people (context information) 115 | query_dets, query_feats = [], [] 116 | for images, targets in tqdm(query_loader, ncols=0): 117 | images, targets = to_device(images, targets, device) 118 | # targets will be modified in the model, so deepcopy it 119 | outputs = model(images, deepcopy(targets), query_img_as_gallery=True) 120 | 121 | # consistency check 122 | gt_box = targets[0]["boxes"].squeeze() 123 | assert ( 124 | gt_box - outputs[0]["boxes"][0] 125 | ).sum() <= 0.001, "GT box must be the first one in the detected boxes of query image" 126 | 127 | for output in outputs: 128 | box_w_scores = torch.cat([output["boxes"], output["scores"].unsqueeze(1)], dim=1) 129 | query_dets.append(box_w_scores.cpu().numpy()) 130 | query_feats.append(output["embeddings"].cpu().numpy()) 131 | 132 | # extract the features of query boxes 133 | query_box_feats = [] 134 | for images, targets in tqdm(query_loader, ncols=0): 135 | images, targets = to_device(images, targets, device) 136 | embeddings = model(images, targets) 137 | assert len(embeddings) == 1, "batch size in test phase should be 1" 138 | query_box_feats.append(embeddings[0].cpu().numpy()) 139 | 140 | mkdir("data/eval_cache") 141 | save_dict = { 142 | "gallery_dets": gallery_dets, 143 | "gallery_feats": gallery_feats, 144 | "query_dets": query_dets, 145 | "query_feats": query_feats, 146 | "query_box_feats": query_box_feats, 147 | } 148 | torch.save(save_dict, "data/eval_cache/eval_cache.pth") 149 | 150 | eval_detection(gallery_loader.dataset, gallery_dets, det_thresh=0.01) 151 | eval_search_func = ( 152 | eval_search_cuhk if gallery_loader.dataset.name == "CUHK-SYSU" else eval_search_prw 153 | ) 154 | eval_search_func( 155 | gallery_loader.dataset, 156 | query_loader.dataset, 157 | gallery_dets, 158 | gallery_feats, 159 | query_box_feats, 160 | query_dets, 161 | query_feats, 162 | cbgm=use_cbgm, 163 | ) 164 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import errno 3 | import json 4 | import os 5 | import os.path as osp 6 | import pickle 7 | import random 8 | import time 9 | from collections import defaultdict, deque 10 | 11 | import numpy as np 12 | import torch 13 | import torch.distributed as dist 14 | from tabulate import tabulate 15 | 16 | 17 | # -------------------------------------------------------- # 18 | # Logger # 19 | # -------------------------------------------------------- # 20 | class SmoothedValue(object): 21 | """ 22 | Track a series of values and provide access to smoothed values over a 23 | window or the global series average. 24 | """ 25 | 26 | def __init__(self, window_size=20, fmt=None): 27 | if fmt is None: 28 | fmt = "{median:.4f} ({global_avg:.4f})" 29 | self.deque = deque(maxlen=window_size) 30 | self.total = 0.0 31 | self.count = 0 32 | self.fmt = fmt 33 | 34 | def update(self, value, n=1): 35 | self.deque.append(value) 36 | self.count += n 37 | self.total += value * n 38 | 39 | def synchronize_between_processes(self): 40 | """ 41 | Warning: does not synchronize the deque! 42 | """ 43 | if not is_dist_avail_and_initialized(): 44 | return 45 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 46 | dist.barrier() 47 | dist.all_reduce(t) 48 | t = t.tolist() 49 | self.count = int(t[0]) 50 | self.total = t[1] 51 | 52 | @property 53 | def median(self): 54 | d = torch.tensor(list(self.deque)) 55 | return d.median().item() 56 | 57 | @property 58 | def avg(self): 59 | d = torch.tensor(list(self.deque), dtype=torch.float32) 60 | return d.mean().item() 61 | 62 | @property 63 | def global_avg(self): 64 | return self.total / self.count 65 | 66 | @property 67 | def max(self): 68 | return max(self.deque) 69 | 70 | @property 71 | def value(self): 72 | return self.deque[-1] 73 | 74 | def __str__(self): 75 | return self.fmt.format( 76 | median=self.median, 77 | avg=self.avg, 78 | global_avg=self.global_avg, 79 | max=self.max, 80 | value=self.value, 81 | ) 82 | 83 | 84 | class MetricLogger(object): 85 | def __init__(self, delimiter="\t"): 86 | self.meters = defaultdict(SmoothedValue) 87 | self.delimiter = delimiter 88 | 89 | def update(self, **kwargs): 90 | for k, v in kwargs.items(): 91 | if isinstance(v, torch.Tensor): 92 | v = v.item() 93 | assert isinstance(v, (float, int)) 94 | self.meters[k].update(v) 95 | 96 | def __getattr__(self, attr): 97 | if attr in self.meters: 98 | return self.meters[attr] 99 | if attr in self.__dict__: 100 | return self.__dict__[attr] 101 | raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) 102 | 103 | def __str__(self): 104 | loss_str = [] 105 | for name, meter in self.meters.items(): 106 | loss_str.append("{}: {}".format(name, str(meter))) 107 | return self.delimiter.join(loss_str) 108 | 109 | def synchronize_between_processes(self): 110 | for meter in self.meters.values(): 111 | meter.synchronize_between_processes() 112 | 113 | def add_meter(self, name, meter): 114 | self.meters[name] = meter 115 | 116 | def log_every(self, iterable, print_freq, header=None): 117 | i = 0 118 | if not header: 119 | header = "" 120 | start_time = time.time() 121 | end = time.time() 122 | iter_time = SmoothedValue(fmt="{avg:.4f}") 123 | data_time = SmoothedValue(fmt="{avg:.4f}") 124 | space_fmt = ":" + str(len(str(len(iterable)))) + "d" 125 | if torch.cuda.is_available(): 126 | log_msg = self.delimiter.join( 127 | [ 128 | header, 129 | "[{0" + space_fmt + "}/{1}]", 130 | "eta: {eta}", 131 | "{meters}", 132 | "time: {time}", 133 | "data: {data}", 134 | "max mem: {memory:.0f}", 135 | ] 136 | ) 137 | else: 138 | log_msg = self.delimiter.join( 139 | [ 140 | header, 141 | "[{0" + space_fmt + "}/{1}]", 142 | "eta: {eta}", 143 | "{meters}", 144 | "time: {time}", 145 | "data: {data}", 146 | ] 147 | ) 148 | MB = 1024.0 * 1024.0 149 | for obj in iterable: 150 | data_time.update(time.time() - end) 151 | yield obj 152 | iter_time.update(time.time() - end) 153 | if i % print_freq == 0 or i == len(iterable) - 1: 154 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 155 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 156 | if torch.cuda.is_available(): 157 | print( 158 | log_msg.format( 159 | i, 160 | len(iterable), 161 | eta=eta_string, 162 | meters=str(self), 163 | time=str(iter_time), 164 | data=str(data_time), 165 | memory=torch.cuda.max_memory_allocated() / MB, 166 | ) 167 | ) 168 | else: 169 | print( 170 | log_msg.format( 171 | i, 172 | len(iterable), 173 | eta=eta_string, 174 | meters=str(self), 175 | time=str(iter_time), 176 | data=str(data_time), 177 | ) 178 | ) 179 | i += 1 180 | end = time.time() 181 | total_time = time.time() - start_time 182 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 183 | print( 184 | "{} Total time: {} ({:.4f} s / it)".format( 185 | header, total_time_str, total_time / len(iterable) 186 | ) 187 | ) 188 | 189 | 190 | # -------------------------------------------------------- # 191 | # Distributed training # 192 | # -------------------------------------------------------- # 193 | def all_gather(data): 194 | """ 195 | Run all_gather on arbitrary picklable data (not necessarily tensors) 196 | 197 | Args: 198 | data: any picklable object 199 | 200 | Returns: 201 | list[data]: list of data gathered from each rank 202 | """ 203 | world_size = get_world_size() 204 | if world_size == 1: 205 | return [data] 206 | 207 | # serialized to a Tensor 208 | buffer = pickle.dumps(data) 209 | storage = torch.ByteStorage.from_buffer(buffer) 210 | tensor = torch.ByteTensor(storage).to("cuda") 211 | 212 | # obtain Tensor size of each rank 213 | local_size = torch.tensor([tensor.numel()], device="cuda") 214 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] 215 | dist.all_gather(size_list, local_size) 216 | size_list = [int(size.item()) for size in size_list] 217 | max_size = max(size_list) 218 | 219 | # receiving Tensor from all ranks 220 | # we pad the tensor because torch all_gather does not support 221 | # gathering tensors of different shapes 222 | tensor_list = [] 223 | for _ in size_list: 224 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) 225 | if local_size != max_size: 226 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") 227 | tensor = torch.cat((tensor, padding), dim=0) 228 | dist.all_gather(tensor_list, tensor) 229 | 230 | data_list = [] 231 | for size, tensor in zip(size_list, tensor_list): 232 | buffer = tensor.cpu().numpy().tobytes()[:size] 233 | data_list.append(pickle.loads(buffer)) 234 | 235 | return data_list 236 | 237 | 238 | def reduce_dict(input_dict, average=True): 239 | """ 240 | Reduce the values in the dictionary from all processes so that all processes 241 | have the averaged results. Returns a dict with the same fields as 242 | input_dict, after reduction. 243 | 244 | Args: 245 | input_dict (dict): all the values will be reduced 246 | average (bool): whether to do average or sum 247 | """ 248 | world_size = get_world_size() 249 | if world_size < 2: 250 | return input_dict 251 | with torch.no_grad(): 252 | names = [] 253 | values = [] 254 | # sort the keys so that they are consistent across processes 255 | for k in sorted(input_dict.keys()): 256 | names.append(k) 257 | values.append(input_dict[k]) 258 | values = torch.stack(values, dim=0) 259 | dist.all_reduce(values) 260 | if average: 261 | values /= world_size 262 | reduced_dict = {k: v for k, v in zip(names, values)} 263 | return reduced_dict 264 | 265 | 266 | def setup_for_distributed(is_master): 267 | """ 268 | This function disables printing when not in master process 269 | """ 270 | import builtins as __builtin__ 271 | 272 | builtin_print = __builtin__.print 273 | 274 | def print(*args, **kwargs): 275 | force = kwargs.pop("force", False) 276 | if is_master or force: 277 | builtin_print(*args, **kwargs) 278 | 279 | __builtin__.print = print 280 | 281 | 282 | def is_dist_avail_and_initialized(): 283 | if not dist.is_available(): 284 | return False 285 | if not dist.is_initialized(): 286 | return False 287 | return True 288 | 289 | 290 | def get_world_size(): 291 | if not is_dist_avail_and_initialized(): 292 | return 1 293 | return dist.get_world_size() 294 | 295 | 296 | def get_rank(): 297 | if not is_dist_avail_and_initialized(): 298 | return 0 299 | return dist.get_rank() 300 | 301 | 302 | def is_main_process(): 303 | return get_rank() == 0 304 | 305 | 306 | def save_on_master(*args, **kwargs): 307 | if is_main_process(): 308 | torch.save(*args, **kwargs) 309 | 310 | 311 | def init_distributed_mode(args): 312 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: 313 | args.rank = int(os.environ["RANK"]) 314 | args.world_size = int(os.environ["WORLD_SIZE"]) 315 | args.gpu = int(os.environ["LOCAL_RANK"]) 316 | elif "SLURM_PROCID" in os.environ: 317 | args.rank = int(os.environ["SLURM_PROCID"]) 318 | args.gpu = args.rank % torch.cuda.device_count() 319 | else: 320 | print("Not using distributed mode") 321 | args.distributed = False 322 | return 323 | 324 | args.distributed = True 325 | 326 | torch.cuda.set_device(args.gpu) 327 | args.dist_backend = "nccl" 328 | print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) 329 | torch.distributed.init_process_group( 330 | backend=args.dist_backend, 331 | init_method=args.dist_url, 332 | world_size=args.world_size, 333 | rank=args.rank, 334 | ) 335 | torch.distributed.barrier() 336 | setup_for_distributed(args.rank == 0) 337 | 338 | 339 | # -------------------------------------------------------- # 340 | # File operation # 341 | # -------------------------------------------------------- # 342 | def filename(path): 343 | return osp.splitext(osp.basename(path))[0] 344 | 345 | 346 | def mkdir(path): 347 | try: 348 | os.makedirs(path) 349 | except OSError as e: 350 | if e.errno != errno.EEXIST: 351 | raise 352 | 353 | 354 | def read_json(fpath): 355 | with open(fpath, "r") as f: 356 | obj = json.load(f) 357 | return obj 358 | 359 | 360 | def write_json(obj, fpath): 361 | mkdir(osp.dirname(fpath)) 362 | _obj = obj.copy() 363 | for k, v in _obj.items(): 364 | if isinstance(v, np.ndarray): 365 | _obj.pop(k) 366 | with open(fpath, "w") as f: 367 | json.dump(_obj, f, indent=4, separators=(",", ": ")) 368 | 369 | 370 | def symlink(src, dst, overwrite=True, **kwargs): 371 | if os.path.lexists(dst) and overwrite: 372 | os.remove(dst) 373 | os.symlink(src, dst, **kwargs) 374 | 375 | 376 | # -------------------------------------------------------- # 377 | # Misc # 378 | # -------------------------------------------------------- # 379 | def create_small_table(small_dict): 380 | """ 381 | Create a small table using the keys of small_dict as headers. This is only 382 | suitable for small dictionaries. 383 | 384 | Args: 385 | small_dict (dict): a result dictionary of only a few items. 386 | 387 | Returns: 388 | str: the table as a string. 389 | """ 390 | keys, values = tuple(zip(*small_dict.items())) 391 | table = tabulate( 392 | [values], 393 | headers=keys, 394 | tablefmt="pipe", 395 | floatfmt=".3f", 396 | stralign="center", 397 | numalign="center", 398 | ) 399 | return table 400 | 401 | 402 | def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor): 403 | def f(x): 404 | if x >= warmup_iters: 405 | return 1 406 | alpha = float(x) / warmup_iters 407 | return warmup_factor * (1 - alpha) + alpha 408 | 409 | return torch.optim.lr_scheduler.LambdaLR(optimizer, f) 410 | 411 | def resume_from_ckpt(ckpt_path, model, optimizer=None, lr_scheduler=None): 412 | ckpt = torch.load(ckpt_path) 413 | if 'state_dict' in ckpt.keys(): 414 | ckpt = ckpt['state_dict'] 415 | 416 | count = 0 417 | miss = [] 418 | for i in ckpt: 419 | if 'backbone' in i: 420 | model.state_dict()[i.replace('backbone.','backbone.swin.')].copy_(ckpt[i]) 421 | model.state_dict()[i.replace('backbone.','roi_heads.reid_head.swin.')].copy_(ckpt[i]) 422 | count += 1 423 | else: 424 | miss.append(i) 425 | print('%d loaded, %d missed:' %(count,len(miss)),miss) 426 | return 0 427 | 428 | def set_random_seed(seed): 429 | torch.manual_seed(seed) 430 | torch.cuda.manual_seed(seed) 431 | torch.cuda.manual_seed_all(seed) 432 | torch.backends.cudnn.benchmark = False 433 | torch.backends.cudnn.deterministic = True 434 | random.seed(seed) 435 | np.random.seed(seed) 436 | os.environ["PYTHONHASHSEED"] = str(seed) 437 | -------------------------------------------------------------------------------- /models/backbone.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision.models.utils import load_state_dict_from_url 3 | 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 7 | 'wide_resnet50_2', 'wide_resnet101_2'] 8 | 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 17 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 18 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 19 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 20 | } 21 | 22 | 23 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 24 | """3x3 convolution with padding""" 25 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 26 | padding=dilation, groups=groups, bias=False, dilation=dilation) 27 | 28 | 29 | def conv1x1(in_planes, out_planes, stride=1): 30 | """1x1 convolution""" 31 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 32 | 33 | 34 | class BasicBlock(nn.Module): 35 | expansion = 1 36 | 37 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 38 | base_width=64, dilation=1, norm_layer=None): 39 | super(BasicBlock, self).__init__() 40 | if norm_layer is None: 41 | norm_layer = nn.BatchNorm2d 42 | if groups != 1 or base_width != 64: 43 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 44 | if dilation > 1: 45 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 46 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 47 | self.conv1 = conv3x3(inplanes, planes, stride) 48 | self.bn1 = norm_layer(planes) 49 | self.relu = nn.ReLU(inplace=True) 50 | self.conv2 = conv3x3(planes, planes) 51 | self.bn2 = norm_layer(planes) 52 | self.downsample = downsample 53 | self.stride = stride 54 | 55 | def forward(self, x): 56 | identity = x 57 | 58 | out = self.conv1(x) 59 | out = self.bn1(out) 60 | out = self.relu(out) 61 | 62 | out = self.conv2(out) 63 | out = self.bn2(out) 64 | 65 | if self.downsample is not None: 66 | identity = self.downsample(x) 67 | 68 | out += identity 69 | out = self.relu(out) 70 | 71 | return out 72 | 73 | 74 | class Bottleneck(nn.Module): 75 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 76 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 77 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 78 | # This variant is also known as ResNet V1.5 and improves accuracy according to 79 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 80 | 81 | expansion = 4 82 | 83 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 84 | base_width=64, dilation=1, norm_layer=None): 85 | super(Bottleneck, self).__init__() 86 | if norm_layer is None: 87 | norm_layer = nn.BatchNorm2d 88 | width = int(planes * (base_width / 64.)) * groups 89 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 90 | self.conv1 = conv1x1(inplanes, width) 91 | self.bn1 = norm_layer(width) 92 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 93 | self.bn2 = norm_layer(width) 94 | self.conv3 = conv1x1(width, planes * self.expansion) 95 | self.bn3 = norm_layer(planes * self.expansion) 96 | self.relu = nn.ReLU(inplace=True) 97 | self.downsample = downsample 98 | self.stride = stride 99 | 100 | def forward(self, x): 101 | identity = x 102 | 103 | out = self.conv1(x) 104 | out = self.bn1(out) 105 | out = self.relu(out) 106 | 107 | out = self.conv2(out) 108 | out = self.bn2(out) 109 | out = self.relu(out) 110 | 111 | out = self.conv3(out) 112 | out = self.bn3(out) 113 | 114 | if self.downsample is not None: 115 | identity = self.downsample(x) 116 | 117 | out += identity 118 | out = self.relu(out) 119 | 120 | return out 121 | 122 | 123 | class ResNet(nn.Module): 124 | 125 | def __init__(self, block, layers, zero_init_residual=False, 126 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 127 | norm_layer=None): 128 | super(ResNet, self).__init__() 129 | if norm_layer is None: 130 | norm_layer = nn.BatchNorm2d 131 | self._norm_layer = norm_layer 132 | 133 | self.inplanes = 64 134 | self.dilation = 1 135 | if replace_stride_with_dilation is None: 136 | # each element in the tuple indicates if we should replace 137 | # the 2x2 stride with a dilated convolution instead 138 | replace_stride_with_dilation = [False, False, False] 139 | if len(replace_stride_with_dilation) != 3: 140 | raise ValueError("replace_stride_with_dilation should be None " 141 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 142 | self.groups = groups 143 | self.base_width = width_per_group 144 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 145 | bias=False) 146 | self.bn1 = norm_layer(self.inplanes) 147 | self.relu = nn.ReLU(inplace=True) 148 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 149 | self.layer1 = self._make_layer(block, 64, layers[0]) 150 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 151 | dilate=replace_stride_with_dilation[0]) 152 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 153 | dilate=replace_stride_with_dilation[1]) 154 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 155 | dilate=replace_stride_with_dilation[2]) 156 | # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 157 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 158 | 159 | for m in self.modules(): 160 | if isinstance(m, nn.Conv2d): 161 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 162 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 163 | nn.init.constant_(m.weight, 1) 164 | nn.init.constant_(m.bias, 0) 165 | 166 | # Zero-initialize the last BN in each residual branch, 167 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 168 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 169 | if zero_init_residual: 170 | for m in self.modules(): 171 | if isinstance(m, Bottleneck): 172 | nn.init.constant_(m.bn3.weight, 0) 173 | elif isinstance(m, BasicBlock): 174 | nn.init.constant_(m.bn2.weight, 0) 175 | 176 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 177 | norm_layer = self._norm_layer 178 | downsample = None 179 | previous_dilation = self.dilation 180 | if dilate: 181 | self.dilation *= stride 182 | stride = 1 183 | if stride != 1 or self.inplanes != planes * block.expansion: 184 | downsample = nn.Sequential( 185 | conv1x1(self.inplanes, planes * block.expansion, stride), 186 | norm_layer(planes * block.expansion), 187 | ) 188 | 189 | layers = [] 190 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 191 | self.base_width, previous_dilation, norm_layer)) 192 | self.inplanes = planes * block.expansion 193 | for _ in range(1, blocks): 194 | layers.append(block(self.inplanes, planes, groups=self.groups, 195 | base_width=self.base_width, dilation=self.dilation, 196 | norm_layer=norm_layer)) 197 | 198 | return nn.Sequential(*layers) 199 | 200 | def _forward_impl(self, x): 201 | outputs = {} 202 | # See note [TorchScript super()] 203 | x = self.conv1(x) 204 | x = self.bn1(x) 205 | x = self.relu(x) 206 | x = self.maxpool(x) 207 | #outputs['stem'] = x 208 | 209 | x = self.layer1(x) # 1/4 210 | outputs['res2'] = x 211 | 212 | x = self.layer2(x) # 1/8 213 | outputs['res3'] = x 214 | 215 | x = self.layer3(x) # 1/16 216 | outputs['res4'] = x 217 | 218 | x = self.layer4(x) # 1/32 219 | outputs['res5'] = x 220 | 221 | return outputs['res5'] 222 | 223 | def forward(self, x): 224 | return self._forward_impl(x) 225 | 226 | 227 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 228 | model = ResNet(block, layers, **kwargs) 229 | if pretrained: 230 | state_dict = load_state_dict_from_url(model_urls[arch], 231 | progress=progress) 232 | model.load_state_dict(state_dict, strict=False) 233 | return model 234 | 235 | 236 | def resnet18(pretrained=False, progress=True, **kwargs): 237 | r"""ResNet-18 model from 238 | `"Deep Residual Learning for Image Recognition" `_ 239 | Args: 240 | pretrained (bool): If True, returns a model pre-trained on ImageNet 241 | progress (bool): If True, displays a progress bar of the download to stderr 242 | """ 243 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 244 | **kwargs) 245 | 246 | 247 | def resnet34(pretrained=False, progress=True, **kwargs): 248 | r"""ResNet-34 model from 249 | `"Deep Residual Learning for Image Recognition" `_ 250 | Args: 251 | pretrained (bool): If True, returns a model pre-trained on ImageNet 252 | progress (bool): If True, displays a progress bar of the download to stderr 253 | """ 254 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 255 | **kwargs) 256 | 257 | 258 | def resnet50(pretrained=False, progress=True, **kwargs): 259 | r"""ResNet-50 model from 260 | `"Deep Residual Learning for Image Recognition" `_ 261 | Args: 262 | pretrained (bool): If True, returns a model pre-trained on ImageNet 263 | progress (bool): If True, displays a progress bar of the download to stderr 264 | """ 265 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 266 | **kwargs) 267 | 268 | 269 | def resnet101(pretrained=False, progress=True, **kwargs): 270 | r"""ResNet-101 model from 271 | `"Deep Residual Learning for Image Recognition" `_ 272 | Args: 273 | pretrained (bool): If True, returns a model pre-trained on ImageNet 274 | progress (bool): If True, displays a progress bar of the download to stderr 275 | """ 276 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 277 | **kwargs) 278 | 279 | 280 | def resnet152(pretrained=False, progress=True, **kwargs): 281 | r"""ResNet-152 model from 282 | `"Deep Residual Learning for Image Recognition" `_ 283 | Args: 284 | pretrained (bool): If True, returns a model pre-trained on ImageNet 285 | progress (bool): If True, displays a progress bar of the download to stderr 286 | """ 287 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 288 | **kwargs) 289 | 290 | 291 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 292 | r"""ResNeXt-50 32x4d model from 293 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 294 | Args: 295 | pretrained (bool): If True, returns a model pre-trained on ImageNet 296 | progress (bool): If True, displays a progress bar of the download to stderr 297 | """ 298 | kwargs['groups'] = 32 299 | kwargs['width_per_group'] = 4 300 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 301 | pretrained, progress, **kwargs) 302 | 303 | 304 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 305 | r"""ResNeXt-101 32x8d model from 306 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 307 | Args: 308 | pretrained (bool): If True, returns a model pre-trained on ImageNet 309 | progress (bool): If True, displays a progress bar of the download to stderr 310 | """ 311 | kwargs['groups'] = 32 312 | kwargs['width_per_group'] = 8 313 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 314 | pretrained, progress, **kwargs) 315 | 316 | 317 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 318 | r"""Wide ResNet-50-2 model from 319 | `"Wide Residual Networks" `_ 320 | The model is the same as ResNet except for the bottleneck number of channels 321 | which is twice larger in every block. The number of channels in outer 1x1 322 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 323 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 324 | Args: 325 | pretrained (bool): If True, returns a model pre-trained on ImageNet 326 | progress (bool): If True, displays a progress bar of the download to stderr 327 | """ 328 | kwargs['width_per_group'] = 64 * 2 329 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 330 | pretrained, progress, **kwargs) 331 | 332 | 333 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 334 | r"""Wide ResNet-101-2 model from 335 | `"Wide Residual Networks" `_ 336 | The model is the same as ResNet except for the bottleneck number of channels 337 | which is twice larger in every block. The number of channels in outer 1x1 338 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 339 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 340 | Args: 341 | pretrained (bool): If True, returns a model pre-trained on ImageNet 342 | progress (bool): If True, displays a progress bar of the download to stderr 343 | """ 344 | kwargs['width_per_group'] = 64 * 2 345 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 346 | pretrained, progress, **kwargs) 347 | -------------------------------------------------------------------------------- /eval_func.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import numpy as np 4 | from scipy.io import loadmat 5 | from sklearn.metrics import average_precision_score 6 | 7 | from utils.km import run_kuhn_munkres 8 | from utils.utils import write_json 9 | 10 | 11 | def _compute_iou(a, b): 12 | x1 = max(a[0], b[0]) 13 | y1 = max(a[1], b[1]) 14 | x2 = min(a[2], b[2]) 15 | y2 = min(a[3], b[3]) 16 | inter = max(0, x2 - x1) * max(0, y2 - y1) 17 | union = (a[2] - a[0]) * (a[3] - a[1]) + (b[2] - b[0]) * (b[3] - b[1]) - inter 18 | return inter * 1.0 / union 19 | 20 | 21 | def eval_detection( 22 | gallery_dataset, gallery_dets, det_thresh=0.5, iou_thresh=0.5, labeled_only=False 23 | ): 24 | """ 25 | gallery_det (list of ndarray): n_det x [x1, y1, x2, y2, score] per image 26 | det_thresh (float): filter out gallery detections whose scores below this 27 | iou_thresh (float): treat as true positive if IoU is above this threshold 28 | labeled_only (bool): filter out unlabeled background people 29 | """ 30 | assert len(gallery_dataset) == len(gallery_dets) 31 | annos = gallery_dataset.annotations 32 | 33 | y_true, y_score = [], [] 34 | count_gt, count_tp = 0, 0 35 | for anno, det in zip(annos, gallery_dets): 36 | gt_boxes = anno["boxes"] 37 | if labeled_only: 38 | # exclude the unlabeled people (pid == 5555) 39 | inds = np.where(anno["pids"].ravel() != 5555)[0] 40 | if len(inds) == 0: 41 | continue 42 | gt_boxes = gt_boxes[inds] 43 | num_gt = gt_boxes.shape[0] 44 | 45 | if det != []: 46 | det = np.asarray(det) 47 | inds = np.where(det[:, 4].ravel() >= det_thresh)[0] 48 | det = det[inds] 49 | num_det = det.shape[0] 50 | else: 51 | num_det = 0 52 | if num_det == 0: 53 | count_gt += num_gt 54 | continue 55 | 56 | ious = np.zeros((num_gt, num_det), dtype=np.float32) 57 | for i in range(num_gt): 58 | for j in range(num_det): 59 | ious[i, j] = _compute_iou(gt_boxes[i], det[j, :4]) 60 | tfmat = ious >= iou_thresh 61 | # for each det, keep only the largest iou of all the gt 62 | for j in range(num_det): 63 | largest_ind = np.argmax(ious[:, j]) 64 | for i in range(num_gt): 65 | if i != largest_ind: 66 | tfmat[i, j] = False 67 | # for each gt, keep only the largest iou of all the det 68 | for i in range(num_gt): 69 | largest_ind = np.argmax(ious[i, :]) 70 | for j in range(num_det): 71 | if j != largest_ind: 72 | tfmat[i, j] = False 73 | for j in range(num_det): 74 | y_score.append(det[j, -1]) 75 | y_true.append(tfmat[:, j].any()) 76 | count_tp += tfmat.sum() 77 | count_gt += num_gt 78 | 79 | det_rate = count_tp * 1.0 / count_gt 80 | ap = average_precision_score(y_true, y_score) * det_rate 81 | 82 | print("{} detection:".format("labeled only" if labeled_only else "all")) 83 | print(" recall = {:.2%}".format(det_rate)) 84 | if not labeled_only: 85 | print(" ap = {:.2%}".format(ap)) 86 | return det_rate, ap 87 | 88 | 89 | def eval_search_cuhk( 90 | gallery_dataset, 91 | query_dataset, 92 | gallery_dets, 93 | gallery_feats, 94 | query_box_feats, 95 | query_dets, 96 | query_feats, 97 | k1=10, 98 | k2=3, 99 | det_thresh=0.5, 100 | cbgm=False, 101 | gallery_size=100, 102 | ): 103 | """ 104 | gallery_dataset/query_dataset: an instance of BaseDataset 105 | gallery_det (list of ndarray): n_det x [x1, x2, y1, y2, score] per image 106 | gallery_feat (list of ndarray): n_det x D features per image 107 | query_feat (list of ndarray): D dimensional features per query image 108 | det_thresh (float): filter out gallery detections whose scores below this 109 | gallery_size (int): gallery size [-1, 50, 100, 500, 1000, 2000, 4000] 110 | -1 for using full set 111 | """ 112 | assert len(gallery_dataset) == len(gallery_dets) 113 | assert len(gallery_dataset) == len(gallery_feats) 114 | assert len(query_dataset) == len(query_box_feats) 115 | 116 | use_full_set = gallery_size == -1 117 | fname = "TestG{}".format(gallery_size if not use_full_set else 50) 118 | protoc = loadmat(osp.join(gallery_dataset.root, "annotation/test/train_test", fname + ".mat")) 119 | protoc = protoc[fname].squeeze() 120 | 121 | # mapping from gallery image to (det, feat) 122 | annos = gallery_dataset.annotations 123 | name_to_det_feat = {} 124 | for anno, det, feat in zip(annos, gallery_dets, gallery_feats): 125 | name = anno["img_name"] 126 | if len(det) != 0: 127 | scores = det[:, 4].ravel() 128 | inds = np.where(scores >= det_thresh)[0] 129 | if len(inds) > 0: 130 | name_to_det_feat[name] = (det[inds], feat[inds]) 131 | 132 | aps = [] 133 | accs = [] 134 | topk = [1, 5, 10] 135 | ret = {"image_root": gallery_dataset.img_prefix, "results": []} 136 | for i in range(len(query_dataset)): 137 | y_true, y_score = [], [] 138 | imgs, rois = [], [] 139 | count_gt, count_tp = 0, 0 140 | # get L2-normalized feature vector 141 | feat_q = query_box_feats[i].ravel() 142 | # ignore the query image 143 | query_imname = str(protoc["Query"][i]["imname"][0, 0][0]) 144 | query_roi = protoc["Query"][i]["idlocate"][0, 0][0].astype(np.int32) 145 | query_roi[2:] += query_roi[:2] 146 | query_gt = [] 147 | tested = set([query_imname]) 148 | 149 | name2sim = {} 150 | name2gt = {} 151 | sims = [] 152 | imgs_cbgm = [] 153 | # 1. Go through the gallery samples defined by the protocol 154 | for item in protoc["Gallery"][i].squeeze(): 155 | gallery_imname = str(item[0][0]) 156 | # some contain the query (gt not empty), some not 157 | gt = item[1][0].astype(np.int32) 158 | count_gt += gt.size > 0 159 | # compute distance between query and gallery dets 160 | if gallery_imname not in name_to_det_feat: 161 | continue 162 | det, feat_g = name_to_det_feat[gallery_imname] 163 | # no detection in this gallery, skip it 164 | if det.shape[0] == 0: 165 | continue 166 | # get L2-normalized feature matrix NxD 167 | assert feat_g.size == np.prod(feat_g.shape[:2]) 168 | feat_g = feat_g.reshape(feat_g.shape[:2]) 169 | # compute cosine similarities 170 | sim = feat_g.dot(feat_q).ravel() 171 | 172 | if gallery_imname in name2sim: 173 | continue 174 | name2sim[gallery_imname] = sim 175 | name2gt[gallery_imname] = gt 176 | sims.extend(list(sim)) 177 | imgs_cbgm.extend([gallery_imname] * len(sim)) 178 | # 2. Go through the remaining gallery images if using full set 179 | if use_full_set: 180 | # TODO: support CBGM when using full set 181 | for gallery_imname in gallery_dataset.imgs: 182 | if gallery_imname in tested: 183 | continue 184 | if gallery_imname not in name_to_det_feat: 185 | continue 186 | det, feat_g = name_to_det_feat[gallery_imname] 187 | # get L2-normalized feature matrix NxD 188 | assert feat_g.size == np.prod(feat_g.shape[:2]) 189 | feat_g = feat_g.reshape(feat_g.shape[:2]) 190 | # compute cosine similarities 191 | sim = feat_g.dot(feat_q).ravel() 192 | # guaranteed no target query in these gallery images 193 | label = np.zeros(len(sim), dtype=np.int32) 194 | y_true.extend(list(label)) 195 | y_score.extend(list(sim)) 196 | imgs.extend([gallery_imname] * len(sim)) 197 | rois.extend(list(det)) 198 | 199 | if cbgm: 200 | # -------- Context Bipartite Graph Matching (CBGM) ------- # 201 | sims = np.array(sims) 202 | imgs_cbgm = np.array(imgs_cbgm) 203 | # only process the top-k1 gallery images for efficiency 204 | inds = np.argsort(sims)[-k1:] 205 | imgs_cbgm = set(imgs_cbgm[inds]) 206 | for img in imgs_cbgm: 207 | sim = name2sim[img] 208 | det, feat_g = name_to_det_feat[img] 209 | # only regard the people with top-k2 detection confidence 210 | # in the query image as context information 211 | qboxes = query_dets[i][:k2] 212 | qfeats = query_feats[i][:k2] 213 | assert ( 214 | query_roi - qboxes[0][:4] 215 | ).sum() <= 0.001, "query_roi must be the first one in pboxes" 216 | 217 | # build the bipartite graph and run Kuhn-Munkres (K-M) algorithm 218 | # to find the best match 219 | graph = [] 220 | for indx_i, pfeat in enumerate(qfeats): 221 | for indx_j, gfeat in enumerate(feat_g): 222 | graph.append((indx_i, indx_j, (pfeat * gfeat).sum())) 223 | km_res, max_val = run_kuhn_munkres(graph) 224 | 225 | # revise the similarity between query person and its matching 226 | for indx_i, indx_j, _ in km_res: 227 | # 0 denotes the query roi 228 | if indx_i == 0: 229 | sim[indx_j] = max_val 230 | break 231 | for gallery_imname, sim in name2sim.items(): 232 | gt = name2gt[gallery_imname] 233 | det, feat_g = name_to_det_feat[gallery_imname] 234 | # assign label for each det 235 | label = np.zeros(len(sim), dtype=np.int32) 236 | if gt.size > 0: 237 | w, h = gt[2], gt[3] 238 | gt[2:] += gt[:2] 239 | query_gt.append({"img": str(gallery_imname), "roi": list(map(float, list(gt)))}) 240 | iou_thresh = min(0.5, (w * h * 1.0) / ((w + 10) * (h + 10))) 241 | inds = np.argsort(sim)[::-1] 242 | sim = sim[inds] 243 | det = det[inds] 244 | # only set the first matched det as true positive 245 | for j, roi in enumerate(det[:, :4]): 246 | if _compute_iou(roi, gt) >= iou_thresh: 247 | label[j] = 1 248 | count_tp += 1 249 | break 250 | y_true.extend(list(label)) 251 | y_score.extend(list(sim)) 252 | imgs.extend([gallery_imname] * len(sim)) 253 | rois.extend(list(det)) 254 | tested.add(gallery_imname) 255 | # 3. Compute AP for this query (need to scale by recall rate) 256 | y_score = np.asarray(y_score) 257 | y_true = np.asarray(y_true) 258 | assert count_tp <= count_gt 259 | recall_rate = count_tp * 1.0 / count_gt 260 | ap = 0 if count_tp == 0 else average_precision_score(y_true, y_score) * recall_rate 261 | aps.append(ap) 262 | inds = np.argsort(y_score)[::-1] 263 | y_score = y_score[inds] 264 | y_true = y_true[inds] 265 | accs.append([min(1, sum(y_true[:k])) for k in topk]) 266 | # 4. Save result for JSON dump 267 | new_entry = { 268 | "query_img": str(query_imname), 269 | "query_roi": list(map(float, list(query_roi))), 270 | "query_gt": query_gt, 271 | "gallery": [], 272 | } 273 | # only record wrong results 274 | if int(y_true[0]): 275 | continue 276 | # only save top-10 predictions 277 | for k in range(10): 278 | new_entry["gallery"].append( 279 | { 280 | "img": str(imgs[inds[k]]), 281 | "roi": list(map(float, list(rois[inds[k]]))), 282 | "score": float(y_score[k]), 283 | "correct": int(y_true[k]), 284 | } 285 | ) 286 | ret["results"].append(new_entry) 287 | 288 | print("search ranking:") 289 | print(" mAP = {:.2%}".format(np.mean(aps))) 290 | accs = np.mean(accs, axis=0) 291 | for i, k in enumerate(topk): 292 | print(" top-{:2d} = {:.2%}".format(k, accs[i])) 293 | 294 | write_json(ret, "vis/results.json") 295 | 296 | ret["mAP"] = np.mean(aps) 297 | ret["accs"] = accs 298 | return ret 299 | 300 | 301 | def eval_search_prw( 302 | gallery_dataset, 303 | query_dataset, 304 | gallery_dets, 305 | gallery_feats, 306 | query_box_feats, 307 | query_dets, 308 | query_feats, 309 | k1=30, 310 | k2=4, 311 | det_thresh=0.5, 312 | cbgm=False, 313 | ignore_cam_id=True, 314 | ): 315 | """ 316 | gallery_det (list of ndarray): n_det x [x1, x2, y1, y2, score] per image 317 | gallery_feat (list of ndarray): n_det x D features per image 318 | query_feat (list of ndarray): D dimensional features per query image 319 | det_thresh (float): filter out gallery detections whose scores below this 320 | gallery_size (int): -1 for using full set 321 | ignore_cam_id (bool): Set to True acoording to CUHK-SYSU, 322 | although it's a common practice to focus on cross-cam match only. 323 | """ 324 | assert len(gallery_dataset) == len(gallery_dets) 325 | assert len(gallery_dataset) == len(gallery_feats) 326 | assert len(query_dataset) == len(query_box_feats) 327 | 328 | annos = gallery_dataset.annotations 329 | name_to_det_feat = {} 330 | for anno, det, feat in zip(annos, gallery_dets, gallery_feats): 331 | name = anno["img_name"] 332 | scores = det[:, 4].ravel() 333 | inds = np.where(scores >= det_thresh)[0] 334 | if len(inds) > 0: 335 | name_to_det_feat[name] = (det[inds], feat[inds]) 336 | 337 | aps = [] 338 | accs = [] 339 | topk = [1, 5, 10] 340 | ret = {"image_root": gallery_dataset.img_prefix, "results": []} 341 | for i in range(len(query_dataset)): 342 | y_true, y_score = [], [] 343 | imgs, rois = [], [] 344 | count_gt, count_tp = 0, 0 345 | 346 | feat_p = query_box_feats[i].ravel() 347 | 348 | query_imname = query_dataset.annotations[i]["img_name"] 349 | query_roi = query_dataset.annotations[i]["boxes"] 350 | query_pid = query_dataset.annotations[i]["pids"] 351 | query_cam = query_dataset.annotations[i]["cam_id"] 352 | 353 | # Find all occurence of this query 354 | gallery_imgs = [] 355 | for x in annos: 356 | if query_pid in x["pids"] and x["img_name"] != query_imname: 357 | gallery_imgs.append(x) 358 | query_gts = {} 359 | for item in gallery_imgs: 360 | query_gts[item["img_name"]] = item["boxes"][item["pids"] == query_pid] 361 | 362 | # Construct gallery set for this query 363 | if ignore_cam_id: 364 | gallery_imgs = [] 365 | for x in annos: 366 | if x["img_name"] != query_imname: 367 | gallery_imgs.append(x) 368 | else: 369 | gallery_imgs = [] 370 | for x in annos: 371 | if x["img_name"] != query_imname and x["cam_id"] != query_cam: 372 | gallery_imgs.append(x) 373 | 374 | name2sim = {} 375 | sims = [] 376 | imgs_cbgm = [] 377 | # 1. Go through all gallery samples 378 | for item in gallery_imgs: 379 | gallery_imname = item["img_name"] 380 | # some contain the query (gt not empty), some not 381 | count_gt += gallery_imname in query_gts 382 | # compute distance between query and gallery dets 383 | if gallery_imname not in name_to_det_feat: 384 | continue 385 | det, feat_g = name_to_det_feat[gallery_imname] 386 | # get L2-normalized feature matrix NxD 387 | assert feat_g.size == np.prod(feat_g.shape[:2]) 388 | feat_g = feat_g.reshape(feat_g.shape[:2]) 389 | # compute cosine similarities 390 | sim = feat_g.dot(feat_p).ravel() 391 | 392 | if gallery_imname in name2sim: 393 | continue 394 | name2sim[gallery_imname] = sim 395 | sims.extend(list(sim)) 396 | imgs_cbgm.extend([gallery_imname] * len(sim)) 397 | 398 | if cbgm: 399 | sims = np.array(sims) 400 | imgs_cbgm = np.array(imgs_cbgm) 401 | inds = np.argsort(sims)[-k1:] 402 | imgs_cbgm = set(imgs_cbgm[inds]) 403 | for img in imgs_cbgm: 404 | sim = name2sim[img] 405 | det, feat_g = name_to_det_feat[img] 406 | qboxes = query_dets[i][:k2] 407 | qfeats = query_feats[i][:k2] 408 | assert ( 409 | query_roi - qboxes[0][:4] 410 | ).sum() <= 0.001, "query_roi must be the first one in pboxes" 411 | 412 | graph = [] 413 | for indx_i, pfeat in enumerate(qfeats): 414 | for indx_j, gfeat in enumerate(feat_g): 415 | graph.append((indx_i, indx_j, (pfeat * gfeat).sum())) 416 | km_res, max_val = run_kuhn_munkres(graph) 417 | 418 | for indx_i, indx_j, _ in km_res: 419 | if indx_i == 0: 420 | sim[indx_j] = max_val 421 | break 422 | for gallery_imname, sim in name2sim.items(): 423 | det, feat_g = name_to_det_feat[gallery_imname] 424 | # assign label for each det 425 | label = np.zeros(len(sim), dtype=np.int32) 426 | if gallery_imname in query_gts: 427 | gt = query_gts[gallery_imname].ravel() 428 | w, h = gt[2] - gt[0], gt[3] - gt[1] 429 | iou_thresh = min(0.5, (w * h * 1.0) / ((w + 10) * (h + 10))) 430 | inds = np.argsort(sim)[::-1] 431 | sim = sim[inds] 432 | det = det[inds] 433 | # only set the first matched det as true positive 434 | for j, roi in enumerate(det[:, :4]): 435 | if _compute_iou(roi, gt) >= iou_thresh: 436 | label[j] = 1 437 | count_tp += 1 438 | break 439 | y_true.extend(list(label)) 440 | y_score.extend(list(sim)) 441 | imgs.extend([gallery_imname] * len(sim)) 442 | rois.extend(list(det)) 443 | 444 | # 2. Compute AP for this query (need to scale by recall rate) 445 | y_score = np.asarray(y_score) 446 | y_true = np.asarray(y_true) 447 | assert count_tp <= count_gt 448 | recall_rate = count_tp * 1.0 / count_gt 449 | ap = 0 if count_tp == 0 else average_precision_score(y_true, y_score) * recall_rate 450 | aps.append(ap) 451 | inds = np.argsort(y_score)[::-1] 452 | y_score = y_score[inds] 453 | y_true = y_true[inds] 454 | accs.append([min(1, sum(y_true[:k])) for k in topk]) 455 | # 4. Save result for JSON dump 456 | new_entry = { 457 | "query_img": str(query_imname), 458 | "query_roi": list(map(float, list(query_roi.squeeze()))), 459 | "query_gt": query_gts, 460 | "gallery": [], 461 | } 462 | # only save top-10 predictions 463 | for k in range(10): 464 | new_entry["gallery"].append( 465 | { 466 | "img": str(imgs[inds[k]]), 467 | "roi": list(map(float, list(rois[inds[k]]))), 468 | "score": float(y_score[k]), 469 | "correct": int(y_true[k]), 470 | } 471 | ) 472 | ret["results"].append(new_entry) 473 | 474 | print("search ranking:") 475 | mAP = np.mean(aps) 476 | print(" mAP = {:.2%}".format(mAP)) 477 | accs = np.mean(accs, axis=0) 478 | for i, k in enumerate(topk): 479 | print(" top-{:2d} = {:.2%}".format(k, accs[i])) 480 | 481 | # write_json(ret, "vis/results.json") 482 | 483 | ret["mAP"] = np.mean(aps) 484 | ret["accs"] = accs 485 | return ret 486 | -------------------------------------------------------------------------------- /models/seqnet.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import init 7 | from torchvision.models.detection.faster_rcnn import FastRCNNPredictor 8 | from torchvision.models.detection.roi_heads import RoIHeads 9 | from torchvision.models.detection.rpn import AnchorGenerator, RegionProposalNetwork, RPNHead 10 | from torchvision.models.detection.transform import GeneralizedRCNNTransform 11 | from torchvision.ops import MultiScaleRoIAlign 12 | from torchvision.ops import boxes as box_ops 13 | 14 | from models.oim import OIMLoss 15 | from models.resnet import build_resnet 16 | from models.swin import build_swin 17 | 18 | class SeqNet(nn.Module): 19 | def __init__(self, cfg): 20 | super(SeqNet, self).__init__() 21 | 22 | backbone_name = cfg.MODEL.BONE 23 | semantic_weight = cfg.MODEL.SEMANTIC_WEIGHT 24 | if backbone_name == 'resnet50': 25 | backbone, box_head = build_resnet(name="resnet50", pretrained=True) 26 | feat_len = 2048 27 | elif 'swin' in backbone_name: 28 | backbone, box_head, feat_len = build_swin(name=backbone_name, semantic_weight=semantic_weight) 29 | 30 | anchor_generator = AnchorGenerator( 31 | sizes=((32, 64, 128, 256, 512),), aspect_ratios=((0.5, 1.0, 2.0),) 32 | ) 33 | head = RPNHead( 34 | in_channels=backbone.out_channels, 35 | num_anchors=anchor_generator.num_anchors_per_location()[0], 36 | ) 37 | pre_nms_top_n = dict( 38 | training=cfg.MODEL.RPN.PRE_NMS_TOPN_TRAIN, testing=cfg.MODEL.RPN.PRE_NMS_TOPN_TEST 39 | ) 40 | post_nms_top_n = dict( 41 | training=cfg.MODEL.RPN.POST_NMS_TOPN_TRAIN, testing=cfg.MODEL.RPN.POST_NMS_TOPN_TEST 42 | ) 43 | rpn = RegionProposalNetwork( 44 | anchor_generator=anchor_generator, 45 | head=head, 46 | fg_iou_thresh=cfg.MODEL.RPN.POS_THRESH_TRAIN, 47 | bg_iou_thresh=cfg.MODEL.RPN.NEG_THRESH_TRAIN, 48 | batch_size_per_image=cfg.MODEL.RPN.BATCH_SIZE_TRAIN, 49 | positive_fraction=cfg.MODEL.RPN.POS_FRAC_TRAIN, 50 | pre_nms_top_n=pre_nms_top_n, 51 | post_nms_top_n=post_nms_top_n, 52 | nms_thresh=cfg.MODEL.RPN.NMS_THRESH, 53 | ) 54 | 55 | faster_rcnn_predictor = FastRCNNPredictor(feat_len, 2) 56 | reid_head = deepcopy(box_head) 57 | box_roi_pool = MultiScaleRoIAlign( 58 | featmap_names=["feat_res4"], output_size=14, sampling_ratio=2 59 | ) 60 | box_predictor = BBoxRegressor(feat_len, num_classes=2, bn_neck=cfg.MODEL.ROI_HEAD.BN_NECK) 61 | roi_heads = SeqRoIHeads( 62 | # OIM 63 | num_pids=cfg.MODEL.LOSS.LUT_SIZE, 64 | num_cq_size=cfg.MODEL.LOSS.CQ_SIZE, 65 | oim_momentum=cfg.MODEL.LOSS.OIM_MOMENTUM, 66 | oim_scalar=cfg.MODEL.LOSS.OIM_SCALAR, 67 | # SeqNet 68 | faster_rcnn_predictor=faster_rcnn_predictor, 69 | reid_head=reid_head, 70 | # parent class 71 | box_roi_pool=box_roi_pool, 72 | box_head=box_head, 73 | box_predictor=box_predictor, 74 | fg_iou_thresh=cfg.MODEL.ROI_HEAD.POS_THRESH_TRAIN, 75 | bg_iou_thresh=cfg.MODEL.ROI_HEAD.NEG_THRESH_TRAIN, 76 | batch_size_per_image=cfg.MODEL.ROI_HEAD.BATCH_SIZE_TRAIN, 77 | positive_fraction=cfg.MODEL.ROI_HEAD.POS_FRAC_TRAIN, 78 | bbox_reg_weights=None, 79 | score_thresh=cfg.MODEL.ROI_HEAD.SCORE_THRESH_TEST, 80 | nms_thresh=cfg.MODEL.ROI_HEAD.NMS_THRESH_TEST, 81 | detections_per_img=cfg.MODEL.ROI_HEAD.DETECTIONS_PER_IMAGE_TEST, 82 | feat_len=feat_len, 83 | ) 84 | 85 | transform = GeneralizedRCNNTransform( 86 | min_size=cfg.INPUT.MIN_SIZE, 87 | max_size=cfg.INPUT.MAX_SIZE, 88 | image_mean=[0.485, 0.456, 0.406], 89 | image_std=[0.229, 0.224, 0.225], 90 | ) 91 | 92 | self.backbone = backbone 93 | self.rpn = rpn 94 | self.roi_heads = roi_heads 95 | self.transform = transform 96 | 97 | # loss weights 98 | self.lw_rpn_reg = cfg.SOLVER.LW_RPN_REG 99 | self.lw_rpn_cls = cfg.SOLVER.LW_RPN_CLS 100 | self.lw_proposal_reg = cfg.SOLVER.LW_PROPOSAL_REG 101 | self.lw_proposal_cls = cfg.SOLVER.LW_PROPOSAL_CLS 102 | self.lw_box_reg = cfg.SOLVER.LW_BOX_REG 103 | self.lw_box_cls = cfg.SOLVER.LW_BOX_CLS 104 | self.lw_box_reid = cfg.SOLVER.LW_BOX_REID 105 | 106 | def inference(self, images, targets=None, query_img_as_gallery=False): 107 | """ 108 | query_img_as_gallery: Set to True to detect all people in the query image. 109 | Meanwhile, the gt box should be the first of the detected boxes. 110 | This option serves CBGM. 111 | """ 112 | original_image_sizes = [img.shape[-2:] for img in images] 113 | images, targets = self.transform(images, targets) 114 | features = self.backbone(images.tensors) 115 | 116 | if query_img_as_gallery: 117 | assert targets is not None 118 | 119 | if targets is not None and not query_img_as_gallery: 120 | # query 121 | boxes = [t["boxes"] for t in targets] 122 | box_features = self.roi_heads.box_roi_pool(features, boxes, images.image_sizes) 123 | box_features = self.roi_heads.reid_head(box_features) 124 | embeddings, _ = self.roi_heads.embedding_head(box_features) 125 | return embeddings.split(1, 0) 126 | else: 127 | # gallery 128 | proposals, _ = self.rpn(images, features, targets) 129 | detections, _ = self.roi_heads( 130 | features, proposals, images.image_sizes, targets, query_img_as_gallery 131 | ) 132 | detections = self.transform.postprocess( 133 | detections, images.image_sizes, original_image_sizes 134 | ) 135 | return detections 136 | 137 | def forward(self, images, targets=None, query_img_as_gallery=False): 138 | if not self.training: 139 | return self.inference(images, targets, query_img_as_gallery) 140 | 141 | images, targets = self.transform(images, targets) 142 | features = self.backbone(images.tensors) 143 | proposals, proposal_losses = self.rpn(images, features, targets) 144 | _, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets) 145 | 146 | # rename rpn losses to be consistent with detection losses 147 | proposal_losses["loss_rpn_reg"] = proposal_losses.pop("loss_rpn_box_reg") 148 | proposal_losses["loss_rpn_cls"] = proposal_losses.pop("loss_objectness") 149 | 150 | losses = {} 151 | losses.update(detector_losses) 152 | losses.update(proposal_losses) 153 | 154 | # apply loss weights 155 | losses["loss_rpn_reg"] *= self.lw_rpn_reg 156 | losses["loss_rpn_cls"] *= self.lw_rpn_cls 157 | losses["loss_proposal_reg"] *= self.lw_proposal_reg 158 | losses["loss_proposal_cls"] *= self.lw_proposal_cls 159 | losses["loss_box_reg"] *= self.lw_box_reg 160 | losses["loss_box_cls"] *= self.lw_box_cls 161 | losses["loss_box_reid"] *= self.lw_box_reid 162 | return losses 163 | 164 | 165 | class SeqRoIHeads(RoIHeads): 166 | def __init__( 167 | self, 168 | num_pids, 169 | num_cq_size, 170 | oim_momentum, 171 | oim_scalar, 172 | faster_rcnn_predictor, 173 | reid_head, 174 | feat_len, 175 | *args, 176 | **kwargs 177 | ): 178 | super(SeqRoIHeads, self).__init__(*args, **kwargs) 179 | self.embedding_head = NormAwareEmbedding(in_channels=[int(feat_len/2),feat_len]) 180 | self.reid_loss = OIMLoss(256, num_pids, num_cq_size, oim_momentum, oim_scalar) 181 | self.faster_rcnn_predictor = faster_rcnn_predictor 182 | self.reid_head = reid_head 183 | # rename the method inherited from parent class 184 | self.postprocess_proposals = self.postprocess_detections 185 | 186 | def forward(self, features, proposals, image_shapes, targets=None, query_img_as_gallery=False): 187 | """ 188 | Arguments: 189 | features (List[Tensor]) 190 | proposals (List[Tensor[N, 4]]) 191 | image_shapes (List[Tuple[H, W]]) 192 | targets (List[Dict]) 193 | """ 194 | if self.training: 195 | proposals, _, proposal_pid_labels, proposal_reg_targets = self.select_training_samples( 196 | proposals, targets 197 | ) 198 | 199 | # ------------------- Faster R-CNN head ------------------ # 200 | proposal_features = self.box_roi_pool(features, proposals, image_shapes) 201 | proposal_features = self.box_head(proposal_features) 202 | proposal_cls_scores, proposal_regs = self.faster_rcnn_predictor( 203 | proposal_features["feat_res5"] 204 | ) 205 | 206 | if self.training: 207 | boxes = self.get_boxes(proposal_regs, proposals, image_shapes) 208 | boxes = [boxes_per_image.detach() for boxes_per_image in boxes] 209 | boxes, _, box_pid_labels, box_reg_targets = self.select_training_samples(boxes, targets) 210 | else: 211 | # invoke the postprocess method inherited from parent class to process proposals 212 | boxes, scores, _ = self.postprocess_proposals( 213 | proposal_cls_scores, proposal_regs, proposals, image_shapes 214 | ) 215 | 216 | cws = True 217 | gt_det = None 218 | if not self.training and query_img_as_gallery: 219 | # When regarding the query image as gallery, GT boxes may be excluded 220 | # from detected boxes. To avoid this, we compulsorily include GT in the 221 | # detection results. Additionally, CWS should be disabled as the 222 | # confidences of these people in query image are 1 223 | cws = False 224 | gt_box = [targets[0]["boxes"]] 225 | gt_box_features = self.box_roi_pool(features, gt_box, image_shapes) 226 | gt_box_features = self.reid_head(gt_box_features) 227 | embeddings, _ = self.embedding_head(gt_box_features) 228 | gt_det = {"boxes": targets[0]["boxes"], "embeddings": embeddings} 229 | 230 | # no detection predicted by Faster R-CNN head in test phase 231 | if boxes[0].shape[0] == 0: 232 | assert not self.training 233 | boxes = gt_det["boxes"] if gt_det else torch.zeros(0, 4) 234 | labels = torch.ones(1).type_as(boxes) if gt_det else torch.zeros(0) 235 | scores = torch.ones(1).type_as(boxes) if gt_det else torch.zeros(0) 236 | embeddings = gt_det["embeddings"] if gt_det else torch.zeros(0, 256) 237 | return [dict(boxes=boxes, labels=labels, scores=scores, embeddings=embeddings)], [] 238 | 239 | # --------------------- Baseline head -------------------- # 240 | box_features = self.box_roi_pool(features, boxes, image_shapes) 241 | box_features = self.reid_head(box_features) 242 | box_regs = self.box_predictor(box_features["feat_res5"]) 243 | box_embeddings, box_cls_scores = self.embedding_head(box_features) 244 | if box_cls_scores.dim() == 0: 245 | box_cls_scores = box_cls_scores.unsqueeze(0) 246 | 247 | result, losses = [], {} 248 | if self.training: 249 | proposal_labels = [y.clamp(0, 1) for y in proposal_pid_labels] 250 | box_labels = [y.clamp(0, 1) for y in box_pid_labels] 251 | losses = detection_losses( 252 | proposal_cls_scores, 253 | proposal_regs, 254 | proposal_labels, 255 | proposal_reg_targets, 256 | box_cls_scores, 257 | box_regs, 258 | box_labels, 259 | box_reg_targets, 260 | ) 261 | loss_box_reid = self.reid_loss(box_embeddings, box_pid_labels) 262 | losses.update(loss_box_reid=loss_box_reid) 263 | else: 264 | # The IoUs of these boxes are higher than that of proposals, 265 | # so a higher NMS threshold is needed 266 | orig_thresh = self.nms_thresh 267 | self.nms_thresh = 0.5 268 | boxes, scores, embeddings, labels = self.postprocess_boxes( 269 | box_cls_scores, 270 | box_regs, 271 | box_embeddings, 272 | boxes, 273 | image_shapes, 274 | fcs=scores, 275 | gt_det=gt_det, 276 | cws=cws, 277 | ) 278 | # set to original thresh after finishing postprocess 279 | self.nms_thresh = orig_thresh 280 | num_images = len(boxes) 281 | for i in range(num_images): 282 | result.append( 283 | dict( 284 | boxes=boxes[i], labels=labels[i], scores=scores[i], embeddings=embeddings[i] 285 | ) 286 | ) 287 | return result, losses 288 | 289 | def get_boxes(self, box_regression, proposals, image_shapes): 290 | """ 291 | Get boxes from proposals. 292 | """ 293 | boxes_per_image = [len(boxes_in_image) for boxes_in_image in proposals] 294 | pred_boxes = self.box_coder.decode(box_regression, proposals) 295 | pred_boxes = pred_boxes.split(boxes_per_image, 0) 296 | 297 | all_boxes = [] 298 | for boxes, image_shape in zip(pred_boxes, image_shapes): 299 | boxes = box_ops.clip_boxes_to_image(boxes, image_shape) 300 | # remove predictions with the background label 301 | boxes = boxes[:, 1:].reshape(-1, 4) 302 | all_boxes.append(boxes) 303 | 304 | return all_boxes 305 | 306 | def postprocess_boxes( 307 | self, 308 | class_logits, 309 | box_regression, 310 | embeddings, 311 | proposals, 312 | image_shapes, 313 | fcs=None, 314 | gt_det=None, 315 | cws=True, 316 | ): 317 | """ 318 | Similar to RoIHeads.postprocess_detections, but can handle embeddings and implement 319 | First Classification Score (FCS). 320 | """ 321 | device = class_logits.device 322 | 323 | boxes_per_image = [len(boxes_in_image) for boxes_in_image in proposals] 324 | pred_boxes = self.box_coder.decode(box_regression, proposals) 325 | 326 | if fcs is not None: 327 | # Fist Classification Score (FCS) 328 | pred_scores = fcs[0] 329 | else: 330 | pred_scores = torch.sigmoid(class_logits) 331 | if cws: 332 | # Confidence Weighted Similarity (CWS) 333 | embeddings = embeddings * pred_scores.view(-1, 1) 334 | 335 | # split boxes and scores per image 336 | pred_boxes = pred_boxes.split(boxes_per_image, 0) 337 | pred_scores = pred_scores.split(boxes_per_image, 0) 338 | pred_embeddings = embeddings.split(boxes_per_image, 0) 339 | 340 | all_boxes = [] 341 | all_scores = [] 342 | all_labels = [] 343 | all_embeddings = [] 344 | for boxes, scores, embeddings, image_shape in zip( 345 | pred_boxes, pred_scores, pred_embeddings, image_shapes 346 | ): 347 | boxes = box_ops.clip_boxes_to_image(boxes, image_shape) 348 | 349 | # create labels for each prediction 350 | labels = torch.ones(scores.size(0), device=device) 351 | 352 | # remove predictions with the background label 353 | boxes = boxes[:, 1:] 354 | scores = scores.unsqueeze(1) 355 | labels = labels.unsqueeze(1) 356 | 357 | # batch everything, by making every class prediction be a separate instance 358 | boxes = boxes.reshape(-1, 4) 359 | scores = scores.flatten() 360 | labels = labels.flatten() 361 | embeddings = embeddings.reshape(-1, self.embedding_head.dim) 362 | 363 | # remove low scoring boxes 364 | inds = torch.nonzero(scores > self.score_thresh).squeeze(1) 365 | boxes, scores, labels, embeddings = ( 366 | boxes[inds], 367 | scores[inds], 368 | labels[inds], 369 | embeddings[inds], 370 | ) 371 | 372 | # remove empty boxes 373 | keep = box_ops.remove_small_boxes(boxes, min_size=1e-2) 374 | boxes, scores, labels, embeddings = ( 375 | boxes[keep], 376 | scores[keep], 377 | labels[keep], 378 | embeddings[keep], 379 | ) 380 | 381 | if gt_det is not None: 382 | # include GT into the detection results 383 | boxes = torch.cat((boxes, gt_det["boxes"]), dim=0) 384 | labels = torch.cat((labels, torch.tensor([1.0]).to(device)), dim=0) 385 | scores = torch.cat((scores, torch.tensor([1.0]).to(device)), dim=0) 386 | embeddings = torch.cat((embeddings, gt_det["embeddings"]), dim=0) 387 | 388 | # non-maximum suppression, independently done per class 389 | keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh) 390 | # keep only topk scoring predictions 391 | keep = keep[: self.detections_per_img] 392 | boxes, scores, labels, embeddings = ( 393 | boxes[keep], 394 | scores[keep], 395 | labels[keep], 396 | embeddings[keep], 397 | ) 398 | 399 | all_boxes.append(boxes) 400 | all_scores.append(scores) 401 | all_labels.append(labels) 402 | all_embeddings.append(embeddings) 403 | 404 | return all_boxes, all_scores, all_embeddings, all_labels 405 | 406 | 407 | class NormAwareEmbedding(nn.Module): 408 | """ 409 | Implements the Norm-Aware Embedding proposed in 410 | Chen, Di, et al. "Norm-aware embedding for efficient person search." CVPR 2020. 411 | """ 412 | 413 | def __init__(self, featmap_names=["feat_res4", "feat_res5"], in_channels=[1024, 2048], dim=256): 414 | super(NormAwareEmbedding, self).__init__() 415 | self.featmap_names = featmap_names 416 | self.in_channels = in_channels 417 | self.dim = dim 418 | 419 | self.projectors = nn.ModuleDict() 420 | indv_dims = self._split_embedding_dim() 421 | for ftname, in_channel, indv_dim in zip(self.featmap_names, self.in_channels, indv_dims): 422 | proj = nn.Sequential(nn.Linear(in_channel, indv_dim), nn.BatchNorm1d(indv_dim)) 423 | init.normal_(proj[0].weight, std=0.01) 424 | init.normal_(proj[1].weight, std=0.01) 425 | init.constant_(proj[0].bias, 0) 426 | init.constant_(proj[1].bias, 0) 427 | self.projectors[ftname] = proj 428 | 429 | self.rescaler = nn.BatchNorm1d(1, affine=True) 430 | 431 | def forward(self, featmaps): 432 | """ 433 | Arguments: 434 | featmaps: OrderedDict[Tensor], and in featmap_names you can choose which 435 | featmaps to use 436 | Returns: 437 | tensor of size (BatchSize, dim), L2 normalized embeddings. 438 | tensor of size (BatchSize, ) rescaled norm of embeddings, as class_logits. 439 | """ 440 | assert len(featmaps) == len(self.featmap_names) 441 | if len(featmaps) == 1: 442 | k, v = featmaps.items()[0] 443 | v = self._flatten_fc_input(v) 444 | embeddings = self.projectors[k](v) 445 | norms = embeddings.norm(2, 1, keepdim=True) 446 | embeddings = embeddings / norms.expand_as(embeddings).clamp(min=1e-12) 447 | norms = self.rescaler(norms).squeeze() 448 | return embeddings, norms 449 | else: 450 | outputs = [] 451 | for k, v in featmaps.items(): 452 | v = self._flatten_fc_input(v) 453 | outputs.append(self.projectors[k](v)) 454 | embeddings = torch.cat(outputs, dim=1) 455 | norms = embeddings.norm(2, 1, keepdim=True) 456 | embeddings = embeddings / norms.expand_as(embeddings).clamp(min=1e-12) 457 | norms = self.rescaler(norms).squeeze() 458 | return embeddings, norms 459 | 460 | def _flatten_fc_input(self, x): 461 | if x.ndimension() == 4: 462 | assert list(x.shape[2:]) == [1, 1] 463 | return x.flatten(start_dim=1) 464 | return x 465 | 466 | def _split_embedding_dim(self): 467 | parts = len(self.in_channels) 468 | tmp = [self.dim // parts] * parts 469 | if sum(tmp) == self.dim: 470 | return tmp 471 | else: 472 | res = self.dim % parts 473 | for i in range(1, res + 1): 474 | tmp[-i] += 1 475 | assert sum(tmp) == self.dim 476 | return tmp 477 | 478 | 479 | class BBoxRegressor(nn.Module): 480 | """ 481 | Bounding box regression layer. 482 | """ 483 | 484 | def __init__(self, in_channels, num_classes=2, bn_neck=True): 485 | """ 486 | Args: 487 | in_channels (int): Input channels. 488 | num_classes (int, optional): Defaults to 2 (background and pedestrian). 489 | bn_neck (bool, optional): Whether to use BN after Linear. Defaults to True. 490 | """ 491 | super(BBoxRegressor, self).__init__() 492 | if bn_neck: 493 | self.bbox_pred = nn.Sequential( 494 | nn.Linear(in_channels, 4 * num_classes), nn.BatchNorm1d(4 * num_classes) 495 | ) 496 | init.normal_(self.bbox_pred[0].weight, std=0.01) 497 | init.normal_(self.bbox_pred[1].weight, std=0.01) 498 | init.constant_(self.bbox_pred[0].bias, 0) 499 | init.constant_(self.bbox_pred[1].bias, 0) 500 | else: 501 | self.bbox_pred = nn.Linear(in_channels, 4 * num_classes) 502 | init.normal_(self.bbox_pred.weight, std=0.01) 503 | init.constant_(self.bbox_pred.bias, 0) 504 | 505 | def forward(self, x): 506 | if x.ndimension() == 4: 507 | if list(x.shape[2:]) != [1, 1]: 508 | x = F.adaptive_avg_pool2d(x, output_size=1) 509 | x = x.flatten(start_dim=1) 510 | bbox_deltas = self.bbox_pred(x) 511 | return bbox_deltas 512 | 513 | 514 | def detection_losses( 515 | proposal_cls_scores, 516 | proposal_regs, 517 | proposal_labels, 518 | proposal_reg_targets, 519 | box_cls_scores, 520 | box_regs, 521 | box_labels, 522 | box_reg_targets, 523 | ): 524 | proposal_labels = torch.cat(proposal_labels, dim=0) 525 | box_labels = torch.cat(box_labels, dim=0) 526 | proposal_reg_targets = torch.cat(proposal_reg_targets, dim=0) 527 | box_reg_targets = torch.cat(box_reg_targets, dim=0) 528 | 529 | loss_proposal_cls = F.cross_entropy(proposal_cls_scores, proposal_labels) 530 | loss_box_cls = F.binary_cross_entropy_with_logits(box_cls_scores, box_labels.float()) 531 | 532 | # get indices that correspond to the regression targets for the 533 | # corresponding ground truth labels, to be used with advanced indexing 534 | sampled_pos_inds_subset = torch.nonzero(proposal_labels > 0).squeeze(1) 535 | labels_pos = proposal_labels[sampled_pos_inds_subset] 536 | N = proposal_cls_scores.size(0) 537 | proposal_regs = proposal_regs.reshape(N, -1, 4) 538 | 539 | loss_proposal_reg = F.smooth_l1_loss( 540 | proposal_regs[sampled_pos_inds_subset, labels_pos], 541 | proposal_reg_targets[sampled_pos_inds_subset], 542 | reduction="sum", 543 | ) 544 | loss_proposal_reg = loss_proposal_reg / proposal_labels.numel() 545 | 546 | sampled_pos_inds_subset = torch.nonzero(box_labels > 0).squeeze(1) 547 | labels_pos = box_labels[sampled_pos_inds_subset] 548 | N = box_cls_scores.size(0) 549 | box_regs = box_regs.reshape(N, -1, 4) 550 | 551 | loss_box_reg = F.smooth_l1_loss( 552 | box_regs[sampled_pos_inds_subset, labels_pos], 553 | box_reg_targets[sampled_pos_inds_subset], 554 | reduction="sum", 555 | ) 556 | loss_box_reg = loss_box_reg / box_labels.numel() 557 | 558 | return dict( 559 | loss_proposal_cls=loss_proposal_cls, 560 | loss_proposal_reg=loss_proposal_reg, 561 | loss_box_cls=loss_box_cls, 562 | loss_box_reg=loss_box_reg, 563 | ) 564 | -------------------------------------------------------------------------------- /models/swin_transformer.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | from copy import deepcopy 4 | import logging 5 | 6 | import math 7 | from typing import Sequence 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.utils.checkpoint as cp 12 | import numpy as np 13 | import cv2 14 | 15 | from torch.nn import Module as BaseModule 16 | from torch.nn import ModuleList 17 | from torch.nn import Sequential 18 | from torch.nn import Linear 19 | from torch import Tensor 20 | from mmcv.runner import load_checkpoint as _load_checkpoint 21 | 22 | from itertools import repeat 23 | import collections.abc 24 | def _ntuple(n): 25 | 26 | def parse(x): 27 | if isinstance(x, collections.abc.Iterable): 28 | return x 29 | return tuple(repeat(x, n)) 30 | 31 | return parse 32 | to_2tuple = _ntuple(2) 33 | 34 | def trunc_normal_init(module: nn.Module, 35 | mean: float = 0, 36 | std: float = 1, 37 | a: float = -2, 38 | b: float = 2, 39 | bias: float = 0) -> None: 40 | if hasattr(module, 'weight') and module.weight is not None: 41 | #trunc_normal_(module.weight, mean, std, a, b) # type: ignore 42 | _no_grad_trunc_normal_(module.weight, mean, std, a, b) # type: ignore 43 | if hasattr(module, 'bias') and module.bias is not None: 44 | nn.init.constant_(module.bias, bias) # type: ignore 45 | 46 | def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float, 47 | b: float) -> Tensor: 48 | # Method based on 49 | # https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 50 | # Modified from 51 | # https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py 52 | def norm_cdf(x): 53 | # Computes standard normal cumulative distribution function 54 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 55 | 56 | if (mean < a - 2 * std) or (mean > b + 2 * std): 57 | warnings.warn( 58 | 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ' 59 | 'The distribution of values may be incorrect.', 60 | stacklevel=2) 61 | 62 | with torch.no_grad(): 63 | # Values are generated by using a truncated uniform distribution and 64 | # then using the inverse CDF for the normal distribution. 65 | # Get upper and lower cdf values 66 | lower = norm_cdf((a - mean) / std) 67 | upper = norm_cdf((b - mean) / std) 68 | 69 | # Uniformly fill tensor with values from [lower, upper], then translate 70 | # to [2lower-1, 2upper-1]. 71 | tensor.uniform_(2 * lower - 1, 2 * upper - 1) 72 | 73 | # Use inverse cdf transform for normal distribution to get truncated 74 | # standard normal 75 | tensor.erfinv_() 76 | 77 | # Transform to proper mean, std 78 | tensor.mul_(std * math.sqrt(2.)) 79 | tensor.add_(mean) 80 | 81 | # Clamp to ensure it's in the proper range 82 | tensor.clamp_(min=a, max=b) 83 | return tensor 84 | 85 | def trunc_normal_(tensor: Tensor, 86 | mean: float = 0., 87 | std: float = 1., 88 | a: float = -2., 89 | b: float = 2.) -> Tensor: 90 | r"""Fills the input Tensor with values drawn from a truncated 91 | normal distribution. The values are effectively drawn from the 92 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 93 | with values outside :math:`[a, b]` redrawn until they are within 94 | the bounds. The method used for generating the random values works 95 | best when :math:`a \leq \text{mean} \leq b`. 96 | 97 | Modified from 98 | https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py 99 | 100 | Args: 101 | tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`. 102 | mean (float): the mean of the normal distribution. 103 | std (float): the standard deviation of the normal distribution. 104 | a (float): the minimum cutoff value. 105 | b (float): the maximum cutoff value. 106 | """ 107 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 108 | 109 | def constant_init(module, val, bias=0): 110 | if hasattr(module, 'weight') and module.weight is not None: 111 | nn.init.constant_(module.weight, val) 112 | if hasattr(module, 'bias') and module.bias is not None: 113 | nn.init.constant_(module.bias, bias) 114 | 115 | def build_norm_layer(norm_cfg,embed_dims): 116 | assert norm_cfg['type'] == 'LN' 117 | norm_layer = nn.LayerNorm(embed_dims) 118 | return norm_cfg['type'],norm_layer 119 | 120 | class GELU(nn.Module): 121 | r"""Applies the Gaussian Error Linear Units function: 122 | 123 | .. math:: 124 | \text{GELU}(x) = x * \Phi(x) 125 | where :math:`\Phi(x)` is the Cumulative Distribution Function for 126 | Gaussian Distribution. 127 | 128 | Shape: 129 | - Input: :math:`(N, *)` where `*` means, any number of additional 130 | dimensions 131 | - Output: :math:`(N, *)`, same shape as the input 132 | 133 | .. image:: scripts/activation_images/GELU.png 134 | 135 | Examples:: 136 | 137 | >>> m = nn.GELU() 138 | >>> input = torch.randn(2) 139 | >>> output = m(input) 140 | """ 141 | 142 | def forward(self, input): 143 | return F.gelu(input) 144 | 145 | def build_activation_layer(act_cfg): 146 | if act_cfg['type'] == 'ReLU': 147 | act_layer = nn.ReLU(inplace=act_cfg['inplace']) 148 | elif act_cfg['type'] == 'GELU': 149 | act_layer = GELU() 150 | return act_layer 151 | 152 | def build_conv_layer(conv_cfg, 153 | in_channels, 154 | out_channels, 155 | kernel_size, 156 | stride, 157 | padding, 158 | dilation, 159 | bias): 160 | conv_layer = nn.Conv2d( 161 | in_channels=in_channels, 162 | out_channels=out_channels, 163 | kernel_size=kernel_size, 164 | stride=stride, 165 | padding=padding, 166 | dilation=dilation, 167 | bias=bias) 168 | return conv_layer 169 | 170 | def drop_path(x, drop_prob=0., training=False): 171 | """Drop paths (Stochastic Depth) per sample (when applied in main path of 172 | residual blocks). 173 | 174 | We follow the implementation 175 | https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501 176 | """ 177 | if drop_prob == 0. or not training: 178 | return x 179 | keep_prob = 1 - drop_prob 180 | # handle tensors with different dimensions, not just 4D tensors. 181 | shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) 182 | random_tensor = keep_prob + torch.rand( 183 | shape, dtype=x.dtype, device=x.device) 184 | output = x.div(keep_prob) * random_tensor.floor() 185 | return output 186 | 187 | class DropPath(nn.Module): 188 | """Drop paths (Stochastic Depth) per sample (when applied in main path of 189 | residual blocks). 190 | 191 | We follow the implementation 192 | https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501 193 | 194 | Args: 195 | drop_prob (float): Probability of the path to be zeroed. Default: 0.1 196 | """ 197 | 198 | def __init__(self, drop_prob=0.1): 199 | super(DropPath, self).__init__() 200 | self.drop_prob = drop_prob 201 | 202 | def forward(self, x): 203 | return drop_path(x, self.drop_prob, self.training) 204 | 205 | def build_dropout(drop_cfg): 206 | drop_layer = DropPath(drop_cfg['drop_prob']) 207 | return drop_layer 208 | 209 | class FFN(BaseModule): 210 | def __init__(self, 211 | embed_dims=256, 212 | feedforward_channels=1024, 213 | num_fcs=2, 214 | act_cfg=dict(type='ReLU', inplace=True), 215 | ffn_drop=0., 216 | dropout_layer=None, 217 | add_identity=True, 218 | init_cfg=None, 219 | **kwargs): 220 | super(FFN, self).__init__() 221 | assert num_fcs >= 2, 'num_fcs should be no less ' \ 222 | f'than 2. got {num_fcs}.' 223 | self.embed_dims = embed_dims 224 | self.feedforward_channels = feedforward_channels 225 | self.num_fcs = num_fcs 226 | self.act_cfg = act_cfg 227 | self.activate = build_activation_layer(act_cfg) 228 | 229 | layers = [] 230 | in_channels = embed_dims 231 | for _ in range(num_fcs - 1): 232 | layers.append( 233 | Sequential( 234 | Linear(in_channels, feedforward_channels), self.activate, 235 | nn.Dropout(ffn_drop))) 236 | in_channels = feedforward_channels 237 | layers.append(Linear(feedforward_channels, embed_dims)) 238 | layers.append(nn.Dropout(ffn_drop)) 239 | self.layers = Sequential(*layers) 240 | self.dropout_layer = build_dropout( 241 | dropout_layer) if dropout_layer else torch.nn.Identity() 242 | self.add_identity = add_identity 243 | 244 | def forward(self, x, identity=None): 245 | """Forward function for `FFN`. 246 | 247 | The function would add x to the output tensor if residue is None. 248 | """ 249 | out = self.layers(x) 250 | if not self.add_identity: 251 | return self.dropout_layer(out) 252 | if identity is None: 253 | identity = x 254 | return identity + self.dropout_layer(out) 255 | 256 | def swin_converter(ckpt): 257 | 258 | new_ckpt = OrderedDict() 259 | 260 | def correct_unfold_reduction_order(x): 261 | out_channel, in_channel = x.shape 262 | x = x.reshape(out_channel, 4, in_channel // 4) 263 | x = x[:, [0, 2, 1, 3], :].transpose(1, 264 | 2).reshape(out_channel, in_channel) 265 | return x 266 | 267 | def correct_unfold_norm_order(x): 268 | in_channel = x.shape[0] 269 | x = x.reshape(4, in_channel // 4) 270 | x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel) 271 | return x 272 | 273 | for k, v in ckpt.items(): 274 | if k.startswith('head'): 275 | continue 276 | elif k.startswith('layers'): 277 | new_v = v 278 | if 'attn.' in k: 279 | new_k = k.replace('attn.', 'attn.w_msa.') 280 | elif 'mlp.' in k: 281 | if 'mlp.fc1.' in k: 282 | new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.') 283 | elif 'mlp.fc2.' in k: 284 | new_k = k.replace('mlp.fc2.', 'ffn.layers.1.') 285 | else: 286 | new_k = k.replace('mlp.', 'ffn.') 287 | elif 'downsample' in k: 288 | new_k = k 289 | if 'reduction.' in k: 290 | new_v = correct_unfold_reduction_order(v) 291 | elif 'norm.' in k: 292 | new_v = correct_unfold_norm_order(v) 293 | else: 294 | new_k = k 295 | new_k = new_k.replace('layers', 'stages', 1) 296 | elif k.startswith('patch_embed'): 297 | new_v = v 298 | if 'proj' in k: 299 | new_k = k.replace('proj', 'projection') 300 | else: 301 | new_k = k 302 | else: 303 | new_v = v 304 | new_k = k 305 | 306 | new_ckpt['backbone.' + new_k] = new_v 307 | 308 | return new_ckpt 309 | 310 | class AdaptivePadding(nn.Module): 311 | """Applies padding to input (if needed) so that input can get fully covered 312 | by filter you specified. It support two modes "same" and "corner". The 313 | "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around 314 | input. The "corner" mode would pad zero to bottom right. 315 | Args: 316 | kernel_size (int | tuple): Size of the kernel: 317 | stride (int | tuple): Stride of the filter. Default: 1: 318 | dilation (int | tuple): Spacing between kernel elements. 319 | Default: 1 320 | padding (str): Support "same" and "corner", "corner" mode 321 | would pad zero to bottom right, and "same" mode would 322 | pad zero around input. Default: "corner". 323 | Example: 324 | >>> kernel_size = 16 325 | >>> stride = 16 326 | >>> dilation = 1 327 | >>> input = torch.rand(1, 1, 15, 17) 328 | >>> adap_pad = AdaptivePadding( 329 | >>> kernel_size=kernel_size, 330 | >>> stride=stride, 331 | >>> dilation=dilation, 332 | >>> padding="corner") 333 | >>> out = adap_pad(input) 334 | >>> assert (out.shape[2], out.shape[3]) == (16, 32) 335 | >>> input = torch.rand(1, 1, 16, 17) 336 | >>> out = adap_pad(input) 337 | >>> assert (out.shape[2], out.shape[3]) == (16, 32) 338 | """ 339 | 340 | def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'): 341 | 342 | super(AdaptivePadding, self).__init__() 343 | 344 | assert padding in ('same', 'corner') 345 | 346 | kernel_size = to_2tuple(kernel_size) 347 | stride = to_2tuple(stride) 348 | padding = to_2tuple(padding) 349 | dilation = to_2tuple(dilation) 350 | 351 | self.padding = padding 352 | self.kernel_size = kernel_size 353 | self.stride = stride 354 | self.dilation = dilation 355 | 356 | def get_pad_shape(self, input_shape): 357 | input_h, input_w = input_shape 358 | kernel_h, kernel_w = self.kernel_size 359 | stride_h, stride_w = self.stride 360 | output_h = math.ceil(input_h / stride_h) 361 | output_w = math.ceil(input_w / stride_w) 362 | pad_h = max((output_h - 1) * stride_h + 363 | (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0) 364 | pad_w = max((output_w - 1) * stride_w + 365 | (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0) 366 | return pad_h, pad_w 367 | 368 | def forward(self, x): 369 | pad_h, pad_w = self.get_pad_shape(x.size()[-2:]) 370 | if pad_h > 0 or pad_w > 0: 371 | if self.padding == 'corner': 372 | x = F.pad(x, [0, pad_w, 0, pad_h]) 373 | elif self.padding == 'same': 374 | x = F.pad(x, [ 375 | pad_w // 2, pad_w - pad_w // 2, pad_h // 2, 376 | pad_h - pad_h // 2 377 | ]) 378 | return x 379 | 380 | class PatchEmbed(BaseModule): 381 | """Image to Patch Embedding. 382 | We use a conv layer to implement PatchEmbed. 383 | Args: 384 | in_channels (int): The num of input channels. Default: 3 385 | embed_dims (int): The dimensions of embedding. Default: 768 386 | conv_type (str): The config dict for embedding 387 | conv layer type selection. Default: "Conv2d. 388 | kernel_size (int): The kernel_size of embedding conv. Default: 16. 389 | stride (int): The slide stride of embedding conv. 390 | Default: None (Would be set as `kernel_size`). 391 | padding (int | tuple | string ): The padding length of 392 | embedding conv. When it is a string, it means the mode 393 | of adaptive padding, support "same" and "corner" now. 394 | Default: "corner". 395 | dilation (int): The dilation rate of embedding conv. Default: 1. 396 | bias (bool): Bias of embed conv. Default: True. 397 | norm_cfg (dict, optional): Config dict for normalization layer. 398 | Default: None. 399 | input_size (int | tuple | None): The size of input, which will be 400 | used to calculate the out size. Only work when `dynamic_size` 401 | is False. Default: None. 402 | init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization. 403 | Default: None. 404 | """ 405 | 406 | def __init__( 407 | self, 408 | in_channels=3, 409 | embed_dims=768, 410 | conv_type='Conv2d', 411 | kernel_size=16, 412 | stride=16, 413 | padding='corner', 414 | dilation=1, 415 | bias=True, 416 | norm_cfg=None, 417 | input_size=None, 418 | init_cfg=None, 419 | ): 420 | super(PatchEmbed, self).__init__() 421 | 422 | self.embed_dims = embed_dims 423 | if stride is None: 424 | stride = kernel_size 425 | 426 | kernel_size = to_2tuple(kernel_size) 427 | stride = to_2tuple(stride) 428 | dilation = to_2tuple(dilation) 429 | 430 | if isinstance(padding, str): 431 | self.adap_padding = AdaptivePadding( 432 | kernel_size=kernel_size, 433 | stride=stride, 434 | dilation=dilation, 435 | padding=padding) 436 | # disable the padding of conv 437 | padding = 0 438 | else: 439 | self.adap_padding = None 440 | padding = to_2tuple(padding) 441 | 442 | self.projection = build_conv_layer( 443 | dict(type=conv_type), 444 | in_channels=in_channels, 445 | out_channels=embed_dims, 446 | kernel_size=kernel_size, 447 | stride=stride, 448 | padding=padding, 449 | dilation=dilation, 450 | bias=bias) 451 | 452 | if norm_cfg is not None: 453 | self.norm = build_norm_layer(norm_cfg, embed_dims)[1] 454 | else: 455 | self.norm = None 456 | 457 | if input_size: 458 | input_size = to_2tuple(input_size) 459 | # `init_out_size` would be used outside to 460 | # calculate the num_patches 461 | # when `use_abs_pos_embed` outside 462 | self.init_input_size = input_size 463 | if self.adap_padding: 464 | pad_h, pad_w = self.adap_padding.get_pad_shape(input_size) 465 | input_h, input_w = input_size 466 | input_h = input_h + pad_h 467 | input_w = input_w + pad_w 468 | input_size = (input_h, input_w) 469 | 470 | # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html 471 | h_out = (input_size[0] + 2 * padding[0] - dilation[0] * 472 | (kernel_size[0] - 1) - 1) // stride[0] + 1 473 | w_out = (input_size[1] + 2 * padding[1] - dilation[1] * 474 | (kernel_size[1] - 1) - 1) // stride[1] + 1 475 | self.init_out_size = (h_out, w_out) 476 | else: 477 | self.init_input_size = None 478 | self.init_out_size = None 479 | 480 | def forward(self, x): 481 | """ 482 | Args: 483 | x (Tensor): Has shape (B, C, H, W). In most case, C is 3. 484 | Returns: 485 | tuple: Contains merged results and its spatial shape. 486 | - x (Tensor): Has shape (B, out_h * out_w, embed_dims) 487 | - out_size (tuple[int]): Spatial shape of x, arrange as 488 | (out_h, out_w). 489 | """ 490 | 491 | if self.adap_padding: 492 | x = self.adap_padding(x) 493 | 494 | x = self.projection(x) 495 | out_size = (x.shape[2], x.shape[3]) 496 | x = x.flatten(2).transpose(1, 2) 497 | if self.norm is not None: 498 | x = self.norm(x) 499 | return x, out_size 500 | 501 | 502 | class PatchMerging(BaseModule): 503 | """Merge patch feature map. 504 | This layer groups feature map by kernel_size, and applies norm and linear 505 | layers to the grouped feature map. Our implementation uses `nn.Unfold` to 506 | merge patch, which is about 25% faster than original implementation. 507 | Instead, we need to modify pretrained models for compatibility. 508 | Args: 509 | in_channels (int): The num of input channels. 510 | to gets fully covered by filter and stride you specified.. 511 | Default: True. 512 | out_channels (int): The num of output channels. 513 | kernel_size (int | tuple, optional): the kernel size in the unfold 514 | layer. Defaults to 2. 515 | stride (int | tuple, optional): the stride of the sliding blocks in the 516 | unfold layer. Default: None. (Would be set as `kernel_size`) 517 | padding (int | tuple | string ): The padding length of 518 | embedding conv. When it is a string, it means the mode 519 | of adaptive padding, support "same" and "corner" now. 520 | Default: "corner". 521 | dilation (int | tuple, optional): dilation parameter in the unfold 522 | layer. Default: 1. 523 | bias (bool, optional): Whether to add bias in linear layer or not. 524 | Defaults: False. 525 | norm_cfg (dict, optional): Config dict for normalization layer. 526 | Default: dict(type='LN'). 527 | init_cfg (dict, optional): The extra config for initialization. 528 | Default: None. 529 | """ 530 | 531 | def __init__(self, 532 | in_channels, 533 | out_channels, 534 | kernel_size=2, 535 | stride=None, 536 | padding='corner', 537 | dilation=1, 538 | bias=False, 539 | norm_cfg=dict(type='LN'), 540 | init_cfg=None): 541 | super().__init__() 542 | self.in_channels = in_channels 543 | self.out_channels = out_channels 544 | if stride: 545 | stride = stride 546 | else: 547 | stride = kernel_size 548 | 549 | kernel_size = to_2tuple(kernel_size) 550 | stride = to_2tuple(stride) 551 | dilation = to_2tuple(dilation) 552 | 553 | if isinstance(padding, str): 554 | self.adap_padding = AdaptivePadding( 555 | kernel_size=kernel_size, 556 | stride=stride, 557 | dilation=dilation, 558 | padding=padding) 559 | # disable the padding of unfold 560 | padding = 0 561 | else: 562 | self.adap_padding = None 563 | 564 | padding = to_2tuple(padding) 565 | self.sampler = nn.Unfold( 566 | kernel_size=kernel_size, 567 | dilation=dilation, 568 | padding=padding, 569 | stride=stride) 570 | 571 | sample_dim = kernel_size[0] * kernel_size[1] * in_channels 572 | 573 | if norm_cfg is not None: 574 | self.norm = build_norm_layer(norm_cfg, sample_dim)[1] 575 | else: 576 | self.norm = None 577 | 578 | self.reduction = nn.Linear(sample_dim, out_channels, bias=bias) 579 | 580 | def forward(self, x, input_size): 581 | """ 582 | Args: 583 | x (Tensor): Has shape (B, H*W, C_in). 584 | input_size (tuple[int]): The spatial shape of x, arrange as (H, W). 585 | Default: None. 586 | Returns: 587 | tuple: Contains merged results and its spatial shape. 588 | - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out) 589 | - out_size (tuple[int]): Spatial shape of x, arrange as 590 | (Merged_H, Merged_W). 591 | """ 592 | B, L, C = x.shape 593 | assert isinstance(input_size, Sequence), f'Expect ' \ 594 | f'input_size is ' \ 595 | f'`Sequence` ' \ 596 | f'but get {input_size}' 597 | 598 | H, W = input_size 599 | assert L == H * W, 'input feature has wrong size' 600 | 601 | x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W 602 | # Use nn.Unfold to merge patch. About 25% faster than original method, 603 | # but need to modify pretrained model for compatibility 604 | 605 | if self.adap_padding: 606 | x = self.adap_padding(x) 607 | H, W = x.shape[-2:] 608 | 609 | x = self.sampler(x) 610 | # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2) 611 | 612 | out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * 613 | (self.sampler.kernel_size[0] - 1) - 614 | 1) // self.sampler.stride[0] + 1 615 | out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * 616 | (self.sampler.kernel_size[1] - 1) - 617 | 1) // self.sampler.stride[1] + 1 618 | 619 | output_size = (out_h, out_w) 620 | x = x.transpose(1, 2) # B, H/2*W/2, 4*C 621 | x = self.norm(x) if self.norm else x 622 | x = self.reduction(x) 623 | return x, output_size 624 | 625 | class WindowMSA(BaseModule): 626 | """Window based multi-head self-attention (W-MSA) module with relative 627 | position bias. 628 | Args: 629 | embed_dims (int): Number of input channels. 630 | num_heads (int): Number of attention heads. 631 | window_size (tuple[int]): The height and width of the window. 632 | qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. 633 | Default: True. 634 | qk_scale (float | None, optional): Override default qk scale of 635 | head_dim ** -0.5 if set. Default: None. 636 | attn_drop_rate (float, optional): Dropout ratio of attention weight. 637 | Default: 0.0 638 | proj_drop_rate (float, optional): Dropout ratio of output. Default: 0. 639 | init_cfg (dict | None, optional): The Config for initialization. 640 | Default: None. 641 | """ 642 | 643 | def __init__(self, 644 | embed_dims, 645 | num_heads, 646 | window_size, 647 | qkv_bias=True, 648 | qk_scale=None, 649 | attn_drop_rate=0., 650 | proj_drop_rate=0., 651 | init_cfg=None): 652 | 653 | super().__init__() 654 | self.embed_dims = embed_dims 655 | self.window_size = window_size # Wh, Ww 656 | self.num_heads = num_heads 657 | head_embed_dims = embed_dims // num_heads 658 | self.scale = qk_scale or head_embed_dims**-0.5 659 | self.init_cfg = init_cfg 660 | 661 | # define a parameter table of relative position bias 662 | self.relative_position_bias_table = nn.Parameter( 663 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), 664 | num_heads)) # 2*Wh-1 * 2*Ww-1, nH 665 | 666 | # About 2x faster than original impl 667 | Wh, Ww = self.window_size 668 | rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww) 669 | rel_position_index = rel_index_coords + rel_index_coords.T 670 | rel_position_index = rel_position_index.flip(1).contiguous() 671 | self.register_buffer('relative_position_index', rel_position_index) 672 | 673 | self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) 674 | self.attn_drop = nn.Dropout(attn_drop_rate) 675 | self.proj = nn.Linear(embed_dims, embed_dims) 676 | self.proj_drop = nn.Dropout(proj_drop_rate) 677 | 678 | self.softmax = nn.Softmax(dim=-1) 679 | 680 | def init_weights(self): 681 | trunc_normal_(self.relative_position_bias_table, std=0.02) 682 | 683 | def forward(self, x, mask=None): 684 | """ 685 | Args: 686 | x (tensor): input features with shape of (num_windows*B, N, C) 687 | mask (tensor | None, Optional): mask with shape of (num_windows, 688 | Wh*Ww, Wh*Ww), value should be between (-inf, 0]. 689 | """ 690 | B, N, C = x.shape 691 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, 692 | C // self.num_heads).permute(2, 0, 3, 1, 4) 693 | # make torchscript happy (cannot use tensor as tuple) 694 | q, k, v = qkv[0], qkv[1], qkv[2] 695 | 696 | q = q * self.scale 697 | attn = (q @ k.transpose(-2, -1)) 698 | 699 | relative_position_bias = self.relative_position_bias_table[ 700 | self.relative_position_index.view(-1)].view( 701 | self.window_size[0] * self.window_size[1], 702 | self.window_size[0] * self.window_size[1], 703 | -1) # Wh*Ww,Wh*Ww,nH 704 | relative_position_bias = relative_position_bias.permute( 705 | 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 706 | attn = attn + relative_position_bias.unsqueeze(0) 707 | 708 | if mask is not None: 709 | nW = mask.shape[0] 710 | attn = attn.view(B // nW, nW, self.num_heads, N, 711 | N) + mask.unsqueeze(1).unsqueeze(0) 712 | attn = attn.view(-1, self.num_heads, N, N) 713 | attn = self.softmax(attn) 714 | 715 | attn = self.attn_drop(attn) 716 | 717 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 718 | x = self.proj(x) 719 | x = self.proj_drop(x) 720 | return x 721 | 722 | @staticmethod 723 | def double_step_seq(step1, len1, step2, len2): 724 | seq1 = torch.arange(0, step1 * len1, step1) 725 | seq2 = torch.arange(0, step2 * len2, step2) 726 | return (seq1[:, None] + seq2[None, :]).reshape(1, -1) 727 | 728 | 729 | class ShiftWindowMSA(BaseModule): 730 | """Shifted Window Multihead Self-Attention Module. 731 | Args: 732 | embed_dims (int): Number of input channels. 733 | num_heads (int): Number of attention heads. 734 | window_size (int): The height and width of the window. 735 | shift_size (int, optional): The shift step of each window towards 736 | right-bottom. If zero, act as regular window-msa. Defaults to 0. 737 | qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. 738 | Default: True 739 | qk_scale (float | None, optional): Override default qk scale of 740 | head_dim ** -0.5 if set. Defaults: None. 741 | attn_drop_rate (float, optional): Dropout ratio of attention weight. 742 | Defaults: 0. 743 | proj_drop_rate (float, optional): Dropout ratio of output. 744 | Defaults: 0. 745 | dropout_layer (dict, optional): The dropout_layer used before output. 746 | Defaults: dict(type='DropPath', drop_prob=0.). 747 | init_cfg (dict, optional): The extra config for initialization. 748 | Default: None. 749 | """ 750 | 751 | def __init__(self, 752 | embed_dims, 753 | num_heads, 754 | window_size, 755 | shift_size=0, 756 | qkv_bias=True, 757 | qk_scale=None, 758 | attn_drop_rate=0, 759 | proj_drop_rate=0, 760 | dropout_layer=dict(type='DropPath', drop_prob=0.), 761 | init_cfg=None): 762 | super().__init__() 763 | 764 | self.window_size = window_size 765 | self.shift_size = shift_size 766 | assert 0 <= self.shift_size < self.window_size 767 | 768 | self.w_msa = WindowMSA( 769 | embed_dims=embed_dims, 770 | num_heads=num_heads, 771 | window_size=to_2tuple(window_size), 772 | qkv_bias=qkv_bias, 773 | qk_scale=qk_scale, 774 | attn_drop_rate=attn_drop_rate, 775 | proj_drop_rate=proj_drop_rate, 776 | init_cfg=None) 777 | 778 | self.drop = build_dropout(dropout_layer) 779 | 780 | def forward(self, query, hw_shape): 781 | B, L, C = query.shape 782 | H, W = hw_shape 783 | assert L == H * W, 'input feature has wrong size' 784 | query = query.view(B, H, W, C) 785 | 786 | # pad feature maps to multiples of window size 787 | pad_r = (self.window_size - W % self.window_size) % self.window_size 788 | pad_b = (self.window_size - H % self.window_size) % self.window_size 789 | query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b)) 790 | H_pad, W_pad = query.shape[1], query.shape[2] 791 | 792 | # cyclic shift 793 | if self.shift_size > 0: 794 | shifted_query = torch.roll( 795 | query, 796 | shifts=(-self.shift_size, -self.shift_size), 797 | dims=(1, 2)) 798 | 799 | # calculate attention mask for SW-MSA 800 | img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device) 801 | h_slices = (slice(0, -self.window_size), 802 | slice(-self.window_size, 803 | -self.shift_size), slice(-self.shift_size, None)) 804 | w_slices = (slice(0, -self.window_size), 805 | slice(-self.window_size, 806 | -self.shift_size), slice(-self.shift_size, None)) 807 | cnt = 0 808 | for h in h_slices: 809 | for w in w_slices: 810 | img_mask[:, h, w, :] = cnt 811 | cnt += 1 812 | 813 | # nW, window_size, window_size, 1 814 | mask_windows = self.window_partition(img_mask) 815 | mask_windows = mask_windows.view( 816 | -1, self.window_size * self.window_size) 817 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 818 | attn_mask = attn_mask.masked_fill(attn_mask != 0, 819 | float(-100.0)).masked_fill( 820 | attn_mask == 0, float(0.0)) 821 | else: 822 | shifted_query = query 823 | attn_mask = None 824 | 825 | # nW*B, window_size, window_size, C 826 | query_windows = self.window_partition(shifted_query) 827 | # nW*B, window_size*window_size, C 828 | query_windows = query_windows.view(-1, self.window_size**2, C) 829 | 830 | # W-MSA/SW-MSA (nW*B, window_size*window_size, C) 831 | attn_windows = self.w_msa(query_windows, mask=attn_mask) 832 | 833 | # merge windows 834 | attn_windows = attn_windows.view(-1, self.window_size, 835 | self.window_size, C) 836 | 837 | # B H' W' C 838 | shifted_x = self.window_reverse(attn_windows, H_pad, W_pad) 839 | # reverse cyclic shift 840 | if self.shift_size > 0: 841 | x = torch.roll( 842 | shifted_x, 843 | shifts=(self.shift_size, self.shift_size), 844 | dims=(1, 2)) 845 | else: 846 | x = shifted_x 847 | 848 | if pad_r > 0 or pad_b: 849 | x = x[:, :H, :W, :].contiguous() 850 | 851 | x = x.view(B, H * W, C) 852 | 853 | x = self.drop(x) 854 | return x 855 | 856 | def window_reverse(self, windows, H, W): 857 | """ 858 | Args: 859 | windows: (num_windows*B, window_size, window_size, C) 860 | H (int): Height of image 861 | W (int): Width of image 862 | Returns: 863 | x: (B, H, W, C) 864 | """ 865 | window_size = self.window_size 866 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 867 | x = windows.view(B, H // window_size, W // window_size, window_size, 868 | window_size, -1) 869 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 870 | return x 871 | 872 | def window_partition(self, x): 873 | """ 874 | Args: 875 | x: (B, H, W, C) 876 | Returns: 877 | windows: (num_windows*B, window_size, window_size, C) 878 | """ 879 | B, H, W, C = x.shape 880 | window_size = self.window_size 881 | x = x.view(B, H // window_size, window_size, W // window_size, 882 | window_size, C) 883 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() 884 | windows = windows.view(-1, window_size, window_size, C) 885 | return windows 886 | 887 | 888 | class SwinBlock(BaseModule): 889 | """" 890 | Args: 891 | embed_dims (int): The feature dimension. 892 | num_heads (int): Parallel attention heads. 893 | feedforward_channels (int): The hidden dimension for FFNs. 894 | window_size (int, optional): The local window scale. Default: 7. 895 | shift (bool, optional): whether to shift window or not. Default False. 896 | qkv_bias (bool, optional): enable bias for qkv if True. Default: True. 897 | qk_scale (float | None, optional): Override default qk scale of 898 | head_dim ** -0.5 if set. Default: None. 899 | drop_rate (float, optional): Dropout rate. Default: 0. 900 | attn_drop_rate (float, optional): Attention dropout rate. Default: 0. 901 | drop_path_rate (float, optional): Stochastic depth rate. Default: 0. 902 | act_cfg (dict, optional): The config dict of activation function. 903 | Default: dict(type='GELU'). 904 | norm_cfg (dict, optional): The config dict of normalization. 905 | Default: dict(type='LN'). 906 | with_cp (bool, optional): Use checkpoint or not. Using checkpoint 907 | will save some memory while slowing down the training speed. 908 | Default: False. 909 | init_cfg (dict | list | None, optional): The init config. 910 | Default: None. 911 | """ 912 | 913 | def __init__(self, 914 | embed_dims, 915 | num_heads, 916 | feedforward_channels, 917 | window_size=7, 918 | shift=False, 919 | qkv_bias=True, 920 | qk_scale=None, 921 | drop_rate=0., 922 | attn_drop_rate=0., 923 | drop_path_rate=0., 924 | act_cfg=dict(type='GELU'), 925 | norm_cfg=dict(type='LN'), 926 | with_cp=False, 927 | init_cfg=None): 928 | 929 | super(SwinBlock, self).__init__() 930 | 931 | self.init_cfg = init_cfg 932 | self.with_cp = with_cp 933 | 934 | self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] 935 | self.attn = ShiftWindowMSA( 936 | embed_dims=embed_dims, 937 | num_heads=num_heads, 938 | window_size=window_size, 939 | shift_size=window_size // 2 if shift else 0, 940 | qkv_bias=qkv_bias, 941 | qk_scale=qk_scale, 942 | attn_drop_rate=attn_drop_rate, 943 | proj_drop_rate=drop_rate, 944 | dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), 945 | init_cfg=None) 946 | 947 | self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] 948 | self.ffn = FFN( 949 | embed_dims=embed_dims, 950 | feedforward_channels=feedforward_channels, 951 | num_fcs=2, 952 | ffn_drop=drop_rate, 953 | dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), 954 | act_cfg=act_cfg, 955 | add_identity=True, 956 | init_cfg=None) 957 | 958 | def forward(self, x, hw_shape): 959 | 960 | def _inner_forward(x): 961 | identity = x 962 | x = self.norm1(x) 963 | x = self.attn(x, hw_shape) 964 | 965 | x = x + identity 966 | 967 | identity = x 968 | x = self.norm2(x) 969 | x = self.ffn(x, identity=identity) 970 | 971 | return x 972 | 973 | if self.with_cp and x.requires_grad: 974 | x = cp.checkpoint(_inner_forward, x) 975 | else: 976 | x = _inner_forward(x) 977 | 978 | return x 979 | 980 | 981 | class SwinBlockSequence(BaseModule): 982 | """Implements one stage in Swin Transformer. 983 | Args: 984 | embed_dims (int): The feature dimension. 985 | num_heads (int): Parallel attention heads. 986 | feedforward_channels (int): The hidden dimension for FFNs. 987 | depth (int): The number of blocks in this stage. 988 | window_size (int, optional): The local window scale. Default: 7. 989 | qkv_bias (bool, optional): enable bias for qkv if True. Default: True. 990 | qk_scale (float | None, optional): Override default qk scale of 991 | head_dim ** -0.5 if set. Default: None. 992 | drop_rate (float, optional): Dropout rate. Default: 0. 993 | attn_drop_rate (float, optional): Attention dropout rate. Default: 0. 994 | drop_path_rate (float | list[float], optional): Stochastic depth 995 | rate. Default: 0. 996 | downsample (BaseModule | None, optional): The downsample operation 997 | module. Default: None. 998 | act_cfg (dict, optional): The config dict of activation function. 999 | Default: dict(type='GELU'). 1000 | norm_cfg (dict, optional): The config dict of normalization. 1001 | Default: dict(type='LN'). 1002 | with_cp (bool, optional): Use checkpoint or not. Using checkpoint 1003 | will save some memory while slowing down the training speed. 1004 | Default: False. 1005 | init_cfg (dict | list | None, optional): The init config. 1006 | Default: None. 1007 | """ 1008 | 1009 | def __init__(self, 1010 | embed_dims, 1011 | num_heads, 1012 | feedforward_channels, 1013 | depth, 1014 | window_size=7, 1015 | qkv_bias=True, 1016 | qk_scale=None, 1017 | drop_rate=0., 1018 | attn_drop_rate=0., 1019 | drop_path_rate=0., 1020 | downsample=None, 1021 | act_cfg=dict(type='GELU'), 1022 | norm_cfg=dict(type='LN'), 1023 | with_cp=False, 1024 | init_cfg=None): 1025 | super().__init__() 1026 | 1027 | if isinstance(drop_path_rate, list): 1028 | drop_path_rates = drop_path_rate 1029 | assert len(drop_path_rates) == depth 1030 | else: 1031 | drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)] 1032 | 1033 | self.blocks = ModuleList() 1034 | for i in range(depth): 1035 | block = SwinBlock( 1036 | embed_dims=embed_dims, 1037 | num_heads=num_heads, 1038 | feedforward_channels=feedforward_channels, 1039 | window_size=window_size, 1040 | shift=False if i % 2 == 0 else True, 1041 | qkv_bias=qkv_bias, 1042 | qk_scale=qk_scale, 1043 | drop_rate=drop_rate, 1044 | attn_drop_rate=attn_drop_rate, 1045 | drop_path_rate=drop_path_rates[i], 1046 | act_cfg=act_cfg, 1047 | norm_cfg=norm_cfg, 1048 | with_cp=with_cp, 1049 | init_cfg=None) 1050 | self.blocks.append(block) 1051 | 1052 | self.downsample = downsample 1053 | 1054 | def forward(self, x, hw_shape): 1055 | for block in self.blocks: 1056 | x = block(x, hw_shape) 1057 | 1058 | if self.downsample: 1059 | x_down, down_hw_shape = self.downsample(x, hw_shape) 1060 | return x_down, down_hw_shape, x, hw_shape 1061 | else: 1062 | return x, hw_shape, x, hw_shape 1063 | 1064 | class SwinTransformer(BaseModule): 1065 | """ Swin Transformer 1066 | A PyTorch implement of : `Swin Transformer: 1067 | Hierarchical Vision Transformer using Shifted Windows` - 1068 | https://arxiv.org/abs/2103.14030 1069 | Inspiration from 1070 | https://github.com/microsoft/Swin-Transformer 1071 | Args: 1072 | pretrain_img_size (int | tuple[int]): The size of input image when 1073 | pretrain. Defaults: 224. 1074 | in_channels (int): The num of input channels. 1075 | Defaults: 3. 1076 | embed_dims (int): The feature dimension. Default: 96. 1077 | patch_size (int | tuple[int]): Patch size. Default: 4. 1078 | window_size (int): Window size. Default: 7. 1079 | mlp_ratio (int): Ratio of mlp hidden dim to embedding dim. 1080 | Default: 4. 1081 | depths (tuple[int]): Depths of each Swin Transformer stage. 1082 | Default: (2, 2, 6, 2). 1083 | num_heads (tuple[int]): Parallel attention heads of each Swin 1084 | Transformer stage. Default: (3, 6, 12, 24). 1085 | strides (tuple[int]): The patch merging or patch embedding stride of 1086 | each Swin Transformer stage. (In swin, we set kernel size equal to 1087 | stride.) Default: (4, 2, 2, 2). 1088 | out_indices (tuple[int]): Output from which stages. 1089 | Default: (0, 1, 2, 3). 1090 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, 1091 | value. Default: True 1092 | qk_scale (float | None, optional): Override default qk scale of 1093 | head_dim ** -0.5 if set. Default: None. 1094 | patch_norm (bool): If add a norm layer for patch embed and patch 1095 | merging. Default: True. 1096 | drop_rate (float): Dropout rate. Defaults: 0. 1097 | attn_drop_rate (float): Attention dropout rate. Default: 0. 1098 | drop_path_rate (float): Stochastic depth rate. Defaults: 0.1. 1099 | use_abs_pos_embed (bool): If True, add absolute position embedding to 1100 | the patch embedding. Defaults: False. 1101 | act_cfg (dict): Config dict for activation layer. 1102 | Default: dict(type='LN'). 1103 | norm_cfg (dict): Config dict for normalization layer at 1104 | output of backone. Defaults: dict(type='LN'). 1105 | with_cp (bool, optional): Use checkpoint or not. Using checkpoint 1106 | will save some memory while slowing down the training speed. 1107 | Default: False. 1108 | pretrained (str, optional): model pretrained path. Default: None. 1109 | convert_weights (bool): The flag indicates whether the 1110 | pre-trained model is from the original repo. We may need 1111 | to convert some keys to make it compatible. 1112 | Default: False. 1113 | frozen_stages (int): Stages to be frozen (stop grad and set eval mode). 1114 | -1 means not freezing any parameters. 1115 | init_cfg (dict, optional): The Config for initialization. 1116 | Defaults to None. 1117 | """ 1118 | 1119 | def __init__(self, 1120 | pretrain_img_size=224, 1121 | in_channels=3, 1122 | embed_dims=96, 1123 | patch_size=4, 1124 | window_size=7, 1125 | mlp_ratio=4, 1126 | depths=(2, 2, 6, 2), 1127 | num_heads=(3, 6, 12, 24), 1128 | strides=(4, 2, 2, 2), 1129 | out_indices=(0, 1, 2, 3), 1130 | qkv_bias=True, 1131 | qk_scale=None, 1132 | patch_norm=True, 1133 | drop_rate=0., 1134 | attn_drop_rate=0., 1135 | drop_path_rate=0.1, 1136 | use_abs_pos_embed=False, 1137 | act_cfg=dict(type='GELU'), 1138 | norm_cfg=dict(type='LN'), 1139 | with_cp=False, 1140 | pretrained=None, 1141 | convert_weights=False, 1142 | frozen_stages=-1, 1143 | init_cfg=None, 1144 | semantic_weight=0.0): 1145 | self.convert_weights = convert_weights 1146 | self.frozen_stages = frozen_stages 1147 | if isinstance(pretrain_img_size, int): 1148 | pretrain_img_size = to_2tuple(pretrain_img_size) 1149 | elif isinstance(pretrain_img_size, tuple): 1150 | if len(pretrain_img_size) == 1: 1151 | pretrain_img_size = to_2tuple(pretrain_img_size[0]) 1152 | assert len(pretrain_img_size) == 2, \ 1153 | f'The size of image should have length 1 or 2, ' \ 1154 | f'but got {len(pretrain_img_size)}' 1155 | 1156 | assert not (init_cfg and pretrained), \ 1157 | 'init_cfg and pretrained cannot be specified at the same time' 1158 | if isinstance(pretrained, str): 1159 | warnings.warn('DeprecationWarning: pretrained is deprecated, ' 1160 | 'please use "init_cfg" instead') 1161 | self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) 1162 | elif pretrained is None: 1163 | self.init_cfg = init_cfg 1164 | else: 1165 | raise TypeError('pretrained must be a str or None') 1166 | 1167 | super(SwinTransformer, self).__init__() 1168 | 1169 | num_layers = len(depths) 1170 | self.out_indices = out_indices 1171 | self.use_abs_pos_embed = use_abs_pos_embed 1172 | 1173 | assert strides[0] == patch_size, 'Use non-overlapping patch embed.' 1174 | 1175 | self.patch_embed = PatchEmbed( 1176 | in_channels=in_channels, 1177 | embed_dims=embed_dims, 1178 | conv_type='Conv2d', 1179 | kernel_size=patch_size, 1180 | stride=strides[0], 1181 | norm_cfg=norm_cfg if patch_norm else None, 1182 | init_cfg=None) 1183 | 1184 | if self.use_abs_pos_embed: 1185 | patch_row = pretrain_img_size[0] // patch_size 1186 | patch_col = pretrain_img_size[1] // patch_size 1187 | num_patches = patch_row * patch_col 1188 | self.absolute_pos_embed = nn.Parameter( 1189 | torch.zeros((1, num_patches, embed_dims))) 1190 | 1191 | self.drop_after_pos = nn.Dropout(p=drop_rate) 1192 | 1193 | # set stochastic depth decay rule 1194 | total_depth = sum(depths) 1195 | dpr = [ 1196 | x.item() for x in torch.linspace(0, drop_path_rate, total_depth) 1197 | ] 1198 | 1199 | self.stages = ModuleList() 1200 | in_channels = embed_dims 1201 | for i in range(num_layers): 1202 | if i < num_layers - 1: 1203 | downsample = PatchMerging( 1204 | in_channels=in_channels, 1205 | out_channels=2 * in_channels, 1206 | stride=strides[i + 1], 1207 | norm_cfg=norm_cfg if patch_norm else None, 1208 | init_cfg=None) 1209 | else: 1210 | downsample = None 1211 | 1212 | stage = SwinBlockSequence( 1213 | embed_dims=in_channels, 1214 | num_heads=num_heads[i], 1215 | feedforward_channels=mlp_ratio * in_channels, 1216 | depth=depths[i], 1217 | window_size=window_size, 1218 | qkv_bias=qkv_bias, 1219 | qk_scale=qk_scale, 1220 | drop_rate=drop_rate, 1221 | attn_drop_rate=attn_drop_rate, 1222 | drop_path_rate=dpr[sum(depths[:i]):sum(depths[:i + 1])], 1223 | downsample=downsample, 1224 | act_cfg=act_cfg, 1225 | norm_cfg=norm_cfg, 1226 | with_cp=with_cp, 1227 | init_cfg=None) 1228 | self.stages.append(stage) 1229 | if downsample: 1230 | in_channels = downsample.out_channels 1231 | 1232 | self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)] 1233 | # Add a norm layer for each output 1234 | for i in out_indices: 1235 | layer = build_norm_layer(norm_cfg, self.num_features[i])[1] 1236 | layer_name = f'norm{i}' 1237 | self.add_module(layer_name, layer) 1238 | 1239 | self.avgpool = nn.AdaptiveAvgPool2d((1,1)) 1240 | 1241 | # semantic embedding 1242 | self.semantic_weight = semantic_weight 1243 | if self.semantic_weight >= 0: 1244 | self.semantic_embed_w = ModuleList() 1245 | self.semantic_embed_b = ModuleList() 1246 | for i in range(len(depths)): 1247 | if i >= len(depths) - 1: 1248 | i = len(depths) - 2 1249 | semantic_embed_w = nn.Linear(2, self.num_features[i+1]) 1250 | semantic_embed_b = nn.Linear(2, self.num_features[i+1]) 1251 | trunc_normal_init(semantic_embed_w, std=.02, bias=0.) 1252 | trunc_normal_init(semantic_embed_b, std=.02, bias=0.) 1253 | self.semantic_embed_w.append(semantic_embed_w) 1254 | self.semantic_embed_b.append(semantic_embed_b) 1255 | self.softplus = nn.Softplus() 1256 | 1257 | def train(self, mode=True): 1258 | """Convert the model into training mode while keep layers freezed.""" 1259 | super(SwinTransformer, self).train(mode) 1260 | self._freeze_stages() 1261 | 1262 | def _freeze_stages(self): 1263 | if self.frozen_stages >= 0: 1264 | self.patch_embed.eval() 1265 | for param in self.patch_embed.parameters(): 1266 | param.requires_grad = False 1267 | if self.use_abs_pos_embed: 1268 | self.absolute_pos_embed.requires_grad = False 1269 | self.drop_after_pos.eval() 1270 | 1271 | for i in range(1, self.frozen_stages + 1): 1272 | 1273 | if (i - 1) in self.out_indices: 1274 | norm_layer = getattr(self, f'norm{i-1}') 1275 | norm_layer.eval() 1276 | for param in norm_layer.parameters(): 1277 | param.requires_grad = False 1278 | 1279 | m = self.stages[i - 1] 1280 | m.eval() 1281 | for param in m.parameters(): 1282 | param.requires_grad = False 1283 | 1284 | def init_weights(self, pretrained=None): 1285 | logger = logging.getLogger("loading parameters.") 1286 | if pretrained is None: 1287 | logger.warn(f'No pre-trained weights for ' 1288 | f'{self.__class__.__name__}, ' 1289 | f'training start from scratch') 1290 | if self.use_abs_pos_embed: 1291 | trunc_normal_(self.absolute_pos_embed, std=0.02) 1292 | for m in self.modules(): 1293 | if isinstance(m, nn.Linear): 1294 | trunc_normal_init(m, std=.02, bias=0.) 1295 | elif isinstance(m, nn.LayerNorm): 1296 | constant_init(m.bias, 0) 1297 | constant_init(m.weight, 1.0) 1298 | else: 1299 | ckpt = torch.load(pretrained,map_location='cpu') 1300 | if 'teacher' in ckpt: 1301 | ckpt = ckpt['teacher'] 1302 | 1303 | if 'state_dict' in ckpt: 1304 | _state_dict = ckpt['state_dict'] 1305 | elif 'model' in ckpt: 1306 | _state_dict = ckpt['model'] 1307 | else: 1308 | _state_dict = ckpt 1309 | if self.convert_weights: 1310 | # supported loading weight from original repo, 1311 | _state_dict = swin_converter(_state_dict) 1312 | 1313 | state_dict = OrderedDict() 1314 | for k, v in _state_dict.items(): 1315 | if k.startswith('backbone.'): 1316 | state_dict[k[9:]] = v 1317 | 1318 | # strip prefix of state_dict 1319 | if list(state_dict.keys())[0].startswith('module.'): 1320 | state_dict = {k[7:]: v for k, v in state_dict.items()} 1321 | 1322 | # reshape absolute position embedding 1323 | if state_dict.get('absolute_pos_embed') is not None: 1324 | absolute_pos_embed = state_dict['absolute_pos_embed'] 1325 | N1, L, C1 = absolute_pos_embed.size() 1326 | N2, C2, H, W = self.absolute_pos_embed.size() 1327 | if N1 != N2 or C1 != C2 or L != H * W: 1328 | logger.warning('Error in loading absolute_pos_embed, pass') 1329 | else: 1330 | state_dict['absolute_pos_embed'] = absolute_pos_embed.view( 1331 | N2, H, W, C2).permute(0, 3, 1, 2).contiguous() 1332 | 1333 | # interpolate position bias table if needed 1334 | relative_position_bias_table_keys = [ 1335 | k for k in state_dict.keys() 1336 | if 'relative_position_bias_table' in k 1337 | ] 1338 | for table_key in relative_position_bias_table_keys: 1339 | table_pretrained = state_dict[table_key] 1340 | table_current = self.state_dict()[table_key] 1341 | L1, nH1 = table_pretrained.size() 1342 | L2, nH2 = table_current.size() 1343 | if nH1 != nH2: 1344 | logger.warning(f'Error in loading {table_key}, pass') 1345 | elif L1 != L2: 1346 | S1 = int(L1**0.5) 1347 | S2 = int(L2**0.5) 1348 | table_pretrained_resized = F.interpolate( 1349 | table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1), 1350 | size=(S2, S2), 1351 | mode='bicubic') 1352 | state_dict[table_key] = table_pretrained_resized.view( 1353 | nH2, L2).permute(1, 0).contiguous() 1354 | 1355 | res = self.load_state_dict(state_dict, False) 1356 | print('unloaded parameters:', res) 1357 | 1358 | def forward(self, x, semantic_weight=None): 1359 | if self.semantic_weight >= 0 and semantic_weight == None: 1360 | w = torch.ones(x.shape[0],1) * self.semantic_weight 1361 | w = torch.cat([w, 1-w], axis=-1) 1362 | semantic_weight = w.cuda() 1363 | 1364 | x, hw_shape = self.patch_embed(x) 1365 | 1366 | if self.use_abs_pos_embed: 1367 | x = x + self.absolute_pos_embed 1368 | x = self.drop_after_pos(x) 1369 | 1370 | outs = [] 1371 | for i, stage in enumerate(self.stages): 1372 | x, hw_shape, out, out_hw_shape = stage(x, hw_shape) 1373 | if self.semantic_weight >= 0: 1374 | sw = self.semantic_embed_w[i](semantic_weight).unsqueeze(1) 1375 | sb = self.semantic_embed_b[i](semantic_weight).unsqueeze(1) 1376 | x = x * self.softplus(sw) + sb 1377 | if i in self.out_indices: 1378 | norm_layer = getattr(self, f'norm{i}') 1379 | out = norm_layer(out) 1380 | out = out.view(-1, *out_hw_shape, 1381 | self.num_features[i]).permute(0, 3, 1, 1382 | 2).contiguous() 1383 | outs.append(out) 1384 | x = self.avgpool(outs[-1]) 1385 | x = torch.flatten(x, 1) 1386 | return x, outs 1387 | 1388 | def swin_base_patch4_window7_224(img_size=224,drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0., **kwargs): 1389 | model = SwinTransformer(pretrain_img_size = img_size, patch_size=4, window_size=7, embed_dims=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), drop_path_rate=drop_path_rate, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, **kwargs) 1390 | return model 1391 | 1392 | def swin_small_patch4_window7_224(img_size=224,drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0., **kwargs): 1393 | model = SwinTransformer(pretrain_img_size = img_size, patch_size=4, window_size=7, embed_dims=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), drop_path_rate=drop_path_rate, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, **kwargs) 1394 | return model 1395 | 1396 | def swin_tiny_patch4_window7_224(img_size=224,drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0., **kwargs): 1397 | model = SwinTransformer(pretrain_img_size = img_size, patch_size=4, window_size=7, embed_dims=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), drop_path_rate=drop_path_rate, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, **kwargs) 1398 | return model 1399 | --------------------------------------------------------------------------------