├── datasets
├── __init__.py
├── base.py
├── prw.py
├── build.py
└── cuhk_sysu.py
├── doc
├── title.jpg
└── net_arch.jpg
├── demo_imgs
├── demo.jpg
├── query.jpg
├── gallery-1.jpg
├── gallery-2.jpg
├── gallery-3.jpg
├── gallery-4.jpg
└── gallery-5.jpg
├── configs
├── prw.yaml
└── cuhk_sysu.yaml
├── requirements.txt
├── convert_model.py
├── dev
└── linter.sh
├── run.sh
├── utils
├── transforms.py
├── km.py
└── utils.py
├── models
├── resnet.py
├── oim.py
├── swin.py
├── backbone.py
├── seqnet.py
└── swin_transformer.py
├── README.md
├── demo.py
├── train.py
├── defaults.py
├── engine.py
└── eval_func.py
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_test_loader, build_train_loader
2 |
--------------------------------------------------------------------------------
/doc/title.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tinyvision/SOLIDER-PersonSearch/HEAD/doc/title.jpg
--------------------------------------------------------------------------------
/demo_imgs/demo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tinyvision/SOLIDER-PersonSearch/HEAD/demo_imgs/demo.jpg
--------------------------------------------------------------------------------
/doc/net_arch.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tinyvision/SOLIDER-PersonSearch/HEAD/doc/net_arch.jpg
--------------------------------------------------------------------------------
/demo_imgs/query.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tinyvision/SOLIDER-PersonSearch/HEAD/demo_imgs/query.jpg
--------------------------------------------------------------------------------
/demo_imgs/gallery-1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tinyvision/SOLIDER-PersonSearch/HEAD/demo_imgs/gallery-1.jpg
--------------------------------------------------------------------------------
/demo_imgs/gallery-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tinyvision/SOLIDER-PersonSearch/HEAD/demo_imgs/gallery-2.jpg
--------------------------------------------------------------------------------
/demo_imgs/gallery-3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tinyvision/SOLIDER-PersonSearch/HEAD/demo_imgs/gallery-3.jpg
--------------------------------------------------------------------------------
/demo_imgs/gallery-4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tinyvision/SOLIDER-PersonSearch/HEAD/demo_imgs/gallery-4.jpg
--------------------------------------------------------------------------------
/demo_imgs/gallery-5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tinyvision/SOLIDER-PersonSearch/HEAD/demo_imgs/gallery-5.jpg
--------------------------------------------------------------------------------
/configs/prw.yaml:
--------------------------------------------------------------------------------
1 | OUTPUT_DIR: "./exp_prw"
2 | INPUT:
3 | DATASET: "PRW"
4 | DATA_ROOT: "data/PRW"
5 | SOLVER:
6 | MAX_EPOCHS: 18
7 | MODEL:
8 | LOSS:
9 | LUT_SIZE: 482
10 | CQ_SIZE: 500
--------------------------------------------------------------------------------
/configs/cuhk_sysu.yaml:
--------------------------------------------------------------------------------
1 | OUTPUT_DIR: "./exp_cuhk"
2 | INPUT:
3 | DATASET: "CUHK-SYSU"
4 | DATA_ROOT: "data/CUHK-SYSU"
5 | SOLVER:
6 | MAX_EPOCHS: 20
7 | MODEL:
8 | LOSS:
9 | LUT_SIZE: 5532
10 | CQ_SIZE: 5000
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | black==20.8b1
2 | flake8==3.9.0
3 | isort==5.8.0
4 | numpy==1.16.4
5 | Pillow==6.1.0
6 | scikit-learn==0.23.1
7 | scipy==1.5.1
8 | tabulate==0.8.7
9 | torch==1.7.1
10 | torchvision==0.8.2
11 | tqdm==4.48.2
12 | yacs==0.1.8
13 | future==0.18.2
14 | tensorboard==2.4.1
15 |
--------------------------------------------------------------------------------
/convert_model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3 |
4 | import pickle as pkl
5 | import sys
6 | import torch
7 |
8 | if __name__ == "__main__":
9 | input = sys.argv[1]
10 | obj = torch.load(input, map_location="cpu")
11 | obj = obj["teacher"]
12 | torch.save(obj,sys.argv[2])
13 |
--------------------------------------------------------------------------------
/dev/linter.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -e
2 |
3 | # Run this script at project root by "./dev/linter.sh" before you commit
4 |
5 | {
6 | black --version | grep -E "20.8b1" > /dev/null
7 | } || {
8 | echo "Linter requires 'black==20.8b1' !"
9 | exit 1
10 | }
11 |
12 | ISORT_VERSION=$(isort --version-number)
13 | if [[ "$ISORT_VERSION" != 5.8.0 ]]; then
14 | echo "Linter requires isort==5.8.0 !"
15 | exit 1
16 | fi
17 |
18 | echo "Running isort ..."
19 | isort --line-length=100 --profile=black .
20 |
21 | echo "Running black ..."
22 | black --line-length=100 .
23 |
24 | echo "Running flake8 ..."
25 | if [ -x "$(command -v flake8-3)" ]; then
26 | flake8-3 .
27 | else
28 | python3 -m flake8 .
29 | fi
30 |
31 | command -v arc > /dev/null && arc lint
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | # Swin Base
2 | #CUDA_VISIBLE_DEVICES=0 python train.py --cfg configs/cuhk_sysu.yaml --resume --ckpt path/to/SOLIDER/log/lup/swin_base/checkpoint_tea.pth OUTPUT_DIR './results/cuhk_sysu/swin_base' SOLVER.BASE_LR 0.0003 EVAL_PERIOD 5 MODEL.BONE 'swin_base' INPUT.BATCH_SIZE_TRAIN 2 MODEL.SEMANTIC_WEIGHT 0.6
3 |
4 | # Swin Small
5 | #CUDA_VISIBLE_DEVICES=0 python train.py --cfg configs/cuhk_sysu.yaml --resume --ckpt path/to/SOLIDER/log/lup/swin_small/checkpoint_tea.pth OUTPUT_DIR './results/cuhk_sysu/swin_small' SOLVER.BASE_LR 0.0003 EVAL_PERIOD 5 MODEL.BONE 'swin_small' INPUT.BATCH_SIZE_TRAIN 3 MODEL.SEMANTIC_WEIGHT 0.6
6 |
7 | # Swin Tiny
8 | CUDA_VISIBLE_DEVICES=0 python train.py --cfg configs/cuhk_sysu.yaml --resume --ckpt path/to/SOLIDER/log/lup/swin_tiny/checkpoint_tea.pth OUTPUT_DIR './results/cuhk_sysu/swin_tiny' SOLVER.BASE_LR 0.0003 EVAL_PERIOD 5 MODEL.BONE 'swin_tiny' INPUT.BATCH_SIZE_TRAIN 4 MODEL.SEMANTIC_WEIGHT 0.6
9 |
--------------------------------------------------------------------------------
/utils/transforms.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from torchvision.transforms import functional as F
4 |
5 |
6 | class Compose:
7 | def __init__(self, transforms):
8 | self.transforms = transforms
9 |
10 | def __call__(self, image, target):
11 | for t in self.transforms:
12 | image, target = t(image, target)
13 | return image, target
14 |
15 |
16 | class RandomHorizontalFlip:
17 | def __init__(self, prob=0.5):
18 | self.prob = prob
19 |
20 | def __call__(self, image, target):
21 | if random.random() < self.prob:
22 | height, width = image.shape[-2:]
23 | image = image.flip(-1)
24 | bbox = target["boxes"]
25 | bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
26 | target["boxes"] = bbox
27 | return image, target
28 |
29 |
30 | class ToTensor:
31 | def __call__(self, image, target):
32 | # convert [0, 255] to [0, 1]
33 | image = F.to_tensor(image)
34 | return image, target
35 |
36 |
37 | def build_transforms(is_train):
38 | transforms = []
39 | transforms.append(ToTensor())
40 | if is_train:
41 | transforms.append(RandomHorizontalFlip())
42 | return Compose(transforms)
43 |
--------------------------------------------------------------------------------
/datasets/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 |
4 |
5 | class BaseDataset:
6 | """
7 | Base class of person search dataset.
8 | """
9 |
10 | def __init__(self, root, transforms, split):
11 | self.root = root
12 | self.transforms = transforms
13 | self.split = split
14 | assert self.split in ("train", "gallery", "query")
15 | self.annotations = self._load_annotations()
16 |
17 | def _load_annotations(self):
18 | """
19 | For each image, load its annotation that is a dictionary with the following keys:
20 | img_name (str): image name
21 | img_path (str): image path
22 | boxes (np.array[N, 4]): ground-truth boxes in (x1, y1, x2, y2) format
23 | pids (np.array[N]): person IDs corresponding to these boxes
24 | cam_id (int): camera ID (only for PRW dataset)
25 | """
26 | raise NotImplementedError
27 |
28 | def __getitem__(self, index):
29 | anno = self.annotations[index]
30 | img = Image.open(anno["img_path"]).convert("RGB")
31 | boxes = torch.as_tensor(anno["boxes"], dtype=torch.float32)
32 | labels = torch.as_tensor(anno["pids"], dtype=torch.int64)
33 | target = {"img_name": anno["img_name"], "boxes": boxes, "labels": labels}
34 | if self.transforms is not None:
35 | img, target = self.transforms(img, target)
36 | return img, target
37 |
38 | def __len__(self):
39 | return len(self.annotations)
40 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch.nn.functional as F
4 | import torchvision
5 | from torch import nn
6 |
7 | from .backbone import resnet50
8 |
9 | class Backbone(nn.Sequential):
10 | def __init__(self, resnet):
11 | super(Backbone, self).__init__(
12 | OrderedDict(
13 | [
14 | ["conv1", resnet.conv1],
15 | ["bn1", resnet.bn1],
16 | ["relu", resnet.relu],
17 | ["maxpool", resnet.maxpool],
18 | ["layer1", resnet.layer1], # res2
19 | ["layer2", resnet.layer2], # res3
20 | ["layer3", resnet.layer3], # res4
21 | ]
22 | )
23 | )
24 | self.out_channels = 1024
25 |
26 | def forward(self, x):
27 | # using the forward method from nn.Sequential
28 | feat = super(Backbone, self).forward(x)
29 | return OrderedDict([["feat_res4", feat]])
30 |
31 |
32 | class Res5Head(nn.Sequential):
33 | def __init__(self, resnet):
34 | super(Res5Head, self).__init__(OrderedDict([["layer4", resnet.layer4]])) # res5
35 | self.out_channels = [1024, 2048]
36 |
37 | def forward(self, x):
38 | feat = super(Res5Head, self).forward(x)
39 | x = F.adaptive_max_pool2d(x, 1)
40 | feat = F.adaptive_max_pool2d(feat, 1)
41 | print(x.shape,feat.shape)
42 | return OrderedDict([["feat_res4", x], ["feat_res5", feat]])
43 |
44 |
45 | def build_resnet(name="resnet50", pretrained=True):
46 | #resnet = torchvision.models.resnet.__dict__[name](pretrained=pretrained)
47 | resnet = resnet50(pretrained=True)
48 |
49 | # freeze layers
50 | resnet.conv1.weight.requires_grad_(False)
51 | resnet.bn1.weight.requires_grad_(False)
52 | resnet.bn1.bias.requires_grad_(False)
53 |
54 | return Backbone(resnet), Res5Head(resnet)
55 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SOLIDER on [Person Search]
2 |
3 | [](https://paperswithcode.com/sota/person-search-on-cuhk-sysu?p=beyond-appearance-a-semantic-controllable)
4 | [](https://paperswithcode.com/sota/person-search-on-prw?p=beyond-appearance-a-semantic-controllable)
5 |
6 | This repo provides details about how to use [SOLIDER](https://github.com/tinyvision/SOLIDER) pretrained representation on person search task.
7 | We modify the code from [SeqNet](https://github.com/serend1p1ty/SeqNet), and you can refer to the original repo for more details.
8 |
9 | ## Installation and Datasets
10 |
11 | Details of installation and dataset preparation can be found in [SeqNet](https://github.com/serend1p1ty/SeqNet).
12 |
13 | ## Prepare Pre-trained Models
14 | You can download models from [SOLIDER](https://github.com/tinyvision/SOLIDER), or use [SOLIDER](https://github.com/tinyvision/SOLIDER) to train your own models.
15 | Before training, you should convert the models first.
16 |
17 | ```bash
18 | python convert_model.py path/to/SOLIDER/log/lup/swin_tiny/checkpoint.pth path/to/SOLIDER/log/lup/swin_tiny/checkpoint_tea.pth
19 | ```
20 |
21 | ## Training
22 |
23 | We utilize 1 GPU for training. Please modify the `ckpt` and `OUTPUT_DIR` in the bash file.
24 |
25 | ```bash
26 | sh run.sh
27 | ```
28 |
29 | ## Performance
30 |
31 | | Method | Model | CUHK-SYSU
(mAP/R1) | PRW
(mAP/R1) |
32 | | ------ | :---: | :---: | :---: |
33 | | SOLIDER | Swin Tiny | 94.91/95.72 | 56.84/86.78 |
34 | | SOLIDER | Swin Small | 95.46/95.79 | 59.84/86.73 |
35 | | SOLIDER | Swin Base | 94.93/95.52 | 59.72/86.83 |
36 |
37 | - We use the pretrained models from [SOLIDER](https://github.com/tinyvision/SOLIDER).
38 | - The semantic weight is set to 0.6 in these experiments.
39 |
40 | ## Citation
41 |
42 | If you find this code useful for your research, please cite our paper
43 |
44 | ```
45 | @inproceedings{chen2023beyond,
46 | title={Beyond Appearance: a Semantic Controllable Self-Supervised Learning Framework for Human-Centric Visual Tasks},
47 | author={Weihua Chen and Xianzhe Xu and Jian Jia and Hao Luo and Yaohua Wang and Fan Wang and Rong Jin and Xiuyu Sun},
48 | booktitle={The IEEE/CVF Conference on Computer Vision and Pattern Recognition},
49 | year={2023},
50 | }
51 | ```
52 |
--------------------------------------------------------------------------------
/models/oim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import autograd, nn
4 |
5 | # from utils.distributed import tensor_gather
6 |
7 |
8 | class OIM(autograd.Function):
9 | @staticmethod
10 | def forward(ctx, inputs, targets, lut, cq, header, momentum):
11 | ctx.save_for_backward(inputs, targets, lut, cq, header, momentum)
12 | outputs_labeled = inputs.mm(lut.t())
13 | outputs_unlabeled = inputs.mm(cq.t())
14 | return torch.cat([outputs_labeled, outputs_unlabeled], dim=1)
15 |
16 | @staticmethod
17 | def backward(ctx, grad_outputs):
18 | inputs, targets, lut, cq, header, momentum = ctx.saved_tensors
19 |
20 | # inputs, targets = tensor_gather((inputs, targets))
21 |
22 | grad_inputs = None
23 | if ctx.needs_input_grad[0]:
24 | grad_inputs = grad_outputs.mm(torch.cat([lut, cq], dim=0))
25 | if grad_inputs.dtype == torch.float16:
26 | grad_inputs = grad_inputs.to(torch.float32)
27 |
28 | for x, y in zip(inputs, targets):
29 | if y < len(lut):
30 | lut[y] = momentum * lut[y] + (1.0 - momentum) * x
31 | lut[y] /= lut[y].norm()
32 | else:
33 | cq[header] = x
34 | header = (header + 1) % cq.size(0)
35 | return grad_inputs, None, None, None, None, None
36 |
37 |
38 | def oim(inputs, targets, lut, cq, header, momentum=0.5):
39 | return OIM.apply(inputs, targets, lut, cq, torch.tensor(header), torch.tensor(momentum))
40 |
41 |
42 | class OIMLoss(nn.Module):
43 | def __init__(self, num_features, num_pids, num_cq_size, oim_momentum, oim_scalar):
44 | super(OIMLoss, self).__init__()
45 | self.num_features = num_features
46 | self.num_pids = num_pids
47 | self.num_unlabeled = num_cq_size
48 | self.momentum = oim_momentum
49 | self.oim_scalar = oim_scalar
50 |
51 | self.register_buffer("lut", torch.zeros(self.num_pids, self.num_features))
52 | self.register_buffer("cq", torch.zeros(self.num_unlabeled, self.num_features))
53 |
54 | self.header_cq = 0
55 |
56 | def forward(self, inputs, roi_label):
57 | # merge into one batch, background label = 0
58 | targets = torch.cat(roi_label)
59 | label = targets - 1 # background label = -1
60 |
61 | inds = label >= 0
62 | label = label[inds]
63 | inputs = inputs[inds.unsqueeze(1).expand_as(inputs)].view(-1, self.num_features)
64 |
65 | projected = oim(inputs, label, self.lut, self.cq, self.header_cq, momentum=self.momentum)
66 | projected *= self.oim_scalar
67 |
68 | self.header_cq = (
69 | self.header_cq + (label >= self.num_pids).long().sum().item()
70 | ) % self.num_unlabeled
71 | loss_oim = F.cross_entropy(projected, label, ignore_index=5554)
72 | return loss_oim
73 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from glob import glob
3 |
4 | import matplotlib.pyplot as plt
5 | import torch
6 | import torch.utils.data
7 | from PIL import Image
8 | from torchvision.transforms import functional as F
9 |
10 | from defaults import get_default_cfg
11 | from models.seqnet import SeqNet
12 | from utils.utils import resume_from_ckpt
13 |
14 |
15 | def visualize_result(img_path, detections, similarities):
16 | fig, ax = plt.subplots(figsize=(16, 9))
17 | ax.imshow(plt.imread(img_path))
18 | plt.axis("off")
19 | for detection, sim in zip(detections, similarities):
20 | x1, y1, x2, y2 = detection
21 | ax.add_patch(
22 | plt.Rectangle(
23 | (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="#4CAF50", linewidth=3.5
24 | )
25 | )
26 | ax.add_patch(
27 | plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="white", linewidth=1)
28 | )
29 | ax.text(
30 | x1 + 5,
31 | y1 - 18,
32 | "{:.2f}".format(sim),
33 | bbox=dict(facecolor="#4CAF50", linewidth=0),
34 | fontsize=20,
35 | color="white",
36 | )
37 | plt.tight_layout()
38 | fig.savefig(img_path.replace("gallery", "result"))
39 | plt.show()
40 | plt.close(fig)
41 |
42 |
43 | def main(args):
44 | cfg = get_default_cfg()
45 | if args.cfg_file:
46 | cfg.merge_from_file(args.cfg_file)
47 | cfg.merge_from_list(args.opts)
48 | cfg.freeze()
49 |
50 | device = torch.device(cfg.DEVICE)
51 |
52 | print("Creating model")
53 | model = SeqNet(cfg)
54 | model.to(device)
55 | model.eval()
56 |
57 | resume_from_ckpt(args.ckpt, model)
58 |
59 | query_img = [F.to_tensor(Image.open("demo_imgs/query.jpg").convert("RGB")).to(device)]
60 | query_target = [{"boxes": torch.tensor([[0, 0, 466, 943]]).to(device)}]
61 | query_feat = model(query_img, query_target)[0]
62 |
63 | gallery_img_paths = sorted(glob("demo_imgs/gallery-*.jpg"))
64 | for gallery_img_path in gallery_img_paths:
65 | print(f"Processing {gallery_img_path}")
66 | gallery_img = [F.to_tensor(Image.open(gallery_img_path).convert("RGB")).to(device)]
67 | gallery_output = model(gallery_img)[0]
68 | detections = gallery_output["boxes"]
69 | gallery_feats = gallery_output["embeddings"]
70 |
71 | # Compute pairwise cosine similarities,
72 | # which equals to inner-products, as features are already L2-normed
73 | similarities = gallery_feats.mm(query_feat.view(-1, 1)).squeeze()
74 |
75 | visualize_result(gallery_img_path, detections, similarities)
76 |
77 |
78 | if __name__ == "__main__":
79 | parser = argparse.ArgumentParser(description="Train a person search network.")
80 | parser.add_argument("--cfg", dest="cfg_file", help="Path to configuration file.")
81 | parser.add_argument("--ckpt", required=True, help="Path to checkpoint to resume or evaluate.")
82 | parser.add_argument(
83 | "opts", nargs=argparse.REMAINDER, help="Modify config options using the command-line"
84 | )
85 | args = parser.parse_args()
86 | main(args)
87 |
--------------------------------------------------------------------------------
/datasets/prw.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import re
3 |
4 | import numpy as np
5 | from scipy.io import loadmat
6 |
7 | from .base import BaseDataset
8 |
9 |
10 | class PRW(BaseDataset):
11 | def __init__(self, root, transforms, split):
12 | self.name = "PRW"
13 | self.img_prefix = osp.join(root, "frames")
14 | super(PRW, self).__init__(root, transforms, split)
15 |
16 | def _get_cam_id(self, img_name):
17 | match = re.search(r"c\d", img_name).group().replace("c", "")
18 | return int(match)
19 |
20 | def _load_queries(self):
21 | query_info = osp.join(self.root, "query_info.txt")
22 | with open(query_info, "rb") as f:
23 | raw = f.readlines()
24 |
25 | queries = []
26 | for line in raw:
27 | linelist = str(line, "utf-8").split(" ")
28 | pid = int(linelist[0])
29 | x, y, w, h = (
30 | float(linelist[1]),
31 | float(linelist[2]),
32 | float(linelist[3]),
33 | float(linelist[4]),
34 | )
35 | roi = np.array([x, y, x + w, y + h]).astype(np.int32)
36 | roi = np.clip(roi, 0, None) # several coordinates are negative
37 | img_name = linelist[5][:-2] + ".jpg"
38 | queries.append(
39 | {
40 | "img_name": img_name,
41 | "img_path": osp.join(self.img_prefix, img_name),
42 | "boxes": roi[np.newaxis, :],
43 | "pids": np.array([pid]),
44 | "cam_id": self._get_cam_id(img_name),
45 | }
46 | )
47 | return queries
48 |
49 | def _load_split_img_names(self):
50 | """
51 | Load the image names for the specific split.
52 | """
53 | assert self.split in ("train", "gallery")
54 | if self.split == "train":
55 | imgs = loadmat(osp.join(self.root, "frame_train.mat"))["img_index_train"]
56 | else:
57 | imgs = loadmat(osp.join(self.root, "frame_test.mat"))["img_index_test"]
58 | return [img[0][0] + ".jpg" for img in imgs]
59 |
60 | def _load_annotations(self):
61 | if self.split == "query":
62 | return self._load_queries()
63 |
64 | annotations = []
65 | imgs = self._load_split_img_names()
66 | for img_name in imgs:
67 | anno_path = osp.join(self.root, "annotations", img_name)
68 | anno = loadmat(anno_path)
69 | box_key = "box_new"
70 | if box_key not in anno.keys():
71 | box_key = "anno_file"
72 | if box_key not in anno.keys():
73 | box_key = "anno_previous"
74 |
75 | rois = anno[box_key][:, 1:]
76 | ids = anno[box_key][:, 0]
77 | rois = np.clip(rois, 0, None) # several coordinates are negative
78 |
79 | assert len(rois) == len(ids)
80 |
81 | rois[:, 2:] += rois[:, :2]
82 | ids[ids == -2] = 5555 # assign pid = 5555 for unlabeled people
83 | annotations.append(
84 | {
85 | "img_name": img_name,
86 | "img_path": osp.join(self.img_prefix, img_name),
87 | "boxes": rois.astype(np.int32),
88 | # FIXME: (training pids) 1, 2,..., 478, 480, 481, 482, 483, 932, 5555
89 | "pids": ids.astype(np.int32),
90 | "cam_id": self._get_cam_id(img_name),
91 | }
92 | )
93 | return annotations
94 |
--------------------------------------------------------------------------------
/datasets/build.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from utils.transforms import build_transforms
4 | from utils.utils import create_small_table
5 |
6 | from .cuhk_sysu import CUHKSYSU
7 | from .prw import PRW
8 |
9 |
10 | def print_statistics(dataset):
11 | """
12 | Print dataset statistics.
13 | """
14 | num_imgs = len(dataset.annotations)
15 | num_boxes = 0
16 | pid_set = set()
17 | for anno in dataset.annotations:
18 | num_boxes += anno["boxes"].shape[0]
19 | for pid in anno["pids"]:
20 | pid_set.add(pid)
21 | statistics = {
22 | "dataset": dataset.name,
23 | "split": dataset.split,
24 | "num_images": num_imgs,
25 | "num_boxes": num_boxes,
26 | }
27 | if dataset.name != "CUHK-SYSU" or dataset.split != "query":
28 | pid_list = sorted(list(pid_set))
29 | if dataset.split == "query":
30 | num_pids, min_pid, max_pid = len(pid_list), min(pid_list), max(pid_list)
31 | statistics.update(
32 | {
33 | "num_labeled_pids": num_pids,
34 | "min_labeled_pid": int(min_pid),
35 | "max_labeled_pid": int(max_pid),
36 | }
37 | )
38 | else:
39 | unlabeled_pid = pid_list[-1]
40 | pid_list = pid_list[:-1] # remove unlabeled pid
41 | num_pids, min_pid, max_pid = len(pid_list), min(pid_list), max(pid_list)
42 | statistics.update(
43 | {
44 | "num_labeled_pids": num_pids,
45 | "min_labeled_pid": int(min_pid),
46 | "max_labeled_pid": int(max_pid),
47 | "unlabeled_pid": int(unlabeled_pid),
48 | }
49 | )
50 | print(f"=> {dataset.name}-{dataset.split} loaded:\n" + create_small_table(statistics))
51 |
52 |
53 | def build_dataset(dataset_name, root, transforms, split, verbose=True):
54 | if dataset_name == "CUHK-SYSU":
55 | dataset = CUHKSYSU(root, transforms, split)
56 | elif dataset_name == "PRW":
57 | dataset = PRW(root, transforms, split)
58 | else:
59 | raise NotImplementedError(f"Unknow dataset: {dataset_name}")
60 | if verbose:
61 | print_statistics(dataset)
62 | return dataset
63 |
64 |
65 | def collate_fn(batch):
66 | return tuple(zip(*batch))
67 |
68 |
69 | def build_train_loader(cfg):
70 | transforms = build_transforms(is_train=True)
71 | dataset = build_dataset(cfg.INPUT.DATASET, cfg.INPUT.DATA_ROOT, transforms, "train")
72 | return torch.utils.data.DataLoader(
73 | dataset,
74 | batch_size=cfg.INPUT.BATCH_SIZE_TRAIN,
75 | shuffle=True,
76 | num_workers=cfg.INPUT.NUM_WORKERS_TRAIN,
77 | pin_memory=True,
78 | drop_last=True,
79 | collate_fn=collate_fn,
80 | )
81 |
82 |
83 | def build_test_loader(cfg):
84 | transforms = build_transforms(is_train=False)
85 | gallery_set = build_dataset(cfg.INPUT.DATASET, cfg.INPUT.DATA_ROOT, transforms, "gallery")
86 | query_set = build_dataset(cfg.INPUT.DATASET, cfg.INPUT.DATA_ROOT, transforms, "query")
87 | gallery_loader = torch.utils.data.DataLoader(
88 | gallery_set,
89 | batch_size=cfg.INPUT.BATCH_SIZE_TEST,
90 | shuffle=False,
91 | num_workers=cfg.INPUT.NUM_WORKERS_TEST,
92 | pin_memory=True,
93 | collate_fn=collate_fn,
94 | )
95 | query_loader = torch.utils.data.DataLoader(
96 | query_set,
97 | batch_size=cfg.INPUT.BATCH_SIZE_TEST,
98 | shuffle=False,
99 | num_workers=cfg.INPUT.NUM_WORKERS_TEST,
100 | pin_memory=True,
101 | collate_fn=collate_fn,
102 | )
103 | return gallery_loader, query_loader
104 |
--------------------------------------------------------------------------------
/models/swin.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | import torchvision
6 | from torch import nn
7 |
8 | from .swin_transformer import swin_tiny_patch4_window7_224,swin_small_patch4_window7_224,swin_base_patch4_window7_224
9 |
10 | bonenum = 3
11 |
12 | class Backbone(nn.Sequential):
13 | def __init__(self, swin, out_channels=384):
14 | super().__init__()
15 | self.swin = swin
16 | self.out_channels = out_channels
17 |
18 | def forward(self, x):
19 | if self.swin.semantic_weight >= 0:
20 | w = torch.ones(x.shape[0],1) * self.swin.semantic_weight
21 | w = torch.cat([w, 1-w], axis=-1)
22 | semantic_weight = w.cuda()
23 |
24 | x, hw_shape = self.swin.patch_embed(x)
25 |
26 | if self.swin.use_abs_pos_embed:
27 | x = x + self.swin.absolute_pos_embed
28 | x = self.swin.drop_after_pos(x)
29 |
30 | outs = []
31 | for i, stage in enumerate(self.swin.stages[:bonenum]):
32 | x, hw_shape, out, out_hw_shape = stage(x, hw_shape)
33 | if self.swin.semantic_weight >= 0:
34 | sw = self.swin.semantic_embed_w[i](semantic_weight).unsqueeze(1)
35 | sb = self.swin.semantic_embed_b[i](semantic_weight).unsqueeze(1)
36 | x = x * self.swin.softplus(sw) + sb
37 | if i == bonenum-1:
38 | norm_layer = getattr(self.swin, f'norm{i}')
39 | out = norm_layer(out)
40 | out = out.view(-1, *out_hw_shape,
41 | self.swin.num_features[i]).permute(0, 3, 1,
42 | 2).contiguous()
43 | outs.append(out)
44 | return OrderedDict([["feat_res4", outs[-1]]])
45 |
46 | class Res5Head(nn.Sequential):
47 | def __init__(self, swin, out_channels=384):
48 | super().__init__() # last block
49 | self.swin = swin
50 | self.out_channels = [out_channels, out_channels*2]
51 |
52 | def forward(self, x):
53 | if self.swin.semantic_weight >= 0:
54 | w = torch.ones(x.shape[0],1) * self.swin.semantic_weight
55 | w = torch.cat([w, 1-w], axis=-1)
56 | semantic_weight = w.cuda()
57 |
58 | feat = x
59 | hw_shape = x.shape[-2:]
60 | x = torch.flatten(x, 2)
61 | x = x.permute(0, 2, 1)
62 | x,hw_shape = self.swin.stages[bonenum-1].downsample(x,hw_shape)
63 | if self.swin.semantic_weight >= 0:
64 | sw = self.swin.semantic_embed_w[bonenum-1](semantic_weight).unsqueeze(1)
65 | sb = self.swin.semantic_embed_b[bonenum-1](semantic_weight).unsqueeze(1)
66 | x = x * self.swin.softplus(sw) + sb
67 | for i, stage in enumerate(self.swin.stages[bonenum:]):
68 | x, hw_shape, out, out_hw_shape = stage(x, hw_shape)
69 | if self.swin.semantic_weight >= 0:
70 | sw = self.swin.semantic_embed_w[bonenum+i](semantic_weight).unsqueeze(1)
71 | sb = self.swin.semantic_embed_b[bonenum+i](semantic_weight).unsqueeze(1)
72 | x = x * self.swin.softplus(sw) + sb
73 | if i == len(self.swin.stages) - bonenum - 1:
74 | norm_layer = getattr(self.swin, f'norm{bonenum+i}')
75 | out = norm_layer(out)
76 | out = out.view(-1, *out_hw_shape,
77 | self.swin.num_features[bonenum+i]).permute(0, 3, 1,
78 | 2).contiguous()
79 | feat = self.swin.avgpool(feat)
80 | out = self.swin.avgpool(out)
81 | return OrderedDict([["feat_res4", feat], ["feat_res5", out]])
82 |
83 | def build_swin(name="swin_tiny", semantic_weight=1.0):
84 | if 'tiny' in name:
85 | swin = swin_tiny_patch4_window7_224(drop_path_rate=0.1,semantic_weight=semantic_weight)
86 | out_channels = 384
87 | elif 'small' in name:
88 | swin = swin_small_patch4_window7_224(drop_path_rate=0.1,semantic_weight=semantic_weight)
89 | out_channels = 384
90 | elif 'base' in name:
91 | swin = swin_base_patch4_window7_224(drop_path_rate=0.1,semantic_weight=semantic_weight)
92 | out_channels = 512
93 |
94 | return Backbone(swin,out_channels), Res5Head(swin,out_channels), out_channels*2
95 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import os.path as osp
4 | import time
5 |
6 | import torch
7 | import torch.utils.data
8 |
9 | from datasets import build_test_loader, build_train_loader
10 | from defaults import get_default_cfg
11 | from engine import evaluate_performance, train_one_epoch
12 | from models.seqnet import SeqNet
13 | from utils.utils import mkdir, resume_from_ckpt, save_on_master, set_random_seed
14 |
15 |
16 | def main(args):
17 | cfg = get_default_cfg()
18 | if args.cfg_file:
19 | cfg.merge_from_file(args.cfg_file)
20 | cfg.merge_from_list(args.opts)
21 | cfg.freeze()
22 |
23 | device = torch.device(cfg.DEVICE)
24 | if cfg.SEED >= 0:
25 | set_random_seed(cfg.SEED)
26 |
27 | print("Creating model")
28 | model = SeqNet(cfg)
29 | model.to(device)
30 |
31 | print("Loading data")
32 | train_loader = build_train_loader(cfg)
33 | gallery_loader, query_loader = build_test_loader(cfg)
34 |
35 | if args.eval:
36 | assert args.ckpt, "--ckpt must be specified when --eval enabled"
37 | resume_from_ckpt(args.ckpt, model)
38 | evaluate_performance(
39 | model,
40 | gallery_loader,
41 | query_loader,
42 | device,
43 | use_gt=cfg.EVAL_USE_GT,
44 | use_cache=cfg.EVAL_USE_CACHE,
45 | use_cbgm=cfg.EVAL_USE_CBGM,
46 | )
47 | exit(0)
48 |
49 | params = [p for p in model.parameters() if p.requires_grad]
50 | optimizer = torch.optim.SGD(
51 | params,
52 | lr=cfg.SOLVER.BASE_LR,
53 | momentum=cfg.SOLVER.SGD_MOMENTUM,
54 | weight_decay=cfg.SOLVER.WEIGHT_DECAY,
55 | )
56 |
57 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
58 | optimizer, milestones=cfg.SOLVER.LR_DECAY_MILESTONES, gamma=0.1
59 | )
60 |
61 | start_epoch = 0
62 | if args.resume:
63 | assert args.ckpt, "--ckpt must be specified when --resume enabled"
64 | start_epoch = resume_from_ckpt(args.ckpt, model, optimizer, lr_scheduler)
65 |
66 | print("Creating output folder")
67 | output_dir = cfg.OUTPUT_DIR
68 | mkdir(output_dir)
69 | path = osp.join(output_dir, "config.yaml")
70 | with open(path, "w") as f:
71 | f.write(cfg.dump())
72 | print(f"Full config is saved to {path}")
73 | tfboard = None
74 | if cfg.TF_BOARD:
75 | from torch.utils.tensorboard import SummaryWriter
76 |
77 | tf_log_path = osp.join(output_dir, "tf_log")
78 | mkdir(tf_log_path)
79 | tfboard = SummaryWriter(log_dir=tf_log_path)
80 | print(f"TensorBoard files are saved to {tf_log_path}")
81 |
82 | print("Start training")
83 | start_time = time.time()
84 | for epoch in range(start_epoch, cfg.SOLVER.MAX_EPOCHS):
85 | train_one_epoch(cfg, model, optimizer, train_loader, device, epoch, tfboard)
86 | lr_scheduler.step()
87 |
88 | if (epoch + 1) % cfg.EVAL_PERIOD == 0 or epoch == cfg.SOLVER.MAX_EPOCHS - 1:
89 | evaluate_performance(
90 | model,
91 | gallery_loader,
92 | query_loader,
93 | device,
94 | use_gt=cfg.EVAL_USE_GT,
95 | use_cache=cfg.EVAL_USE_CACHE,
96 | use_cbgm=cfg.EVAL_USE_CBGM,
97 | )
98 |
99 | if (epoch + 1) % cfg.CKPT_PERIOD == 0 or epoch == cfg.SOLVER.MAX_EPOCHS - 1:
100 | save_on_master(
101 | {
102 | "model": model.state_dict(),
103 | "optimizer": optimizer.state_dict(),
104 | "lr_scheduler": lr_scheduler.state_dict(),
105 | "epoch": epoch,
106 | },
107 | osp.join(output_dir, f"epoch_{epoch}.pth"),
108 | )
109 |
110 | if tfboard:
111 | tfboard.close()
112 | total_time = time.time() - start_time
113 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
114 | print(f"Total training time {total_time_str}")
115 |
116 |
117 | if __name__ == "__main__":
118 | parser = argparse.ArgumentParser(description="Train a person search network.")
119 | parser.add_argument("--cfg", dest="cfg_file", help="Path to configuration file.")
120 | parser.add_argument(
121 | "--eval", action="store_true", help="Evaluate the performance of a given checkpoint."
122 | )
123 | parser.add_argument(
124 | "--resume", action="store_true", help="Resume from the specified checkpoint."
125 | )
126 | parser.add_argument("--ckpt", help="Path to checkpoint to resume or evaluate.")
127 | parser.add_argument(
128 | "opts", nargs=argparse.REMAINDER, help="Modify config options using the command-line"
129 | )
130 | args = parser.parse_args()
131 | main(args)
132 |
--------------------------------------------------------------------------------
/utils/km.py:
--------------------------------------------------------------------------------
1 | # encoding=utf-8
2 |
3 | import random
4 |
5 | import numpy as np
6 |
7 | zero_threshold = 0.00000001
8 |
9 |
10 | class KMNode(object):
11 | def __init__(self, id, exception=0, match=None, visit=False):
12 | self.id = id
13 | self.exception = exception
14 | self.match = match
15 | self.visit = visit
16 |
17 |
18 | class KuhnMunkres(object):
19 | def __init__(self):
20 | self.matrix = None
21 | self.x_nodes = []
22 | self.y_nodes = []
23 | self.minz = float("inf")
24 | self.x_length = 0
25 | self.y_length = 0
26 | self.index_x = 0
27 | self.index_y = 1
28 |
29 | def __del__(self):
30 | pass
31 |
32 | def set_matrix(self, x_y_values):
33 | xs = set()
34 | ys = set()
35 | for x, y, value in x_y_values:
36 | xs.add(x)
37 | ys.add(y)
38 |
39 | # 选取较小的作为x
40 | if len(xs) < len(ys):
41 | self.index_x = 0
42 | self.index_y = 1
43 | else:
44 | self.index_x = 1
45 | self.index_y = 0
46 | xs, ys = ys, xs
47 |
48 | x_dic = {x: i for i, x in enumerate(xs)}
49 | y_dic = {y: j for j, y in enumerate(ys)}
50 | self.x_nodes = [KMNode(x) for x in xs]
51 | self.y_nodes = [KMNode(y) for y in ys]
52 | self.x_length = len(xs)
53 | self.y_length = len(ys)
54 |
55 | self.matrix = np.zeros((self.x_length, self.y_length))
56 | for row in x_y_values:
57 | x = row[self.index_x]
58 | y = row[self.index_y]
59 | value = row[2]
60 | x_index = x_dic[x]
61 | y_index = y_dic[y]
62 | self.matrix[x_index, y_index] = value
63 |
64 | for i in range(self.x_length):
65 | self.x_nodes[i].exception = max(self.matrix[i, :])
66 |
67 | def km(self):
68 | for i in range(self.x_length):
69 | while True:
70 | self.minz = float("inf")
71 | self.set_false(self.x_nodes)
72 | self.set_false(self.y_nodes)
73 |
74 | if self.dfs(i):
75 | break
76 |
77 | self.change_exception(self.x_nodes, -self.minz)
78 | self.change_exception(self.y_nodes, self.minz)
79 |
80 | def dfs(self, i):
81 | x_node = self.x_nodes[i]
82 | x_node.visit = True
83 | for j in range(self.y_length):
84 | y_node = self.y_nodes[j]
85 | if not y_node.visit:
86 | t = x_node.exception + y_node.exception - self.matrix[i][j]
87 | if abs(t) < zero_threshold:
88 | y_node.visit = True
89 | if y_node.match is None or self.dfs(y_node.match):
90 | x_node.match = j
91 | y_node.match = i
92 | return True
93 | else:
94 | if t >= zero_threshold:
95 | self.minz = min(self.minz, t)
96 | return False
97 |
98 | def set_false(self, nodes):
99 | for node in nodes:
100 | node.visit = False
101 |
102 | def change_exception(self, nodes, change):
103 | for node in nodes:
104 | if node.visit:
105 | node.exception += change
106 |
107 | def get_connect_result(self):
108 | ret = []
109 | for i in range(self.x_length):
110 | x_node = self.x_nodes[i]
111 | j = x_node.match
112 | y_node = self.y_nodes[j]
113 | x_id = x_node.id
114 | y_id = y_node.id
115 | value = self.matrix[i][j]
116 |
117 | if self.index_x == 1 and self.index_y == 0:
118 | x_id, y_id = y_id, x_id
119 | ret.append((x_id, y_id, value))
120 |
121 | return ret
122 |
123 | def get_max_value_result(self):
124 | # ret = 0
125 | ret = -100
126 | # ret = []
127 | for i in range(self.x_length):
128 | j = self.x_nodes[i].match
129 | # ret += self.matrix[i][j]
130 | ret = max(ret, self.matrix[i][j])
131 | # ret.append(self.matrix[i][j])
132 | # ret.sort()
133 | # ret = ret[-1:]
134 | # ret = np.array(ret).mean()
135 | return ret
136 |
137 |
138 | def run_kuhn_munkres(x_y_values):
139 | process = KuhnMunkres()
140 | process.set_matrix(x_y_values)
141 | process.km()
142 | return process.get_connect_result(), process.get_max_value_result()
143 |
144 |
145 | def test():
146 | values = []
147 | random.seed(0)
148 | for i in range(500):
149 | for j in range(1000):
150 | value = random.random()
151 | values.append((i, j, value))
152 |
153 | return run_kuhn_munkres(values)
154 |
155 |
156 | if __name__ == "__main__":
157 | # s_time = time.time()
158 | # ret = test()
159 | # print "time usage: %s " % str(time.time() - s_time)
160 | values = [(1, 1, 3), (1, 3, 4), (2, 1, 2), (2, 2, 1), (2, 3, 3), (3, 2, 4), (3, 3, 5)]
161 | print(run_kuhn_munkres(values))
162 |
--------------------------------------------------------------------------------
/datasets/cuhk_sysu.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import numpy as np
4 | from scipy.io import loadmat
5 |
6 | from .base import BaseDataset
7 |
8 |
9 | class CUHKSYSU(BaseDataset):
10 | def __init__(self, root, transforms, split):
11 | self.name = "CUHK-SYSU"
12 | self.img_prefix = osp.join(root, "Image", "SSM")
13 | super(CUHKSYSU, self).__init__(root, transforms, split)
14 |
15 | def _load_queries(self):
16 | # TestG50: a test protocol, 50 gallery images per query
17 | protoc = loadmat(osp.join(self.root, "annotation/test/train_test/TestG50.mat"))
18 | protoc = protoc["TestG50"].squeeze()
19 | queries = []
20 | for item in protoc["Query"]:
21 | img_name = str(item["imname"][0, 0][0])
22 | roi = item["idlocate"][0, 0][0].astype(np.int32)
23 | roi[2:] += roi[:2]
24 | queries.append(
25 | {
26 | "img_name": img_name,
27 | "img_path": osp.join(self.img_prefix, img_name),
28 | "boxes": roi[np.newaxis, :],
29 | "pids": np.array([-100]), # dummy pid
30 | }
31 | )
32 | return queries
33 |
34 | def _load_split_img_names(self):
35 | """
36 | Load the image names for the specific split.
37 | """
38 | assert self.split in ("train", "gallery")
39 | # gallery images
40 | gallery_imgs = loadmat(osp.join(self.root, "annotation", "pool.mat"))
41 | gallery_imgs = gallery_imgs["pool"].squeeze()
42 | gallery_imgs = [str(a[0]) for a in gallery_imgs]
43 | if self.split == "gallery":
44 | return gallery_imgs
45 | # all images
46 | all_imgs = loadmat(osp.join(self.root, "annotation", "Images.mat"))
47 | all_imgs = all_imgs["Img"].squeeze()
48 | all_imgs = [str(a[0][0]) for a in all_imgs]
49 | # training images = all images - gallery images
50 | training_imgs = sorted(list(set(all_imgs) - set(gallery_imgs)))
51 | return training_imgs
52 |
53 | def _load_annotations(self):
54 | if self.split == "query":
55 | return self._load_queries()
56 |
57 | # load all images and build a dict from image to boxes
58 | all_imgs = loadmat(osp.join(self.root, "annotation", "Images.mat"))
59 | all_imgs = all_imgs["Img"].squeeze()
60 | name_to_boxes = {}
61 | name_to_pids = {}
62 | unlabeled_pid = 5555 # default pid for unlabeled people
63 | for img_name, _, boxes in all_imgs:
64 | img_name = str(img_name[0])
65 | boxes = np.asarray([b[0] for b in boxes[0]])
66 | boxes = boxes.reshape(boxes.shape[0], 4) # (x1, y1, w, h)
67 | valid_index = np.where((boxes[:, 2] > 0) & (boxes[:, 3] > 0))[0]
68 | assert valid_index.size > 0, "Warning: {} has no valid boxes.".format(img_name)
69 | boxes = boxes[valid_index]
70 | name_to_boxes[img_name] = boxes.astype(np.int32)
71 | name_to_pids[img_name] = unlabeled_pid * np.ones(boxes.shape[0], dtype=np.int32)
72 |
73 | def set_box_pid(boxes, box, pids, pid):
74 | for i in range(boxes.shape[0]):
75 | if np.all(boxes[i] == box):
76 | pids[i] = pid
77 | return
78 |
79 | # assign a unique pid from 1 to N for each identity
80 | if self.split == "train":
81 | train = loadmat(osp.join(self.root, "annotation/test/train_test/Train.mat"))
82 | train = train["Train"].squeeze()
83 | for index, item in enumerate(train):
84 | scenes = item[0, 0][2].squeeze()
85 | for img_name, box, _ in scenes:
86 | img_name = str(img_name[0])
87 | box = box.squeeze().astype(np.int32)
88 | set_box_pid(name_to_boxes[img_name], box, name_to_pids[img_name], index + 1)
89 | else:
90 | protoc = loadmat(osp.join(self.root, "annotation/test/train_test/TestG50.mat"))
91 | protoc = protoc["TestG50"].squeeze()
92 | for index, item in enumerate(protoc):
93 | # query
94 | im_name = str(item["Query"][0, 0][0][0])
95 | box = item["Query"][0, 0][1].squeeze().astype(np.int32)
96 | set_box_pid(name_to_boxes[im_name], box, name_to_pids[im_name], index + 1)
97 | # gallery
98 | gallery = item["Gallery"].squeeze()
99 | for im_name, box, _ in gallery:
100 | im_name = str(im_name[0])
101 | if box.size == 0:
102 | break
103 | box = box.squeeze().astype(np.int32)
104 | set_box_pid(name_to_boxes[im_name], box, name_to_pids[im_name], index + 1)
105 |
106 | annotations = []
107 | imgs = self._load_split_img_names()
108 | for img_name in imgs:
109 | boxes = name_to_boxes[img_name]
110 | boxes[:, 2:] += boxes[:, :2] # (x1, y1, w, h) -> (x1, y1, x2, y2)
111 | pids = name_to_pids[img_name]
112 | annotations.append(
113 | {
114 | "img_name": img_name,
115 | "img_path": osp.join(self.img_prefix, img_name),
116 | "boxes": boxes,
117 | "pids": pids,
118 | }
119 | )
120 | return annotations
121 |
--------------------------------------------------------------------------------
/defaults.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | _C = CN()
4 |
5 | # -------------------------------------------------------- #
6 | # Input #
7 | # -------------------------------------------------------- #
8 | _C.INPUT = CN()
9 | _C.INPUT.DATASET = "CUHK-SYSU"
10 | _C.INPUT.DATA_ROOT = "data/CUHK-SYSU"
11 |
12 | # Size of the smallest side of the image
13 | _C.INPUT.MIN_SIZE = 900
14 | # Maximum size of the side of the image
15 | _C.INPUT.MAX_SIZE = 1500
16 |
17 | # TODO: support aspect ratio grouping
18 | # Whether to use aspect ratio grouping for saving GPU memory
19 | # _C.INPUT.ASPECT_RATIO_GROUPING_TRAIN = False
20 |
21 | # Number of images per batch
22 | _C.INPUT.BATCH_SIZE_TRAIN = 2
23 | _C.INPUT.BATCH_SIZE_TEST = 1
24 |
25 | # Number of data loading threads
26 | _C.INPUT.NUM_WORKERS_TRAIN = 5
27 | _C.INPUT.NUM_WORKERS_TEST = 1
28 |
29 | # -------------------------------------------------------- #
30 | # Solver #
31 | # -------------------------------------------------------- #
32 | _C.SOLVER = CN()
33 | _C.SOLVER.MAX_EPOCHS = 20
34 |
35 | # Learning rate settings
36 | _C.SOLVER.BASE_LR = 0.003
37 |
38 | # TODO: add config option WARMUP_EPOCHS
39 | _C.SOLVER.WARMUP_FACTOR = 1.0 / 1000
40 | # _C.SOLVER.WARMUP_EPOCHS = 1
41 |
42 | # The epoch milestones to decrease the learning rate by GAMMA
43 | _C.SOLVER.LR_DECAY_MILESTONES = [16]
44 | _C.SOLVER.GAMMA = 0.1
45 |
46 | _C.SOLVER.WEIGHT_DECAY = 0.0005
47 | _C.SOLVER.SGD_MOMENTUM = 0.9
48 |
49 | # Loss weight of RPN regression
50 | _C.SOLVER.LW_RPN_REG = 1
51 | # Loss weight of RPN classification
52 | _C.SOLVER.LW_RPN_CLS = 1
53 | # Loss weight of proposal regression
54 | _C.SOLVER.LW_PROPOSAL_REG = 10
55 | # Loss weight of proposal classification
56 | _C.SOLVER.LW_PROPOSAL_CLS = 1
57 | # Loss weight of box regression
58 | _C.SOLVER.LW_BOX_REG = 1
59 | # Loss weight of box classification
60 | _C.SOLVER.LW_BOX_CLS = 1
61 | # Loss weight of box OIM (i.e. Online Instance Matching)
62 | _C.SOLVER.LW_BOX_REID = 1
63 |
64 | # Set to negative value to disable gradient clipping
65 | _C.SOLVER.CLIP_GRADIENTS = 10.0
66 |
67 | # -------------------------------------------------------- #
68 | # RPN #
69 | # -------------------------------------------------------- #
70 | _C.MODEL = CN()
71 | _C.MODEL.BONE = "swin_tiny"
72 | _C.MODEL.SEMANTIC_WEIGHT = 1.0
73 |
74 | _C.MODEL.RPN = CN()
75 | # NMS threshold used on RoIs
76 | _C.MODEL.RPN.NMS_THRESH = 0.7
77 | # Number of anchors per image used to train RPN
78 | _C.MODEL.RPN.BATCH_SIZE_TRAIN = 256
79 | # Target fraction of foreground examples per RPN minibatch
80 | _C.MODEL.RPN.POS_FRAC_TRAIN = 0.5
81 | # Overlap threshold for an anchor to be considered foreground (if >= POS_THRESH_TRAIN)
82 | _C.MODEL.RPN.POS_THRESH_TRAIN = 0.7
83 | # Overlap threshold for an anchor to be considered background (if < NEG_THRESH_TRAIN)
84 | _C.MODEL.RPN.NEG_THRESH_TRAIN = 0.3
85 | # Number of top scoring RPN RoIs to keep before applying NMS
86 | _C.MODEL.RPN.PRE_NMS_TOPN_TRAIN = 12000
87 | _C.MODEL.RPN.PRE_NMS_TOPN_TEST = 6000
88 | # Number of top scoring RPN RoIs to keep after applying NMS
89 | _C.MODEL.RPN.POST_NMS_TOPN_TRAIN = 2000
90 | _C.MODEL.RPN.POST_NMS_TOPN_TEST = 300
91 |
92 | # -------------------------------------------------------- #
93 | # RoI head #
94 | # -------------------------------------------------------- #
95 | _C.MODEL.ROI_HEAD = CN()
96 | # Whether to use bn neck (i.e. batch normalization after linear)
97 | _C.MODEL.ROI_HEAD.BN_NECK = True
98 | # Number of RoIs per image used to train RoI head
99 | _C.MODEL.ROI_HEAD.BATCH_SIZE_TRAIN = 128
100 | # Target fraction of foreground examples per RoI minibatch
101 | _C.MODEL.ROI_HEAD.POS_FRAC_TRAIN = 0.5
102 | # Overlap threshold for an RoI to be considered foreground (if >= POS_THRESH_TRAIN)
103 | _C.MODEL.ROI_HEAD.POS_THRESH_TRAIN = 0.5
104 | # Overlap threshold for an RoI to be considered background (if < NEG_THRESH_TRAIN)
105 | _C.MODEL.ROI_HEAD.NEG_THRESH_TRAIN = 0.5
106 | # Minimum score threshold
107 | _C.MODEL.ROI_HEAD.SCORE_THRESH_TEST = 0.5
108 | # NMS threshold used on boxes
109 | _C.MODEL.ROI_HEAD.NMS_THRESH_TEST = 0.4
110 | # Maximum number of detected objects
111 | _C.MODEL.ROI_HEAD.DETECTIONS_PER_IMAGE_TEST = 300
112 |
113 | # -------------------------------------------------------- #
114 | # Loss #
115 | # -------------------------------------------------------- #
116 | _C.MODEL.LOSS = CN()
117 | # Size of the lookup table in OIM
118 | _C.MODEL.LOSS.LUT_SIZE = 5532
119 | # Size of the circular queue in OIM
120 | _C.MODEL.LOSS.CQ_SIZE = 5000
121 | _C.MODEL.LOSS.OIM_MOMENTUM = 0.5
122 | _C.MODEL.LOSS.OIM_SCALAR = 30.0
123 |
124 | # -------------------------------------------------------- #
125 | # Evaluation #
126 | # -------------------------------------------------------- #
127 | # The period to evaluate the model during training
128 | _C.EVAL_PERIOD = 1
129 | # Evaluation with GT boxes to verify the upper bound of person search performance
130 | _C.EVAL_USE_GT = False
131 | # Fast evaluation with cached features
132 | _C.EVAL_USE_CACHE = False
133 | # Evaluation with Context Bipartite Graph Matching (CBGM) algorithm
134 | _C.EVAL_USE_CBGM = False
135 |
136 | # -------------------------------------------------------- #
137 | # Miscs #
138 | # -------------------------------------------------------- #
139 | # Save a checkpoint after every this number of epochs
140 | _C.CKPT_PERIOD = 1
141 | # The period (in terms of iterations) to display training losses
142 | _C.DISP_PERIOD = 10
143 | # Whether to use tensorboard for visualization
144 | _C.TF_BOARD = True
145 | # The device loading the model
146 | _C.DEVICE = "cuda"
147 | # Set seed to negative to fully randomize everything
148 | _C.SEED = 1
149 | # Directory where output files are written
150 | _C.OUTPUT_DIR = "./output"
151 |
152 |
153 | def get_default_cfg():
154 | """
155 | Get a copy of the default config.
156 | """
157 | return _C.clone()
158 |
--------------------------------------------------------------------------------
/engine.py:
--------------------------------------------------------------------------------
1 | import math
2 | import sys
3 | from copy import deepcopy
4 |
5 | import torch
6 | from torch.nn.utils import clip_grad_norm_
7 | from tqdm import tqdm
8 |
9 | from eval_func import eval_detection, eval_search_cuhk, eval_search_prw
10 | from utils.utils import MetricLogger, SmoothedValue, mkdir, reduce_dict, warmup_lr_scheduler
11 |
12 |
13 | def to_device(images, targets, device):
14 | images = [image.to(device) for image in images]
15 | for t in targets:
16 | t["boxes"] = t["boxes"].to(device)
17 | t["labels"] = t["labels"].to(device)
18 | return images, targets
19 |
20 |
21 | def train_one_epoch(cfg, model, optimizer, data_loader, device, epoch, tfboard=None):
22 | model.train()
23 | metric_logger = MetricLogger(delimiter=" ")
24 | metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
25 | header = "Epoch: [{}]".format(epoch)
26 |
27 | # warmup learning rate in the first epoch
28 | if epoch == 0:
29 | warmup_factor = 1.0 / 1000
30 | # FIXME: min(1000, len(data_loader) - 1)
31 | warmup_iters = len(data_loader) - 1
32 | warmup_scheduler = warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)
33 |
34 | for i, (images, targets) in enumerate(
35 | metric_logger.log_every(data_loader, cfg.DISP_PERIOD, header)
36 | ):
37 | images, targets = to_device(images, targets, device)
38 |
39 | loss_dict = model(images, targets)
40 | losses = sum(loss for loss in loss_dict.values())
41 |
42 | # reduce losses over all GPUs for logging purposes
43 | loss_dict_reduced = reduce_dict(loss_dict)
44 | losses_reduced = sum(loss for loss in loss_dict_reduced.values())
45 | loss_value = losses_reduced.item()
46 |
47 | if not math.isfinite(loss_value):
48 | print(f"Loss is {loss_value}, stopping training")
49 | print(loss_dict_reduced)
50 | sys.exit(1)
51 |
52 | optimizer.zero_grad()
53 | losses.backward()
54 | if cfg.SOLVER.CLIP_GRADIENTS > 0:
55 | clip_grad_norm_(model.parameters(), cfg.SOLVER.CLIP_GRADIENTS)
56 | optimizer.step()
57 |
58 | if epoch == 0:
59 | warmup_scheduler.step()
60 |
61 | metric_logger.update(loss=loss_value, **loss_dict_reduced)
62 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
63 | if tfboard:
64 | iter = epoch * len(data_loader) + i
65 | for k, v in loss_dict_reduced.items():
66 | tfboard.add_scalars("train", {k: v}, iter)
67 |
68 |
69 | @torch.no_grad()
70 | def evaluate_performance(
71 | model, gallery_loader, query_loader, device, use_gt=False, use_cache=False, use_cbgm=False
72 | ):
73 | """
74 | Args:
75 | use_gt (bool, optional): Whether to use GT as detection results to verify the upper
76 | bound of person search performance. Defaults to False.
77 | use_cache (bool, optional): Whether to use the cached features. Defaults to False.
78 | use_cbgm (bool, optional): Whether to use Context Bipartite Graph Matching algorithm.
79 | Defaults to False.
80 | """
81 | model.eval()
82 | if use_cache:
83 | eval_cache = torch.load("data/eval_cache/eval_cache.pth")
84 | gallery_dets = eval_cache["gallery_dets"]
85 | gallery_feats = eval_cache["gallery_feats"]
86 | query_dets = eval_cache["query_dets"]
87 | query_feats = eval_cache["query_feats"]
88 | query_box_feats = eval_cache["query_box_feats"]
89 | else:
90 | gallery_dets, gallery_feats = [], []
91 | for images, targets in tqdm(gallery_loader, ncols=0):
92 | images, targets = to_device(images, targets, device)
93 | if not use_gt:
94 | outputs = model(images)
95 | else:
96 | boxes = targets[0]["boxes"]
97 | n_boxes = boxes.size(0)
98 | embeddings = model(images, targets)
99 | outputs = [
100 | {
101 | "boxes": boxes,
102 | "embeddings": torch.cat(embeddings),
103 | "labels": torch.ones(n_boxes).to(device),
104 | "scores": torch.ones(n_boxes).to(device),
105 | }
106 | ]
107 |
108 | for output in outputs:
109 | box_w_scores = torch.cat([output["boxes"], output["scores"].unsqueeze(1)], dim=1)
110 | gallery_dets.append(box_w_scores.cpu().numpy())
111 | gallery_feats.append(output["embeddings"].cpu().numpy())
112 |
113 | # regarding query image as gallery to detect all people
114 | # i.e. query person + surrounding people (context information)
115 | query_dets, query_feats = [], []
116 | for images, targets in tqdm(query_loader, ncols=0):
117 | images, targets = to_device(images, targets, device)
118 | # targets will be modified in the model, so deepcopy it
119 | outputs = model(images, deepcopy(targets), query_img_as_gallery=True)
120 |
121 | # consistency check
122 | gt_box = targets[0]["boxes"].squeeze()
123 | assert (
124 | gt_box - outputs[0]["boxes"][0]
125 | ).sum() <= 0.001, "GT box must be the first one in the detected boxes of query image"
126 |
127 | for output in outputs:
128 | box_w_scores = torch.cat([output["boxes"], output["scores"].unsqueeze(1)], dim=1)
129 | query_dets.append(box_w_scores.cpu().numpy())
130 | query_feats.append(output["embeddings"].cpu().numpy())
131 |
132 | # extract the features of query boxes
133 | query_box_feats = []
134 | for images, targets in tqdm(query_loader, ncols=0):
135 | images, targets = to_device(images, targets, device)
136 | embeddings = model(images, targets)
137 | assert len(embeddings) == 1, "batch size in test phase should be 1"
138 | query_box_feats.append(embeddings[0].cpu().numpy())
139 |
140 | mkdir("data/eval_cache")
141 | save_dict = {
142 | "gallery_dets": gallery_dets,
143 | "gallery_feats": gallery_feats,
144 | "query_dets": query_dets,
145 | "query_feats": query_feats,
146 | "query_box_feats": query_box_feats,
147 | }
148 | torch.save(save_dict, "data/eval_cache/eval_cache.pth")
149 |
150 | eval_detection(gallery_loader.dataset, gallery_dets, det_thresh=0.01)
151 | eval_search_func = (
152 | eval_search_cuhk if gallery_loader.dataset.name == "CUHK-SYSU" else eval_search_prw
153 | )
154 | eval_search_func(
155 | gallery_loader.dataset,
156 | query_loader.dataset,
157 | gallery_dets,
158 | gallery_feats,
159 | query_box_feats,
160 | query_dets,
161 | query_feats,
162 | cbgm=use_cbgm,
163 | )
164 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import errno
3 | import json
4 | import os
5 | import os.path as osp
6 | import pickle
7 | import random
8 | import time
9 | from collections import defaultdict, deque
10 |
11 | import numpy as np
12 | import torch
13 | import torch.distributed as dist
14 | from tabulate import tabulate
15 |
16 |
17 | # -------------------------------------------------------- #
18 | # Logger #
19 | # -------------------------------------------------------- #
20 | class SmoothedValue(object):
21 | """
22 | Track a series of values and provide access to smoothed values over a
23 | window or the global series average.
24 | """
25 |
26 | def __init__(self, window_size=20, fmt=None):
27 | if fmt is None:
28 | fmt = "{median:.4f} ({global_avg:.4f})"
29 | self.deque = deque(maxlen=window_size)
30 | self.total = 0.0
31 | self.count = 0
32 | self.fmt = fmt
33 |
34 | def update(self, value, n=1):
35 | self.deque.append(value)
36 | self.count += n
37 | self.total += value * n
38 |
39 | def synchronize_between_processes(self):
40 | """
41 | Warning: does not synchronize the deque!
42 | """
43 | if not is_dist_avail_and_initialized():
44 | return
45 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
46 | dist.barrier()
47 | dist.all_reduce(t)
48 | t = t.tolist()
49 | self.count = int(t[0])
50 | self.total = t[1]
51 |
52 | @property
53 | def median(self):
54 | d = torch.tensor(list(self.deque))
55 | return d.median().item()
56 |
57 | @property
58 | def avg(self):
59 | d = torch.tensor(list(self.deque), dtype=torch.float32)
60 | return d.mean().item()
61 |
62 | @property
63 | def global_avg(self):
64 | return self.total / self.count
65 |
66 | @property
67 | def max(self):
68 | return max(self.deque)
69 |
70 | @property
71 | def value(self):
72 | return self.deque[-1]
73 |
74 | def __str__(self):
75 | return self.fmt.format(
76 | median=self.median,
77 | avg=self.avg,
78 | global_avg=self.global_avg,
79 | max=self.max,
80 | value=self.value,
81 | )
82 |
83 |
84 | class MetricLogger(object):
85 | def __init__(self, delimiter="\t"):
86 | self.meters = defaultdict(SmoothedValue)
87 | self.delimiter = delimiter
88 |
89 | def update(self, **kwargs):
90 | for k, v in kwargs.items():
91 | if isinstance(v, torch.Tensor):
92 | v = v.item()
93 | assert isinstance(v, (float, int))
94 | self.meters[k].update(v)
95 |
96 | def __getattr__(self, attr):
97 | if attr in self.meters:
98 | return self.meters[attr]
99 | if attr in self.__dict__:
100 | return self.__dict__[attr]
101 | raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
102 |
103 | def __str__(self):
104 | loss_str = []
105 | for name, meter in self.meters.items():
106 | loss_str.append("{}: {}".format(name, str(meter)))
107 | return self.delimiter.join(loss_str)
108 |
109 | def synchronize_between_processes(self):
110 | for meter in self.meters.values():
111 | meter.synchronize_between_processes()
112 |
113 | def add_meter(self, name, meter):
114 | self.meters[name] = meter
115 |
116 | def log_every(self, iterable, print_freq, header=None):
117 | i = 0
118 | if not header:
119 | header = ""
120 | start_time = time.time()
121 | end = time.time()
122 | iter_time = SmoothedValue(fmt="{avg:.4f}")
123 | data_time = SmoothedValue(fmt="{avg:.4f}")
124 | space_fmt = ":" + str(len(str(len(iterable)))) + "d"
125 | if torch.cuda.is_available():
126 | log_msg = self.delimiter.join(
127 | [
128 | header,
129 | "[{0" + space_fmt + "}/{1}]",
130 | "eta: {eta}",
131 | "{meters}",
132 | "time: {time}",
133 | "data: {data}",
134 | "max mem: {memory:.0f}",
135 | ]
136 | )
137 | else:
138 | log_msg = self.delimiter.join(
139 | [
140 | header,
141 | "[{0" + space_fmt + "}/{1}]",
142 | "eta: {eta}",
143 | "{meters}",
144 | "time: {time}",
145 | "data: {data}",
146 | ]
147 | )
148 | MB = 1024.0 * 1024.0
149 | for obj in iterable:
150 | data_time.update(time.time() - end)
151 | yield obj
152 | iter_time.update(time.time() - end)
153 | if i % print_freq == 0 or i == len(iterable) - 1:
154 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
155 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
156 | if torch.cuda.is_available():
157 | print(
158 | log_msg.format(
159 | i,
160 | len(iterable),
161 | eta=eta_string,
162 | meters=str(self),
163 | time=str(iter_time),
164 | data=str(data_time),
165 | memory=torch.cuda.max_memory_allocated() / MB,
166 | )
167 | )
168 | else:
169 | print(
170 | log_msg.format(
171 | i,
172 | len(iterable),
173 | eta=eta_string,
174 | meters=str(self),
175 | time=str(iter_time),
176 | data=str(data_time),
177 | )
178 | )
179 | i += 1
180 | end = time.time()
181 | total_time = time.time() - start_time
182 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
183 | print(
184 | "{} Total time: {} ({:.4f} s / it)".format(
185 | header, total_time_str, total_time / len(iterable)
186 | )
187 | )
188 |
189 |
190 | # -------------------------------------------------------- #
191 | # Distributed training #
192 | # -------------------------------------------------------- #
193 | def all_gather(data):
194 | """
195 | Run all_gather on arbitrary picklable data (not necessarily tensors)
196 |
197 | Args:
198 | data: any picklable object
199 |
200 | Returns:
201 | list[data]: list of data gathered from each rank
202 | """
203 | world_size = get_world_size()
204 | if world_size == 1:
205 | return [data]
206 |
207 | # serialized to a Tensor
208 | buffer = pickle.dumps(data)
209 | storage = torch.ByteStorage.from_buffer(buffer)
210 | tensor = torch.ByteTensor(storage).to("cuda")
211 |
212 | # obtain Tensor size of each rank
213 | local_size = torch.tensor([tensor.numel()], device="cuda")
214 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
215 | dist.all_gather(size_list, local_size)
216 | size_list = [int(size.item()) for size in size_list]
217 | max_size = max(size_list)
218 |
219 | # receiving Tensor from all ranks
220 | # we pad the tensor because torch all_gather does not support
221 | # gathering tensors of different shapes
222 | tensor_list = []
223 | for _ in size_list:
224 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
225 | if local_size != max_size:
226 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
227 | tensor = torch.cat((tensor, padding), dim=0)
228 | dist.all_gather(tensor_list, tensor)
229 |
230 | data_list = []
231 | for size, tensor in zip(size_list, tensor_list):
232 | buffer = tensor.cpu().numpy().tobytes()[:size]
233 | data_list.append(pickle.loads(buffer))
234 |
235 | return data_list
236 |
237 |
238 | def reduce_dict(input_dict, average=True):
239 | """
240 | Reduce the values in the dictionary from all processes so that all processes
241 | have the averaged results. Returns a dict with the same fields as
242 | input_dict, after reduction.
243 |
244 | Args:
245 | input_dict (dict): all the values will be reduced
246 | average (bool): whether to do average or sum
247 | """
248 | world_size = get_world_size()
249 | if world_size < 2:
250 | return input_dict
251 | with torch.no_grad():
252 | names = []
253 | values = []
254 | # sort the keys so that they are consistent across processes
255 | for k in sorted(input_dict.keys()):
256 | names.append(k)
257 | values.append(input_dict[k])
258 | values = torch.stack(values, dim=0)
259 | dist.all_reduce(values)
260 | if average:
261 | values /= world_size
262 | reduced_dict = {k: v for k, v in zip(names, values)}
263 | return reduced_dict
264 |
265 |
266 | def setup_for_distributed(is_master):
267 | """
268 | This function disables printing when not in master process
269 | """
270 | import builtins as __builtin__
271 |
272 | builtin_print = __builtin__.print
273 |
274 | def print(*args, **kwargs):
275 | force = kwargs.pop("force", False)
276 | if is_master or force:
277 | builtin_print(*args, **kwargs)
278 |
279 | __builtin__.print = print
280 |
281 |
282 | def is_dist_avail_and_initialized():
283 | if not dist.is_available():
284 | return False
285 | if not dist.is_initialized():
286 | return False
287 | return True
288 |
289 |
290 | def get_world_size():
291 | if not is_dist_avail_and_initialized():
292 | return 1
293 | return dist.get_world_size()
294 |
295 |
296 | def get_rank():
297 | if not is_dist_avail_and_initialized():
298 | return 0
299 | return dist.get_rank()
300 |
301 |
302 | def is_main_process():
303 | return get_rank() == 0
304 |
305 |
306 | def save_on_master(*args, **kwargs):
307 | if is_main_process():
308 | torch.save(*args, **kwargs)
309 |
310 |
311 | def init_distributed_mode(args):
312 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
313 | args.rank = int(os.environ["RANK"])
314 | args.world_size = int(os.environ["WORLD_SIZE"])
315 | args.gpu = int(os.environ["LOCAL_RANK"])
316 | elif "SLURM_PROCID" in os.environ:
317 | args.rank = int(os.environ["SLURM_PROCID"])
318 | args.gpu = args.rank % torch.cuda.device_count()
319 | else:
320 | print("Not using distributed mode")
321 | args.distributed = False
322 | return
323 |
324 | args.distributed = True
325 |
326 | torch.cuda.set_device(args.gpu)
327 | args.dist_backend = "nccl"
328 | print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
329 | torch.distributed.init_process_group(
330 | backend=args.dist_backend,
331 | init_method=args.dist_url,
332 | world_size=args.world_size,
333 | rank=args.rank,
334 | )
335 | torch.distributed.barrier()
336 | setup_for_distributed(args.rank == 0)
337 |
338 |
339 | # -------------------------------------------------------- #
340 | # File operation #
341 | # -------------------------------------------------------- #
342 | def filename(path):
343 | return osp.splitext(osp.basename(path))[0]
344 |
345 |
346 | def mkdir(path):
347 | try:
348 | os.makedirs(path)
349 | except OSError as e:
350 | if e.errno != errno.EEXIST:
351 | raise
352 |
353 |
354 | def read_json(fpath):
355 | with open(fpath, "r") as f:
356 | obj = json.load(f)
357 | return obj
358 |
359 |
360 | def write_json(obj, fpath):
361 | mkdir(osp.dirname(fpath))
362 | _obj = obj.copy()
363 | for k, v in _obj.items():
364 | if isinstance(v, np.ndarray):
365 | _obj.pop(k)
366 | with open(fpath, "w") as f:
367 | json.dump(_obj, f, indent=4, separators=(",", ": "))
368 |
369 |
370 | def symlink(src, dst, overwrite=True, **kwargs):
371 | if os.path.lexists(dst) and overwrite:
372 | os.remove(dst)
373 | os.symlink(src, dst, **kwargs)
374 |
375 |
376 | # -------------------------------------------------------- #
377 | # Misc #
378 | # -------------------------------------------------------- #
379 | def create_small_table(small_dict):
380 | """
381 | Create a small table using the keys of small_dict as headers. This is only
382 | suitable for small dictionaries.
383 |
384 | Args:
385 | small_dict (dict): a result dictionary of only a few items.
386 |
387 | Returns:
388 | str: the table as a string.
389 | """
390 | keys, values = tuple(zip(*small_dict.items()))
391 | table = tabulate(
392 | [values],
393 | headers=keys,
394 | tablefmt="pipe",
395 | floatfmt=".3f",
396 | stralign="center",
397 | numalign="center",
398 | )
399 | return table
400 |
401 |
402 | def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor):
403 | def f(x):
404 | if x >= warmup_iters:
405 | return 1
406 | alpha = float(x) / warmup_iters
407 | return warmup_factor * (1 - alpha) + alpha
408 |
409 | return torch.optim.lr_scheduler.LambdaLR(optimizer, f)
410 |
411 | def resume_from_ckpt(ckpt_path, model, optimizer=None, lr_scheduler=None):
412 | ckpt = torch.load(ckpt_path)
413 | if 'state_dict' in ckpt.keys():
414 | ckpt = ckpt['state_dict']
415 |
416 | count = 0
417 | miss = []
418 | for i in ckpt:
419 | if 'backbone' in i:
420 | model.state_dict()[i.replace('backbone.','backbone.swin.')].copy_(ckpt[i])
421 | model.state_dict()[i.replace('backbone.','roi_heads.reid_head.swin.')].copy_(ckpt[i])
422 | count += 1
423 | else:
424 | miss.append(i)
425 | print('%d loaded, %d missed:' %(count,len(miss)),miss)
426 | return 0
427 |
428 | def set_random_seed(seed):
429 | torch.manual_seed(seed)
430 | torch.cuda.manual_seed(seed)
431 | torch.cuda.manual_seed_all(seed)
432 | torch.backends.cudnn.benchmark = False
433 | torch.backends.cudnn.deterministic = True
434 | random.seed(seed)
435 | np.random.seed(seed)
436 | os.environ["PYTHONHASHSEED"] = str(seed)
437 |
--------------------------------------------------------------------------------
/models/backbone.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torchvision.models.utils import load_state_dict_from_url
3 |
4 |
5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
6 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
7 | 'wide_resnet50_2', 'wide_resnet101_2']
8 |
9 |
10 | model_urls = {
11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
16 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
17 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
18 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
19 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
20 | }
21 |
22 |
23 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
24 | """3x3 convolution with padding"""
25 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
26 | padding=dilation, groups=groups, bias=False, dilation=dilation)
27 |
28 |
29 | def conv1x1(in_planes, out_planes, stride=1):
30 | """1x1 convolution"""
31 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
32 |
33 |
34 | class BasicBlock(nn.Module):
35 | expansion = 1
36 |
37 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
38 | base_width=64, dilation=1, norm_layer=None):
39 | super(BasicBlock, self).__init__()
40 | if norm_layer is None:
41 | norm_layer = nn.BatchNorm2d
42 | if groups != 1 or base_width != 64:
43 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
44 | if dilation > 1:
45 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
46 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
47 | self.conv1 = conv3x3(inplanes, planes, stride)
48 | self.bn1 = norm_layer(planes)
49 | self.relu = nn.ReLU(inplace=True)
50 | self.conv2 = conv3x3(planes, planes)
51 | self.bn2 = norm_layer(planes)
52 | self.downsample = downsample
53 | self.stride = stride
54 |
55 | def forward(self, x):
56 | identity = x
57 |
58 | out = self.conv1(x)
59 | out = self.bn1(out)
60 | out = self.relu(out)
61 |
62 | out = self.conv2(out)
63 | out = self.bn2(out)
64 |
65 | if self.downsample is not None:
66 | identity = self.downsample(x)
67 |
68 | out += identity
69 | out = self.relu(out)
70 |
71 | return out
72 |
73 |
74 | class Bottleneck(nn.Module):
75 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
76 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
77 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
78 | # This variant is also known as ResNet V1.5 and improves accuracy according to
79 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
80 |
81 | expansion = 4
82 |
83 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
84 | base_width=64, dilation=1, norm_layer=None):
85 | super(Bottleneck, self).__init__()
86 | if norm_layer is None:
87 | norm_layer = nn.BatchNorm2d
88 | width = int(planes * (base_width / 64.)) * groups
89 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
90 | self.conv1 = conv1x1(inplanes, width)
91 | self.bn1 = norm_layer(width)
92 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
93 | self.bn2 = norm_layer(width)
94 | self.conv3 = conv1x1(width, planes * self.expansion)
95 | self.bn3 = norm_layer(planes * self.expansion)
96 | self.relu = nn.ReLU(inplace=True)
97 | self.downsample = downsample
98 | self.stride = stride
99 |
100 | def forward(self, x):
101 | identity = x
102 |
103 | out = self.conv1(x)
104 | out = self.bn1(out)
105 | out = self.relu(out)
106 |
107 | out = self.conv2(out)
108 | out = self.bn2(out)
109 | out = self.relu(out)
110 |
111 | out = self.conv3(out)
112 | out = self.bn3(out)
113 |
114 | if self.downsample is not None:
115 | identity = self.downsample(x)
116 |
117 | out += identity
118 | out = self.relu(out)
119 |
120 | return out
121 |
122 |
123 | class ResNet(nn.Module):
124 |
125 | def __init__(self, block, layers, zero_init_residual=False,
126 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
127 | norm_layer=None):
128 | super(ResNet, self).__init__()
129 | if norm_layer is None:
130 | norm_layer = nn.BatchNorm2d
131 | self._norm_layer = norm_layer
132 |
133 | self.inplanes = 64
134 | self.dilation = 1
135 | if replace_stride_with_dilation is None:
136 | # each element in the tuple indicates if we should replace
137 | # the 2x2 stride with a dilated convolution instead
138 | replace_stride_with_dilation = [False, False, False]
139 | if len(replace_stride_with_dilation) != 3:
140 | raise ValueError("replace_stride_with_dilation should be None "
141 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
142 | self.groups = groups
143 | self.base_width = width_per_group
144 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
145 | bias=False)
146 | self.bn1 = norm_layer(self.inplanes)
147 | self.relu = nn.ReLU(inplace=True)
148 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
149 | self.layer1 = self._make_layer(block, 64, layers[0])
150 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
151 | dilate=replace_stride_with_dilation[0])
152 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
153 | dilate=replace_stride_with_dilation[1])
154 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
155 | dilate=replace_stride_with_dilation[2])
156 | # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
157 | # self.fc = nn.Linear(512 * block.expansion, num_classes)
158 |
159 | for m in self.modules():
160 | if isinstance(m, nn.Conv2d):
161 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
162 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
163 | nn.init.constant_(m.weight, 1)
164 | nn.init.constant_(m.bias, 0)
165 |
166 | # Zero-initialize the last BN in each residual branch,
167 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
168 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
169 | if zero_init_residual:
170 | for m in self.modules():
171 | if isinstance(m, Bottleneck):
172 | nn.init.constant_(m.bn3.weight, 0)
173 | elif isinstance(m, BasicBlock):
174 | nn.init.constant_(m.bn2.weight, 0)
175 |
176 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
177 | norm_layer = self._norm_layer
178 | downsample = None
179 | previous_dilation = self.dilation
180 | if dilate:
181 | self.dilation *= stride
182 | stride = 1
183 | if stride != 1 or self.inplanes != planes * block.expansion:
184 | downsample = nn.Sequential(
185 | conv1x1(self.inplanes, planes * block.expansion, stride),
186 | norm_layer(planes * block.expansion),
187 | )
188 |
189 | layers = []
190 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
191 | self.base_width, previous_dilation, norm_layer))
192 | self.inplanes = planes * block.expansion
193 | for _ in range(1, blocks):
194 | layers.append(block(self.inplanes, planes, groups=self.groups,
195 | base_width=self.base_width, dilation=self.dilation,
196 | norm_layer=norm_layer))
197 |
198 | return nn.Sequential(*layers)
199 |
200 | def _forward_impl(self, x):
201 | outputs = {}
202 | # See note [TorchScript super()]
203 | x = self.conv1(x)
204 | x = self.bn1(x)
205 | x = self.relu(x)
206 | x = self.maxpool(x)
207 | #outputs['stem'] = x
208 |
209 | x = self.layer1(x) # 1/4
210 | outputs['res2'] = x
211 |
212 | x = self.layer2(x) # 1/8
213 | outputs['res3'] = x
214 |
215 | x = self.layer3(x) # 1/16
216 | outputs['res4'] = x
217 |
218 | x = self.layer4(x) # 1/32
219 | outputs['res5'] = x
220 |
221 | return outputs['res5']
222 |
223 | def forward(self, x):
224 | return self._forward_impl(x)
225 |
226 |
227 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
228 | model = ResNet(block, layers, **kwargs)
229 | if pretrained:
230 | state_dict = load_state_dict_from_url(model_urls[arch],
231 | progress=progress)
232 | model.load_state_dict(state_dict, strict=False)
233 | return model
234 |
235 |
236 | def resnet18(pretrained=False, progress=True, **kwargs):
237 | r"""ResNet-18 model from
238 | `"Deep Residual Learning for Image Recognition" `_
239 | Args:
240 | pretrained (bool): If True, returns a model pre-trained on ImageNet
241 | progress (bool): If True, displays a progress bar of the download to stderr
242 | """
243 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
244 | **kwargs)
245 |
246 |
247 | def resnet34(pretrained=False, progress=True, **kwargs):
248 | r"""ResNet-34 model from
249 | `"Deep Residual Learning for Image Recognition" `_
250 | Args:
251 | pretrained (bool): If True, returns a model pre-trained on ImageNet
252 | progress (bool): If True, displays a progress bar of the download to stderr
253 | """
254 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
255 | **kwargs)
256 |
257 |
258 | def resnet50(pretrained=False, progress=True, **kwargs):
259 | r"""ResNet-50 model from
260 | `"Deep Residual Learning for Image Recognition" `_
261 | Args:
262 | pretrained (bool): If True, returns a model pre-trained on ImageNet
263 | progress (bool): If True, displays a progress bar of the download to stderr
264 | """
265 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
266 | **kwargs)
267 |
268 |
269 | def resnet101(pretrained=False, progress=True, **kwargs):
270 | r"""ResNet-101 model from
271 | `"Deep Residual Learning for Image Recognition" `_
272 | Args:
273 | pretrained (bool): If True, returns a model pre-trained on ImageNet
274 | progress (bool): If True, displays a progress bar of the download to stderr
275 | """
276 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
277 | **kwargs)
278 |
279 |
280 | def resnet152(pretrained=False, progress=True, **kwargs):
281 | r"""ResNet-152 model from
282 | `"Deep Residual Learning for Image Recognition" `_
283 | Args:
284 | pretrained (bool): If True, returns a model pre-trained on ImageNet
285 | progress (bool): If True, displays a progress bar of the download to stderr
286 | """
287 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
288 | **kwargs)
289 |
290 |
291 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
292 | r"""ResNeXt-50 32x4d model from
293 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
294 | Args:
295 | pretrained (bool): If True, returns a model pre-trained on ImageNet
296 | progress (bool): If True, displays a progress bar of the download to stderr
297 | """
298 | kwargs['groups'] = 32
299 | kwargs['width_per_group'] = 4
300 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
301 | pretrained, progress, **kwargs)
302 |
303 |
304 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
305 | r"""ResNeXt-101 32x8d model from
306 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
307 | Args:
308 | pretrained (bool): If True, returns a model pre-trained on ImageNet
309 | progress (bool): If True, displays a progress bar of the download to stderr
310 | """
311 | kwargs['groups'] = 32
312 | kwargs['width_per_group'] = 8
313 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
314 | pretrained, progress, **kwargs)
315 |
316 |
317 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
318 | r"""Wide ResNet-50-2 model from
319 | `"Wide Residual Networks" `_
320 | The model is the same as ResNet except for the bottleneck number of channels
321 | which is twice larger in every block. The number of channels in outer 1x1
322 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
323 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
324 | Args:
325 | pretrained (bool): If True, returns a model pre-trained on ImageNet
326 | progress (bool): If True, displays a progress bar of the download to stderr
327 | """
328 | kwargs['width_per_group'] = 64 * 2
329 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
330 | pretrained, progress, **kwargs)
331 |
332 |
333 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
334 | r"""Wide ResNet-101-2 model from
335 | `"Wide Residual Networks" `_
336 | The model is the same as ResNet except for the bottleneck number of channels
337 | which is twice larger in every block. The number of channels in outer 1x1
338 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
339 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
340 | Args:
341 | pretrained (bool): If True, returns a model pre-trained on ImageNet
342 | progress (bool): If True, displays a progress bar of the download to stderr
343 | """
344 | kwargs['width_per_group'] = 64 * 2
345 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
346 | pretrained, progress, **kwargs)
347 |
--------------------------------------------------------------------------------
/eval_func.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import numpy as np
4 | from scipy.io import loadmat
5 | from sklearn.metrics import average_precision_score
6 |
7 | from utils.km import run_kuhn_munkres
8 | from utils.utils import write_json
9 |
10 |
11 | def _compute_iou(a, b):
12 | x1 = max(a[0], b[0])
13 | y1 = max(a[1], b[1])
14 | x2 = min(a[2], b[2])
15 | y2 = min(a[3], b[3])
16 | inter = max(0, x2 - x1) * max(0, y2 - y1)
17 | union = (a[2] - a[0]) * (a[3] - a[1]) + (b[2] - b[0]) * (b[3] - b[1]) - inter
18 | return inter * 1.0 / union
19 |
20 |
21 | def eval_detection(
22 | gallery_dataset, gallery_dets, det_thresh=0.5, iou_thresh=0.5, labeled_only=False
23 | ):
24 | """
25 | gallery_det (list of ndarray): n_det x [x1, y1, x2, y2, score] per image
26 | det_thresh (float): filter out gallery detections whose scores below this
27 | iou_thresh (float): treat as true positive if IoU is above this threshold
28 | labeled_only (bool): filter out unlabeled background people
29 | """
30 | assert len(gallery_dataset) == len(gallery_dets)
31 | annos = gallery_dataset.annotations
32 |
33 | y_true, y_score = [], []
34 | count_gt, count_tp = 0, 0
35 | for anno, det in zip(annos, gallery_dets):
36 | gt_boxes = anno["boxes"]
37 | if labeled_only:
38 | # exclude the unlabeled people (pid == 5555)
39 | inds = np.where(anno["pids"].ravel() != 5555)[0]
40 | if len(inds) == 0:
41 | continue
42 | gt_boxes = gt_boxes[inds]
43 | num_gt = gt_boxes.shape[0]
44 |
45 | if det != []:
46 | det = np.asarray(det)
47 | inds = np.where(det[:, 4].ravel() >= det_thresh)[0]
48 | det = det[inds]
49 | num_det = det.shape[0]
50 | else:
51 | num_det = 0
52 | if num_det == 0:
53 | count_gt += num_gt
54 | continue
55 |
56 | ious = np.zeros((num_gt, num_det), dtype=np.float32)
57 | for i in range(num_gt):
58 | for j in range(num_det):
59 | ious[i, j] = _compute_iou(gt_boxes[i], det[j, :4])
60 | tfmat = ious >= iou_thresh
61 | # for each det, keep only the largest iou of all the gt
62 | for j in range(num_det):
63 | largest_ind = np.argmax(ious[:, j])
64 | for i in range(num_gt):
65 | if i != largest_ind:
66 | tfmat[i, j] = False
67 | # for each gt, keep only the largest iou of all the det
68 | for i in range(num_gt):
69 | largest_ind = np.argmax(ious[i, :])
70 | for j in range(num_det):
71 | if j != largest_ind:
72 | tfmat[i, j] = False
73 | for j in range(num_det):
74 | y_score.append(det[j, -1])
75 | y_true.append(tfmat[:, j].any())
76 | count_tp += tfmat.sum()
77 | count_gt += num_gt
78 |
79 | det_rate = count_tp * 1.0 / count_gt
80 | ap = average_precision_score(y_true, y_score) * det_rate
81 |
82 | print("{} detection:".format("labeled only" if labeled_only else "all"))
83 | print(" recall = {:.2%}".format(det_rate))
84 | if not labeled_only:
85 | print(" ap = {:.2%}".format(ap))
86 | return det_rate, ap
87 |
88 |
89 | def eval_search_cuhk(
90 | gallery_dataset,
91 | query_dataset,
92 | gallery_dets,
93 | gallery_feats,
94 | query_box_feats,
95 | query_dets,
96 | query_feats,
97 | k1=10,
98 | k2=3,
99 | det_thresh=0.5,
100 | cbgm=False,
101 | gallery_size=100,
102 | ):
103 | """
104 | gallery_dataset/query_dataset: an instance of BaseDataset
105 | gallery_det (list of ndarray): n_det x [x1, x2, y1, y2, score] per image
106 | gallery_feat (list of ndarray): n_det x D features per image
107 | query_feat (list of ndarray): D dimensional features per query image
108 | det_thresh (float): filter out gallery detections whose scores below this
109 | gallery_size (int): gallery size [-1, 50, 100, 500, 1000, 2000, 4000]
110 | -1 for using full set
111 | """
112 | assert len(gallery_dataset) == len(gallery_dets)
113 | assert len(gallery_dataset) == len(gallery_feats)
114 | assert len(query_dataset) == len(query_box_feats)
115 |
116 | use_full_set = gallery_size == -1
117 | fname = "TestG{}".format(gallery_size if not use_full_set else 50)
118 | protoc = loadmat(osp.join(gallery_dataset.root, "annotation/test/train_test", fname + ".mat"))
119 | protoc = protoc[fname].squeeze()
120 |
121 | # mapping from gallery image to (det, feat)
122 | annos = gallery_dataset.annotations
123 | name_to_det_feat = {}
124 | for anno, det, feat in zip(annos, gallery_dets, gallery_feats):
125 | name = anno["img_name"]
126 | if len(det) != 0:
127 | scores = det[:, 4].ravel()
128 | inds = np.where(scores >= det_thresh)[0]
129 | if len(inds) > 0:
130 | name_to_det_feat[name] = (det[inds], feat[inds])
131 |
132 | aps = []
133 | accs = []
134 | topk = [1, 5, 10]
135 | ret = {"image_root": gallery_dataset.img_prefix, "results": []}
136 | for i in range(len(query_dataset)):
137 | y_true, y_score = [], []
138 | imgs, rois = [], []
139 | count_gt, count_tp = 0, 0
140 | # get L2-normalized feature vector
141 | feat_q = query_box_feats[i].ravel()
142 | # ignore the query image
143 | query_imname = str(protoc["Query"][i]["imname"][0, 0][0])
144 | query_roi = protoc["Query"][i]["idlocate"][0, 0][0].astype(np.int32)
145 | query_roi[2:] += query_roi[:2]
146 | query_gt = []
147 | tested = set([query_imname])
148 |
149 | name2sim = {}
150 | name2gt = {}
151 | sims = []
152 | imgs_cbgm = []
153 | # 1. Go through the gallery samples defined by the protocol
154 | for item in protoc["Gallery"][i].squeeze():
155 | gallery_imname = str(item[0][0])
156 | # some contain the query (gt not empty), some not
157 | gt = item[1][0].astype(np.int32)
158 | count_gt += gt.size > 0
159 | # compute distance between query and gallery dets
160 | if gallery_imname not in name_to_det_feat:
161 | continue
162 | det, feat_g = name_to_det_feat[gallery_imname]
163 | # no detection in this gallery, skip it
164 | if det.shape[0] == 0:
165 | continue
166 | # get L2-normalized feature matrix NxD
167 | assert feat_g.size == np.prod(feat_g.shape[:2])
168 | feat_g = feat_g.reshape(feat_g.shape[:2])
169 | # compute cosine similarities
170 | sim = feat_g.dot(feat_q).ravel()
171 |
172 | if gallery_imname in name2sim:
173 | continue
174 | name2sim[gallery_imname] = sim
175 | name2gt[gallery_imname] = gt
176 | sims.extend(list(sim))
177 | imgs_cbgm.extend([gallery_imname] * len(sim))
178 | # 2. Go through the remaining gallery images if using full set
179 | if use_full_set:
180 | # TODO: support CBGM when using full set
181 | for gallery_imname in gallery_dataset.imgs:
182 | if gallery_imname in tested:
183 | continue
184 | if gallery_imname not in name_to_det_feat:
185 | continue
186 | det, feat_g = name_to_det_feat[gallery_imname]
187 | # get L2-normalized feature matrix NxD
188 | assert feat_g.size == np.prod(feat_g.shape[:2])
189 | feat_g = feat_g.reshape(feat_g.shape[:2])
190 | # compute cosine similarities
191 | sim = feat_g.dot(feat_q).ravel()
192 | # guaranteed no target query in these gallery images
193 | label = np.zeros(len(sim), dtype=np.int32)
194 | y_true.extend(list(label))
195 | y_score.extend(list(sim))
196 | imgs.extend([gallery_imname] * len(sim))
197 | rois.extend(list(det))
198 |
199 | if cbgm:
200 | # -------- Context Bipartite Graph Matching (CBGM) ------- #
201 | sims = np.array(sims)
202 | imgs_cbgm = np.array(imgs_cbgm)
203 | # only process the top-k1 gallery images for efficiency
204 | inds = np.argsort(sims)[-k1:]
205 | imgs_cbgm = set(imgs_cbgm[inds])
206 | for img in imgs_cbgm:
207 | sim = name2sim[img]
208 | det, feat_g = name_to_det_feat[img]
209 | # only regard the people with top-k2 detection confidence
210 | # in the query image as context information
211 | qboxes = query_dets[i][:k2]
212 | qfeats = query_feats[i][:k2]
213 | assert (
214 | query_roi - qboxes[0][:4]
215 | ).sum() <= 0.001, "query_roi must be the first one in pboxes"
216 |
217 | # build the bipartite graph and run Kuhn-Munkres (K-M) algorithm
218 | # to find the best match
219 | graph = []
220 | for indx_i, pfeat in enumerate(qfeats):
221 | for indx_j, gfeat in enumerate(feat_g):
222 | graph.append((indx_i, indx_j, (pfeat * gfeat).sum()))
223 | km_res, max_val = run_kuhn_munkres(graph)
224 |
225 | # revise the similarity between query person and its matching
226 | for indx_i, indx_j, _ in km_res:
227 | # 0 denotes the query roi
228 | if indx_i == 0:
229 | sim[indx_j] = max_val
230 | break
231 | for gallery_imname, sim in name2sim.items():
232 | gt = name2gt[gallery_imname]
233 | det, feat_g = name_to_det_feat[gallery_imname]
234 | # assign label for each det
235 | label = np.zeros(len(sim), dtype=np.int32)
236 | if gt.size > 0:
237 | w, h = gt[2], gt[3]
238 | gt[2:] += gt[:2]
239 | query_gt.append({"img": str(gallery_imname), "roi": list(map(float, list(gt)))})
240 | iou_thresh = min(0.5, (w * h * 1.0) / ((w + 10) * (h + 10)))
241 | inds = np.argsort(sim)[::-1]
242 | sim = sim[inds]
243 | det = det[inds]
244 | # only set the first matched det as true positive
245 | for j, roi in enumerate(det[:, :4]):
246 | if _compute_iou(roi, gt) >= iou_thresh:
247 | label[j] = 1
248 | count_tp += 1
249 | break
250 | y_true.extend(list(label))
251 | y_score.extend(list(sim))
252 | imgs.extend([gallery_imname] * len(sim))
253 | rois.extend(list(det))
254 | tested.add(gallery_imname)
255 | # 3. Compute AP for this query (need to scale by recall rate)
256 | y_score = np.asarray(y_score)
257 | y_true = np.asarray(y_true)
258 | assert count_tp <= count_gt
259 | recall_rate = count_tp * 1.0 / count_gt
260 | ap = 0 if count_tp == 0 else average_precision_score(y_true, y_score) * recall_rate
261 | aps.append(ap)
262 | inds = np.argsort(y_score)[::-1]
263 | y_score = y_score[inds]
264 | y_true = y_true[inds]
265 | accs.append([min(1, sum(y_true[:k])) for k in topk])
266 | # 4. Save result for JSON dump
267 | new_entry = {
268 | "query_img": str(query_imname),
269 | "query_roi": list(map(float, list(query_roi))),
270 | "query_gt": query_gt,
271 | "gallery": [],
272 | }
273 | # only record wrong results
274 | if int(y_true[0]):
275 | continue
276 | # only save top-10 predictions
277 | for k in range(10):
278 | new_entry["gallery"].append(
279 | {
280 | "img": str(imgs[inds[k]]),
281 | "roi": list(map(float, list(rois[inds[k]]))),
282 | "score": float(y_score[k]),
283 | "correct": int(y_true[k]),
284 | }
285 | )
286 | ret["results"].append(new_entry)
287 |
288 | print("search ranking:")
289 | print(" mAP = {:.2%}".format(np.mean(aps)))
290 | accs = np.mean(accs, axis=0)
291 | for i, k in enumerate(topk):
292 | print(" top-{:2d} = {:.2%}".format(k, accs[i]))
293 |
294 | write_json(ret, "vis/results.json")
295 |
296 | ret["mAP"] = np.mean(aps)
297 | ret["accs"] = accs
298 | return ret
299 |
300 |
301 | def eval_search_prw(
302 | gallery_dataset,
303 | query_dataset,
304 | gallery_dets,
305 | gallery_feats,
306 | query_box_feats,
307 | query_dets,
308 | query_feats,
309 | k1=30,
310 | k2=4,
311 | det_thresh=0.5,
312 | cbgm=False,
313 | ignore_cam_id=True,
314 | ):
315 | """
316 | gallery_det (list of ndarray): n_det x [x1, x2, y1, y2, score] per image
317 | gallery_feat (list of ndarray): n_det x D features per image
318 | query_feat (list of ndarray): D dimensional features per query image
319 | det_thresh (float): filter out gallery detections whose scores below this
320 | gallery_size (int): -1 for using full set
321 | ignore_cam_id (bool): Set to True acoording to CUHK-SYSU,
322 | although it's a common practice to focus on cross-cam match only.
323 | """
324 | assert len(gallery_dataset) == len(gallery_dets)
325 | assert len(gallery_dataset) == len(gallery_feats)
326 | assert len(query_dataset) == len(query_box_feats)
327 |
328 | annos = gallery_dataset.annotations
329 | name_to_det_feat = {}
330 | for anno, det, feat in zip(annos, gallery_dets, gallery_feats):
331 | name = anno["img_name"]
332 | scores = det[:, 4].ravel()
333 | inds = np.where(scores >= det_thresh)[0]
334 | if len(inds) > 0:
335 | name_to_det_feat[name] = (det[inds], feat[inds])
336 |
337 | aps = []
338 | accs = []
339 | topk = [1, 5, 10]
340 | ret = {"image_root": gallery_dataset.img_prefix, "results": []}
341 | for i in range(len(query_dataset)):
342 | y_true, y_score = [], []
343 | imgs, rois = [], []
344 | count_gt, count_tp = 0, 0
345 |
346 | feat_p = query_box_feats[i].ravel()
347 |
348 | query_imname = query_dataset.annotations[i]["img_name"]
349 | query_roi = query_dataset.annotations[i]["boxes"]
350 | query_pid = query_dataset.annotations[i]["pids"]
351 | query_cam = query_dataset.annotations[i]["cam_id"]
352 |
353 | # Find all occurence of this query
354 | gallery_imgs = []
355 | for x in annos:
356 | if query_pid in x["pids"] and x["img_name"] != query_imname:
357 | gallery_imgs.append(x)
358 | query_gts = {}
359 | for item in gallery_imgs:
360 | query_gts[item["img_name"]] = item["boxes"][item["pids"] == query_pid]
361 |
362 | # Construct gallery set for this query
363 | if ignore_cam_id:
364 | gallery_imgs = []
365 | for x in annos:
366 | if x["img_name"] != query_imname:
367 | gallery_imgs.append(x)
368 | else:
369 | gallery_imgs = []
370 | for x in annos:
371 | if x["img_name"] != query_imname and x["cam_id"] != query_cam:
372 | gallery_imgs.append(x)
373 |
374 | name2sim = {}
375 | sims = []
376 | imgs_cbgm = []
377 | # 1. Go through all gallery samples
378 | for item in gallery_imgs:
379 | gallery_imname = item["img_name"]
380 | # some contain the query (gt not empty), some not
381 | count_gt += gallery_imname in query_gts
382 | # compute distance between query and gallery dets
383 | if gallery_imname not in name_to_det_feat:
384 | continue
385 | det, feat_g = name_to_det_feat[gallery_imname]
386 | # get L2-normalized feature matrix NxD
387 | assert feat_g.size == np.prod(feat_g.shape[:2])
388 | feat_g = feat_g.reshape(feat_g.shape[:2])
389 | # compute cosine similarities
390 | sim = feat_g.dot(feat_p).ravel()
391 |
392 | if gallery_imname in name2sim:
393 | continue
394 | name2sim[gallery_imname] = sim
395 | sims.extend(list(sim))
396 | imgs_cbgm.extend([gallery_imname] * len(sim))
397 |
398 | if cbgm:
399 | sims = np.array(sims)
400 | imgs_cbgm = np.array(imgs_cbgm)
401 | inds = np.argsort(sims)[-k1:]
402 | imgs_cbgm = set(imgs_cbgm[inds])
403 | for img in imgs_cbgm:
404 | sim = name2sim[img]
405 | det, feat_g = name_to_det_feat[img]
406 | qboxes = query_dets[i][:k2]
407 | qfeats = query_feats[i][:k2]
408 | assert (
409 | query_roi - qboxes[0][:4]
410 | ).sum() <= 0.001, "query_roi must be the first one in pboxes"
411 |
412 | graph = []
413 | for indx_i, pfeat in enumerate(qfeats):
414 | for indx_j, gfeat in enumerate(feat_g):
415 | graph.append((indx_i, indx_j, (pfeat * gfeat).sum()))
416 | km_res, max_val = run_kuhn_munkres(graph)
417 |
418 | for indx_i, indx_j, _ in km_res:
419 | if indx_i == 0:
420 | sim[indx_j] = max_val
421 | break
422 | for gallery_imname, sim in name2sim.items():
423 | det, feat_g = name_to_det_feat[gallery_imname]
424 | # assign label for each det
425 | label = np.zeros(len(sim), dtype=np.int32)
426 | if gallery_imname in query_gts:
427 | gt = query_gts[gallery_imname].ravel()
428 | w, h = gt[2] - gt[0], gt[3] - gt[1]
429 | iou_thresh = min(0.5, (w * h * 1.0) / ((w + 10) * (h + 10)))
430 | inds = np.argsort(sim)[::-1]
431 | sim = sim[inds]
432 | det = det[inds]
433 | # only set the first matched det as true positive
434 | for j, roi in enumerate(det[:, :4]):
435 | if _compute_iou(roi, gt) >= iou_thresh:
436 | label[j] = 1
437 | count_tp += 1
438 | break
439 | y_true.extend(list(label))
440 | y_score.extend(list(sim))
441 | imgs.extend([gallery_imname] * len(sim))
442 | rois.extend(list(det))
443 |
444 | # 2. Compute AP for this query (need to scale by recall rate)
445 | y_score = np.asarray(y_score)
446 | y_true = np.asarray(y_true)
447 | assert count_tp <= count_gt
448 | recall_rate = count_tp * 1.0 / count_gt
449 | ap = 0 if count_tp == 0 else average_precision_score(y_true, y_score) * recall_rate
450 | aps.append(ap)
451 | inds = np.argsort(y_score)[::-1]
452 | y_score = y_score[inds]
453 | y_true = y_true[inds]
454 | accs.append([min(1, sum(y_true[:k])) for k in topk])
455 | # 4. Save result for JSON dump
456 | new_entry = {
457 | "query_img": str(query_imname),
458 | "query_roi": list(map(float, list(query_roi.squeeze()))),
459 | "query_gt": query_gts,
460 | "gallery": [],
461 | }
462 | # only save top-10 predictions
463 | for k in range(10):
464 | new_entry["gallery"].append(
465 | {
466 | "img": str(imgs[inds[k]]),
467 | "roi": list(map(float, list(rois[inds[k]]))),
468 | "score": float(y_score[k]),
469 | "correct": int(y_true[k]),
470 | }
471 | )
472 | ret["results"].append(new_entry)
473 |
474 | print("search ranking:")
475 | mAP = np.mean(aps)
476 | print(" mAP = {:.2%}".format(mAP))
477 | accs = np.mean(accs, axis=0)
478 | for i, k in enumerate(topk):
479 | print(" top-{:2d} = {:.2%}".format(k, accs[i]))
480 |
481 | # write_json(ret, "vis/results.json")
482 |
483 | ret["mAP"] = np.mean(aps)
484 | ret["accs"] = accs
485 | return ret
486 |
--------------------------------------------------------------------------------
/models/seqnet.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn import init
7 | from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
8 | from torchvision.models.detection.roi_heads import RoIHeads
9 | from torchvision.models.detection.rpn import AnchorGenerator, RegionProposalNetwork, RPNHead
10 | from torchvision.models.detection.transform import GeneralizedRCNNTransform
11 | from torchvision.ops import MultiScaleRoIAlign
12 | from torchvision.ops import boxes as box_ops
13 |
14 | from models.oim import OIMLoss
15 | from models.resnet import build_resnet
16 | from models.swin import build_swin
17 |
18 | class SeqNet(nn.Module):
19 | def __init__(self, cfg):
20 | super(SeqNet, self).__init__()
21 |
22 | backbone_name = cfg.MODEL.BONE
23 | semantic_weight = cfg.MODEL.SEMANTIC_WEIGHT
24 | if backbone_name == 'resnet50':
25 | backbone, box_head = build_resnet(name="resnet50", pretrained=True)
26 | feat_len = 2048
27 | elif 'swin' in backbone_name:
28 | backbone, box_head, feat_len = build_swin(name=backbone_name, semantic_weight=semantic_weight)
29 |
30 | anchor_generator = AnchorGenerator(
31 | sizes=((32, 64, 128, 256, 512),), aspect_ratios=((0.5, 1.0, 2.0),)
32 | )
33 | head = RPNHead(
34 | in_channels=backbone.out_channels,
35 | num_anchors=anchor_generator.num_anchors_per_location()[0],
36 | )
37 | pre_nms_top_n = dict(
38 | training=cfg.MODEL.RPN.PRE_NMS_TOPN_TRAIN, testing=cfg.MODEL.RPN.PRE_NMS_TOPN_TEST
39 | )
40 | post_nms_top_n = dict(
41 | training=cfg.MODEL.RPN.POST_NMS_TOPN_TRAIN, testing=cfg.MODEL.RPN.POST_NMS_TOPN_TEST
42 | )
43 | rpn = RegionProposalNetwork(
44 | anchor_generator=anchor_generator,
45 | head=head,
46 | fg_iou_thresh=cfg.MODEL.RPN.POS_THRESH_TRAIN,
47 | bg_iou_thresh=cfg.MODEL.RPN.NEG_THRESH_TRAIN,
48 | batch_size_per_image=cfg.MODEL.RPN.BATCH_SIZE_TRAIN,
49 | positive_fraction=cfg.MODEL.RPN.POS_FRAC_TRAIN,
50 | pre_nms_top_n=pre_nms_top_n,
51 | post_nms_top_n=post_nms_top_n,
52 | nms_thresh=cfg.MODEL.RPN.NMS_THRESH,
53 | )
54 |
55 | faster_rcnn_predictor = FastRCNNPredictor(feat_len, 2)
56 | reid_head = deepcopy(box_head)
57 | box_roi_pool = MultiScaleRoIAlign(
58 | featmap_names=["feat_res4"], output_size=14, sampling_ratio=2
59 | )
60 | box_predictor = BBoxRegressor(feat_len, num_classes=2, bn_neck=cfg.MODEL.ROI_HEAD.BN_NECK)
61 | roi_heads = SeqRoIHeads(
62 | # OIM
63 | num_pids=cfg.MODEL.LOSS.LUT_SIZE,
64 | num_cq_size=cfg.MODEL.LOSS.CQ_SIZE,
65 | oim_momentum=cfg.MODEL.LOSS.OIM_MOMENTUM,
66 | oim_scalar=cfg.MODEL.LOSS.OIM_SCALAR,
67 | # SeqNet
68 | faster_rcnn_predictor=faster_rcnn_predictor,
69 | reid_head=reid_head,
70 | # parent class
71 | box_roi_pool=box_roi_pool,
72 | box_head=box_head,
73 | box_predictor=box_predictor,
74 | fg_iou_thresh=cfg.MODEL.ROI_HEAD.POS_THRESH_TRAIN,
75 | bg_iou_thresh=cfg.MODEL.ROI_HEAD.NEG_THRESH_TRAIN,
76 | batch_size_per_image=cfg.MODEL.ROI_HEAD.BATCH_SIZE_TRAIN,
77 | positive_fraction=cfg.MODEL.ROI_HEAD.POS_FRAC_TRAIN,
78 | bbox_reg_weights=None,
79 | score_thresh=cfg.MODEL.ROI_HEAD.SCORE_THRESH_TEST,
80 | nms_thresh=cfg.MODEL.ROI_HEAD.NMS_THRESH_TEST,
81 | detections_per_img=cfg.MODEL.ROI_HEAD.DETECTIONS_PER_IMAGE_TEST,
82 | feat_len=feat_len,
83 | )
84 |
85 | transform = GeneralizedRCNNTransform(
86 | min_size=cfg.INPUT.MIN_SIZE,
87 | max_size=cfg.INPUT.MAX_SIZE,
88 | image_mean=[0.485, 0.456, 0.406],
89 | image_std=[0.229, 0.224, 0.225],
90 | )
91 |
92 | self.backbone = backbone
93 | self.rpn = rpn
94 | self.roi_heads = roi_heads
95 | self.transform = transform
96 |
97 | # loss weights
98 | self.lw_rpn_reg = cfg.SOLVER.LW_RPN_REG
99 | self.lw_rpn_cls = cfg.SOLVER.LW_RPN_CLS
100 | self.lw_proposal_reg = cfg.SOLVER.LW_PROPOSAL_REG
101 | self.lw_proposal_cls = cfg.SOLVER.LW_PROPOSAL_CLS
102 | self.lw_box_reg = cfg.SOLVER.LW_BOX_REG
103 | self.lw_box_cls = cfg.SOLVER.LW_BOX_CLS
104 | self.lw_box_reid = cfg.SOLVER.LW_BOX_REID
105 |
106 | def inference(self, images, targets=None, query_img_as_gallery=False):
107 | """
108 | query_img_as_gallery: Set to True to detect all people in the query image.
109 | Meanwhile, the gt box should be the first of the detected boxes.
110 | This option serves CBGM.
111 | """
112 | original_image_sizes = [img.shape[-2:] for img in images]
113 | images, targets = self.transform(images, targets)
114 | features = self.backbone(images.tensors)
115 |
116 | if query_img_as_gallery:
117 | assert targets is not None
118 |
119 | if targets is not None and not query_img_as_gallery:
120 | # query
121 | boxes = [t["boxes"] for t in targets]
122 | box_features = self.roi_heads.box_roi_pool(features, boxes, images.image_sizes)
123 | box_features = self.roi_heads.reid_head(box_features)
124 | embeddings, _ = self.roi_heads.embedding_head(box_features)
125 | return embeddings.split(1, 0)
126 | else:
127 | # gallery
128 | proposals, _ = self.rpn(images, features, targets)
129 | detections, _ = self.roi_heads(
130 | features, proposals, images.image_sizes, targets, query_img_as_gallery
131 | )
132 | detections = self.transform.postprocess(
133 | detections, images.image_sizes, original_image_sizes
134 | )
135 | return detections
136 |
137 | def forward(self, images, targets=None, query_img_as_gallery=False):
138 | if not self.training:
139 | return self.inference(images, targets, query_img_as_gallery)
140 |
141 | images, targets = self.transform(images, targets)
142 | features = self.backbone(images.tensors)
143 | proposals, proposal_losses = self.rpn(images, features, targets)
144 | _, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
145 |
146 | # rename rpn losses to be consistent with detection losses
147 | proposal_losses["loss_rpn_reg"] = proposal_losses.pop("loss_rpn_box_reg")
148 | proposal_losses["loss_rpn_cls"] = proposal_losses.pop("loss_objectness")
149 |
150 | losses = {}
151 | losses.update(detector_losses)
152 | losses.update(proposal_losses)
153 |
154 | # apply loss weights
155 | losses["loss_rpn_reg"] *= self.lw_rpn_reg
156 | losses["loss_rpn_cls"] *= self.lw_rpn_cls
157 | losses["loss_proposal_reg"] *= self.lw_proposal_reg
158 | losses["loss_proposal_cls"] *= self.lw_proposal_cls
159 | losses["loss_box_reg"] *= self.lw_box_reg
160 | losses["loss_box_cls"] *= self.lw_box_cls
161 | losses["loss_box_reid"] *= self.lw_box_reid
162 | return losses
163 |
164 |
165 | class SeqRoIHeads(RoIHeads):
166 | def __init__(
167 | self,
168 | num_pids,
169 | num_cq_size,
170 | oim_momentum,
171 | oim_scalar,
172 | faster_rcnn_predictor,
173 | reid_head,
174 | feat_len,
175 | *args,
176 | **kwargs
177 | ):
178 | super(SeqRoIHeads, self).__init__(*args, **kwargs)
179 | self.embedding_head = NormAwareEmbedding(in_channels=[int(feat_len/2),feat_len])
180 | self.reid_loss = OIMLoss(256, num_pids, num_cq_size, oim_momentum, oim_scalar)
181 | self.faster_rcnn_predictor = faster_rcnn_predictor
182 | self.reid_head = reid_head
183 | # rename the method inherited from parent class
184 | self.postprocess_proposals = self.postprocess_detections
185 |
186 | def forward(self, features, proposals, image_shapes, targets=None, query_img_as_gallery=False):
187 | """
188 | Arguments:
189 | features (List[Tensor])
190 | proposals (List[Tensor[N, 4]])
191 | image_shapes (List[Tuple[H, W]])
192 | targets (List[Dict])
193 | """
194 | if self.training:
195 | proposals, _, proposal_pid_labels, proposal_reg_targets = self.select_training_samples(
196 | proposals, targets
197 | )
198 |
199 | # ------------------- Faster R-CNN head ------------------ #
200 | proposal_features = self.box_roi_pool(features, proposals, image_shapes)
201 | proposal_features = self.box_head(proposal_features)
202 | proposal_cls_scores, proposal_regs = self.faster_rcnn_predictor(
203 | proposal_features["feat_res5"]
204 | )
205 |
206 | if self.training:
207 | boxes = self.get_boxes(proposal_regs, proposals, image_shapes)
208 | boxes = [boxes_per_image.detach() for boxes_per_image in boxes]
209 | boxes, _, box_pid_labels, box_reg_targets = self.select_training_samples(boxes, targets)
210 | else:
211 | # invoke the postprocess method inherited from parent class to process proposals
212 | boxes, scores, _ = self.postprocess_proposals(
213 | proposal_cls_scores, proposal_regs, proposals, image_shapes
214 | )
215 |
216 | cws = True
217 | gt_det = None
218 | if not self.training and query_img_as_gallery:
219 | # When regarding the query image as gallery, GT boxes may be excluded
220 | # from detected boxes. To avoid this, we compulsorily include GT in the
221 | # detection results. Additionally, CWS should be disabled as the
222 | # confidences of these people in query image are 1
223 | cws = False
224 | gt_box = [targets[0]["boxes"]]
225 | gt_box_features = self.box_roi_pool(features, gt_box, image_shapes)
226 | gt_box_features = self.reid_head(gt_box_features)
227 | embeddings, _ = self.embedding_head(gt_box_features)
228 | gt_det = {"boxes": targets[0]["boxes"], "embeddings": embeddings}
229 |
230 | # no detection predicted by Faster R-CNN head in test phase
231 | if boxes[0].shape[0] == 0:
232 | assert not self.training
233 | boxes = gt_det["boxes"] if gt_det else torch.zeros(0, 4)
234 | labels = torch.ones(1).type_as(boxes) if gt_det else torch.zeros(0)
235 | scores = torch.ones(1).type_as(boxes) if gt_det else torch.zeros(0)
236 | embeddings = gt_det["embeddings"] if gt_det else torch.zeros(0, 256)
237 | return [dict(boxes=boxes, labels=labels, scores=scores, embeddings=embeddings)], []
238 |
239 | # --------------------- Baseline head -------------------- #
240 | box_features = self.box_roi_pool(features, boxes, image_shapes)
241 | box_features = self.reid_head(box_features)
242 | box_regs = self.box_predictor(box_features["feat_res5"])
243 | box_embeddings, box_cls_scores = self.embedding_head(box_features)
244 | if box_cls_scores.dim() == 0:
245 | box_cls_scores = box_cls_scores.unsqueeze(0)
246 |
247 | result, losses = [], {}
248 | if self.training:
249 | proposal_labels = [y.clamp(0, 1) for y in proposal_pid_labels]
250 | box_labels = [y.clamp(0, 1) for y in box_pid_labels]
251 | losses = detection_losses(
252 | proposal_cls_scores,
253 | proposal_regs,
254 | proposal_labels,
255 | proposal_reg_targets,
256 | box_cls_scores,
257 | box_regs,
258 | box_labels,
259 | box_reg_targets,
260 | )
261 | loss_box_reid = self.reid_loss(box_embeddings, box_pid_labels)
262 | losses.update(loss_box_reid=loss_box_reid)
263 | else:
264 | # The IoUs of these boxes are higher than that of proposals,
265 | # so a higher NMS threshold is needed
266 | orig_thresh = self.nms_thresh
267 | self.nms_thresh = 0.5
268 | boxes, scores, embeddings, labels = self.postprocess_boxes(
269 | box_cls_scores,
270 | box_regs,
271 | box_embeddings,
272 | boxes,
273 | image_shapes,
274 | fcs=scores,
275 | gt_det=gt_det,
276 | cws=cws,
277 | )
278 | # set to original thresh after finishing postprocess
279 | self.nms_thresh = orig_thresh
280 | num_images = len(boxes)
281 | for i in range(num_images):
282 | result.append(
283 | dict(
284 | boxes=boxes[i], labels=labels[i], scores=scores[i], embeddings=embeddings[i]
285 | )
286 | )
287 | return result, losses
288 |
289 | def get_boxes(self, box_regression, proposals, image_shapes):
290 | """
291 | Get boxes from proposals.
292 | """
293 | boxes_per_image = [len(boxes_in_image) for boxes_in_image in proposals]
294 | pred_boxes = self.box_coder.decode(box_regression, proposals)
295 | pred_boxes = pred_boxes.split(boxes_per_image, 0)
296 |
297 | all_boxes = []
298 | for boxes, image_shape in zip(pred_boxes, image_shapes):
299 | boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
300 | # remove predictions with the background label
301 | boxes = boxes[:, 1:].reshape(-1, 4)
302 | all_boxes.append(boxes)
303 |
304 | return all_boxes
305 |
306 | def postprocess_boxes(
307 | self,
308 | class_logits,
309 | box_regression,
310 | embeddings,
311 | proposals,
312 | image_shapes,
313 | fcs=None,
314 | gt_det=None,
315 | cws=True,
316 | ):
317 | """
318 | Similar to RoIHeads.postprocess_detections, but can handle embeddings and implement
319 | First Classification Score (FCS).
320 | """
321 | device = class_logits.device
322 |
323 | boxes_per_image = [len(boxes_in_image) for boxes_in_image in proposals]
324 | pred_boxes = self.box_coder.decode(box_regression, proposals)
325 |
326 | if fcs is not None:
327 | # Fist Classification Score (FCS)
328 | pred_scores = fcs[0]
329 | else:
330 | pred_scores = torch.sigmoid(class_logits)
331 | if cws:
332 | # Confidence Weighted Similarity (CWS)
333 | embeddings = embeddings * pred_scores.view(-1, 1)
334 |
335 | # split boxes and scores per image
336 | pred_boxes = pred_boxes.split(boxes_per_image, 0)
337 | pred_scores = pred_scores.split(boxes_per_image, 0)
338 | pred_embeddings = embeddings.split(boxes_per_image, 0)
339 |
340 | all_boxes = []
341 | all_scores = []
342 | all_labels = []
343 | all_embeddings = []
344 | for boxes, scores, embeddings, image_shape in zip(
345 | pred_boxes, pred_scores, pred_embeddings, image_shapes
346 | ):
347 | boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
348 |
349 | # create labels for each prediction
350 | labels = torch.ones(scores.size(0), device=device)
351 |
352 | # remove predictions with the background label
353 | boxes = boxes[:, 1:]
354 | scores = scores.unsqueeze(1)
355 | labels = labels.unsqueeze(1)
356 |
357 | # batch everything, by making every class prediction be a separate instance
358 | boxes = boxes.reshape(-1, 4)
359 | scores = scores.flatten()
360 | labels = labels.flatten()
361 | embeddings = embeddings.reshape(-1, self.embedding_head.dim)
362 |
363 | # remove low scoring boxes
364 | inds = torch.nonzero(scores > self.score_thresh).squeeze(1)
365 | boxes, scores, labels, embeddings = (
366 | boxes[inds],
367 | scores[inds],
368 | labels[inds],
369 | embeddings[inds],
370 | )
371 |
372 | # remove empty boxes
373 | keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
374 | boxes, scores, labels, embeddings = (
375 | boxes[keep],
376 | scores[keep],
377 | labels[keep],
378 | embeddings[keep],
379 | )
380 |
381 | if gt_det is not None:
382 | # include GT into the detection results
383 | boxes = torch.cat((boxes, gt_det["boxes"]), dim=0)
384 | labels = torch.cat((labels, torch.tensor([1.0]).to(device)), dim=0)
385 | scores = torch.cat((scores, torch.tensor([1.0]).to(device)), dim=0)
386 | embeddings = torch.cat((embeddings, gt_det["embeddings"]), dim=0)
387 |
388 | # non-maximum suppression, independently done per class
389 | keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
390 | # keep only topk scoring predictions
391 | keep = keep[: self.detections_per_img]
392 | boxes, scores, labels, embeddings = (
393 | boxes[keep],
394 | scores[keep],
395 | labels[keep],
396 | embeddings[keep],
397 | )
398 |
399 | all_boxes.append(boxes)
400 | all_scores.append(scores)
401 | all_labels.append(labels)
402 | all_embeddings.append(embeddings)
403 |
404 | return all_boxes, all_scores, all_embeddings, all_labels
405 |
406 |
407 | class NormAwareEmbedding(nn.Module):
408 | """
409 | Implements the Norm-Aware Embedding proposed in
410 | Chen, Di, et al. "Norm-aware embedding for efficient person search." CVPR 2020.
411 | """
412 |
413 | def __init__(self, featmap_names=["feat_res4", "feat_res5"], in_channels=[1024, 2048], dim=256):
414 | super(NormAwareEmbedding, self).__init__()
415 | self.featmap_names = featmap_names
416 | self.in_channels = in_channels
417 | self.dim = dim
418 |
419 | self.projectors = nn.ModuleDict()
420 | indv_dims = self._split_embedding_dim()
421 | for ftname, in_channel, indv_dim in zip(self.featmap_names, self.in_channels, indv_dims):
422 | proj = nn.Sequential(nn.Linear(in_channel, indv_dim), nn.BatchNorm1d(indv_dim))
423 | init.normal_(proj[0].weight, std=0.01)
424 | init.normal_(proj[1].weight, std=0.01)
425 | init.constant_(proj[0].bias, 0)
426 | init.constant_(proj[1].bias, 0)
427 | self.projectors[ftname] = proj
428 |
429 | self.rescaler = nn.BatchNorm1d(1, affine=True)
430 |
431 | def forward(self, featmaps):
432 | """
433 | Arguments:
434 | featmaps: OrderedDict[Tensor], and in featmap_names you can choose which
435 | featmaps to use
436 | Returns:
437 | tensor of size (BatchSize, dim), L2 normalized embeddings.
438 | tensor of size (BatchSize, ) rescaled norm of embeddings, as class_logits.
439 | """
440 | assert len(featmaps) == len(self.featmap_names)
441 | if len(featmaps) == 1:
442 | k, v = featmaps.items()[0]
443 | v = self._flatten_fc_input(v)
444 | embeddings = self.projectors[k](v)
445 | norms = embeddings.norm(2, 1, keepdim=True)
446 | embeddings = embeddings / norms.expand_as(embeddings).clamp(min=1e-12)
447 | norms = self.rescaler(norms).squeeze()
448 | return embeddings, norms
449 | else:
450 | outputs = []
451 | for k, v in featmaps.items():
452 | v = self._flatten_fc_input(v)
453 | outputs.append(self.projectors[k](v))
454 | embeddings = torch.cat(outputs, dim=1)
455 | norms = embeddings.norm(2, 1, keepdim=True)
456 | embeddings = embeddings / norms.expand_as(embeddings).clamp(min=1e-12)
457 | norms = self.rescaler(norms).squeeze()
458 | return embeddings, norms
459 |
460 | def _flatten_fc_input(self, x):
461 | if x.ndimension() == 4:
462 | assert list(x.shape[2:]) == [1, 1]
463 | return x.flatten(start_dim=1)
464 | return x
465 |
466 | def _split_embedding_dim(self):
467 | parts = len(self.in_channels)
468 | tmp = [self.dim // parts] * parts
469 | if sum(tmp) == self.dim:
470 | return tmp
471 | else:
472 | res = self.dim % parts
473 | for i in range(1, res + 1):
474 | tmp[-i] += 1
475 | assert sum(tmp) == self.dim
476 | return tmp
477 |
478 |
479 | class BBoxRegressor(nn.Module):
480 | """
481 | Bounding box regression layer.
482 | """
483 |
484 | def __init__(self, in_channels, num_classes=2, bn_neck=True):
485 | """
486 | Args:
487 | in_channels (int): Input channels.
488 | num_classes (int, optional): Defaults to 2 (background and pedestrian).
489 | bn_neck (bool, optional): Whether to use BN after Linear. Defaults to True.
490 | """
491 | super(BBoxRegressor, self).__init__()
492 | if bn_neck:
493 | self.bbox_pred = nn.Sequential(
494 | nn.Linear(in_channels, 4 * num_classes), nn.BatchNorm1d(4 * num_classes)
495 | )
496 | init.normal_(self.bbox_pred[0].weight, std=0.01)
497 | init.normal_(self.bbox_pred[1].weight, std=0.01)
498 | init.constant_(self.bbox_pred[0].bias, 0)
499 | init.constant_(self.bbox_pred[1].bias, 0)
500 | else:
501 | self.bbox_pred = nn.Linear(in_channels, 4 * num_classes)
502 | init.normal_(self.bbox_pred.weight, std=0.01)
503 | init.constant_(self.bbox_pred.bias, 0)
504 |
505 | def forward(self, x):
506 | if x.ndimension() == 4:
507 | if list(x.shape[2:]) != [1, 1]:
508 | x = F.adaptive_avg_pool2d(x, output_size=1)
509 | x = x.flatten(start_dim=1)
510 | bbox_deltas = self.bbox_pred(x)
511 | return bbox_deltas
512 |
513 |
514 | def detection_losses(
515 | proposal_cls_scores,
516 | proposal_regs,
517 | proposal_labels,
518 | proposal_reg_targets,
519 | box_cls_scores,
520 | box_regs,
521 | box_labels,
522 | box_reg_targets,
523 | ):
524 | proposal_labels = torch.cat(proposal_labels, dim=0)
525 | box_labels = torch.cat(box_labels, dim=0)
526 | proposal_reg_targets = torch.cat(proposal_reg_targets, dim=0)
527 | box_reg_targets = torch.cat(box_reg_targets, dim=0)
528 |
529 | loss_proposal_cls = F.cross_entropy(proposal_cls_scores, proposal_labels)
530 | loss_box_cls = F.binary_cross_entropy_with_logits(box_cls_scores, box_labels.float())
531 |
532 | # get indices that correspond to the regression targets for the
533 | # corresponding ground truth labels, to be used with advanced indexing
534 | sampled_pos_inds_subset = torch.nonzero(proposal_labels > 0).squeeze(1)
535 | labels_pos = proposal_labels[sampled_pos_inds_subset]
536 | N = proposal_cls_scores.size(0)
537 | proposal_regs = proposal_regs.reshape(N, -1, 4)
538 |
539 | loss_proposal_reg = F.smooth_l1_loss(
540 | proposal_regs[sampled_pos_inds_subset, labels_pos],
541 | proposal_reg_targets[sampled_pos_inds_subset],
542 | reduction="sum",
543 | )
544 | loss_proposal_reg = loss_proposal_reg / proposal_labels.numel()
545 |
546 | sampled_pos_inds_subset = torch.nonzero(box_labels > 0).squeeze(1)
547 | labels_pos = box_labels[sampled_pos_inds_subset]
548 | N = box_cls_scores.size(0)
549 | box_regs = box_regs.reshape(N, -1, 4)
550 |
551 | loss_box_reg = F.smooth_l1_loss(
552 | box_regs[sampled_pos_inds_subset, labels_pos],
553 | box_reg_targets[sampled_pos_inds_subset],
554 | reduction="sum",
555 | )
556 | loss_box_reg = loss_box_reg / box_labels.numel()
557 |
558 | return dict(
559 | loss_proposal_cls=loss_proposal_cls,
560 | loss_proposal_reg=loss_proposal_reg,
561 | loss_box_cls=loss_box_cls,
562 | loss_box_reg=loss_box_reg,
563 | )
564 |
--------------------------------------------------------------------------------
/models/swin_transformer.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from collections import OrderedDict
3 | from copy import deepcopy
4 | import logging
5 |
6 | import math
7 | from typing import Sequence
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torch.utils.checkpoint as cp
12 | import numpy as np
13 | import cv2
14 |
15 | from torch.nn import Module as BaseModule
16 | from torch.nn import ModuleList
17 | from torch.nn import Sequential
18 | from torch.nn import Linear
19 | from torch import Tensor
20 | from mmcv.runner import load_checkpoint as _load_checkpoint
21 |
22 | from itertools import repeat
23 | import collections.abc
24 | def _ntuple(n):
25 |
26 | def parse(x):
27 | if isinstance(x, collections.abc.Iterable):
28 | return x
29 | return tuple(repeat(x, n))
30 |
31 | return parse
32 | to_2tuple = _ntuple(2)
33 |
34 | def trunc_normal_init(module: nn.Module,
35 | mean: float = 0,
36 | std: float = 1,
37 | a: float = -2,
38 | b: float = 2,
39 | bias: float = 0) -> None:
40 | if hasattr(module, 'weight') and module.weight is not None:
41 | #trunc_normal_(module.weight, mean, std, a, b) # type: ignore
42 | _no_grad_trunc_normal_(module.weight, mean, std, a, b) # type: ignore
43 | if hasattr(module, 'bias') and module.bias is not None:
44 | nn.init.constant_(module.bias, bias) # type: ignore
45 |
46 | def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float,
47 | b: float) -> Tensor:
48 | # Method based on
49 | # https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
50 | # Modified from
51 | # https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
52 | def norm_cdf(x):
53 | # Computes standard normal cumulative distribution function
54 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
55 |
56 | if (mean < a - 2 * std) or (mean > b + 2 * std):
57 | warnings.warn(
58 | 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
59 | 'The distribution of values may be incorrect.',
60 | stacklevel=2)
61 |
62 | with torch.no_grad():
63 | # Values are generated by using a truncated uniform distribution and
64 | # then using the inverse CDF for the normal distribution.
65 | # Get upper and lower cdf values
66 | lower = norm_cdf((a - mean) / std)
67 | upper = norm_cdf((b - mean) / std)
68 |
69 | # Uniformly fill tensor with values from [lower, upper], then translate
70 | # to [2lower-1, 2upper-1].
71 | tensor.uniform_(2 * lower - 1, 2 * upper - 1)
72 |
73 | # Use inverse cdf transform for normal distribution to get truncated
74 | # standard normal
75 | tensor.erfinv_()
76 |
77 | # Transform to proper mean, std
78 | tensor.mul_(std * math.sqrt(2.))
79 | tensor.add_(mean)
80 |
81 | # Clamp to ensure it's in the proper range
82 | tensor.clamp_(min=a, max=b)
83 | return tensor
84 |
85 | def trunc_normal_(tensor: Tensor,
86 | mean: float = 0.,
87 | std: float = 1.,
88 | a: float = -2.,
89 | b: float = 2.) -> Tensor:
90 | r"""Fills the input Tensor with values drawn from a truncated
91 | normal distribution. The values are effectively drawn from the
92 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
93 | with values outside :math:`[a, b]` redrawn until they are within
94 | the bounds. The method used for generating the random values works
95 | best when :math:`a \leq \text{mean} \leq b`.
96 |
97 | Modified from
98 | https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
99 |
100 | Args:
101 | tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`.
102 | mean (float): the mean of the normal distribution.
103 | std (float): the standard deviation of the normal distribution.
104 | a (float): the minimum cutoff value.
105 | b (float): the maximum cutoff value.
106 | """
107 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
108 |
109 | def constant_init(module, val, bias=0):
110 | if hasattr(module, 'weight') and module.weight is not None:
111 | nn.init.constant_(module.weight, val)
112 | if hasattr(module, 'bias') and module.bias is not None:
113 | nn.init.constant_(module.bias, bias)
114 |
115 | def build_norm_layer(norm_cfg,embed_dims):
116 | assert norm_cfg['type'] == 'LN'
117 | norm_layer = nn.LayerNorm(embed_dims)
118 | return norm_cfg['type'],norm_layer
119 |
120 | class GELU(nn.Module):
121 | r"""Applies the Gaussian Error Linear Units function:
122 |
123 | .. math::
124 | \text{GELU}(x) = x * \Phi(x)
125 | where :math:`\Phi(x)` is the Cumulative Distribution Function for
126 | Gaussian Distribution.
127 |
128 | Shape:
129 | - Input: :math:`(N, *)` where `*` means, any number of additional
130 | dimensions
131 | - Output: :math:`(N, *)`, same shape as the input
132 |
133 | .. image:: scripts/activation_images/GELU.png
134 |
135 | Examples::
136 |
137 | >>> m = nn.GELU()
138 | >>> input = torch.randn(2)
139 | >>> output = m(input)
140 | """
141 |
142 | def forward(self, input):
143 | return F.gelu(input)
144 |
145 | def build_activation_layer(act_cfg):
146 | if act_cfg['type'] == 'ReLU':
147 | act_layer = nn.ReLU(inplace=act_cfg['inplace'])
148 | elif act_cfg['type'] == 'GELU':
149 | act_layer = GELU()
150 | return act_layer
151 |
152 | def build_conv_layer(conv_cfg,
153 | in_channels,
154 | out_channels,
155 | kernel_size,
156 | stride,
157 | padding,
158 | dilation,
159 | bias):
160 | conv_layer = nn.Conv2d(
161 | in_channels=in_channels,
162 | out_channels=out_channels,
163 | kernel_size=kernel_size,
164 | stride=stride,
165 | padding=padding,
166 | dilation=dilation,
167 | bias=bias)
168 | return conv_layer
169 |
170 | def drop_path(x, drop_prob=0., training=False):
171 | """Drop paths (Stochastic Depth) per sample (when applied in main path of
172 | residual blocks).
173 |
174 | We follow the implementation
175 | https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
176 | """
177 | if drop_prob == 0. or not training:
178 | return x
179 | keep_prob = 1 - drop_prob
180 | # handle tensors with different dimensions, not just 4D tensors.
181 | shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
182 | random_tensor = keep_prob + torch.rand(
183 | shape, dtype=x.dtype, device=x.device)
184 | output = x.div(keep_prob) * random_tensor.floor()
185 | return output
186 |
187 | class DropPath(nn.Module):
188 | """Drop paths (Stochastic Depth) per sample (when applied in main path of
189 | residual blocks).
190 |
191 | We follow the implementation
192 | https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
193 |
194 | Args:
195 | drop_prob (float): Probability of the path to be zeroed. Default: 0.1
196 | """
197 |
198 | def __init__(self, drop_prob=0.1):
199 | super(DropPath, self).__init__()
200 | self.drop_prob = drop_prob
201 |
202 | def forward(self, x):
203 | return drop_path(x, self.drop_prob, self.training)
204 |
205 | def build_dropout(drop_cfg):
206 | drop_layer = DropPath(drop_cfg['drop_prob'])
207 | return drop_layer
208 |
209 | class FFN(BaseModule):
210 | def __init__(self,
211 | embed_dims=256,
212 | feedforward_channels=1024,
213 | num_fcs=2,
214 | act_cfg=dict(type='ReLU', inplace=True),
215 | ffn_drop=0.,
216 | dropout_layer=None,
217 | add_identity=True,
218 | init_cfg=None,
219 | **kwargs):
220 | super(FFN, self).__init__()
221 | assert num_fcs >= 2, 'num_fcs should be no less ' \
222 | f'than 2. got {num_fcs}.'
223 | self.embed_dims = embed_dims
224 | self.feedforward_channels = feedforward_channels
225 | self.num_fcs = num_fcs
226 | self.act_cfg = act_cfg
227 | self.activate = build_activation_layer(act_cfg)
228 |
229 | layers = []
230 | in_channels = embed_dims
231 | for _ in range(num_fcs - 1):
232 | layers.append(
233 | Sequential(
234 | Linear(in_channels, feedforward_channels), self.activate,
235 | nn.Dropout(ffn_drop)))
236 | in_channels = feedforward_channels
237 | layers.append(Linear(feedforward_channels, embed_dims))
238 | layers.append(nn.Dropout(ffn_drop))
239 | self.layers = Sequential(*layers)
240 | self.dropout_layer = build_dropout(
241 | dropout_layer) if dropout_layer else torch.nn.Identity()
242 | self.add_identity = add_identity
243 |
244 | def forward(self, x, identity=None):
245 | """Forward function for `FFN`.
246 |
247 | The function would add x to the output tensor if residue is None.
248 | """
249 | out = self.layers(x)
250 | if not self.add_identity:
251 | return self.dropout_layer(out)
252 | if identity is None:
253 | identity = x
254 | return identity + self.dropout_layer(out)
255 |
256 | def swin_converter(ckpt):
257 |
258 | new_ckpt = OrderedDict()
259 |
260 | def correct_unfold_reduction_order(x):
261 | out_channel, in_channel = x.shape
262 | x = x.reshape(out_channel, 4, in_channel // 4)
263 | x = x[:, [0, 2, 1, 3], :].transpose(1,
264 | 2).reshape(out_channel, in_channel)
265 | return x
266 |
267 | def correct_unfold_norm_order(x):
268 | in_channel = x.shape[0]
269 | x = x.reshape(4, in_channel // 4)
270 | x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
271 | return x
272 |
273 | for k, v in ckpt.items():
274 | if k.startswith('head'):
275 | continue
276 | elif k.startswith('layers'):
277 | new_v = v
278 | if 'attn.' in k:
279 | new_k = k.replace('attn.', 'attn.w_msa.')
280 | elif 'mlp.' in k:
281 | if 'mlp.fc1.' in k:
282 | new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.')
283 | elif 'mlp.fc2.' in k:
284 | new_k = k.replace('mlp.fc2.', 'ffn.layers.1.')
285 | else:
286 | new_k = k.replace('mlp.', 'ffn.')
287 | elif 'downsample' in k:
288 | new_k = k
289 | if 'reduction.' in k:
290 | new_v = correct_unfold_reduction_order(v)
291 | elif 'norm.' in k:
292 | new_v = correct_unfold_norm_order(v)
293 | else:
294 | new_k = k
295 | new_k = new_k.replace('layers', 'stages', 1)
296 | elif k.startswith('patch_embed'):
297 | new_v = v
298 | if 'proj' in k:
299 | new_k = k.replace('proj', 'projection')
300 | else:
301 | new_k = k
302 | else:
303 | new_v = v
304 | new_k = k
305 |
306 | new_ckpt['backbone.' + new_k] = new_v
307 |
308 | return new_ckpt
309 |
310 | class AdaptivePadding(nn.Module):
311 | """Applies padding to input (if needed) so that input can get fully covered
312 | by filter you specified. It support two modes "same" and "corner". The
313 | "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around
314 | input. The "corner" mode would pad zero to bottom right.
315 | Args:
316 | kernel_size (int | tuple): Size of the kernel:
317 | stride (int | tuple): Stride of the filter. Default: 1:
318 | dilation (int | tuple): Spacing between kernel elements.
319 | Default: 1
320 | padding (str): Support "same" and "corner", "corner" mode
321 | would pad zero to bottom right, and "same" mode would
322 | pad zero around input. Default: "corner".
323 | Example:
324 | >>> kernel_size = 16
325 | >>> stride = 16
326 | >>> dilation = 1
327 | >>> input = torch.rand(1, 1, 15, 17)
328 | >>> adap_pad = AdaptivePadding(
329 | >>> kernel_size=kernel_size,
330 | >>> stride=stride,
331 | >>> dilation=dilation,
332 | >>> padding="corner")
333 | >>> out = adap_pad(input)
334 | >>> assert (out.shape[2], out.shape[3]) == (16, 32)
335 | >>> input = torch.rand(1, 1, 16, 17)
336 | >>> out = adap_pad(input)
337 | >>> assert (out.shape[2], out.shape[3]) == (16, 32)
338 | """
339 |
340 | def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
341 |
342 | super(AdaptivePadding, self).__init__()
343 |
344 | assert padding in ('same', 'corner')
345 |
346 | kernel_size = to_2tuple(kernel_size)
347 | stride = to_2tuple(stride)
348 | padding = to_2tuple(padding)
349 | dilation = to_2tuple(dilation)
350 |
351 | self.padding = padding
352 | self.kernel_size = kernel_size
353 | self.stride = stride
354 | self.dilation = dilation
355 |
356 | def get_pad_shape(self, input_shape):
357 | input_h, input_w = input_shape
358 | kernel_h, kernel_w = self.kernel_size
359 | stride_h, stride_w = self.stride
360 | output_h = math.ceil(input_h / stride_h)
361 | output_w = math.ceil(input_w / stride_w)
362 | pad_h = max((output_h - 1) * stride_h +
363 | (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
364 | pad_w = max((output_w - 1) * stride_w +
365 | (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
366 | return pad_h, pad_w
367 |
368 | def forward(self, x):
369 | pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
370 | if pad_h > 0 or pad_w > 0:
371 | if self.padding == 'corner':
372 | x = F.pad(x, [0, pad_w, 0, pad_h])
373 | elif self.padding == 'same':
374 | x = F.pad(x, [
375 | pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
376 | pad_h - pad_h // 2
377 | ])
378 | return x
379 |
380 | class PatchEmbed(BaseModule):
381 | """Image to Patch Embedding.
382 | We use a conv layer to implement PatchEmbed.
383 | Args:
384 | in_channels (int): The num of input channels. Default: 3
385 | embed_dims (int): The dimensions of embedding. Default: 768
386 | conv_type (str): The config dict for embedding
387 | conv layer type selection. Default: "Conv2d.
388 | kernel_size (int): The kernel_size of embedding conv. Default: 16.
389 | stride (int): The slide stride of embedding conv.
390 | Default: None (Would be set as `kernel_size`).
391 | padding (int | tuple | string ): The padding length of
392 | embedding conv. When it is a string, it means the mode
393 | of adaptive padding, support "same" and "corner" now.
394 | Default: "corner".
395 | dilation (int): The dilation rate of embedding conv. Default: 1.
396 | bias (bool): Bias of embed conv. Default: True.
397 | norm_cfg (dict, optional): Config dict for normalization layer.
398 | Default: None.
399 | input_size (int | tuple | None): The size of input, which will be
400 | used to calculate the out size. Only work when `dynamic_size`
401 | is False. Default: None.
402 | init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
403 | Default: None.
404 | """
405 |
406 | def __init__(
407 | self,
408 | in_channels=3,
409 | embed_dims=768,
410 | conv_type='Conv2d',
411 | kernel_size=16,
412 | stride=16,
413 | padding='corner',
414 | dilation=1,
415 | bias=True,
416 | norm_cfg=None,
417 | input_size=None,
418 | init_cfg=None,
419 | ):
420 | super(PatchEmbed, self).__init__()
421 |
422 | self.embed_dims = embed_dims
423 | if stride is None:
424 | stride = kernel_size
425 |
426 | kernel_size = to_2tuple(kernel_size)
427 | stride = to_2tuple(stride)
428 | dilation = to_2tuple(dilation)
429 |
430 | if isinstance(padding, str):
431 | self.adap_padding = AdaptivePadding(
432 | kernel_size=kernel_size,
433 | stride=stride,
434 | dilation=dilation,
435 | padding=padding)
436 | # disable the padding of conv
437 | padding = 0
438 | else:
439 | self.adap_padding = None
440 | padding = to_2tuple(padding)
441 |
442 | self.projection = build_conv_layer(
443 | dict(type=conv_type),
444 | in_channels=in_channels,
445 | out_channels=embed_dims,
446 | kernel_size=kernel_size,
447 | stride=stride,
448 | padding=padding,
449 | dilation=dilation,
450 | bias=bias)
451 |
452 | if norm_cfg is not None:
453 | self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
454 | else:
455 | self.norm = None
456 |
457 | if input_size:
458 | input_size = to_2tuple(input_size)
459 | # `init_out_size` would be used outside to
460 | # calculate the num_patches
461 | # when `use_abs_pos_embed` outside
462 | self.init_input_size = input_size
463 | if self.adap_padding:
464 | pad_h, pad_w = self.adap_padding.get_pad_shape(input_size)
465 | input_h, input_w = input_size
466 | input_h = input_h + pad_h
467 | input_w = input_w + pad_w
468 | input_size = (input_h, input_w)
469 |
470 | # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
471 | h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
472 | (kernel_size[0] - 1) - 1) // stride[0] + 1
473 | w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
474 | (kernel_size[1] - 1) - 1) // stride[1] + 1
475 | self.init_out_size = (h_out, w_out)
476 | else:
477 | self.init_input_size = None
478 | self.init_out_size = None
479 |
480 | def forward(self, x):
481 | """
482 | Args:
483 | x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
484 | Returns:
485 | tuple: Contains merged results and its spatial shape.
486 | - x (Tensor): Has shape (B, out_h * out_w, embed_dims)
487 | - out_size (tuple[int]): Spatial shape of x, arrange as
488 | (out_h, out_w).
489 | """
490 |
491 | if self.adap_padding:
492 | x = self.adap_padding(x)
493 |
494 | x = self.projection(x)
495 | out_size = (x.shape[2], x.shape[3])
496 | x = x.flatten(2).transpose(1, 2)
497 | if self.norm is not None:
498 | x = self.norm(x)
499 | return x, out_size
500 |
501 |
502 | class PatchMerging(BaseModule):
503 | """Merge patch feature map.
504 | This layer groups feature map by kernel_size, and applies norm and linear
505 | layers to the grouped feature map. Our implementation uses `nn.Unfold` to
506 | merge patch, which is about 25% faster than original implementation.
507 | Instead, we need to modify pretrained models for compatibility.
508 | Args:
509 | in_channels (int): The num of input channels.
510 | to gets fully covered by filter and stride you specified..
511 | Default: True.
512 | out_channels (int): The num of output channels.
513 | kernel_size (int | tuple, optional): the kernel size in the unfold
514 | layer. Defaults to 2.
515 | stride (int | tuple, optional): the stride of the sliding blocks in the
516 | unfold layer. Default: None. (Would be set as `kernel_size`)
517 | padding (int | tuple | string ): The padding length of
518 | embedding conv. When it is a string, it means the mode
519 | of adaptive padding, support "same" and "corner" now.
520 | Default: "corner".
521 | dilation (int | tuple, optional): dilation parameter in the unfold
522 | layer. Default: 1.
523 | bias (bool, optional): Whether to add bias in linear layer or not.
524 | Defaults: False.
525 | norm_cfg (dict, optional): Config dict for normalization layer.
526 | Default: dict(type='LN').
527 | init_cfg (dict, optional): The extra config for initialization.
528 | Default: None.
529 | """
530 |
531 | def __init__(self,
532 | in_channels,
533 | out_channels,
534 | kernel_size=2,
535 | stride=None,
536 | padding='corner',
537 | dilation=1,
538 | bias=False,
539 | norm_cfg=dict(type='LN'),
540 | init_cfg=None):
541 | super().__init__()
542 | self.in_channels = in_channels
543 | self.out_channels = out_channels
544 | if stride:
545 | stride = stride
546 | else:
547 | stride = kernel_size
548 |
549 | kernel_size = to_2tuple(kernel_size)
550 | stride = to_2tuple(stride)
551 | dilation = to_2tuple(dilation)
552 |
553 | if isinstance(padding, str):
554 | self.adap_padding = AdaptivePadding(
555 | kernel_size=kernel_size,
556 | stride=stride,
557 | dilation=dilation,
558 | padding=padding)
559 | # disable the padding of unfold
560 | padding = 0
561 | else:
562 | self.adap_padding = None
563 |
564 | padding = to_2tuple(padding)
565 | self.sampler = nn.Unfold(
566 | kernel_size=kernel_size,
567 | dilation=dilation,
568 | padding=padding,
569 | stride=stride)
570 |
571 | sample_dim = kernel_size[0] * kernel_size[1] * in_channels
572 |
573 | if norm_cfg is not None:
574 | self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
575 | else:
576 | self.norm = None
577 |
578 | self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
579 |
580 | def forward(self, x, input_size):
581 | """
582 | Args:
583 | x (Tensor): Has shape (B, H*W, C_in).
584 | input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
585 | Default: None.
586 | Returns:
587 | tuple: Contains merged results and its spatial shape.
588 | - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
589 | - out_size (tuple[int]): Spatial shape of x, arrange as
590 | (Merged_H, Merged_W).
591 | """
592 | B, L, C = x.shape
593 | assert isinstance(input_size, Sequence), f'Expect ' \
594 | f'input_size is ' \
595 | f'`Sequence` ' \
596 | f'but get {input_size}'
597 |
598 | H, W = input_size
599 | assert L == H * W, 'input feature has wrong size'
600 |
601 | x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
602 | # Use nn.Unfold to merge patch. About 25% faster than original method,
603 | # but need to modify pretrained model for compatibility
604 |
605 | if self.adap_padding:
606 | x = self.adap_padding(x)
607 | H, W = x.shape[-2:]
608 |
609 | x = self.sampler(x)
610 | # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
611 |
612 | out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
613 | (self.sampler.kernel_size[0] - 1) -
614 | 1) // self.sampler.stride[0] + 1
615 | out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
616 | (self.sampler.kernel_size[1] - 1) -
617 | 1) // self.sampler.stride[1] + 1
618 |
619 | output_size = (out_h, out_w)
620 | x = x.transpose(1, 2) # B, H/2*W/2, 4*C
621 | x = self.norm(x) if self.norm else x
622 | x = self.reduction(x)
623 | return x, output_size
624 |
625 | class WindowMSA(BaseModule):
626 | """Window based multi-head self-attention (W-MSA) module with relative
627 | position bias.
628 | Args:
629 | embed_dims (int): Number of input channels.
630 | num_heads (int): Number of attention heads.
631 | window_size (tuple[int]): The height and width of the window.
632 | qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
633 | Default: True.
634 | qk_scale (float | None, optional): Override default qk scale of
635 | head_dim ** -0.5 if set. Default: None.
636 | attn_drop_rate (float, optional): Dropout ratio of attention weight.
637 | Default: 0.0
638 | proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
639 | init_cfg (dict | None, optional): The Config for initialization.
640 | Default: None.
641 | """
642 |
643 | def __init__(self,
644 | embed_dims,
645 | num_heads,
646 | window_size,
647 | qkv_bias=True,
648 | qk_scale=None,
649 | attn_drop_rate=0.,
650 | proj_drop_rate=0.,
651 | init_cfg=None):
652 |
653 | super().__init__()
654 | self.embed_dims = embed_dims
655 | self.window_size = window_size # Wh, Ww
656 | self.num_heads = num_heads
657 | head_embed_dims = embed_dims // num_heads
658 | self.scale = qk_scale or head_embed_dims**-0.5
659 | self.init_cfg = init_cfg
660 |
661 | # define a parameter table of relative position bias
662 | self.relative_position_bias_table = nn.Parameter(
663 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
664 | num_heads)) # 2*Wh-1 * 2*Ww-1, nH
665 |
666 | # About 2x faster than original impl
667 | Wh, Ww = self.window_size
668 | rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
669 | rel_position_index = rel_index_coords + rel_index_coords.T
670 | rel_position_index = rel_position_index.flip(1).contiguous()
671 | self.register_buffer('relative_position_index', rel_position_index)
672 |
673 | self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
674 | self.attn_drop = nn.Dropout(attn_drop_rate)
675 | self.proj = nn.Linear(embed_dims, embed_dims)
676 | self.proj_drop = nn.Dropout(proj_drop_rate)
677 |
678 | self.softmax = nn.Softmax(dim=-1)
679 |
680 | def init_weights(self):
681 | trunc_normal_(self.relative_position_bias_table, std=0.02)
682 |
683 | def forward(self, x, mask=None):
684 | """
685 | Args:
686 | x (tensor): input features with shape of (num_windows*B, N, C)
687 | mask (tensor | None, Optional): mask with shape of (num_windows,
688 | Wh*Ww, Wh*Ww), value should be between (-inf, 0].
689 | """
690 | B, N, C = x.shape
691 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
692 | C // self.num_heads).permute(2, 0, 3, 1, 4)
693 | # make torchscript happy (cannot use tensor as tuple)
694 | q, k, v = qkv[0], qkv[1], qkv[2]
695 |
696 | q = q * self.scale
697 | attn = (q @ k.transpose(-2, -1))
698 |
699 | relative_position_bias = self.relative_position_bias_table[
700 | self.relative_position_index.view(-1)].view(
701 | self.window_size[0] * self.window_size[1],
702 | self.window_size[0] * self.window_size[1],
703 | -1) # Wh*Ww,Wh*Ww,nH
704 | relative_position_bias = relative_position_bias.permute(
705 | 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
706 | attn = attn + relative_position_bias.unsqueeze(0)
707 |
708 | if mask is not None:
709 | nW = mask.shape[0]
710 | attn = attn.view(B // nW, nW, self.num_heads, N,
711 | N) + mask.unsqueeze(1).unsqueeze(0)
712 | attn = attn.view(-1, self.num_heads, N, N)
713 | attn = self.softmax(attn)
714 |
715 | attn = self.attn_drop(attn)
716 |
717 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
718 | x = self.proj(x)
719 | x = self.proj_drop(x)
720 | return x
721 |
722 | @staticmethod
723 | def double_step_seq(step1, len1, step2, len2):
724 | seq1 = torch.arange(0, step1 * len1, step1)
725 | seq2 = torch.arange(0, step2 * len2, step2)
726 | return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
727 |
728 |
729 | class ShiftWindowMSA(BaseModule):
730 | """Shifted Window Multihead Self-Attention Module.
731 | Args:
732 | embed_dims (int): Number of input channels.
733 | num_heads (int): Number of attention heads.
734 | window_size (int): The height and width of the window.
735 | shift_size (int, optional): The shift step of each window towards
736 | right-bottom. If zero, act as regular window-msa. Defaults to 0.
737 | qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
738 | Default: True
739 | qk_scale (float | None, optional): Override default qk scale of
740 | head_dim ** -0.5 if set. Defaults: None.
741 | attn_drop_rate (float, optional): Dropout ratio of attention weight.
742 | Defaults: 0.
743 | proj_drop_rate (float, optional): Dropout ratio of output.
744 | Defaults: 0.
745 | dropout_layer (dict, optional): The dropout_layer used before output.
746 | Defaults: dict(type='DropPath', drop_prob=0.).
747 | init_cfg (dict, optional): The extra config for initialization.
748 | Default: None.
749 | """
750 |
751 | def __init__(self,
752 | embed_dims,
753 | num_heads,
754 | window_size,
755 | shift_size=0,
756 | qkv_bias=True,
757 | qk_scale=None,
758 | attn_drop_rate=0,
759 | proj_drop_rate=0,
760 | dropout_layer=dict(type='DropPath', drop_prob=0.),
761 | init_cfg=None):
762 | super().__init__()
763 |
764 | self.window_size = window_size
765 | self.shift_size = shift_size
766 | assert 0 <= self.shift_size < self.window_size
767 |
768 | self.w_msa = WindowMSA(
769 | embed_dims=embed_dims,
770 | num_heads=num_heads,
771 | window_size=to_2tuple(window_size),
772 | qkv_bias=qkv_bias,
773 | qk_scale=qk_scale,
774 | attn_drop_rate=attn_drop_rate,
775 | proj_drop_rate=proj_drop_rate,
776 | init_cfg=None)
777 |
778 | self.drop = build_dropout(dropout_layer)
779 |
780 | def forward(self, query, hw_shape):
781 | B, L, C = query.shape
782 | H, W = hw_shape
783 | assert L == H * W, 'input feature has wrong size'
784 | query = query.view(B, H, W, C)
785 |
786 | # pad feature maps to multiples of window size
787 | pad_r = (self.window_size - W % self.window_size) % self.window_size
788 | pad_b = (self.window_size - H % self.window_size) % self.window_size
789 | query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
790 | H_pad, W_pad = query.shape[1], query.shape[2]
791 |
792 | # cyclic shift
793 | if self.shift_size > 0:
794 | shifted_query = torch.roll(
795 | query,
796 | shifts=(-self.shift_size, -self.shift_size),
797 | dims=(1, 2))
798 |
799 | # calculate attention mask for SW-MSA
800 | img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device)
801 | h_slices = (slice(0, -self.window_size),
802 | slice(-self.window_size,
803 | -self.shift_size), slice(-self.shift_size, None))
804 | w_slices = (slice(0, -self.window_size),
805 | slice(-self.window_size,
806 | -self.shift_size), slice(-self.shift_size, None))
807 | cnt = 0
808 | for h in h_slices:
809 | for w in w_slices:
810 | img_mask[:, h, w, :] = cnt
811 | cnt += 1
812 |
813 | # nW, window_size, window_size, 1
814 | mask_windows = self.window_partition(img_mask)
815 | mask_windows = mask_windows.view(
816 | -1, self.window_size * self.window_size)
817 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
818 | attn_mask = attn_mask.masked_fill(attn_mask != 0,
819 | float(-100.0)).masked_fill(
820 | attn_mask == 0, float(0.0))
821 | else:
822 | shifted_query = query
823 | attn_mask = None
824 |
825 | # nW*B, window_size, window_size, C
826 | query_windows = self.window_partition(shifted_query)
827 | # nW*B, window_size*window_size, C
828 | query_windows = query_windows.view(-1, self.window_size**2, C)
829 |
830 | # W-MSA/SW-MSA (nW*B, window_size*window_size, C)
831 | attn_windows = self.w_msa(query_windows, mask=attn_mask)
832 |
833 | # merge windows
834 | attn_windows = attn_windows.view(-1, self.window_size,
835 | self.window_size, C)
836 |
837 | # B H' W' C
838 | shifted_x = self.window_reverse(attn_windows, H_pad, W_pad)
839 | # reverse cyclic shift
840 | if self.shift_size > 0:
841 | x = torch.roll(
842 | shifted_x,
843 | shifts=(self.shift_size, self.shift_size),
844 | dims=(1, 2))
845 | else:
846 | x = shifted_x
847 |
848 | if pad_r > 0 or pad_b:
849 | x = x[:, :H, :W, :].contiguous()
850 |
851 | x = x.view(B, H * W, C)
852 |
853 | x = self.drop(x)
854 | return x
855 |
856 | def window_reverse(self, windows, H, W):
857 | """
858 | Args:
859 | windows: (num_windows*B, window_size, window_size, C)
860 | H (int): Height of image
861 | W (int): Width of image
862 | Returns:
863 | x: (B, H, W, C)
864 | """
865 | window_size = self.window_size
866 | B = int(windows.shape[0] / (H * W / window_size / window_size))
867 | x = windows.view(B, H // window_size, W // window_size, window_size,
868 | window_size, -1)
869 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
870 | return x
871 |
872 | def window_partition(self, x):
873 | """
874 | Args:
875 | x: (B, H, W, C)
876 | Returns:
877 | windows: (num_windows*B, window_size, window_size, C)
878 | """
879 | B, H, W, C = x.shape
880 | window_size = self.window_size
881 | x = x.view(B, H // window_size, window_size, W // window_size,
882 | window_size, C)
883 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
884 | windows = windows.view(-1, window_size, window_size, C)
885 | return windows
886 |
887 |
888 | class SwinBlock(BaseModule):
889 | """"
890 | Args:
891 | embed_dims (int): The feature dimension.
892 | num_heads (int): Parallel attention heads.
893 | feedforward_channels (int): The hidden dimension for FFNs.
894 | window_size (int, optional): The local window scale. Default: 7.
895 | shift (bool, optional): whether to shift window or not. Default False.
896 | qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
897 | qk_scale (float | None, optional): Override default qk scale of
898 | head_dim ** -0.5 if set. Default: None.
899 | drop_rate (float, optional): Dropout rate. Default: 0.
900 | attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
901 | drop_path_rate (float, optional): Stochastic depth rate. Default: 0.
902 | act_cfg (dict, optional): The config dict of activation function.
903 | Default: dict(type='GELU').
904 | norm_cfg (dict, optional): The config dict of normalization.
905 | Default: dict(type='LN').
906 | with_cp (bool, optional): Use checkpoint or not. Using checkpoint
907 | will save some memory while slowing down the training speed.
908 | Default: False.
909 | init_cfg (dict | list | None, optional): The init config.
910 | Default: None.
911 | """
912 |
913 | def __init__(self,
914 | embed_dims,
915 | num_heads,
916 | feedforward_channels,
917 | window_size=7,
918 | shift=False,
919 | qkv_bias=True,
920 | qk_scale=None,
921 | drop_rate=0.,
922 | attn_drop_rate=0.,
923 | drop_path_rate=0.,
924 | act_cfg=dict(type='GELU'),
925 | norm_cfg=dict(type='LN'),
926 | with_cp=False,
927 | init_cfg=None):
928 |
929 | super(SwinBlock, self).__init__()
930 |
931 | self.init_cfg = init_cfg
932 | self.with_cp = with_cp
933 |
934 | self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
935 | self.attn = ShiftWindowMSA(
936 | embed_dims=embed_dims,
937 | num_heads=num_heads,
938 | window_size=window_size,
939 | shift_size=window_size // 2 if shift else 0,
940 | qkv_bias=qkv_bias,
941 | qk_scale=qk_scale,
942 | attn_drop_rate=attn_drop_rate,
943 | proj_drop_rate=drop_rate,
944 | dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
945 | init_cfg=None)
946 |
947 | self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
948 | self.ffn = FFN(
949 | embed_dims=embed_dims,
950 | feedforward_channels=feedforward_channels,
951 | num_fcs=2,
952 | ffn_drop=drop_rate,
953 | dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
954 | act_cfg=act_cfg,
955 | add_identity=True,
956 | init_cfg=None)
957 |
958 | def forward(self, x, hw_shape):
959 |
960 | def _inner_forward(x):
961 | identity = x
962 | x = self.norm1(x)
963 | x = self.attn(x, hw_shape)
964 |
965 | x = x + identity
966 |
967 | identity = x
968 | x = self.norm2(x)
969 | x = self.ffn(x, identity=identity)
970 |
971 | return x
972 |
973 | if self.with_cp and x.requires_grad:
974 | x = cp.checkpoint(_inner_forward, x)
975 | else:
976 | x = _inner_forward(x)
977 |
978 | return x
979 |
980 |
981 | class SwinBlockSequence(BaseModule):
982 | """Implements one stage in Swin Transformer.
983 | Args:
984 | embed_dims (int): The feature dimension.
985 | num_heads (int): Parallel attention heads.
986 | feedforward_channels (int): The hidden dimension for FFNs.
987 | depth (int): The number of blocks in this stage.
988 | window_size (int, optional): The local window scale. Default: 7.
989 | qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
990 | qk_scale (float | None, optional): Override default qk scale of
991 | head_dim ** -0.5 if set. Default: None.
992 | drop_rate (float, optional): Dropout rate. Default: 0.
993 | attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
994 | drop_path_rate (float | list[float], optional): Stochastic depth
995 | rate. Default: 0.
996 | downsample (BaseModule | None, optional): The downsample operation
997 | module. Default: None.
998 | act_cfg (dict, optional): The config dict of activation function.
999 | Default: dict(type='GELU').
1000 | norm_cfg (dict, optional): The config dict of normalization.
1001 | Default: dict(type='LN').
1002 | with_cp (bool, optional): Use checkpoint or not. Using checkpoint
1003 | will save some memory while slowing down the training speed.
1004 | Default: False.
1005 | init_cfg (dict | list | None, optional): The init config.
1006 | Default: None.
1007 | """
1008 |
1009 | def __init__(self,
1010 | embed_dims,
1011 | num_heads,
1012 | feedforward_channels,
1013 | depth,
1014 | window_size=7,
1015 | qkv_bias=True,
1016 | qk_scale=None,
1017 | drop_rate=0.,
1018 | attn_drop_rate=0.,
1019 | drop_path_rate=0.,
1020 | downsample=None,
1021 | act_cfg=dict(type='GELU'),
1022 | norm_cfg=dict(type='LN'),
1023 | with_cp=False,
1024 | init_cfg=None):
1025 | super().__init__()
1026 |
1027 | if isinstance(drop_path_rate, list):
1028 | drop_path_rates = drop_path_rate
1029 | assert len(drop_path_rates) == depth
1030 | else:
1031 | drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)]
1032 |
1033 | self.blocks = ModuleList()
1034 | for i in range(depth):
1035 | block = SwinBlock(
1036 | embed_dims=embed_dims,
1037 | num_heads=num_heads,
1038 | feedforward_channels=feedforward_channels,
1039 | window_size=window_size,
1040 | shift=False if i % 2 == 0 else True,
1041 | qkv_bias=qkv_bias,
1042 | qk_scale=qk_scale,
1043 | drop_rate=drop_rate,
1044 | attn_drop_rate=attn_drop_rate,
1045 | drop_path_rate=drop_path_rates[i],
1046 | act_cfg=act_cfg,
1047 | norm_cfg=norm_cfg,
1048 | with_cp=with_cp,
1049 | init_cfg=None)
1050 | self.blocks.append(block)
1051 |
1052 | self.downsample = downsample
1053 |
1054 | def forward(self, x, hw_shape):
1055 | for block in self.blocks:
1056 | x = block(x, hw_shape)
1057 |
1058 | if self.downsample:
1059 | x_down, down_hw_shape = self.downsample(x, hw_shape)
1060 | return x_down, down_hw_shape, x, hw_shape
1061 | else:
1062 | return x, hw_shape, x, hw_shape
1063 |
1064 | class SwinTransformer(BaseModule):
1065 | """ Swin Transformer
1066 | A PyTorch implement of : `Swin Transformer:
1067 | Hierarchical Vision Transformer using Shifted Windows` -
1068 | https://arxiv.org/abs/2103.14030
1069 | Inspiration from
1070 | https://github.com/microsoft/Swin-Transformer
1071 | Args:
1072 | pretrain_img_size (int | tuple[int]): The size of input image when
1073 | pretrain. Defaults: 224.
1074 | in_channels (int): The num of input channels.
1075 | Defaults: 3.
1076 | embed_dims (int): The feature dimension. Default: 96.
1077 | patch_size (int | tuple[int]): Patch size. Default: 4.
1078 | window_size (int): Window size. Default: 7.
1079 | mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
1080 | Default: 4.
1081 | depths (tuple[int]): Depths of each Swin Transformer stage.
1082 | Default: (2, 2, 6, 2).
1083 | num_heads (tuple[int]): Parallel attention heads of each Swin
1084 | Transformer stage. Default: (3, 6, 12, 24).
1085 | strides (tuple[int]): The patch merging or patch embedding stride of
1086 | each Swin Transformer stage. (In swin, we set kernel size equal to
1087 | stride.) Default: (4, 2, 2, 2).
1088 | out_indices (tuple[int]): Output from which stages.
1089 | Default: (0, 1, 2, 3).
1090 | qkv_bias (bool, optional): If True, add a learnable bias to query, key,
1091 | value. Default: True
1092 | qk_scale (float | None, optional): Override default qk scale of
1093 | head_dim ** -0.5 if set. Default: None.
1094 | patch_norm (bool): If add a norm layer for patch embed and patch
1095 | merging. Default: True.
1096 | drop_rate (float): Dropout rate. Defaults: 0.
1097 | attn_drop_rate (float): Attention dropout rate. Default: 0.
1098 | drop_path_rate (float): Stochastic depth rate. Defaults: 0.1.
1099 | use_abs_pos_embed (bool): If True, add absolute position embedding to
1100 | the patch embedding. Defaults: False.
1101 | act_cfg (dict): Config dict for activation layer.
1102 | Default: dict(type='LN').
1103 | norm_cfg (dict): Config dict for normalization layer at
1104 | output of backone. Defaults: dict(type='LN').
1105 | with_cp (bool, optional): Use checkpoint or not. Using checkpoint
1106 | will save some memory while slowing down the training speed.
1107 | Default: False.
1108 | pretrained (str, optional): model pretrained path. Default: None.
1109 | convert_weights (bool): The flag indicates whether the
1110 | pre-trained model is from the original repo. We may need
1111 | to convert some keys to make it compatible.
1112 | Default: False.
1113 | frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
1114 | -1 means not freezing any parameters.
1115 | init_cfg (dict, optional): The Config for initialization.
1116 | Defaults to None.
1117 | """
1118 |
1119 | def __init__(self,
1120 | pretrain_img_size=224,
1121 | in_channels=3,
1122 | embed_dims=96,
1123 | patch_size=4,
1124 | window_size=7,
1125 | mlp_ratio=4,
1126 | depths=(2, 2, 6, 2),
1127 | num_heads=(3, 6, 12, 24),
1128 | strides=(4, 2, 2, 2),
1129 | out_indices=(0, 1, 2, 3),
1130 | qkv_bias=True,
1131 | qk_scale=None,
1132 | patch_norm=True,
1133 | drop_rate=0.,
1134 | attn_drop_rate=0.,
1135 | drop_path_rate=0.1,
1136 | use_abs_pos_embed=False,
1137 | act_cfg=dict(type='GELU'),
1138 | norm_cfg=dict(type='LN'),
1139 | with_cp=False,
1140 | pretrained=None,
1141 | convert_weights=False,
1142 | frozen_stages=-1,
1143 | init_cfg=None,
1144 | semantic_weight=0.0):
1145 | self.convert_weights = convert_weights
1146 | self.frozen_stages = frozen_stages
1147 | if isinstance(pretrain_img_size, int):
1148 | pretrain_img_size = to_2tuple(pretrain_img_size)
1149 | elif isinstance(pretrain_img_size, tuple):
1150 | if len(pretrain_img_size) == 1:
1151 | pretrain_img_size = to_2tuple(pretrain_img_size[0])
1152 | assert len(pretrain_img_size) == 2, \
1153 | f'The size of image should have length 1 or 2, ' \
1154 | f'but got {len(pretrain_img_size)}'
1155 |
1156 | assert not (init_cfg and pretrained), \
1157 | 'init_cfg and pretrained cannot be specified at the same time'
1158 | if isinstance(pretrained, str):
1159 | warnings.warn('DeprecationWarning: pretrained is deprecated, '
1160 | 'please use "init_cfg" instead')
1161 | self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
1162 | elif pretrained is None:
1163 | self.init_cfg = init_cfg
1164 | else:
1165 | raise TypeError('pretrained must be a str or None')
1166 |
1167 | super(SwinTransformer, self).__init__()
1168 |
1169 | num_layers = len(depths)
1170 | self.out_indices = out_indices
1171 | self.use_abs_pos_embed = use_abs_pos_embed
1172 |
1173 | assert strides[0] == patch_size, 'Use non-overlapping patch embed.'
1174 |
1175 | self.patch_embed = PatchEmbed(
1176 | in_channels=in_channels,
1177 | embed_dims=embed_dims,
1178 | conv_type='Conv2d',
1179 | kernel_size=patch_size,
1180 | stride=strides[0],
1181 | norm_cfg=norm_cfg if patch_norm else None,
1182 | init_cfg=None)
1183 |
1184 | if self.use_abs_pos_embed:
1185 | patch_row = pretrain_img_size[0] // patch_size
1186 | patch_col = pretrain_img_size[1] // patch_size
1187 | num_patches = patch_row * patch_col
1188 | self.absolute_pos_embed = nn.Parameter(
1189 | torch.zeros((1, num_patches, embed_dims)))
1190 |
1191 | self.drop_after_pos = nn.Dropout(p=drop_rate)
1192 |
1193 | # set stochastic depth decay rule
1194 | total_depth = sum(depths)
1195 | dpr = [
1196 | x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
1197 | ]
1198 |
1199 | self.stages = ModuleList()
1200 | in_channels = embed_dims
1201 | for i in range(num_layers):
1202 | if i < num_layers - 1:
1203 | downsample = PatchMerging(
1204 | in_channels=in_channels,
1205 | out_channels=2 * in_channels,
1206 | stride=strides[i + 1],
1207 | norm_cfg=norm_cfg if patch_norm else None,
1208 | init_cfg=None)
1209 | else:
1210 | downsample = None
1211 |
1212 | stage = SwinBlockSequence(
1213 | embed_dims=in_channels,
1214 | num_heads=num_heads[i],
1215 | feedforward_channels=mlp_ratio * in_channels,
1216 | depth=depths[i],
1217 | window_size=window_size,
1218 | qkv_bias=qkv_bias,
1219 | qk_scale=qk_scale,
1220 | drop_rate=drop_rate,
1221 | attn_drop_rate=attn_drop_rate,
1222 | drop_path_rate=dpr[sum(depths[:i]):sum(depths[:i + 1])],
1223 | downsample=downsample,
1224 | act_cfg=act_cfg,
1225 | norm_cfg=norm_cfg,
1226 | with_cp=with_cp,
1227 | init_cfg=None)
1228 | self.stages.append(stage)
1229 | if downsample:
1230 | in_channels = downsample.out_channels
1231 |
1232 | self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)]
1233 | # Add a norm layer for each output
1234 | for i in out_indices:
1235 | layer = build_norm_layer(norm_cfg, self.num_features[i])[1]
1236 | layer_name = f'norm{i}'
1237 | self.add_module(layer_name, layer)
1238 |
1239 | self.avgpool = nn.AdaptiveAvgPool2d((1,1))
1240 |
1241 | # semantic embedding
1242 | self.semantic_weight = semantic_weight
1243 | if self.semantic_weight >= 0:
1244 | self.semantic_embed_w = ModuleList()
1245 | self.semantic_embed_b = ModuleList()
1246 | for i in range(len(depths)):
1247 | if i >= len(depths) - 1:
1248 | i = len(depths) - 2
1249 | semantic_embed_w = nn.Linear(2, self.num_features[i+1])
1250 | semantic_embed_b = nn.Linear(2, self.num_features[i+1])
1251 | trunc_normal_init(semantic_embed_w, std=.02, bias=0.)
1252 | trunc_normal_init(semantic_embed_b, std=.02, bias=0.)
1253 | self.semantic_embed_w.append(semantic_embed_w)
1254 | self.semantic_embed_b.append(semantic_embed_b)
1255 | self.softplus = nn.Softplus()
1256 |
1257 | def train(self, mode=True):
1258 | """Convert the model into training mode while keep layers freezed."""
1259 | super(SwinTransformer, self).train(mode)
1260 | self._freeze_stages()
1261 |
1262 | def _freeze_stages(self):
1263 | if self.frozen_stages >= 0:
1264 | self.patch_embed.eval()
1265 | for param in self.patch_embed.parameters():
1266 | param.requires_grad = False
1267 | if self.use_abs_pos_embed:
1268 | self.absolute_pos_embed.requires_grad = False
1269 | self.drop_after_pos.eval()
1270 |
1271 | for i in range(1, self.frozen_stages + 1):
1272 |
1273 | if (i - 1) in self.out_indices:
1274 | norm_layer = getattr(self, f'norm{i-1}')
1275 | norm_layer.eval()
1276 | for param in norm_layer.parameters():
1277 | param.requires_grad = False
1278 |
1279 | m = self.stages[i - 1]
1280 | m.eval()
1281 | for param in m.parameters():
1282 | param.requires_grad = False
1283 |
1284 | def init_weights(self, pretrained=None):
1285 | logger = logging.getLogger("loading parameters.")
1286 | if pretrained is None:
1287 | logger.warn(f'No pre-trained weights for '
1288 | f'{self.__class__.__name__}, '
1289 | f'training start from scratch')
1290 | if self.use_abs_pos_embed:
1291 | trunc_normal_(self.absolute_pos_embed, std=0.02)
1292 | for m in self.modules():
1293 | if isinstance(m, nn.Linear):
1294 | trunc_normal_init(m, std=.02, bias=0.)
1295 | elif isinstance(m, nn.LayerNorm):
1296 | constant_init(m.bias, 0)
1297 | constant_init(m.weight, 1.0)
1298 | else:
1299 | ckpt = torch.load(pretrained,map_location='cpu')
1300 | if 'teacher' in ckpt:
1301 | ckpt = ckpt['teacher']
1302 |
1303 | if 'state_dict' in ckpt:
1304 | _state_dict = ckpt['state_dict']
1305 | elif 'model' in ckpt:
1306 | _state_dict = ckpt['model']
1307 | else:
1308 | _state_dict = ckpt
1309 | if self.convert_weights:
1310 | # supported loading weight from original repo,
1311 | _state_dict = swin_converter(_state_dict)
1312 |
1313 | state_dict = OrderedDict()
1314 | for k, v in _state_dict.items():
1315 | if k.startswith('backbone.'):
1316 | state_dict[k[9:]] = v
1317 |
1318 | # strip prefix of state_dict
1319 | if list(state_dict.keys())[0].startswith('module.'):
1320 | state_dict = {k[7:]: v for k, v in state_dict.items()}
1321 |
1322 | # reshape absolute position embedding
1323 | if state_dict.get('absolute_pos_embed') is not None:
1324 | absolute_pos_embed = state_dict['absolute_pos_embed']
1325 | N1, L, C1 = absolute_pos_embed.size()
1326 | N2, C2, H, W = self.absolute_pos_embed.size()
1327 | if N1 != N2 or C1 != C2 or L != H * W:
1328 | logger.warning('Error in loading absolute_pos_embed, pass')
1329 | else:
1330 | state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
1331 | N2, H, W, C2).permute(0, 3, 1, 2).contiguous()
1332 |
1333 | # interpolate position bias table if needed
1334 | relative_position_bias_table_keys = [
1335 | k for k in state_dict.keys()
1336 | if 'relative_position_bias_table' in k
1337 | ]
1338 | for table_key in relative_position_bias_table_keys:
1339 | table_pretrained = state_dict[table_key]
1340 | table_current = self.state_dict()[table_key]
1341 | L1, nH1 = table_pretrained.size()
1342 | L2, nH2 = table_current.size()
1343 | if nH1 != nH2:
1344 | logger.warning(f'Error in loading {table_key}, pass')
1345 | elif L1 != L2:
1346 | S1 = int(L1**0.5)
1347 | S2 = int(L2**0.5)
1348 | table_pretrained_resized = F.interpolate(
1349 | table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),
1350 | size=(S2, S2),
1351 | mode='bicubic')
1352 | state_dict[table_key] = table_pretrained_resized.view(
1353 | nH2, L2).permute(1, 0).contiguous()
1354 |
1355 | res = self.load_state_dict(state_dict, False)
1356 | print('unloaded parameters:', res)
1357 |
1358 | def forward(self, x, semantic_weight=None):
1359 | if self.semantic_weight >= 0 and semantic_weight == None:
1360 | w = torch.ones(x.shape[0],1) * self.semantic_weight
1361 | w = torch.cat([w, 1-w], axis=-1)
1362 | semantic_weight = w.cuda()
1363 |
1364 | x, hw_shape = self.patch_embed(x)
1365 |
1366 | if self.use_abs_pos_embed:
1367 | x = x + self.absolute_pos_embed
1368 | x = self.drop_after_pos(x)
1369 |
1370 | outs = []
1371 | for i, stage in enumerate(self.stages):
1372 | x, hw_shape, out, out_hw_shape = stage(x, hw_shape)
1373 | if self.semantic_weight >= 0:
1374 | sw = self.semantic_embed_w[i](semantic_weight).unsqueeze(1)
1375 | sb = self.semantic_embed_b[i](semantic_weight).unsqueeze(1)
1376 | x = x * self.softplus(sw) + sb
1377 | if i in self.out_indices:
1378 | norm_layer = getattr(self, f'norm{i}')
1379 | out = norm_layer(out)
1380 | out = out.view(-1, *out_hw_shape,
1381 | self.num_features[i]).permute(0, 3, 1,
1382 | 2).contiguous()
1383 | outs.append(out)
1384 | x = self.avgpool(outs[-1])
1385 | x = torch.flatten(x, 1)
1386 | return x, outs
1387 |
1388 | def swin_base_patch4_window7_224(img_size=224,drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0., **kwargs):
1389 | model = SwinTransformer(pretrain_img_size = img_size, patch_size=4, window_size=7, embed_dims=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), drop_path_rate=drop_path_rate, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, **kwargs)
1390 | return model
1391 |
1392 | def swin_small_patch4_window7_224(img_size=224,drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0., **kwargs):
1393 | model = SwinTransformer(pretrain_img_size = img_size, patch_size=4, window_size=7, embed_dims=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), drop_path_rate=drop_path_rate, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, **kwargs)
1394 | return model
1395 |
1396 | def swin_tiny_patch4_window7_224(img_size=224,drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0., **kwargs):
1397 | model = SwinTransformer(pretrain_img_size = img_size, patch_size=4, window_size=7, embed_dims=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), drop_path_rate=drop_path_rate, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, **kwargs)
1398 | return model
1399 |
--------------------------------------------------------------------------------