├── vehicle_reid_pytorch ├── __init__.py ├── models │ ├── backbones │ │ ├── __init__.py │ │ ├── resnet.py │ │ ├── resnet_ibn.py │ │ └── senet.py │ ├── __init__.py │ ├── blocks.py │ ├── aaver.py │ ├── ram.py │ └── baseline.py ├── data │ ├── samplers │ │ ├── __init__.py │ │ └── triplet_sampler.py │ ├── datasets │ │ ├── __init__.py │ │ ├── common.py │ │ └── bases.py │ ├── transforms │ │ ├── __init__.py │ │ ├── resize_with_kp.py │ │ ├── pad_to_mul.py │ │ └── random_erasing.py │ ├── __init__.py │ └── demo_transforms.py ├── loss │ ├── __init__.py │ ├── test_tuplet_loss.py │ ├── center_loss.py │ ├── tuplet_loss.py │ └── triplet_loss.py ├── utils │ ├── path.py │ ├── __init__.py │ ├── tools.py │ ├── math.py │ ├── iotools.py │ ├── pytorch_tools.py │ └── visualize.py └── metrics │ ├── __init__.py │ ├── rerank.py │ ├── R1_mAP.py │ └── eval_reid.py ├── examples ├── parsing_reid │ ├── test_vehicleid.py │ ├── configs │ │ ├── veri776_b64_baseline.yml │ │ ├── veri776_b64_pven.yml │ │ ├── vehicleid_b256_pven.yml │ │ ├── veriwild_b256_224_baseline.yml │ │ ├── veriwild_b128_pven.yml │ │ └── veriwild_b256_224_pven.yml │ ├── model.py │ ├── math_tools.py │ └── main.py ├── preprocess_data │ ├── preprocess_veriwild2.py │ └── generate_pkl.py └── parsing │ ├── veri776_poly2mask.py │ ├── generate_masks.py │ ├── batch_gen_masks.py │ ├── train_parsing.py │ └── dataset.py ├── .gitignore ├── requirements.txt ├── setup.py └── README.md /vehicle_reid_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/parsing_reid/test_vehicleid.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vehicle_reid_pytorch/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vehicle_reid_pytorch/data/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .triplet_sampler import * -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | examples/EDA/* 2 | examples/outputs/* 3 | *.pth 4 | __pycache__ 5 | *.egg-info 6 | runs 7 | .vscode -------------------------------------------------------------------------------- /vehicle_reid_pytorch/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .baseline import * 2 | from .ram import * 3 | from .aaver import * -------------------------------------------------------------------------------- /vehicle_reid_pytorch/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .triplet_loss import * 2 | from .center_loss import * 3 | from .tuplet_loss import * -------------------------------------------------------------------------------- /vehicle_reid_pytorch/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .bases import ReIDDataset, ReIDMetaDataset 2 | from .common import CommonReIDDataset -------------------------------------------------------------------------------- /vehicle_reid_pytorch/utils/path.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def mkdir_p(path): 5 | if not os.path.exists(path): 6 | os.makedirs(path) 7 | -------------------------------------------------------------------------------- /vehicle_reid_pytorch/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .R1_mAP import R1_mAP, CMC10Times 2 | from .eval_reid import eval_func, eval_func_mp 3 | from .rerank import re_ranking -------------------------------------------------------------------------------- /vehicle_reid_pytorch/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .tools import * 2 | from .path import * 3 | from .math import * 4 | from .iotools import * 5 | from .pytorch_tools import * -------------------------------------------------------------------------------- /vehicle_reid_pytorch/data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .random_erasing import AlbuRandomErasing 2 | from .pad_to_mul import AlbuPadImageToMultipliesOf 3 | from .resize_with_kp import ResizeWithKp, MultiScale -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | albumentations 3 | scipy 4 | logzero 5 | yacs 6 | pandas 7 | torch >= 1.3.0 8 | torchvision >= 0.4.0 9 | segmentation_models_pytorch 10 | opencv-python 11 | tensorboardX 12 | opencv-python 13 | asranger 14 | resnest 15 | click -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import glob 3 | 4 | setup( 5 | name="vehicle_reid_pytorch", 6 | version="1.0", 7 | # keywords=("pytorch", "vehicle", "ReID"), 8 | # description="Vechile ReID utils implemented with pytorch", 9 | # long_description="", 10 | packages=find_packages(exclude=('examples', 'examples.*')), 11 | scripts=glob.glob('scripts/*') 12 | ) -------------------------------------------------------------------------------- /vehicle_reid_pytorch/loss/test_tuplet_loss.py: -------------------------------------------------------------------------------- 1 | from .tuplet_loss import * 2 | 3 | def test_generate_tuplets(): 4 | print(generate_tuplets(4, 8)) 5 | print(generate_tuplets(5, 8)) 6 | 7 | 8 | def test_tuplet_loss(): 9 | K = 4 10 | P = 8 11 | s = 32 12 | C = 16 13 | tuplet_loss = TupletLoss(K, P, s) 14 | feats = torch.randn(K * P, C) 15 | loss = tuplet_loss(feats) 16 | print(loss) 17 | -------------------------------------------------------------------------------- /examples/parsing_reid/configs/veri776_b64_baseline.yml: -------------------------------------------------------------------------------- 1 | data: 2 | pkl_path: '../outputs/veri776.pkl' 3 | train_size: (256, 256) 4 | valid_size: (256, 256) 5 | train_num_workers: 16 6 | test_num_workers: 4 7 | batch_size: 64 8 | with_mask: True 9 | 10 | loss: 11 | losses: ["id", "triplet", "center"] 12 | 13 | test: 14 | remove_junk: True 15 | lambda_: 0.0 16 | 17 | device: 'cuda' 18 | output_dir: '../outputs/veri776_b64_baseline/' 19 | 20 | -------------------------------------------------------------------------------- /examples/parsing_reid/configs/veri776_b64_pven.yml: -------------------------------------------------------------------------------- 1 | data: 2 | pkl_path: '../outputs/veri776.pkl' 3 | train_size: (256, 256) 4 | valid_size: (256, 256) 5 | train_num_workers: 16 6 | test_num_workers: 4 7 | batch_size: 64 8 | with_mask: True 9 | 10 | loss: 11 | losses: ["id", "triplet", "center", "local-triplet"] 12 | 13 | test: 14 | remove_junk: True 15 | lambda_: 0.5 16 | 17 | device: 'cuda' 18 | output_dir: '../outputs/veri776_b64_pven/' 19 | 20 | -------------------------------------------------------------------------------- /examples/parsing_reid/configs/vehicleid_b256_pven.yml: -------------------------------------------------------------------------------- 1 | data: 2 | pkl_path: '../outputs/vehicleid.pkl' 3 | train_size: (256, 256) 4 | valid_size: (256, 256) 5 | train_num_workers: 16 6 | test_num_workers: 0 7 | batch_size: 256 8 | with_mask: True 9 | test_ext: "_800" 10 | name: "VehicleID" 11 | 12 | loss: 13 | losses: ["id", "triplet", "center", "local-triplet"] 14 | 15 | test: 16 | remove_junk: False 17 | lambda_: 0.5 18 | period: 1000 19 | 20 | device: 'cuda' 21 | output_dir: '../outputs/vehicleid_b256_pven/' 22 | 23 | -------------------------------------------------------------------------------- /examples/parsing_reid/configs/veriwild_b256_224_baseline.yml: -------------------------------------------------------------------------------- 1 | data: 2 | name: 'VERIWild' 3 | pkl_path: '../outputs/veriwild.pkl' 4 | train_size: (224, 224) 5 | valid_size: (224, 224) 6 | train_num_workers: 32 7 | test_num_workers: 8 8 | batch_size: 128 9 | with_mask: True 10 | test_ext: '_3000' 11 | 12 | train: 13 | epochs: 60 14 | 15 | loss: 16 | losses: ["id", "triplet", "center"] 17 | 18 | test: 19 | remove_junk: False 20 | lambda_: 0.0 21 | device: "cuda" 22 | split: 100 23 | model_path: "../outputs/veriwild_b128_pven/model_60.pth" 24 | 25 | scheduler: 26 | milestones: [40, 70] 27 | 28 | device: 'cuda' 29 | output_dir: '../outputs/veriwild_b256_224_baseline/' -------------------------------------------------------------------------------- /examples/parsing_reid/configs/veriwild_b128_pven.yml: -------------------------------------------------------------------------------- 1 | data: 2 | name: 'VERIWild' 3 | pkl_path: '../outputs/veriwild.pkl' 4 | train_size: (256, 256) 5 | valid_size: (256, 256) 6 | train_num_workers: 32 7 | test_num_workers: 16 8 | batch_size: 128 9 | with_mask: True 10 | test_ext: '_3000' 11 | 12 | train: 13 | epochs: 60 14 | 15 | loss: 16 | losses: ["id", "triplet", "center", "local-triplet"] 17 | 18 | test: 19 | remove_junk: False 20 | lambda_: 0.5 21 | device: "cuda" 22 | split: 500 23 | model_path: "../outputs/veriwild_b128_pven/model_60.pth" 24 | 25 | scheduler: 26 | milestones: [30, 50] 27 | 28 | device: 'cuda' 29 | output_dir: '../outputs/veriwild_b128_pven/' 30 | 31 | -------------------------------------------------------------------------------- /examples/parsing_reid/configs/veriwild_b256_224_pven.yml: -------------------------------------------------------------------------------- 1 | data: 2 | name: 'VERIWild' 3 | pkl_path: '../outputs/veriwild.pkl' 4 | train_size: (224, 224) 5 | valid_size: (224, 224) 6 | train_num_workers: 16 7 | test_num_workers: 16 8 | batch_size: 256 9 | with_mask: True 10 | test_ext: '_3000' 11 | 12 | train: 13 | epochs: 120 14 | 15 | loss: 16 | losses: ["id", "triplet", "center", "local-triplet"] 17 | 18 | test: 19 | remove_junk: False 20 | lambda_: 0.5 21 | device: "cuda" 22 | split: 100 23 | model_path: "../outputs/veriwild_b256_224_pven/model_120.pth" 24 | 25 | scheduler: 26 | milestones: [40, 70] 27 | 28 | device: 'cuda' 29 | output_dir: '../outputs/veriwild_b256_224_pven/' 30 | 31 | -------------------------------------------------------------------------------- /vehicle_reid_pytorch/models/blocks.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | def conv_block(num_in, num_out): 5 | return nn.Sequential( 6 | nn.BatchNorm2d(num_in), 7 | nn.ReLU(True), 8 | nn.Conv2d(num_in, num_out / 2, 1), 9 | nn.BatchNorm2d(num_out / 2), 10 | nn.ReLU(True), 11 | nn.Conv2d(num_out / 2, num_out / 2, 3, 1, 1), 12 | nn.BatchNorm2d(num_out / 2), 13 | nn.ReLU(True), 14 | nn.Conv2d(num_out / 2, num_out, 1) 15 | ) 16 | 17 | 18 | def skip_layer(num_in, num_out): 19 | if num_in == num_out: 20 | return Identity() 21 | else: 22 | return nn.Sequential( 23 | nn.Conv2d(num_in, num_out, 1, 1) 24 | ) 25 | 26 | 27 | class Residual(nn.Module): 28 | def __init__(self, num_in, num_out): 29 | super(Residual, self).__init__() 30 | self.conv_block = conv_block(num_in, num_out) 31 | self.skip_layer = skip_layer(num_in, num_out) 32 | 33 | def forward(self, x): 34 | return self.conv_block(x) + self.skip_layer(x) 35 | 36 | 37 | class Identity(nn.Module): 38 | def forward(self, x, **kwargs): 39 | return x 40 | -------------------------------------------------------------------------------- /vehicle_reid_pytorch/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import datasets 2 | from . import demo_transforms as demo_trans 3 | 4 | 5 | def make_basic_dataset(pkl_path, train_size, val_size, pad, *, test_ext='', re_prob=0.5, with_mask=False, for_vis=False): 6 | """ 7 | 构建基础数据集。 8 | """ 9 | 10 | meta_dataset = datasets.CommonReIDDataset(pkl_path=pkl_path, test_ext=test_ext) 11 | train_transform = demo_trans.get_training_albumentations(train_size, pad, re_prob) 12 | val_transform = demo_trans.get_validation_augmentations(val_size) 13 | if for_vis: 14 | preprocessing = None 15 | else: 16 | # baiyan model 17 | preprocessing = demo_trans.get_preprocessing(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 18 | # preprocessing = demo_trans.get_preprocessing() 19 | 20 | train_dataset = datasets.ReIDDataset( 21 | meta_dataset.train, with_mask=with_mask, transform=train_transform, preprocessing=preprocessing) 22 | 23 | val_dataset = datasets.ReIDDataset(meta_dataset.query + meta_dataset.gallery, with_mask=with_mask, transform=val_transform, 24 | preprocessing=preprocessing) 25 | 26 | return train_dataset, val_dataset, meta_dataset 27 | -------------------------------------------------------------------------------- /vehicle_reid_pytorch/data/transforms/resize_with_kp.py: -------------------------------------------------------------------------------- 1 | import albumentations as albu 2 | from albumentations.augmentations import functional 3 | import numpy as np 4 | import cv2 5 | import random 6 | 7 | 8 | class ResizeWithKp(albu.Resize): 9 | def apply_to_keypoint(self, keypoint, **params): 10 | x = int(keypoint[0] / params["cols"] * self.width) 11 | y = int(keypoint[1] / params["rows"] * self.height) 12 | return (x, y, 0, 0) 13 | 14 | class MultiScale(albu.ImageOnlyTransform): 15 | def __init__(self, interpolation=cv2.INTER_LINEAR, always_apply=False, p=1): 16 | super(MultiScale, self).__init__(always_apply, p) 17 | self.interpolation = interpolation 18 | 19 | def apply(self, image, **params): 20 | height, width, _ = image.shape 21 | if width > 320 or height > 320: 22 | scale = random.random() * 0.4 + 0.2 23 | image = functional.resize(image, height=int(height * scale), width=int(width * scale), interpolation=self.interpolation) 24 | return functional.resize(image, height=height, width=width, interpolation=self.interpolation) 25 | else: 26 | return image 27 | 28 | def get_transform_init_args_names(self): 29 | return ("interpolation", ) 30 | -------------------------------------------------------------------------------- /examples/preprocess_data/preprocess_veriwild2.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import pickle as pkl 3 | import click 4 | 5 | 6 | """ 7 | For the veriwild2 test set. 8 | The file tree should be like this 9 | 10 | . 11 | ├── gallery 12 | │   └── gallery 13 | │   └── gallery_final 14 | ├── query 15 | │   └── query 16 | │   └── query_final 17 | └── test_split 18 | 19 | """ 20 | @click.command() 21 | @click.option('--input-path', default='/home/aa/mengdechao/datasets/veriwild2') 22 | @click.option('--output-path', default='../outputs/veriwild2.pkl') 23 | def veriwild2(input_path, output_path): 24 | output = {} 25 | 26 | output["train"] = [] 27 | PATH = input_path 28 | for name in ['A', 'B', 'All']: 29 | for phase in ['query', 'gallery']: 30 | raw_metas = open(PATH + f'/test_split/{name}_{phase}.txt').read().strip().split('\n') 31 | output_list = [ 32 | { 33 | 'image_path': osp.join(PATH, phase, raw.split(' ')[0]), 34 | 'id': raw.split(' ')[1], 35 | 'cam': 1 36 | } 37 | for raw in raw_metas 38 | ] 39 | output[f'{phase}_{name}'] = output_list 40 | 41 | with open(output_path, 'wb') as f: 42 | pkl.dump(output, f) 43 | 44 | 45 | if __name__ == "__main__": 46 | veriwild2() -------------------------------------------------------------------------------- /vehicle_reid_pytorch/data/transforms/pad_to_mul.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from vehicle_reid_pytorch.utils.math import pad_image_size_to_multiples_of 3 | import albumentations as albu 4 | import numpy as np 5 | 6 | 7 | def pad_image_to_shape(img, shape, *, return_padding=False): 8 | """ 9 | Zeros pad the given image to given shape whiling keeping the image 10 | in the center; 11 | :param shape: (h, w) 12 | :param return_padding: 13 | """ 14 | shape = list(shape[:2]) 15 | if img.ndim > 2: 16 | shape.extend(img.shape[2:]) 17 | shape = tuple(shape) 18 | 19 | h, w = img.shape[:2] 20 | assert w <= shape[1] and h <= shape[0] 21 | pad_width = shape[1] - w 22 | pad_height = shape[0] - h 23 | 24 | pad_w0 = pad_width // 2 25 | pad_w1 = shape[1] - (pad_width - pad_w0) 26 | pad_h0 = pad_height // 2 27 | pad_h1 = shape[0] - (pad_height - pad_h0) 28 | 29 | ret = np.zeros(shape, dtype=img.dtype) 30 | ret[pad_h0:pad_h1, pad_w0:pad_w1] = img 31 | if return_padding: 32 | return ret, (pad_h0, pad_w0) 33 | else: 34 | return ret 35 | 36 | 37 | def AlbuPadImageToMultipliesOf(multiply=32, align="top-left", **kwargs): 38 | fun = partial(pad_image_size_to_multiples_of, multiply=multiply, align=align) 39 | return albu.Lambda(image=fun, mask=fun) 40 | -------------------------------------------------------------------------------- /vehicle_reid_pytorch/data/datasets/common.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: Dechao Meng 4 | @contact: dechao.meng@vipl.ict.ac.cn 5 | """ 6 | 7 | import glob 8 | import re 9 | import os.path as osp 10 | from pathlib import Path 11 | import time 12 | import numpy as np 13 | import pickle as pkl 14 | from vehicle_reid_pytorch.data.datasets.bases import ReIDMetaDataset, relabel, get_imagedata_info 15 | 16 | 17 | class CommonReIDDataset(ReIDMetaDataset): 18 | def __init__(self, pkl_path, verbose=True, test_ext='', **kwargs): 19 | """ 20 | test_ext: For VehicleID and VERIWild, there are multi test sets. Pass the test ext to select which one to use. 21 | """ 22 | metas = pkl.load(open(pkl_path, 'rb')) 23 | self.train = metas["train"] 24 | self.query = metas["query" + str(test_ext)] 25 | self.gallery = metas["gallery" + str(test_ext)] 26 | 27 | self.relabel() 28 | 29 | if verbose: 30 | print("=> Dataset loaded") 31 | self.print_dataset_statistics() 32 | 33 | self.num_train_ids, self.num_train_imgs, self.num_train_cams = get_imagedata_info(self.train) 34 | self.num_query_ids, self.num_query_imgs, self.num_query_cams = get_imagedata_info(self.query) 35 | self.num_gallery_ids, self.num_gallery_imgs, self.num_gallery_cams = get_imagedata_info(self.gallery) 36 | -------------------------------------------------------------------------------- /vehicle_reid_pytorch/models/aaver.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from vehicle_reid_pytorch.models import Baseline 3 | from torch import nn 4 | import torch 5 | from vehicle_reid_pytorch.models.baseline import weights_init_classifier, weights_init_kaiming 6 | 7 | 8 | class AAVER(Baseline): 9 | def __init__(self, with_kp=True, with_mask=False, *args, **kwargs): 10 | """AAVER 11 | 12 | Arguments: 13 | Baseline {[type]} -- [description] 14 | """ 15 | super(AAVER, self).__init__(*args, **kwargs) 16 | conv1_input_ch = 3 17 | self.with_kp = with_kp 18 | self.with_mask = with_mask 19 | if self.with_kp: 20 | conv1_input_ch += 1 21 | if self.with_mask: 22 | conv1_input_ch += 4 23 | old_conv1 = self.base.conv1 24 | self.base.conv1 = nn.Conv2d( 25 | conv1_input_ch, 64, kernel_size=7, stride=2, padding=3, bias=False) 26 | self.base.conv1.weight[:, :3, :, :].data.copy_(old_conv1.weight.data) 27 | 28 | def forward(self, image, **kwargs): 29 | if self.with_kp: 30 | kp_heatmap = kwargs["kp_heatmap"] 31 | image = torch.cat([image, kp_heatmap.unsqueeze(1)], dim=1) 32 | if self.with_mask: 33 | mask = kwargs["mask"][:, 1:5, :, :] 34 | image = torch.cat([image, mask], dim=1) 35 | return super(AAVER, self).forward(image, **kwargs) 36 | 37 | 38 | if __name__ == "__main__": 39 | model = AAVER(333, 1, '/data/models/resnet50.pth', 40 | 'bnneck', 'after', 'resnet50', 'imagenet') 41 | 42 | -------------------------------------------------------------------------------- /vehicle_reid_pytorch/data/demo_transforms.py: -------------------------------------------------------------------------------- 1 | import albumentations as albu 2 | from albumentations.pytorch import ToTensor 3 | import numpy as np 4 | import cv2 5 | import torch 6 | 7 | from .transforms import AlbuRandomErasing, ResizeWithKp, MultiScale 8 | 9 | 10 | def get_training_albumentations(size=(256, 256), pad=10, re_prob=0.5, with_keypoints=False, ms_prob=0.5): 11 | h, w = size 12 | train_transform = [ 13 | MultiScale(p=ms_prob), 14 | ResizeWithKp(h, w, interpolation=cv2.INTER_CUBIC), 15 | albu.PadIfNeeded(h + 2 * pad, w + 2 * pad, border_mode=cv2.BORDER_CONSTANT, value=0), 16 | albu.RandomCrop(height=h, width=w, always_apply=True), 17 | AlbuRandomErasing(re_prob), 18 | ] 19 | if with_keypoints: 20 | return albu.Compose(train_transform, keypoint_params=albu.KeypointParams(format='xy', remove_invisible=False)) 21 | else: 22 | return albu.Compose(train_transform) 23 | 24 | 25 | def get_validation_augmentations(size=(256, 256), with_keypoints=False): 26 | h, w = size 27 | test_transform = [ 28 | ResizeWithKp(h, w), 29 | ] 30 | if with_keypoints: 31 | return albu.Compose(test_transform, keypoint_params=albu.KeypointParams(format='xy', remove_invisible=False)) 32 | else: 33 | return albu.Compose(test_transform) 34 | 35 | 36 | def to_tensor(x, **kwargs): 37 | x = np.transpose(x, [2, 0, 1]) 38 | return torch.tensor(x) 39 | 40 | 41 | def get_preprocessing(mean=(0.485, 0.456, 0.406), 42 | std=(0.229, 0.224, 0.225)): 43 | _transform = [ 44 | albu.Normalize(mean, std), 45 | albu.Lambda(image=to_tensor, mask=to_tensor) 46 | ] 47 | return albu.Compose(_transform) 48 | -------------------------------------------------------------------------------- /examples/parsing/veri776_poly2mask.py: -------------------------------------------------------------------------------- 1 | """ 2 | Change polys to masks 3 | 4 | author: Dechao Meng 5 | email: mengdechaolive@qq.com 6 | """ 7 | from tqdm import tqdm 8 | import json 9 | import cv2 10 | import numpy as np 11 | import argparse 12 | from vehicle_reid_pytorch.utils import mkdir_p 13 | 14 | 15 | def poly2mask(polys, classes, shape): 16 | mask = np.zeros(shape, dtype=np.uint8) 17 | for poly, class_ in zip(polys, classes): 18 | poly = np.array(poly) 19 | poly[:, 0] *= mask.shape[1] 20 | poly[:, 1] *= mask.shape[0] 21 | poly = poly.astype(np.int) 22 | cv2.fillPoly(mask, [poly], class_) 23 | return mask 24 | 25 | 26 | def get_metas_dirty(item): 27 | nori_id = item['uris'][0] 28 | shape = item['resources'][0]['size'] 29 | shape = [shape['height'], shape["width"]] 30 | polys_list = item['results']['polys'] 31 | polys = [poly['poly'] for poly in polys_list] 32 | classes = [int(poly['attr']['side']) for poly in polys_list] 33 | return nori_id, shape, polys, classes 34 | 35 | 36 | if __name__ == '__main__': 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument("--json-path", default="poly.json") 39 | parser.add_argument("--output-path", default="veri776_parsing3165") 40 | args = parser.parse_args() 41 | 42 | with open(args.json_path, "r") as f: 43 | polygons = json.load(f) 44 | output_path = args.output_path 45 | mkdir_p(output_path) 46 | 47 | for i, item in tqdm(enumerate(polygons)): 48 | image_name = item["image_name"] 49 | shape = item["shape"] 50 | polys = item["polys"] 51 | classes = item["classes"] 52 | mask = poly2mask(polys, classes, shape) 53 | print(image_name) 54 | image_name = image_name.split('/')[1].split('.')[0] 55 | cv2.imwrite('{}/{}.png'.format(output_path, image_name), mask) 56 | 57 | 58 | -------------------------------------------------------------------------------- /vehicle_reid_pytorch/loss/center_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class CenterLoss(nn.Module): 8 | """Center loss. 9 | 10 | Reference: 11 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 12 | 13 | Args: 14 | num_classes (int): number of classes. 15 | feat_dim (int): feature dimension. 16 | """ 17 | 18 | def __init__(self, num_classes=751, feat_dim=2048): 19 | super(CenterLoss, self).__init__() 20 | self.num_classes = num_classes 21 | self.feat_dim = feat_dim 22 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 23 | 24 | def forward(self, x, labels): 25 | """ 26 | Args: 27 | x: feature matrix with shape (batch_size, feat_dim). 28 | labels: ground truth labels with shape (num_classes). 29 | """ 30 | assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)" 31 | 32 | batch_size = x.size(0) 33 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 34 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 35 | distmat.addmm_(1, -2, x, self.centers.t()) 36 | 37 | classes = torch.arange(self.num_classes, device=x.device).long() 38 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 39 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 40 | 41 | dist = [] 42 | for i in range(batch_size): 43 | value = distmat[i][mask[i]] 44 | value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 45 | dist.append(value) 46 | dist = torch.cat(dist) 47 | loss = dist.mean() 48 | return loss 49 | 50 | 51 | if __name__ == '__main__': 52 | use_gpu = False 53 | center_loss = CenterLoss() 54 | features = torch.rand(16, 2048) 55 | targets = torch.tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).long() 56 | if use_gpu: 57 | features = torch.rand(16, 2048).cuda() 58 | targets = torch.tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).cuda() 59 | 60 | loss = center_loss(features, targets) 61 | print(loss) 62 | -------------------------------------------------------------------------------- /vehicle_reid_pytorch/loss/tuplet_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import functional as F 4 | from vehicle_reid_pytorch.utils.math import euclidean_dist 5 | 6 | 7 | def generate_tuplets(K, P): 8 | """ 9 | 生成tuplets 10 | 11 | :param torch.Tensor label: [B] 12 | :return: 13 | """ 14 | tuplets = [] 15 | for k in range(K): 16 | for p in range(P): 17 | index = k + p * K 18 | positives = torch.arange(K * p, K * (p + 1)) 19 | positives = positives[positives != index] 20 | negative_labels = torch.arange(P) 21 | negative_labels = negative_labels[negative_labels != p] 22 | for positive in positives: 23 | negatives = torch.randint(K, (P - 1,)) + (negative_labels * K) 24 | tuplet = torch.tensor([index, positive, *negatives]) 25 | tuplets.append(tuplet) 26 | tuplets = torch.stack(tuplets) 27 | return tuplets 28 | 29 | 30 | def _tuplet_loss(tuplet_feats, s, beta): 31 | """ 32 | 33 | :param tuplet_feats: [B, P+1, C] 34 | :param s: 35 | :return: 36 | """ 37 | B, P, C = tuplet_feats.shape 38 | P = P - 1 39 | anchors = tuplet_feats[:, 0] # B C 40 | positives = tuplet_feats[:, 1] # B C 41 | negatives = tuplet_feats[:, 2:] # B P-1 C 42 | cos_ap = torch.sum(anchors * positives, 1) # B 43 | if beta != 0: 44 | theta_ap = torch.acos(cos_ap) 45 | cos_ap_beta = torch.cos(theta_ap - beta) 46 | cos_an = torch.sum(anchors.view(B, 1, -1) * negatives, 2) # B, P-1 47 | return torch.log(1 + torch.sum(torch.exp(s * (cos_an - cos_ap_beta.view(-1, 1))), 1)) 48 | 49 | 50 | class TupletLoss(object): 51 | """ 52 | An reproduce of the margin tuplet loss proposed by 53 | "Deep Metric Learning with Tuplet Margin Loss", ICCV2019 54 | 55 | """ 56 | 57 | def __init__(self, K, P, s=32, beta=0.): 58 | """ 59 | 60 | 61 | :param K: number of images per classes in a minibatch 62 | :param P: numebr of classes in a minibatch 63 | :param s: scale factor 64 | :param beta: slack margin 65 | """ 66 | self.K = K 67 | self.P = P 68 | self.s = s 69 | self.beta = beta 70 | 71 | def __call__(self, feats): 72 | """ 73 | 74 | :param torch.Tensor feats: [N, C] 75 | :return: 76 | """ 77 | feats = feats / F.norm(feats, dim=1).view(-1, 1) 78 | tuplets = generate_tuplets(self.K, self.P).to(feats.device) 79 | tuplet_feats = feats[tuplets] 80 | 81 | loss = _tuplet_loss(tuplet_feats, self.s, self.beta).mean() 82 | return loss 83 | -------------------------------------------------------------------------------- /vehicle_reid_pytorch/utils/tools.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | import os 4 | import torch 5 | import time 6 | from logzero import logger 7 | from yacs.config import CfgNode 8 | 9 | 10 | def setup_logger(name, save_dir, distributed_rank, level="INFO"): 11 | logger = logging.getLogger(name) 12 | logger.setLevel(10) 13 | # don't log results for the non-master process 14 | if distributed_rank > 0: 15 | return logger 16 | ch = logging.StreamHandler(stream=sys.stdout) 17 | ch.setLevel(getattr(logging, level.upper())) 18 | formatter = logging.Formatter( 19 | "%(asctime)s %(name)s %(levelname)s: %(message)s") 20 | ch.setFormatter(formatter) 21 | logger.addHandler(ch) 22 | 23 | if save_dir: 24 | fh = logging.FileHandler(os.path.join(save_dir, "log.txt"), mode='w') 25 | fh.setLevel(getattr(logging, level.upper())) 26 | fh.setFormatter(formatter) 27 | logger.addHandler(fh) 28 | 29 | return logger 30 | 31 | 32 | def tb_log(kv_map, writter, global_steps): 33 | """ 34 | 接受一个字典,将里面的值发送到tensorboard中 35 | :param dict losses: 36 | :param writter: 37 | :param global_steps: 38 | :return: 39 | """ 40 | for loss_name, value in kv_map.items(): 41 | if isinstance(value, torch.Tensor): 42 | value = value.item() 43 | writter.add_scalar(loss_name, value, global_steps) 44 | logger.debug(f'{loss_name}: {value}') 45 | 46 | 47 | class Session(): 48 | def __init__(self): 49 | pass 50 | 51 | def train(self): 52 | pass 53 | 54 | def eval(self): 55 | pass 56 | 57 | 58 | class TimeCounter(): 59 | """ 60 | 统计程序运行时间。支持使用with语句。 61 | 62 | """ 63 | 64 | def __init__(self, verbose=False): 65 | self._verbose = verbose 66 | self.period = 0 67 | 68 | def __enter__(self): 69 | self._start = time.time() 70 | 71 | def __exit__(self, exc_type, exc_val, exc_tb): 72 | self._end = time.time() 73 | self.period += self._end - self._start 74 | 75 | if self._verbose: 76 | print(f'Cost time: {self.period}') 77 | 78 | 79 | def iter_x(x): 80 | if isinstance(x, CfgNode): 81 | for key, value in x.items(): 82 | yield (key, value) 83 | 84 | 85 | 86 | def _flat_cfg(x): 87 | for key, value in iter_x(x): 88 | if isinstance(value, (dict, list)): 89 | for k, v in _flat_cfg(value): 90 | k = f'{key}.{k}' 91 | yield (k, v) 92 | else: 93 | yield (key, value) 94 | 95 | def flat_cfg(x): 96 | output = {} 97 | for k, v in _flat_cfg(x): 98 | output[k] = v 99 | return output 100 | -------------------------------------------------------------------------------- /examples/parsing/generate_masks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate masks for ReID dataset. 3 | """ 4 | import torch 5 | import os 6 | import sys 7 | import pickle 8 | import tqdm 9 | from dataset import * 10 | from vehicle_reid_pytorch.data.datasets import ReIDMetaDataset 11 | from vehicle_reid_pytorch.utils import mkdir_p 12 | import numpy as np 13 | from pathlib import Path 14 | import time 15 | import torch 16 | import segmentation_models_pytorch as smp 17 | import argparse 18 | 19 | ENCODER = 'se_resnext50_32x4d' 20 | ENCODER_WEIGHTS = 'imagenet' 21 | preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS) 22 | 23 | IMG2MASK = {} 24 | 25 | def predict(model, test_dataset, test_dataset_vis, output_path): 26 | mkdir_p(output_path) 27 | for i in tqdm.tqdm(range(len(test_dataset))): 28 | image = test_dataset[i] 29 | image_vis, extra = test_dataset_vis[i] 30 | 31 | # 重复图片直接用之前计算好的即可 32 | image_path = Path(extra["image_path"]) 33 | if str(image_path) in IMG2MASK: 34 | extra["mask_path"] = str(IMG2MASK[str(image_path)]) 35 | continue 36 | mask_path = output_path / f"{image_path.name.split('.')[0]}.png" 37 | 38 | x_tensor = torch.from_numpy(image).to("cuda").unsqueeze(0) 39 | with torch.no_grad(): 40 | pr_mask = model.predict(x_tensor) 41 | pr_map = pr_mask.squeeze().cpu().numpy().round() 42 | pr_map = np.argmax(pr_map, axis=0)[:image_vis.shape[0], :image_vis.shape[1]] 43 | cv2.imwrite(str(mask_path), pr_map.astype(np.uint8)) 44 | extra["mask_path"] = str(mask_path) 45 | 46 | IMG2MASK[str(image_path)] = str(mask_path) 47 | 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument("--model-path", default="best_model_trainval.pth") 53 | parser.add_argument("--reid-pkl-path", type=str, required=True) 54 | parser.add_argument("--output-path", type=str, required=True) 55 | args = parser.parse_args() 56 | model = torch.load(args.model_path) 57 | model = model.cuda() 58 | model.eval() 59 | 60 | with open(args.reid_pkl_path, "rb") as f: 61 | metas = pickle.load(f) 62 | output_path = Path(args.output_path).absolute() 63 | 64 | for phase in metas.keys(): 65 | sub_path = output_path / phase 66 | mkdir_p(str(sub_path)) 67 | dataset = VehicleReIDParsingDataset(metas[phase], augmentation=get_validation_augmentation(), 68 | preprocessing=get_preprocessing(preprocessing_fn)) 69 | dataset_vis = VehicleReIDParsingDataset(metas[phase], with_extra=True) 70 | print('Predict mask to {}'.format(sub_path)) 71 | predict(model, dataset, dataset_vis, sub_path) 72 | 73 | # Write mask path to pkl 74 | with open(args.reid_pkl_path, "wb") as f: 75 | pickle.dump(metas, f) 76 | -------------------------------------------------------------------------------- /vehicle_reid_pytorch/utils/math.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cv2 4 | 5 | 6 | def euclidean_dist(x, y): 7 | """ 8 | 9 | :param torch.Tensor x: 10 | :param torch.Tensor y: 11 | :rtype: torch.Tensor 12 | :return: dist: pytorch Variable, with shape [m, n] 13 | """ 14 | m, n = x.size(0), y.size(0) 15 | xx = torch.pow(x, 2).sum(1, keepdim=True).view(m, 1) 16 | yy = torch.pow(y, 2).sum(1, keepdim=True).view(n, 1).t() 17 | dist = xx + yy 18 | dist.addmm_(1, -2, x, y.t()) 19 | return dist.clamp(0).sqrt() 20 | 21 | 22 | class AverageMeter(object): 23 | """Computes and stores the average and current value""" 24 | 25 | def __init__(self): 26 | self.reset() 27 | 28 | def reset(self): 29 | self.val = 0 30 | self.avg = 0 31 | self.sum = 0 32 | self.count = 0 33 | 34 | def update(self, val, n=1): 35 | self.val = val 36 | self.sum += val * n 37 | self.count += n 38 | self.avg = self.sum / self.count 39 | 40 | 41 | 42 | def near_convex(xys, threshold=20): 43 | """ 44 | 检查一个四边形是否为凸四边形 45 | 46 | 近似是凸的也不能取。因此检查四个内角,如果内角接近180°则直接pass。 47 | 48 | :param np.ndarray xys: 4*2, 4个点的xy坐标 49 | :param threshold: abs degree of the diff with 180° 50 | :return: 51 | """ 52 | vectors = np.empty([4, 2]) 53 | vectors[:3, :] = xys[1:] - xys[:-1] 54 | vectors[3, :] = xys[0] - xys[-1] 55 | 56 | # 通过叉积判断凸性 57 | cross = np.cross(vectors, vectors[[3, 0, 1, 2]]) 58 | if np.any(cross > 0) and np.any(cross < 0): 59 | return True 60 | 61 | # 近似凸也要去掉 62 | angles = np.empty(4) 63 | norm_dot = np.sqrt(np.sum(vectors[:3, :] ** 2, axis=1) * np.sum(vectors[1:, :] ** 2, axis=1)).clip(1e-7, None) 64 | angles[:3] = np.arccos(np.sum(vectors[:3, :] * vectors[1:, :], axis=1) / norm_dot) 65 | norm_dot:np.ndarray = (np.sum(vectors[0, :] ** 2) * np.sum(vectors[3, :] ** 2)).clip(1e-7, None) 66 | angles[3] = np.arccos(np.sum(vectors[0, :] * vectors[3, :]) / norm_dot) 67 | 68 | if np.any(np.abs(angles) < threshold / 180 * np.pi): 69 | return True 70 | return False 71 | 72 | 73 | def perspective_transform(image, quad_pts, target_pts=None, output_size=(128, 128)): 74 | """ 75 | 76 | :param image: 77 | :param quad_pts: 78 | :param output_size: 79 | :param context_size: 80 | :return: 81 | """ 82 | quad_pts = quad_pts.astype(np.float32) 83 | 84 | x, y = output_size 85 | if target_pts is None: 86 | target_pts = [[0, 0], [x, 0], [x, y], [0, y]] 87 | 88 | target_pts = np.array(target_pts).astype(np.float32) 89 | m = cv2.getPerspectiveTransform(quad_pts, target_pts) 90 | warp_img = cv2.warpPerspective(image, m, output_size) 91 | 92 | return warp_img 93 | 94 | def pad_image_size_to_multiples_of(img, multiple, *, align): 95 | """ 96 | '''Pad of image such that size of its edge is the least number that is a 97 | multiple of given multiple and larger than original image. The image 98 | will be placed in the center using pad_image_to_shape. 99 | 100 | :param multiple: the dividend of the targeting size of the image 101 | :param align: one of 'top-left' or 'center' 102 | """ 103 | 104 | assert align in {'top-left', 'center'}, align 105 | 106 | h, w = img.shape[:2] 107 | d = multiple 108 | 109 | def canonicalize(s): 110 | v = s // d 111 | return (v + (v * d != s)) * d 112 | 113 | th, tw = map(canonicalize, (h, w)) 114 | if align == 'top-left': 115 | tshape = (th, tw) 116 | if img.ndim == 3: 117 | tshape = tshape + (img.shape[2],) 118 | ret = np.zeros(tshape, dtype=img.dtype) 119 | ret[:h, :w] = img 120 | return ret 121 | else: 122 | assert align == 'center', align 123 | return pad_image_to_shape(img, (th, tw)) -------------------------------------------------------------------------------- /vehicle_reid_pytorch/data/transforms/random_erasing.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import math 8 | import random 9 | import numpy as np 10 | import albumentations as albu 11 | 12 | 13 | class RandomErasing(object): 14 | """ Randomly selects a rectangle region in an image and erases its pixels. 15 | 'Random Erasing Data Augmentation' by Zhong et al. 16 | See https://arxiv.org/pdf/1708.04896.pdf 17 | Args: 18 | probability: The probability that the Random Erasing operation will be performed. 19 | sl: Minimum proportion of erased area against input image. 20 | sh: Maximum proportion of erased area against input image. 21 | r1: Minimum aspect ratio of erased area. 22 | mean: Erasing value. 23 | """ 24 | 25 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 26 | self.probability = probability 27 | self.mean = mean 28 | self.sl = sl 29 | self.sh = sh 30 | self.r1 = r1 31 | 32 | def handle_numpy(self, img): 33 | """ 34 | img为未归一化的(H,W,C),为albumentation使用 35 | :param img: 36 | :return: 37 | """ 38 | shape = img.shape 39 | if random.uniform(0, 1) >= self.probability: 40 | return img 41 | 42 | for attempt in range(100): 43 | area = shape[0] * shape[1] 44 | 45 | target_area = random.uniform(self.sl, self.sh) * area 46 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 47 | 48 | h = int(round(math.sqrt(target_area * aspect_ratio))) 49 | w = int(round(math.sqrt(target_area / aspect_ratio))) 50 | 51 | if w < shape[1] and h < shape[0]: 52 | x1 = random.randint(0, shape[0] - h) 53 | y1 = random.randint(0, shape[1] - w) 54 | if shape[2] == 3: 55 | img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0] * 255 56 | img[x1:x1 + h, y1:y1 + w, 1] = self.mean[1] * 255 57 | img[x1:x1 + h, y1:y1 + w, 2] = self.mean[2] * 255 58 | else: 59 | img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0] 60 | return img 61 | 62 | return img 63 | 64 | def handel_pil(self, img): 65 | shape = img.size() 66 | if random.uniform(0, 1) >= self.probability: 67 | return img 68 | 69 | for attempt in range(100): 70 | area = shape[1] * shape[2] 71 | 72 | target_area = random.uniform(self.sl, self.sh) * area 73 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 74 | 75 | h = int(round(math.sqrt(target_area * aspect_ratio))) 76 | w = int(round(math.sqrt(target_area / aspect_ratio))) 77 | 78 | if w < shape[2] and h < shape[1]: 79 | x1 = random.randint(0, shape[1] - h) 80 | y1 = random.randint(0, shape[2] - w) 81 | if shape[0] == 3: 82 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 83 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 84 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 85 | else: 86 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 87 | return img 88 | 89 | return img 90 | 91 | def __call__(self, img): 92 | if isinstance(img, np.ndarray): 93 | return self.handle_numpy(img) 94 | else: 95 | return self.handel_pil(img) 96 | 97 | 98 | def AlbuRandomErasing(probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 99 | fun = RandomErasing(probability, sl, sh, r1, mean) 100 | 101 | def wrapper(x, **kwargs): 102 | return fun(x) 103 | 104 | return albu.Lambda(image=wrapper, mask=wrapper) 105 | -------------------------------------------------------------------------------- /vehicle_reid_pytorch/models/ram.py: -------------------------------------------------------------------------------- 1 | from vehicle_reid_pytorch.models import Baseline 2 | from torch import nn 3 | from .baseline import weights_init_classifier, weights_init_kaiming 4 | 5 | 6 | class RAM(Baseline): 7 | def __init__(self, divides, *args, **kwargs): 8 | """RAM 9 | 10 | Arguments: 11 | Baseline {[type]} -- [description] 12 | divide {List[(int, int, int, int, int, int)]} -- 如[(0, 1024, 0, 6, 0, 16), (4, 8, 0, 16), (8, 16, 0, 16)], 13 | 数组内容用于切片。左闭右开。如果是6个值则在通道上做切分,如果是4个值则仅在高宽做切分。 14 | """ 15 | super(RAM, self).__init__(*args, **kwargs) 16 | self.divides = divides 17 | self.num_parts = len(self.divides) 18 | self.local_bottlenecks = nn.ModuleList([]) 19 | self.local_classifiers = nn.ModuleList([]) 20 | for i, divide in enumerate(divides): 21 | if len(divide) == 6: 22 | channels = divide[1] - divide[0] 23 | print(channels) 24 | else: 25 | channels = self.in_planes 26 | local_bottleneck = nn.BatchNorm1d(channels) 27 | local_classifier = nn.Linear(channels, self.num_classes, bias=False) 28 | local_bottleneck.bias.require_grads = False 29 | local_bottleneck.apply(weights_init_kaiming) 30 | local_classifier.apply(weights_init_classifier) 31 | self.local_bottlenecks.append(local_bottleneck) 32 | self.local_classifiers.append(local_classifier) 33 | 34 | 35 | def get_feature(self, feature_map, classifier, bottleneck): 36 | global_feat = self.gap(feature_map) # (b, 2048, 1, 1) 37 | global_feat = global_feat.view( 38 | global_feat.shape[0], -1) # flatten to (bs, 2048) 39 | 40 | if self.neck == 'no': 41 | feat = global_feat 42 | elif self.neck == 'bnneck': 43 | feat = bottleneck(global_feat) # normalize for angular softmax 44 | 45 | if self.training: 46 | cls_score = classifier(feat) 47 | return cls_score, global_feat # global feature for triplet loss 48 | else: 49 | if self.neck_feat == 'after': 50 | # print("Test with feature after BN") 51 | return feat 52 | else: 53 | # print("Test with feature before BN") 54 | return global_feat 55 | 56 | def forward(self, x, **kwargs): 57 | x = self.base(x) 58 | if self.training: 59 | global_score, global_feat = self.get_feature(x, self.classifier, self.bottleneck) # (b, 2048, 1, 1) 60 | local_feats = [] 61 | local_scores = [] 62 | for i, divide in enumerate(self.divides): 63 | if len(divide) == 6: 64 | lc, rc, lh, rh, lw, rw = divide 65 | elif len(divide) == 4: 66 | lh, rh, lw, rw = divide 67 | lc, rc = 0, self.in_planes 68 | 69 | local_fmap = x[:, lc:rc, lh:rh, lw:rw] 70 | local_score, local_feat = self.get_feature(local_fmap, self.local_classifiers[i], self.local_bottlenecks[i]) 71 | local_feats.append(local_feat) 72 | local_scores.append(local_score) 73 | return { 74 | "local_feats": local_feats, 75 | "local_scores": local_scores, 76 | "global_feat": global_feat, 77 | "global_score": global_score 78 | } 79 | else: 80 | global_feat = self.get_feature(x, self.classifier, self.bottleneck) # (b, 2048, 1, 1) 81 | # local_feats = [] 82 | # for i, divide in enumerate(self.divides): 83 | # local_feat = self.get_feature(x, self.local_classifiers[i], self.local_bottlenecks[i]) 84 | # local_feats.append(local_feat) 85 | return { 86 | # "local_feats": local_feats, 87 | "global_feat": global_feat, 88 | } 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /vehicle_reid_pytorch/utils/iotools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import glob 3 | import re 4 | import os.path as osp 5 | from PIL import Image 6 | import cv2 7 | import numpy as np 8 | import socket 9 | 10 | def read_any_img(img_path: str, format='ndarray'): 11 | """ 12 | 单通道图返回原图,多通道图返回RGB。支持文件系统 13 | 14 | :param img_path: 15 | :param format: 16 | :return: 17 | """ 18 | img = read_rgb_image(img_path, format) 19 | return img 20 | 21 | 22 | def read_rgb_image(img_path, format='ndarray'): 23 | """Keep reading image until succeed. 24 | This can avoid IOError incurred by heavy IO process.""" 25 | got_img = False 26 | if not osp.exists(img_path): 27 | raise IOError("{} does not exist".format(img_path)) 28 | while not got_img: 29 | try: 30 | if format == 'PIL': 31 | img = Image.open(img_path).convert("RGB") 32 | elif format == 'ndarray': 33 | img = cv2.imread(img_path) 34 | if len(img.shape) == 3: 35 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 36 | got_img = True 37 | except IOError: 38 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 39 | return img 40 | 41 | 42 | 43 | def load_checkpoint(output_dir, device="cpu", epoch=0, exclude=None, **kwargs): 44 | """ 45 | 含关键字model和optimizer会被正确加载到指定的设备上. 46 | 如果不指定epoch,自动读取最大的epoch。 47 | 如果指定exclude,将会删除含该关键字的参数 48 | 49 | :param str output_dir: 50 | :param device: 51 | :param epoch: 52 | :param kwargs: 53 | :return: 54 | """ 55 | 56 | if output_dir[-1] == '/': 57 | output_dir = output_dir[:-1] 58 | 59 | # 不指定epoch则读取已保存的最大epoch 60 | if epoch == 0: 61 | for key in kwargs.keys(): 62 | pths = glob.glob(f'{output_dir}/{key}_*.pth') 63 | epochs = [re.findall(rf'{output_dir}/{key}_([0-9]+)\.pth', name)[0] for name in pths] 64 | epochs = list(map(int, epochs)) 65 | epoch = max(epochs) 66 | break 67 | 68 | for key, obj in kwargs.items(): 69 | state_dict = torch.load(f'{output_dir}/{key}_{epoch}.pth', map_location=device) 70 | if exclude is not None: 71 | exclude_keys = [] 72 | for k in state_dict.keys(): 73 | if exclude in k: 74 | exclude_keys.append(k) 75 | for k in exclude_keys: 76 | del state_dict[k] 77 | 78 | obj: torch.nn.Module 79 | try: 80 | obj.load_state_dict(state_dict, strict=False) 81 | except TypeError: 82 | obj.load_state_dict(state_dict) 83 | 84 | 85 | # move to target device 86 | if 'model' in key: 87 | obj.to(device) 88 | 89 | elif 'optimizer' in key: 90 | for state in obj.state.values(): 91 | for k, v in state.items(): 92 | if isinstance(v, torch.Tensor): 93 | state[k] = v.to(device) 94 | 95 | return epoch + 1 96 | 97 | 98 | def save_checkpoint(epoch, output_dir, **kwargs): 99 | for key, obj in kwargs.items(): 100 | try: 101 | obj = obj.module 102 | except AttributeError: 103 | pass 104 | 105 | torch.save(obj.state_dict(), f'{output_dir}/{key}_{epoch}.pth') 106 | 107 | 108 | def merge_configs(cfg, config_files, cmd_config): 109 | """ 110 | 融合不同的配置。依次加载默认配置,配置文件和命令行参数。配置文件用,隔开 111 | 112 | :param CfgNode cfg: 113 | :param str config_files: 114 | :param list cmd_config: 115 | :return: 116 | """ 117 | if config_files != "": 118 | config_files = config_files.split(",") 119 | for config_file in config_files: 120 | cfg.merge_from_file(config_file) 121 | 122 | cfg.merge_from_list(cmd_config) 123 | return cfg 124 | 125 | 126 | def get_host_ip(): 127 | """ 128 | 查询本机ip地址 129 | :return: ip 130 | """ 131 | try: 132 | s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 133 | s.connect(('8.8.8.8', 80)) 134 | ip = s.getsockname()[0] 135 | finally: 136 | s.close() 137 | 138 | return ip 139 | -------------------------------------------------------------------------------- /examples/parsing/batch_gen_masks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate masks for ReID dataset. 3 | """ 4 | import math 5 | import torch 6 | import os 7 | from tqdm import tqdm 8 | import numpy as np 9 | from pathlib import Path 10 | import torch 11 | import segmentation_models_pytorch as smp 12 | from torch.utils.data import Dataset, DataLoader 13 | from multiprocessing import Pool 14 | import click 15 | import albumentations as albu 16 | from mdc_tools.timer import Timer 17 | import cv2 18 | import pandas as pd 19 | # import debugpy; debugpy.connect(('100.64.158.205', 5678)) 20 | 21 | ENCODER = 'se_resnext50_32x4d' 22 | ENCODER_WEIGHTS = 'imagenet' 23 | preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS) 24 | 25 | IMG2MASK = {} 26 | 27 | def to_tensor(x, **kwargs): 28 | return x.transpose(2, 0, 1).astype('float32') 29 | 30 | def travel_imgs(root, exts=('.png', '.jpg', '.jpeg')): 31 | outputs = [] 32 | for dir_, dirs, files in os.walk(root): 33 | for file in files: 34 | file: str 35 | if os.path.splitext(file)[1] in exts: 36 | outputs.append(os.path.join(dir_, file)) 37 | return outputs 38 | 39 | 40 | class ParsingTestDataset(Dataset): 41 | def __init__(self, root_path) -> None: 42 | self.transforms = albu.Compose([ 43 | albu.LongestMaxSize(244), 44 | albu.PadIfNeeded(min_height=256, min_width=256, always_apply=True, 45 | border_mode=cv2.BORDER_CONSTANT, position='top_left'), 46 | albu.Lambda(image=preprocessing_fn), 47 | albu.Lambda(image=to_tensor, mask=to_tensor) 48 | ]) 49 | self.root_path = root_path 50 | with Timer(f"Scanning images in {root_path}..."): 51 | self.image_paths = travel_imgs(root_path) 52 | print(f"Total {len(self.image_paths)} images.") 53 | 54 | def __getitem__(self, index): 55 | image_path = self.image_paths[index] 56 | image = cv2.imread(image_path) 57 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 58 | image_shape = image.shape[:2] 59 | 60 | image = self.transforms(image=image)['image'] 61 | return image, image_path, torch.tensor(image_shape) 62 | 63 | def __len__(self): 64 | return len(self.image_paths) 65 | 66 | 67 | def write_worker(args): 68 | mask, image_path, image_shape, input_path, output_path = args 69 | # 处理路径 70 | image_rel_path = os.path.relpath(image_path, input_path) 71 | save_rel_path = os.path.splitext(image_rel_path)[0] + '.png' 72 | image_output_path = os.path.join(output_path, save_rel_path) 73 | os.makedirs(os.path.split(image_output_path)[0], exist_ok=True) 74 | 75 | # 处理形状 76 | image_shape = np.array(image_shape) 77 | target_shape = (244 / max(image_shape) * image_shape).astype(np.int) 78 | mask = mask[:target_shape[0], :target_shape[1]].astype(np.uint8) 79 | mask = cv2.resize(mask, image_shape, interpolation=cv2.INTER_NEAREST) 80 | cv2.imwrite(image_output_path, mask) 81 | 82 | 83 | @click.group() 84 | def main(): 85 | pass 86 | 87 | 88 | @main.command() 89 | @click.option("--model-path", default="parsing_model.pth") 90 | @click.option("--input-path", type=str, required=True) 91 | @click.option("--output-path", type=str, default='') 92 | @click.option("--batch-size", type=int, default=128) 93 | @click.option("--num-workers", type=int, default=16) 94 | @click.option("--device", type=str, default='cuda:0') 95 | def parse_folder(model_path, input_path, output_path, batch_size, num_workers, device): 96 | model = torch.load(model_path, map_location=device) 97 | model = model.to(device) 98 | model.eval() 99 | 100 | output_path = output_path 101 | 102 | dataset = ParsingTestDataset(input_path) 103 | dataloader = DataLoader( 104 | dataset, batch_size=batch_size, num_workers=num_workers) 105 | 106 | pool = Pool(16) 107 | for batch in tqdm(dataloader, total=math.ceil(len(dataset) / batch_size)): 108 | images, image_paths, image_shapes = batch 109 | images = images.to(device) 110 | with torch.no_grad(): 111 | # [B, C, H, W] 112 | pr_map = model.predict(images) 113 | 114 | masks: torch.Tensor = pr_map.round().argmax(dim=1).detach().cpu().numpy() 115 | args = [(mask, image_path, image_shape, input_path, output_path) 116 | for mask, image_path, image_shape in zip(masks, image_paths, image_shapes)] 117 | iters = pool.imap(write_worker, args) 118 | 119 | for iter in iters: 120 | pass 121 | pool.close() 122 | pool.join() 123 | 124 | 125 | if __name__ == '__main__': 126 | main() 127 | -------------------------------------------------------------------------------- /examples/parsing/train_parsing.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from dataset import VeRi3kParsingDataset, get_preprocessing, get_training_albumentations, get_validation_augmentation 3 | from torch.utils.data import DataLoader 4 | from torch import nn 5 | from pathlib import Path 6 | import matplotlib.pyplot as plt 7 | import cv2 8 | import numpy as np 9 | import segmentation_models_pytorch as smp 10 | import argparse 11 | import torch 12 | import os 13 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 14 | 15 | 16 | parser = ArgumentParser() 17 | parser.add_argument("--train-set", default="trainval") 18 | parser.add_argument("--masks-path", default="veri776_parsing3165") 19 | parser.add_argument("--image-path", default="/data/datasets/VeRi/VeRi/image_train") 20 | args = parser.parse_args() 21 | 22 | ENCODER = 'se_resnext50_32x4d' 23 | ENCODER_WEIGHTS = 'imagenet' 24 | DEVICE = 'cuda' 25 | 26 | CLASSES = VeRi3kParsingDataset.CLASSES 27 | ACTIVATION = 'sigmoid' 28 | 29 | model = smp.Unet(encoder_name=ENCODER, 30 | encoder_weights=ENCODER_WEIGHTS, 31 | classes=len(CLASSES), 32 | activation=ACTIVATION) 33 | 34 | preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS) 35 | 36 | # train_dataset = PascalParsingDataset(augmentation=get_training_albumentations(), 37 | # preprocessing=get_preprocessing(preprocessing_fn), 38 | # subset='training') 39 | # valid_dataset = PascalParsingDataset(augmentation=get_validation_augmentation(), 40 | # preprocessing=get_preprocessing(preprocessing_fn), 41 | # subset='validation') 42 | 43 | train_dataset = VeRi3kParsingDataset(image_path=args.image_path, 44 | masks_path=args.masks_path, 45 | augmentation=get_training_albumentations(), 46 | preprocessing=get_preprocessing( 47 | preprocessing_fn), 48 | subset=args.train_set) 49 | 50 | valid_dataset = VeRi3kParsingDataset(image_path=args.image_path, 51 | masks_path=args.masks_path, 52 | augmentation=get_validation_augmentation(), 53 | preprocessing=get_preprocessing( 54 | preprocessing_fn), 55 | subset='validation') 56 | 57 | train_loader = DataLoader(train_dataset, batch_size=8, 58 | shuffle=True, num_workers=12) 59 | valid_loader = DataLoader(valid_dataset, batch_size=1, 60 | shuffle=False, num_workers=4) 61 | # Dice/F1 score - https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient 62 | # IoU/Jaccard score - https://en.wikipedia.org/wiki/Jaccard_index 63 | 64 | # loss = smp.utils.losses.BCEDiceLoss(eps=1.) 65 | class BCEDiceLoss(smp.utils.losses.DiceLoss): 66 | __name__ = 'bce_dice_loss' 67 | 68 | def __init__(self, eps=1e-7, activation='sigmoid'): 69 | super().__init__(eps=eps, activation=activation) 70 | self.bce = nn.BCEWithLogitsLoss(reduction='mean') 71 | 72 | def forward(self, y_pr, y_gt): 73 | dice = super().forward(y_pr, y_gt) 74 | bce = self.bce(y_pr, y_gt) 75 | return dice + bce 76 | 77 | loss = BCEDiceLoss(eps=1.) 78 | metrics = [ 79 | # smp.utils.metrics.IoUMetric(eps=1.), 80 | smp.utils.metrics.IoU(eps=1.), 81 | ] 82 | 83 | optimizer = torch.optim.Adam([ 84 | {'params': model.decoder.parameters(), 'lr': 1e-4}, 85 | # decrease lr for encoder in order not to permute 86 | # pre-trained weights with large gradients on training start 87 | {'params': model.encoder.parameters(), 'lr': 1e-6}, 88 | ]) 89 | 90 | train_epoch = smp.utils.train.TrainEpoch( 91 | model, 92 | loss=loss, 93 | metrics=metrics, 94 | optimizer=optimizer, 95 | device=DEVICE, 96 | verbose=True, 97 | ) 98 | 99 | valid_epoch = smp.utils.train.ValidEpoch( 100 | model, 101 | loss=loss, 102 | metrics=metrics, 103 | device=DEVICE, 104 | verbose=True, 105 | ) 106 | max_score = 0 107 | 108 | for i in range(0, 40): 109 | 110 | print('\nEpoch: {}'.format(i)) 111 | train_logs = train_epoch.run(train_loader) 112 | valid_logs = valid_epoch.run(valid_loader) 113 | 114 | # do something (save model, change lr, etc.) 115 | if max_score < valid_logs['iou_score']: 116 | max_score = valid_logs['iou_score'] 117 | torch.save(model, './best_model_{}.pth'.format(args.train_set)) 118 | print('Model saved!') 119 | 120 | if i == 25: 121 | optimizer.param_groups[0]['lr'] = 1e-5 122 | print('Decrease decoder learning rate to 1e-5!') 123 | -------------------------------------------------------------------------------- /examples/preprocess_data/generate_pkl.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | import pickle as pkl 4 | import click 5 | from pathlib import Path 6 | import re 7 | 8 | @click.group() 9 | def main(): 10 | pass 11 | 12 | 13 | @main.command() 14 | @click.option("--input-path", required=True) 15 | @click.option("--output-path", default="veri776.pkl") 16 | def veri776(input_path, output_path): 17 | input_path = os.path.abspath(input_path) 18 | output_dir = os.path.split(output_path)[0] 19 | if output_dir != '' and not os.path.exists(output_dir): 20 | os.makedirs(output_dir, exist_ok=True) 21 | 22 | input_path = Path(input_path).absolute() 23 | output_dict = {} 24 | 25 | pattern = re.compile(r"(\d+)_c(\d+)_.+\.jpg") 26 | for phase in ["train", "query", "gallery"]: 27 | output_dict[phase] = [] 28 | sub_path = input_path / f"image_{phase}" 29 | if phase == "gallery": 30 | sub_path = input_path / f"image_test" 31 | for image_path in sub_path.iterdir(): 32 | sample = {} 33 | image_name = image_path.name 34 | v_id, camera = pattern.match(image_name).groups() 35 | sample["filename"] = image_name 36 | sample["image_path"] = str(image_path) 37 | sample["id"] = v_id 38 | sample["cam"] = camera 39 | output_dict[phase].append(sample) 40 | with open(output_path, "wb") as f: 41 | pkl.dump(output_dict, f) 42 | 43 | 44 | @main.command() 45 | @click.option('--input-path', default='/data1/dechao_meng/mengdechao/datasets/VehicleID_V1.0') 46 | @click.option('--output-path', default='../outputs/vehicleid.pkl') 47 | def vehicleid(input_path, output_path): 48 | input_path = os.path.abspath(input_path) 49 | PATH = input_path 50 | 51 | images = {} 52 | 53 | images['train'] = open(PATH + '/train_test_split/train_list.txt').read().strip().split('\n') 54 | images['gallery_800'] = open(PATH + '/train_test_split/test_list_800.txt').read().strip().split('\n') 55 | images['gallery_1600'] = open(PATH + '/train_test_split/test_list_1600.txt').read().strip().split('\n') 56 | images['gallery_2400'] = open(PATH + '/train_test_split/test_list_2400.txt').read().strip().split('\n') 57 | images['query_800'] = [] 58 | images['query_1600'] = [] 59 | images['query_2400'] = [] 60 | 61 | outputs = {} 62 | for key, lists in images.items(): 63 | output = [] 64 | for img_name in lists: 65 | item = { 66 | "image_path": f"{PATH}/image/{img_name.split(' ')[0]}.jpg", 67 | "name": img_name, 68 | "id": img_name.split(' ')[1], 69 | "cam": 0 70 | } 71 | output.append(item) 72 | outputs[key] = output 73 | 74 | base_path = os.path.split(output_path)[0] 75 | if base_path != '' and not os.path.exists(base_path): 76 | os.makedirs(base_path, exist_ok=True) 77 | 78 | with open(output_path, 'wb') as f: 79 | pkl.dump(outputs, f) 80 | 81 | @main.command() 82 | @click.option('--input-path', default='/home/aa/mengdechao/datasets/veriwild') 83 | @click.option('--output-path', default='../outputs/veriwild.pkl') 84 | def veriwild(input_path, output_path): 85 | input_path = os.path.abspath(input_path) 86 | PATH = input_path 87 | 88 | images = {} 89 | 90 | images['train'] = open(PATH + '/train_test_split/train_list.txt').read().strip().split('\n') 91 | images['query_3000'] = open(PATH + '/train_test_split/test_3000_query.txt').read().strip().split('\n') 92 | images['gallery_3000'] = open(PATH + '/train_test_split/test_3000.txt').read().strip().split('\n') 93 | images['query_5000'] = open(PATH + '/train_test_split/test_5000_query.txt').read().strip().split('\n') 94 | images['gallery_5000'] = open(PATH + '/train_test_split/test_5000.txt').read().strip().split('\n') 95 | images['query_10000'] = open(PATH + '/train_test_split/test_10000_query.txt').read().strip().split('\n') 96 | images['gallery_10000']= open(PATH + '/train_test_split/test_10000.txt').read().strip().split('\n') 97 | 98 | wild_df = pd.read_csv(f'{PATH}/train_test_split/vehicle_info.txt', sep=';', index_col='id/image') 99 | 100 | # Pandas indexing is very slow, change it to dict 101 | wild_dict = wild_df.to_dict() 102 | camid_dict = wild_dict['Camera ID'] 103 | 104 | outputs = {} 105 | for key, lists in images.items(): 106 | output = [] 107 | for img_name in lists: 108 | item = { 109 | "image_path": f"{PATH}/images/{img_name}.jpg", 110 | "name": img_name, 111 | "id": img_name.split('/')[0], 112 | # "cam": wild_df.loc[img_name]['Camera ID'] 113 | "cam": camid_dict[img_name] 114 | } 115 | output.append(item) 116 | outputs[key] = output 117 | 118 | base_path = os.path.split(output_path)[0] 119 | if base_path != '' and not os.path.exists(base_path): 120 | os.makedirs(base_path, exist_ok=True) 121 | with open(output_path, 'wb') as f: 122 | pkl.dump(outputs, f) 123 | 124 | 125 | if __name__ == "__main__": 126 | main() 127 | -------------------------------------------------------------------------------- /vehicle_reid_pytorch/models/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import math 8 | 9 | import torch 10 | from torch import nn 11 | import torchvision 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | """3x3 convolution with padding""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False) 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | expansion = 1 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None): 23 | super(BasicBlock, self).__init__() 24 | self.conv1 = conv3x3(inplanes, planes, stride) 25 | self.bn1 = nn.BatchNorm2d(planes) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.conv2 = conv3x3(planes, planes) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | self.downsample = downsample 30 | self.stride = stride 31 | 32 | def forward(self, x): 33 | residual = x 34 | 35 | out = self.conv1(x) 36 | out = self.bn1(out) 37 | out = self.relu(out) 38 | 39 | out = self.conv2(out) 40 | out = self.bn2(out) 41 | 42 | if self.downsample is not None: 43 | residual = self.downsample(x) 44 | 45 | out += residual 46 | out = self.relu(out) 47 | 48 | return out 49 | 50 | 51 | class Bottleneck(nn.Module): 52 | expansion = 4 53 | 54 | def __init__(self, inplanes, planes, stride=1, downsample=None): 55 | super(Bottleneck, self).__init__() 56 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 57 | self.bn1 = nn.BatchNorm2d(planes) 58 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 59 | padding=1, bias=False) 60 | self.bn2 = nn.BatchNorm2d(planes) 61 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 62 | self.bn3 = nn.BatchNorm2d(planes * 4) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.downsample = downsample 65 | self.stride = stride 66 | 67 | def forward(self, x): 68 | residual = x 69 | 70 | out = self.conv1(x) 71 | out = self.bn1(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv2(out) 75 | out = self.bn2(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv3(out) 79 | out = self.bn3(out) 80 | 81 | if self.downsample is not None: 82 | residual = self.downsample(x) 83 | 84 | out += residual 85 | out = self.relu(out) 86 | 87 | return out 88 | 89 | 90 | class ResNet(nn.Module): 91 | def __init__(self, last_stride=2, block=Bottleneck, layers=(3, 4, 6, 3)): 92 | super().__init__() 93 | self.inplanes = 64 94 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 95 | bias=False) 96 | self.bn1 = nn.BatchNorm2d(64) 97 | # self.relu = nn.ReLU(inplace=True) # add missed relu 98 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 99 | self.layer1 = self._make_layer(block, 64, layers[0]) 100 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 101 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 102 | self.layer4 = self._make_layer( 103 | block, 512, layers[3], stride=last_stride) 104 | 105 | def _make_layer(self, block, planes, blocks, stride=1): 106 | downsample = None 107 | if stride != 1 or self.inplanes != planes * block.expansion: 108 | downsample = nn.Sequential( 109 | nn.Conv2d(self.inplanes, planes * block.expansion, 110 | kernel_size=1, stride=stride, bias=False), 111 | nn.BatchNorm2d(planes * block.expansion), 112 | ) 113 | 114 | layers = [] 115 | layers.append(block(self.inplanes, planes, stride, downsample)) 116 | self.inplanes = planes * block.expansion 117 | for i in range(1, blocks): 118 | layers.append(block(self.inplanes, planes)) 119 | 120 | return nn.Sequential(*layers) 121 | 122 | def forward(self, x): 123 | x = self.conv1(x) 124 | x = self.bn1(x) 125 | # x = self.relu(x) # add missed relu 126 | x = self.maxpool(x) 127 | 128 | x = self.layer1(x) 129 | x = self.layer2(x) 130 | x = self.layer3(x) 131 | x = self.layer4(x) 132 | 133 | return x 134 | 135 | def load_param(self, model_path): 136 | if model_path == "": 137 | param_dict = torchvision.models.resnet50(pretrained=True).state_dict() 138 | else: 139 | param_dict = torch.load(model_path) 140 | for i in param_dict: 141 | if 'fc' in i: 142 | continue 143 | 144 | self.state_dict()[i].copy_(param_dict[i]) 145 | 146 | def random_init(self): 147 | for m in self.modules(): 148 | if isinstance(m, nn.Conv2d): 149 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 150 | m.weight.data.normal_(0, math.sqrt(2. / n)) 151 | elif isinstance(m, nn.BatchNorm2d): 152 | m.weight.data.fill_(1) 153 | m.bias.data.zero_() 154 | 155 | -------------------------------------------------------------------------------- /vehicle_reid_pytorch/metrics/rerank.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri, 25 May 2018 20:29:09 5 | 6 | @author: luohao 7 | """ 8 | 9 | """ 10 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 11 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 12 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 13 | """ 14 | 15 | """ 16 | API 17 | 18 | probFea: all feature vectors of the query set (torch tensor) 19 | probFea: all feature vectors of the gallery set (torch tensor) 20 | k1,k2,lambda: parameters, the original paper is (k1=20,k2=6,lambda=0.3) 21 | MemorySave: set to 'True' when using MemorySave mode 22 | Minibatch: avaliable when 'MemorySave' is 'True' 23 | """ 24 | 25 | import numpy as np 26 | import torch 27 | 28 | 29 | def re_ranking(probFea, galFea, k1, k2, lambda_value, local_distmat=None, only_local=False, split_gpu=False): 30 | # if feature vector is numpy, you should use 'torch.tensor' transform it to tensor 31 | query_num = probFea.size(0) 32 | all_num = query_num + galFea.size(0) 33 | if only_local: 34 | original_dist = local_distmat 35 | else: 36 | feat = torch.cat([probFea, galFea]) 37 | 38 | print('using GPU to compute original distance') 39 | 40 | if split_gpu: 41 | distmat1 = torch.pow(feat, 2).sum(dim=1, keepdim=True)[:30000] + \ 42 | torch.pow(feat, 2).sum(dim=1, keepdim=True).t() 43 | 44 | distmat2 = torch.pow(feat, 2).sum(dim=1, keepdim=True)[30000:].to("cuda:1") + \ 45 | torch.pow(feat, 2).sum(dim=1, keepdim=True).t().to("cuda:1") 46 | 47 | distmat1.addmm_(1, -2, feat[:20000], feat.t()) 48 | distmat2.addmm_(1, -2, feat[20000:], feat.t()) 49 | original_dist1 = distmat1.cpu().numpy() 50 | original_dist2 = distmat2.cpu().numpy() 51 | 52 | original_dist = np.concatenate([original_dist1, original_dist2]) 53 | else: 54 | distmat = torch.pow(feat, 2).sum(dim=1, keepdim=True) + \ 55 | torch.pow(feat, 2).sum(dim=1, keepdim=True).t() 56 | 57 | distmat.addmm_(1, -2, feat[:20000], feat.t()) 58 | original_dist = distmat.cpu().numpy() 59 | 60 | del feat 61 | if not local_distmat is None: 62 | original_dist = original_dist + local_distmat 63 | gallery_num = original_dist.shape[0] 64 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0)) 65 | V = np.zeros_like(original_dist).astype(np.float16) 66 | initial_rank = np.argsort(original_dist).astype(np.int32) 67 | 68 | print('starting re_ranking') 69 | for i in range(all_num): 70 | # k-reciprocal neighbors 71 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 72 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 73 | fi = np.where(backward_k_neigh_index == i)[0] 74 | k_reciprocal_index = forward_k_neigh_index[fi] 75 | k_reciprocal_expansion_index = k_reciprocal_index 76 | for j in range(len(k_reciprocal_index)): 77 | candidate = k_reciprocal_index[j] 78 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1] 79 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, 80 | :int(np.around(k1 / 2)) + 1] 81 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 82 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 83 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len( 84 | candidate_k_reciprocal_index): 85 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index) 86 | 87 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 88 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) 89 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight) 90 | original_dist = original_dist[:query_num, ] 91 | if k2 != 1: 92 | V_qe = np.zeros_like(V, dtype=np.float16) 93 | for i in range(all_num): 94 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 95 | V = V_qe 96 | del V_qe 97 | del initial_rank 98 | invIndex = [] 99 | for i in range(gallery_num): 100 | invIndex.append(np.where(V[:, i] != 0)[0]) 101 | 102 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16) 103 | 104 | for i in range(query_num): 105 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16) 106 | indNonZero = np.where(V[i, :] != 0)[0] 107 | indImages = [invIndex[ind] for ind in indNonZero] 108 | for j in range(len(indNonZero)): 109 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]], 110 | V[indImages[j], indNonZero[j]]) 111 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 112 | 113 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value 114 | del original_dist 115 | del V 116 | del jaccard_dist 117 | final_dist = final_dist[:query_num, query_num:] 118 | return final_dist 119 | -------------------------------------------------------------------------------- /vehicle_reid_pytorch/utils/pytorch_tools.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from bisect import bisect_right 7 | import torch 8 | from logzero import logger 9 | import cv2 10 | import numpy as np 11 | import time 12 | import asranger as ranger 13 | 14 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 15 | # separating MultiStepLR with WarmupLR 16 | # but the current LRScheduler design doesn't allow it 17 | 18 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 19 | def __init__( 20 | self, 21 | optimizer, 22 | milestones, 23 | gamma=0.1, 24 | warmup_factor=1.0 / 3, 25 | warmup_iters=500, 26 | warmup_method="linear", 27 | last_epoch=-1, 28 | ): 29 | if not list(milestones) == sorted(milestones): 30 | raise ValueError( 31 | "Milestones should be a list of" " increasing integers. Got {}", 32 | milestones, 33 | ) 34 | 35 | if warmup_method not in ("constant", "linear"): 36 | raise ValueError( 37 | "Only 'constant' or 'linear' warmup_method accepted" 38 | "got {}".format(warmup_method) 39 | ) 40 | self.milestones = milestones 41 | self.gamma = gamma 42 | self.warmup_factor = warmup_factor 43 | self.warmup_iters = warmup_iters 44 | self.warmup_method = warmup_method 45 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 46 | 47 | def get_lr(self): 48 | warmup_factor = 1 49 | if self.last_epoch < self.warmup_iters: 50 | if self.warmup_method == "constant": 51 | warmup_factor = self.warmup_factor 52 | elif self.warmup_method == "linear": 53 | alpha = self.last_epoch / self.warmup_iters 54 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 55 | # warmup_factor = self.warmup_factor * self.last_epoch 56 | return [ 57 | base_lr 58 | * warmup_factor 59 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 60 | for base_lr in self.base_lrs 61 | ] 62 | 63 | 64 | def make_optimizer(optim_name, model, base_lr, weight_decay, bias_lr_factor, momentum): 65 | """ 66 | 调低所有bias项的学习率。 67 | 68 | :param optim_name: 69 | :param model: 70 | :param base_lr: 71 | :param weight_decay: 72 | :param bias_lr_factor: 73 | :param momentum: 74 | :return: 75 | """ 76 | params = [] 77 | for key, value in model.named_parameters(): 78 | if not value.requires_grad: 79 | continue 80 | lr = base_lr 81 | if "bias" in key: 82 | lr = base_lr * bias_lr_factor 83 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 84 | if optim_name == 'SGD': 85 | optimizer = getattr(torch.optim, optim_name)(params, momentum=momentum) 86 | elif 'Ranger' in optim_name: 87 | optimizer = getattr(ranger, optim_name)(params) 88 | else: 89 | optimizer = getattr(torch.optim, optim_name)(params) 90 | return optimizer 91 | 92 | 93 | def make_warmup_scheduler(optimizer, milestones, gamma=0.1, warmup_factor=1.0 / 3, warmup_iters=500, 94 | warmup_method="linear", 95 | last_epoch=-1): 96 | if last_epoch == 0: 97 | last_epoch = -1 # init时会自动变成0.否则会初始化错误 98 | scheduler = WarmupMultiStepLR(optimizer, milestones, gamma, warmup_factor, warmup_iters, warmup_method, 99 | last_epoch=last_epoch) 100 | return scheduler 101 | 102 | 103 | def featuremap_perspective_transform(featuremap: torch.Tensor, bpts: torch.Tensor, btarget_pts: torch.Tensor, 104 | output_size): 105 | """对一个batch的featuremap做投影变换 106 | 107 | Arguments: 108 | featuremap {torch.Tensor} -- [B, C, H, W] 109 | pts {torch.Tensor} -- [B, 4, 2] xy格式 110 | target_pts {torch.Tensor} -- [B, 4, 2] xy格式 111 | output_shape {torch.Tensor} -- [2] w, h 112 | """ 113 | device = featuremap.device 114 | B, C, H, W = featuremap.shape 115 | w, h = output_size 116 | 117 | # 求解投影矩阵 118 | bpts_np = bpts.cpu().float().numpy() 119 | btarget_pts_np = btarget_pts.cpu().float().numpy() 120 | 121 | trans_mats = [] 122 | 123 | for pts_np, target_pts_np in zip(bpts_np, btarget_pts_np): 124 | trans_mat = cv2.getPerspectiveTransform(pts_np, target_pts_np) 125 | if np.linalg.matrix_rank(trans_mat) < 3: 126 | trans_mat = np.identity(3, dtype=np.float) 127 | trans_mats.append(torch.from_numpy(trans_mat)) 128 | inv_trans_mats = torch.stack(trans_mats).float().inverse().to(device) 129 | 130 | # 坐标反变换 131 | x, y = torch.meshgrid(torch.arange(h), torch.arange(w)) 132 | z = torch.ones_like(x) 133 | cors = torch.stack([x, y, z]).view(1, 3, -1).to(device).float() 134 | cors = cors.repeat(B, 1, 1) 135 | 136 | reversed_cors = torch.bmm(inv_trans_mats, cors) 137 | reversed_cors = reversed_cors[:, :2, :] / \ 138 | reversed_cors[:, 2, :].view(B, 1, -1) # [B, 2, wh] 139 | reversed_cors = reversed_cors.view(-1, 2, h, w).permute(0, 2, 3, 1) 140 | norm_cors = ((reversed_cors / reversed_cors.new_tensor([W, H])) - 0.5) * 2 141 | 142 | # 插值结果 143 | output = torch.nn.functional.grid_sample(featuremap, norm_cors, padding_mode='border') 144 | assert not torch.any(torch.isnan(output)), "Found NaN" 145 | tmp = output + 1 146 | return output 147 | 148 | 149 | -------------------------------------------------------------------------------- /examples/parsing_reid/model.py: -------------------------------------------------------------------------------- 1 | from vehicle_reid_pytorch.loss.triplet_loss import normalize, euclidean_dist, hard_example_mining 2 | from vehicle_reid_pytorch.models import Baseline 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from functools import reduce 7 | from math_tools import clck_dist 8 | from pprint import pprint 9 | 10 | 11 | class ParsingReidModel(Baseline): 12 | 13 | def __init__(self, num_classes, last_stride, model_path, neck, neck_feat, model_name, pretrain_choice, num_local_branches=4): 14 | super(ParsingReidModel, self).__init__(num_classes, last_stride, model_path, neck, neck_feat, model_name, 15 | pretrain_choice) 16 | 17 | # self.local_bn_neck = nn.BatchNorm1d(2048*num_local_branches) 18 | self.local_bn_neck = nn.BatchNorm1d(2048*num_local_branches) 19 | self.local_classifier = nn.Conv1d(2048, num_classes, 1) 20 | 21 | def forward(self, image, mask=None, **kwargs): 22 | """ 23 | 24 | :param torch.Tensor x: [B, 3, H, W] 25 | :param torch.Tensor mask: [B, N, H, W] front/back, side, window 26 | :return: 27 | """ 28 | # Remove bg 29 | 30 | if mask is not None: 31 | mask = mask[:, 1:, :, :] 32 | B, N, H, W = mask.shape 33 | else: 34 | B, _, H, W = image.shape 35 | N = 4 36 | mask = image.new_zeros(B, 4, H, W) 37 | 38 | x = self.base(image) 39 | 40 | B, C, h, w = x.shape 41 | mask = F.interpolate(mask, x.shape[2:]) 42 | # mask = F.softmax(mask, dim=1) 43 | # mask = F.adaptive_max_pool2d(mask, output_size=x.shape[2:]).view(B, N, h, w) 44 | 45 | global_feat = self.gap(x) # (b, 2048, 1, 1) 46 | 47 | global_feat = global_feat.view( 48 | global_feat.shape[0], -1) # flatten to (bs, 2048) 49 | 50 | vis_score = mask.sum(dim=[2, 3]) + 1 # Laplace平滑 51 | local_feat_map = torch.mul(mask.unsqueeze( 52 | dim=2), x.unsqueeze(dim=1)) # (B, N, C, h, w) 53 | local_feat_map = local_feat_map.view(B, -1, h, w) 54 | local_feat_before = F.adaptive_avg_pool2d(local_feat_map, output_size=(1, 1)).view(B, N, C).permute( 55 | [0, 2, 1]) * (h * w / vis_score.unsqueeze(dim=1)) # (B, C, N) 56 | 57 | if self.neck == 'no': 58 | feat = global_feat 59 | elif self.neck == 'bnneck': 60 | # normalize for angular softmax 61 | feat = self.bottleneck(global_feat) 62 | local_feat = self.local_bn_neck( 63 | local_feat_before.contiguous().view(B, -1)).view(B, -1, N) # 这一步会使其不为0 64 | 65 | if self.training: 66 | cls_score = self.classifier(feat) 67 | local_cls_score = self.local_classifier(local_feat) 68 | # global feature for triplet loss 69 | return {"cls_score": cls_score, 70 | "global_feat": global_feat, 71 | "local_cls_score": local_cls_score, 72 | "local_feat": local_feat, 73 | "vis_score": vis_score} 74 | 75 | else: 76 | if self.neck_feat == 'after': 77 | # print("Test with feature after BN") 78 | return {"global_feat": feat, 79 | "local_feat": local_feat, 80 | "vis_score": vis_score} 81 | else: 82 | # print("Test with feature before BN") 83 | return {"global_feat": global_feat, 84 | "local_feat": local_feat_before, 85 | "vis_score": vis_score} 86 | 87 | 88 | class ParsingTripletLoss: 89 | def __init__(self, margin=None): 90 | self.margin = margin 91 | if margin is not None: 92 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 93 | else: 94 | self.ranking_loss = nn.SoftMarginLoss() 95 | 96 | def __call__(self, local_feat, vis_score, target, normalize_feature=False): 97 | """ 98 | 99 | :param torch.Tensor local_feature: (B, C, N) 100 | :param torch.Tensor visibility_score: (B, N) 101 | :param torch.Tensor target: (B) 102 | :return: 103 | """ 104 | B, C, _ = local_feat.shape 105 | if normalize_feature: 106 | local_feat = normalize(local_feat, 1) 107 | 108 | dist_mat = clck_dist(local_feat, local_feat, 109 | vis_score, vis_score) 110 | 111 | dist_ap, dist_an = hard_example_mining(dist_mat, target) 112 | y = dist_an.new().resize_as_(dist_an).fill_(1) 113 | 114 | if self.margin is not None: 115 | loss = self.ranking_loss(dist_an, dist_ap, y) 116 | else: 117 | loss = self.ranking_loss(dist_an - dist_ap, y) 118 | 119 | return loss, dist_ap, dist_an 120 | 121 | 122 | def build_model(cfg, num_classes): 123 | # if cfg.MODEL.NAME == 'resnet50': 124 | # model = Baseline(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NECK, cfg.TEST.NECK_FEAT) 125 | model = ParsingReidModel(num_classes, cfg.model.last_stride, cfg.model.pretrain_model, cfg.model.neck, 126 | cfg.test.neck_feat, cfg.model.name, cfg.model.pretrain_choice) 127 | return model 128 | 129 | 130 | if __name__ == '__main__': 131 | from tensorboardX import SummaryWriter 132 | 133 | dummy_input = torch.rand(4, 3, 224, 224) 134 | model = Baseline(576, 1, '/home/mengdechao/.cache/torch/checkpoints/resnet50-19c8e357.pth', 'bnneck', 'after', 135 | 'resnet50', 'imagenet') 136 | model.train() 137 | with SummaryWriter(comment="baseline") as w: 138 | w.add_graph(model, [dummy_input, ]) 139 | -------------------------------------------------------------------------------- /examples/parsing/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, Dataset 2 | import os 3 | import pickle 4 | import numpy as np 5 | import json 6 | from pathlib import Path 7 | 8 | from vehicle_reid_pytorch.utils.iotools import read_rgb_image 9 | from vehicle_reid_pytorch.utils.visualize import visualize_img as visualize_img 10 | from vehicle_reid_pytorch.utils.math import pad_image_size_to_multiples_of 11 | from vehicle_reid_pytorch.data.transforms import AlbuRandomErasing 12 | import matplotlib.pyplot as plt 13 | import albumentations as albu 14 | import cv2 15 | from functools import partial 16 | 17 | 18 | def get_training_albumentations(): 19 | train_transform = [ 20 | albu.LongestMaxSize(244), 21 | albu.HorizontalFlip(), 22 | albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0), 23 | albu.PadIfNeeded(min_height=224, min_width=224, always_apply=True, border_mode=0), 24 | albu.RandomCrop(height=224, width=224, always_apply=True), 25 | albu.IAAAdditiveGaussianNoise(p=0.2), 26 | albu.IAAPerspective(p=0.5), 27 | albu.OneOf( 28 | [ 29 | albu.CLAHE(p=1), 30 | albu.RandomBrightness(p=1), 31 | albu.RandomGamma(p=1), 32 | ], 33 | p=0.9, 34 | ), 35 | 36 | albu.OneOf( 37 | [ 38 | albu.IAASharpen(p=1), 39 | albu.Blur(blur_limit=3, p=1), 40 | albu.MotionBlur(blur_limit=3, p=1), 41 | ], 42 | p=0.9, 43 | ), 44 | 45 | albu.OneOf( 46 | [ 47 | albu.RandomContrast(p=1), 48 | albu.HueSaturationValue(p=1), 49 | ], 50 | p=0.9, 51 | ), 52 | # AlbuRandomErasing(0.5) 53 | ] 54 | return albu.Compose(train_transform) 55 | 56 | 57 | def get_validation_augmentation(): 58 | test_transform = [ 59 | albu.Lambda(image=pad_image_to_multiplys_of(32), mask=pad_image_to_multiplys_of(32)) 60 | # albu.LongestMaxSize(224), 61 | # albu.Lambda(image=pad_image_to_multiplys_of(32), mask=pad_image_to_multiplys_of(32)) 62 | 63 | # albu.RandomCrop(height=320, width=320, always_apply=True) 64 | ] 65 | return albu.Compose(test_transform) 66 | 67 | 68 | def to_tensor(x, **kwargs): 69 | return x.transpose(2, 0, 1).astype('float32') 70 | 71 | 72 | def pad_image_to_multiplys_of(multiply=32, **kwargs): 73 | def _pad_image_to_multiplys_of(x, **kwargs): 74 | return pad_image_size_to_multiples_of(x, multiply, align='top-left') 75 | 76 | return _pad_image_to_multiplys_of 77 | 78 | 79 | def get_preprocessing(preprocessing_fn): 80 | _transform = [ 81 | albu.Lambda(image=preprocessing_fn), 82 | albu.Lambda(image=to_tensor, mask=to_tensor) 83 | ] 84 | return albu.Compose(_transform) 85 | 86 | 87 | class VeRi3kParsingDataset(Dataset): 88 | CLASSES = ["background", "front", "back", "roof", "side"] 89 | 90 | def __init__(self, image_path, masks_path, augmentation=None, preprocessing=None, 91 | subset='trainval'): 92 | self.metas = [os.path.splitext(fname)[0] for fname in os.listdir(masks_path)] 93 | self.masks_path = Path(masks_path) 94 | self.image_path = Path(image_path) 95 | if subset == 'trainval': 96 | # self.metas = self.metas[:-500] 97 | self.metas = self.metas 98 | elif subset == 'train': 99 | self.metas = self.metas[:-500] 100 | else: 101 | self.metas = self.metas[-500:] 102 | 103 | self.class_values = [self.CLASSES.index(cls) for cls in self.CLASSES] 104 | self.augmentation = augmentation 105 | self.preprocessing = preprocessing 106 | 107 | def __getitem__(self, item): 108 | image_name = self.metas[item] 109 | img = read_rgb_image(f"{self.image_path/image_name}.jpg", format="ndarray") 110 | mask = cv2.imread(f"{self.masks_path/image_name}.png", cv2.IMREAD_UNCHANGED) 111 | masks = [mask == v for v in self.class_values] 112 | mask = np.stack(masks, axis=-1).astype('float32') 113 | 114 | if self.augmentation: 115 | sample = self.augmentation(image=img, mask=mask) 116 | img = sample["image"] 117 | mask = sample["mask"] 118 | 119 | if self.preprocessing: 120 | sample = self.preprocessing(image=img, mask=mask) 121 | img = sample["image"] 122 | mask = sample["mask"] 123 | 124 | return img, mask 125 | 126 | def __len__(self): 127 | return len(self.metas) 128 | 129 | 130 | class VehicleReIDParsingDataset(Dataset): 131 | """ 132 | 将reid的数据集转化成parsing数据集,仅测试使用 133 | """ 134 | CLASSES = ["background", "back", "front", "side", "roof"] 135 | 136 | def __init__(self, dataset, augmentation=None, preprocessing=None, with_extra=False): 137 | self.augmetation = augmentation 138 | self.preprocessing = preprocessing 139 | self.dataset = dataset 140 | self.with_extra = with_extra 141 | 142 | def __getitem__(self, item): 143 | img_path = self.dataset[item]["image_path"] 144 | assert Path(img_path).exists(), f'{img_path} does not exist!' 145 | image = read_rgb_image(img_path) 146 | image = np.array(image) 147 | if self.augmetation: 148 | sample = self.augmetation(image=image) 149 | image = sample["image"] 150 | if self.preprocessing: 151 | sample = self.preprocessing(image=image) 152 | image = sample["image"] 153 | 154 | if self.with_extra: 155 | return image, self.dataset[item] 156 | 157 | else: 158 | return image 159 | 160 | def __len__(self): 161 | return len(self.dataset) 162 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PVEN 2 | This is the official implementation of article "Parsing-based viewaware embedding network for vehicle ReID"[[arxiv]](https://arxiv.org/abs/2004.05021), which has been accpeted by **CVPR20** as a poster article. 3 | 4 | ## Attention! 5 | Since no test criteria are given in the VERIWild paper, the performance of veriwild provided in the PVEN article does not have the junk image removed. 6 | The results with the junk image removed are provided to facilitate comparative experiments by subsequent researchers. 7 | 8 | | small mAP | small cmc@1 | small cmc@5 | medium mAP | medium cmc@1 | medium cmc@5 | large mAP | large cmc@1 | large cmc@5| 9 | |-- | --- | --- | --- | --- | --- | --- | --- | --- | 10 | |79.80 |94.01 | 98.06 | 73.91 | 92.03 | 97.15 | 66.20 | 88.62 | 95.31 | 11 | 12 | Although the performance is slightly lower after removing the junk images, it is still higher than all comparative methods when the paper is published. Therefore it does not affect the core conclusions. 13 | 14 | ## Requirements 15 | 1. python 3.6+ 16 | 2. torch 1.3.1+ 17 | 18 | ## Install 19 | ``` 20 | git clone https://github.com/silverbulletmdc/PVEN 21 | cd PVEN 22 | pip install -r requirements.txt 23 | python setup.py install 24 | ``` 25 | 26 | If you want to modify the code of this project, use the following commands instead 27 | ``` 28 | cd PVEN 29 | pip install -r requirements.txt 30 | python setup.py develop 31 | ``` 32 | 33 | ## Preparing dataset 34 | Before the pipeline, you should prepare your vehicle ReID dataset first. 35 | For each dataset, you need to generate a description pickle file for it, which is a pickled dict with following structure: 36 | ```json 37 | { 38 | "train":[ 39 | { 40 | "filename": "0001_c001_00016450_0.jpg", 41 | "image_path": "/data/datasets/VeRi/VeRi/image_train/0001_c001_00016450_0.jpg", 42 | "id": "0001", 43 | "cam": "001", 44 | }, 45 | ... 46 | 47 | ], 48 | "gallery":[ 49 | ... 50 | ], 51 | "query":[ 52 | ... 53 | ] 54 | } 55 | ``` 56 | 57 | For different dataset, we have already provided the generating scripts to help you generate the pickle file. 58 | ```shell 59 | cd examples/preprocess_data 60 | # For VeRi776 61 | python generate_pkl.py veri776 --input-path --output-path ../outputs/veri776.pkl 62 | # For VERIWild 63 | python generate_pkl.py veriwild --input-path --output-path ../outputs/veriwild.pkl 64 | # For VehicleID 65 | python generate_pkl.py vehicleid --input-path --output-path ../outputs/vehicleid.pkl 66 | ``` 67 | 68 | ## Training the parsing model 69 | 72 | 73 | ### Convert polygons to parsing masks 74 | As is described in the article, we annotated the parsing information of 3165 images from VeRi776. 75 | We just annotate the vertexs of the polygons as the vehicles are composed by several polygons. 76 | The details of polygons are in `examples/parsing/poly.json`. 77 | Run following command to convert the polygons to parsing masks 78 | ``` 79 | cd examples/parsing 80 | python veri776_poly2mask.py --json-path poly.json --output-path ../outputs/veri776_parsing3165 81 | ``` 82 | The parsing masks will be generated in `../outputs/veri776_parsing3165` folder. 83 | 84 | ### Train parsing model 85 | 86 | Run following command to train the parsing model 87 | ``` 88 | cd examples/parsing 89 | python train_parsing.py --train-set trainval --masks-path ../outputs/veri776_parsing3165 --image-path /image_train 90 | ``` 91 | where the `` is the path of your VeRi776 dataset. 92 | 93 | ## Generate parsing masks for ReID dataset 94 | Running the following command to generate masks for the whole ReID dataset and write the `mask_path` to the dataset pickle file. 95 | ``` 96 | cd examples/parsing 97 | python generate_masks.py --model-path best_model_trainval.pth --reid-pkl-path ../outputs/veri776.pkl --output-path ../outputs/veri776_masks 98 | ``` 99 | where the `` is the generated pickle file above. 100 | 101 | ## Train PVEN 102 | Run the following model to train PVEN. 103 | ```shell 104 | cd examples/parsing_reid 105 | # For VeRi776 106 | CUDA_VISIBLE_DEVICES=0 python main.py train -c configs/veri776_b64_pven.yml 107 | # For vehicleid, use 8 GPUs to train 108 | python main.py train -c configs/vehicleid_b256_pven.yml 109 | # For VERIWild, use 8 GPUs to train 110 | python main.py train -c configs/veriwild_b256_224_pven.yml 111 | ``` 112 | 113 | ## Pretrained Models 114 | We provide the pretrained parsing model, VeRi776 ReID model and VERIWild ReID model ( the classification layer has been removed ) for your convinient. 115 | You can download it from the following link: 116 | Link: https://pan.baidu.com/s/1Q2NMVfGZPCskh-E6vmy9Cw password: iiw1 117 | 118 | ## Evaluate PVEN 119 | ```shell 120 | cd examples/parsing_reid 121 | # For VeRi776 122 | python main.py eval -c configs/veri776_b64_parsing.yml 123 | 124 | # For VERIWild 125 | ## small 126 | python main.py eval -c configs/veriwild_b256_224_pven.yml 127 | ## medium 128 | python main.py eval -c configs/veriwild_b256_224_pven.yml test.ext _5000 129 | ## Large 130 | python main.py eval -c configs/veriwild_b256_224_pven.yml test.ext _10000 131 | ``` 132 | 133 | ## Citation 134 | If you found our method helpful in your research, please cite our work in your publication. 135 | ```bibtex 136 | @inproceedings{meng2020parsing, 137 | title={Parsing-based View-aware Embedding Network for Vehicle Re-Identification}, 138 | author={Meng, Dechao and Li, Liang and Liu, Xuejing and Li, Yadong and Yang, Shijie and Zha, Zheng-Jun and Gao, Xingyu and Wang, Shuhui and Huang, Qingming}, 139 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 140 | pages={7103--7112}, 141 | year={2020} 142 | } 143 | ``` 144 | -------------------------------------------------------------------------------- /vehicle_reid_pytorch/models/backbones/resnet_ibn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | 7 | __all__ = ['ResNet_IBN', 'resnet50_ibn_a', 'resnet101_ibn_a', 8 | 'resnet152_ibn_a'] 9 | 10 | 11 | model_urls = { 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | class IBN(nn.Module): 19 | def __init__(self, planes): 20 | super(IBN, self).__init__() 21 | half1 = int(planes/2) 22 | self.half = half1 23 | half2 = planes - half1 24 | self.IN = nn.InstanceNorm2d(half1, affine=True) 25 | self.BN = nn.BatchNorm2d(half2) 26 | 27 | def forward(self, x): 28 | split = torch.split(x, self.half, 1) 29 | out1 = self.IN(split[0].contiguous()) 30 | out2 = self.BN(split[1].contiguous()) 31 | out = torch.cat((out1, out2), 1) 32 | return out 33 | 34 | 35 | class Bottleneck_IBN(nn.Module): 36 | expansion = 4 37 | 38 | def __init__(self, inplanes, planes, ibn=False, stride=1, downsample=None): 39 | super(Bottleneck_IBN, self).__init__() 40 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 41 | if ibn: 42 | self.bn1 = IBN(planes) 43 | else: 44 | self.bn1 = nn.BatchNorm2d(planes) 45 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 46 | padding=1, bias=False) 47 | self.bn2 = nn.BatchNorm2d(planes) 48 | self.conv3 = nn.Conv2d( 49 | planes, planes * self.expansion, kernel_size=1, bias=False) 50 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 51 | self.relu = nn.ReLU(inplace=True) 52 | self.downsample = downsample 53 | self.stride = stride 54 | 55 | def forward(self, x): 56 | residual = x 57 | 58 | out = self.conv1(x) 59 | out = self.bn1(out) 60 | out = self.relu(out) 61 | 62 | out = self.conv2(out) 63 | out = self.bn2(out) 64 | out = self.relu(out) 65 | 66 | out = self.conv3(out) 67 | out = self.bn3(out) 68 | 69 | if self.downsample is not None: 70 | residual = self.downsample(x) 71 | 72 | out += residual 73 | out = self.relu(out) 74 | 75 | return out 76 | 77 | 78 | class ResNet_IBN(nn.Module): 79 | 80 | def __init__(self, last_stride, block, layers, num_classes=1000): 81 | scale = 64 82 | self.inplanes = scale 83 | super(ResNet_IBN, self).__init__() 84 | self.conv1 = nn.Conv2d(3, scale, kernel_size=7, stride=2, padding=3, 85 | bias=False) 86 | self.bn1 = nn.BatchNorm2d(scale) 87 | self.relu = nn.ReLU(inplace=True) 88 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 89 | self.layer1 = self._make_layer(block, scale, layers[0]) 90 | self.layer2 = self._make_layer(block, scale*2, layers[1], stride=2) 91 | self.layer3 = self._make_layer(block, scale*4, layers[2], stride=2) 92 | self.layer4 = self._make_layer( 93 | block, scale*8, layers[3], stride=last_stride) 94 | self.avgpool = nn.AvgPool2d(7) 95 | self.fc = nn.Linear(scale * 8 * block.expansion, num_classes) 96 | 97 | for m in self.modules(): 98 | if isinstance(m, nn.Conv2d): 99 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 100 | m.weight.data.normal_(0, math.sqrt(2. / n)) 101 | elif isinstance(m, nn.BatchNorm2d): 102 | m.weight.data.fill_(1) 103 | m.bias.data.zero_() 104 | elif isinstance(m, nn.InstanceNorm2d): 105 | m.weight.data.fill_(1) 106 | m.bias.data.zero_() 107 | 108 | def _make_layer(self, block, planes, blocks, stride=1): 109 | downsample = None 110 | if stride != 1 or self.inplanes != planes * block.expansion: 111 | downsample = nn.Sequential( 112 | nn.Conv2d(self.inplanes, planes * block.expansion, 113 | kernel_size=1, stride=stride, bias=False), 114 | nn.BatchNorm2d(planes * block.expansion), 115 | ) 116 | 117 | layers = [] 118 | ibn = True 119 | if planes == 512: 120 | ibn = False 121 | layers.append(block(self.inplanes, planes, ibn, stride, downsample)) 122 | self.inplanes = planes * block.expansion 123 | for i in range(1, blocks): 124 | layers.append(block(self.inplanes, planes, ibn)) 125 | 126 | return nn.Sequential(*layers) 127 | 128 | def forward(self, x): 129 | x = self.conv1(x) 130 | x = self.bn1(x) 131 | x = self.relu(x) 132 | x = self.maxpool(x) 133 | 134 | x = self.layer1(x) 135 | x = self.layer2(x) 136 | x = self.layer3(x) 137 | x = self.layer4(x) 138 | 139 | # x = self.avgpool(x) 140 | # x = x.view(x.size(0), -1) 141 | # x = self.fc(x) 142 | 143 | return x 144 | 145 | def load_param(self, model_path): 146 | param_dict = torch.load(model_path) 147 | for i in param_dict: 148 | if 'fc' in i: 149 | continue 150 | self.state_dict()[i].copy_(param_dict[i]) 151 | 152 | 153 | def resnet50_ibn_a(last_stride, pretrained=False, **kwargs): 154 | """Constructs a ResNet-50 model. 155 | Args: 156 | pretrained (bool): If True, returns a model pre-trained on ImageNet 157 | """ 158 | model = ResNet_IBN(last_stride, Bottleneck_IBN, [3, 4, 6, 3], **kwargs) 159 | if pretrained: 160 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 161 | return model 162 | 163 | 164 | def resnet101_ibn_a(last_stride, pretrained=False, **kwargs): 165 | """Constructs a ResNet-101 model. 166 | Args: 167 | pretrained (bool): If True, returns a model pre-trained on ImageNet 168 | """ 169 | model = ResNet_IBN(last_stride, Bottleneck_IBN, [3, 4, 23, 3], **kwargs) 170 | if pretrained: 171 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 172 | return model 173 | 174 | 175 | def resnet152_ibn_a(last_stride, pretrained=False, **kwargs): 176 | """Constructs a ResNet-152 model. 177 | Args: 178 | pretrained (bool): If True, returns a model pre-trained on ImageNet 179 | """ 180 | model = ResNet_IBN(last_stride, Bottleneck_IBN, [3, 8, 36, 3], **kwargs) 181 | if pretrained: 182 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 183 | return model 184 | -------------------------------------------------------------------------------- /vehicle_reid_pytorch/data/datasets/bases.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import pickle as pkl 4 | from torch.utils.data import Dataset 5 | from vehicle_reid_pytorch.utils.iotools import read_rgb_image 6 | import scipy.stats as st 7 | 8 | 9 | def get_imagedata_info(data): 10 | ids, cams = [], [] 11 | for item in data: 12 | ids.append(item["id"]) 13 | cams.append(item["cam"]) 14 | 15 | ids = [item["id"] for item in data] 16 | cams = [item["cam"] for item in data] 17 | pids = set(ids) 18 | cams = set(cams) 19 | num_pids = len(pids) 20 | num_cams = len(cams) 21 | num_imgs = len(data) 22 | return num_pids, num_imgs, num_cams 23 | 24 | 25 | def relabel(data): 26 | """ 27 | :param list data: 28 | :return: 29 | """ 30 | raw_ids = set() 31 | data = data.copy() 32 | for item in data: 33 | raw_ids.add(item['id']) 34 | raw_ids = sorted(list(raw_ids)) 35 | rawid2label = {raw_vid: i for i, raw_vid in enumerate(raw_ids)} 36 | label2rawid = {i: raw_vid for i, raw_vid in enumerate(raw_ids)} 37 | for item in data: 38 | item["id"] = rawid2label[item["id"]] 39 | item["cam"] = int(item["cam"]) 40 | return data, rawid2label, label2rawid 41 | 42 | 43 | class ReIDMetaDataset: 44 | """ 45 | 定义了ReID数据集的元信息。必须包含train, query, gallery属性。 46 | A list of dict. Dict contains meta infomation, which is 47 | { 48 | "image_path": str, required 49 | "id": int, required 50 | 51 | "cam"(optional): int, 52 | "keypoints"(optional): extra information 53 | "kp_vis"(optional): 每个keypoint是否可见 54 | "mask"(optional): extra information 55 | "box"(optional): extra information 56 | "color"(optional): extra information 57 | "type"(optional): extra information 58 | "view"(optional): extra information 59 | } 60 | """ 61 | def __init__(self, pkl_path, verbose=True, **kwargs): 62 | with open(pkl_path, 'rb') as f: 63 | metas = pkl.load(f) 64 | 65 | self.train = metas["train"] 66 | self.query = metas["query"] 67 | self.gallery = metas["gallery"] 68 | self.relabel() 69 | self._calc_meta_info() 70 | 71 | if verbose: 72 | print("=> Dataset loaded") 73 | self.print_dataset_statistics() 74 | 75 | def relabel(self): 76 | self.train, self.train_rawid2label, self.train_label2rawid = relabel(self.train) 77 | eval_set, self.eval_rawid2label, self.eval_label2rawid = relabel(self.query + self.gallery) 78 | self.query = eval_set[:len(self.query)] 79 | self.gallery = eval_set[len(self.query):] 80 | 81 | def print_dataset_statistics(self): 82 | num_train_pids, num_train_imgs, num_train_cams = get_imagedata_info(self.train) 83 | num_query_pids, num_query_imgs, num_query_cams = get_imagedata_info(self.query) 84 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = get_imagedata_info(self.gallery) 85 | 86 | print("Dataset statistics:") 87 | print(" ----------------------------------------") 88 | print(" subset | # ids | # images | # cameras") 89 | print(" ----------------------------------------") 90 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 91 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 92 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 93 | print(" ----------------------------------------") 94 | 95 | def _calc_meta_info(self): 96 | self.num_train_ids, self.num_train_imgs, self.num_train_cams = get_imagedata_info(self.train) 97 | self.num_query_ids, self.num_query_imgs, self.num_query_cams = get_imagedata_info(self.query) 98 | self.num_gallery_ids, self.num_gallery_imgs, self.num_gallery_cams = get_imagedata_info(self.gallery) 99 | 100 | 101 | class ReIDDataset(Dataset): 102 | def __init__(self, meta_dataset, *, with_mask=False, mask_num=5, transform=None, preprocessing=None): 103 | """将元数据集转化为图片数据集,并进行预处理 104 | 105 | Arguments: 106 | Dataset {ReIDMetaDataset} -- self 107 | meta_dataset {ReIDMetaDataset} -- 元数据集 108 | 109 | Keyword Arguments: 110 | with_box {bool} -- [是否使用检测框做crop。从box属性中读取检测框信息] (default: {False}) 111 | with_mask {bool} -- [是否读取mask。为True时从mask_nori_id读取mask] (default: {False}) 112 | mask_num {int} -- [mask数量] (default: {5}) 113 | sub_bg {bool} -- [是否删除背景。with_mask为True时才会生效。将利用第一个mask对图片做背景减除] (default: {False}) 114 | transform {[type]} -- [数据增强] (default: {None}) 115 | preprocessing {[type]} -- [normalize, to tensor等预处理] (default: {None}) 116 | """ 117 | self.meta_dataset = meta_dataset 118 | self.transform = transform 119 | self.preprocessing = preprocessing 120 | self.with_mask = with_mask 121 | self.mask_num = mask_num 122 | 123 | def read_mask(self, sample): 124 | # 读入mask 125 | mask = cv2.imread(sample["mask_path"], cv2.IMREAD_GRAYSCALE) 126 | mask = [mask == v for v in range(self.mask_num)] 127 | mask = np.stack(mask, axis=-1).astype('float32') 128 | sample["mask"] = mask 129 | 130 | 131 | def __getitem__(self, item): 132 | meta: dict = self.meta_dataset[item] 133 | sample = meta.copy() 134 | # 读入图片 135 | 136 | ######## TODO: Remove this before commit 137 | idx = sample["image_path"].find('/home/aa') 138 | if idx != -1: 139 | sample["image_path"] = '/data1/dechao_meng/' + sample["image_path"][idx + 8:] 140 | sample["mask_path"] = '/data1/dechao_meng/' + sample["mask_path"][idx + 8:] 141 | ################################### 142 | 143 | 144 | sample["image"] = read_rgb_image(sample["image_path"]) 145 | 146 | 147 | # 读入mask 148 | if self.with_mask: 149 | self.read_mask(sample) 150 | 151 | # 数据增强 152 | if self.transform: 153 | sample = self.transform(**sample) 154 | 155 | # preprocessing 156 | if self.preprocessing: 157 | sample = self.preprocessing(**sample) 158 | 159 | return sample 160 | 161 | def __len__(self): 162 | return len(self.meta_dataset) 163 | 164 | 165 | if __name__ == "__main__": 166 | from vehicle_reid_pytorch.data.datasets import AICity 167 | from vehicle_reid_pytorch.utils.visualize import visualize_img 168 | from vehicle_reid_pytorch.data.demo_transforms import get_training_albumentations 169 | import matplotlib.pyplot as plt 170 | meta_dataset = ReIDMetaDataset(pkl_path="") 171 | dataset = ReIDDataset(meta_dataset.train, transform=get_training_albumentations(with_keypoints=True)) 172 | images = [] 173 | for idx in np.random.randint(0, len(dataset), 10): 174 | sample = dataset[idx] 175 | image = sample['image'] 176 | image = image[:, :, :3] * 0.5 + sample['kp_heatmap'].reshape(256, 256, 1) * 50 177 | images.append(image.astype('uint8')) 178 | 179 | visualize_img(*images, cols=2, show=False) 180 | plt.savefig('aaver.png') 181 | print('finish') 182 | -------------------------------------------------------------------------------- /vehicle_reid_pytorch/loss/triplet_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def normalize(x, axis=-1): 6 | """Normalizing to unit length along the specified dimension. 7 | Args: 8 | x: pytorch Variable 9 | Returns: 10 | x: pytorch Variable, same shape as input 11 | """ 12 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 13 | return x 14 | 15 | 16 | def euclidean_dist(x, y, split=0): 17 | """ 18 | Args: 19 | x: pytorch Variable, with shape [m, d] 20 | y: pytorch Variable, with shape [n, d] 21 | split: When the CUDA memory is not sufficient, we can split the dataset into different parts 22 | for the computing of distance. 23 | Returns: 24 | dist: pytorch Variable, with shape [m, n] 25 | """ 26 | m, n = x.size(0), y.size(0) 27 | if split == 0: 28 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 29 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 30 | distmat = xx + yy 31 | distmat.addmm_(x, y.t(), beta=1, alpha=-2) 32 | 33 | else: 34 | distmat = x.new(m, n) 35 | start = 0 36 | x = x.cuda() 37 | 38 | while start < n: 39 | end = start + split if (start + split) < n else n 40 | num = end - start 41 | 42 | sub_distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, num) + \ 43 | torch.pow(y[start:end].cuda(), 2).sum(dim=1, keepdim=True).expand(num, m).t() 44 | # sub_distmat.addmm_(1, -2, x, y[start:end].t()) 45 | sub_distmat.addmm_(x, y[start:end].cuda().t(), beta=1, alpha=-2) 46 | distmat[:, start:end] = sub_distmat.cpu() 47 | start += num 48 | 49 | distmat = distmat.clamp(min=1e-12).sqrt() # for numerical stability 50 | return distmat 51 | 52 | 53 | def hard_example_mining(dist_mat, labels, mask=None, return_inds=False): 54 | """For each anchor, find the hardest positive and negative sample. 55 | Args: 56 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 57 | labels: pytorch LongTensor, with shape [N] 58 | mask: pytorch Tensor, with shape [N, N] 59 | return_inds: whether to return the indices. Save time if `False`(?) 60 | Returns: 61 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 62 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 63 | p_inds: pytorch LongTensor, with shape [N]; 64 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 65 | n_inds: pytorch LongTensor, with shape [N]; 66 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 67 | NOTE: Only consider the case in which all labels have same num of samples, 68 | thus we can cope with all anchors in parallel. 69 | """ 70 | 71 | assert len(dist_mat.size()) == 2 72 | assert dist_mat.size(0) == dist_mat.size(1) 73 | N = dist_mat.size(0) 74 | 75 | # shape [N, N] 76 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 77 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) 78 | 79 | # `dist_ap` means distance(anchor, positive) 80 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 81 | if mask is None: 82 | mask = torch.ones_like(dist_mat) 83 | 84 | aux_mat = torch.zeros_like(dist_mat) 85 | aux_mat[mask==0] -= 10 86 | dist_mat = dist_mat + aux_mat 87 | 88 | dist_ap, relative_p_inds = torch.max( 89 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) 90 | 91 | 92 | 93 | # `dist_an` means distance(anchor, negative) 94 | # both `dist_an` and `relative_n_inds` with shape [N, 1] 95 | # dist_mat[dist_mat == 0] += 10000 # 处理非法值。归一化后的最大距离为2 96 | aux_mat = torch.zeros_like(dist_mat) 97 | aux_mat[mask==0] += 10000 98 | dist_mat = dist_mat + aux_mat 99 | dist_an, relative_n_inds = torch.min( 100 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) 101 | # shape [N] 102 | 103 | 104 | 105 | dist_ap = dist_ap.squeeze(1) 106 | dist_an = dist_an.squeeze(1) 107 | 108 | if return_inds: 109 | # shape [N, N] 110 | ind = (labels.new().resize_as_(labels) 111 | .copy_(torch.arange(0, N).long()) 112 | .unsqueeze(0).expand(N, N)) 113 | # shape [N, 1] 114 | p_inds = torch.gather( 115 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 116 | n_inds = torch.gather( 117 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 118 | # shape [N] 119 | p_inds = p_inds.squeeze(1) 120 | n_inds = n_inds.squeeze(1) 121 | return dist_ap, dist_an, p_inds, n_inds 122 | 123 | return dist_ap, dist_an 124 | 125 | 126 | class TripletLoss(object): 127 | """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid). 128 | Related Triplet Loss theory can be found in paper 'In Defense of the Triplet 129 | Loss for Person Re-Identification'.""" 130 | 131 | def __init__(self, margin=None): 132 | self.margin = margin 133 | if margin is not None: 134 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 135 | else: 136 | self.ranking_loss = nn.SoftMarginLoss() 137 | 138 | def __call__(self, global_feat, labels, mask=None, normalize_feature=False): 139 | """ 140 | 141 | :param global_feat: 142 | :param labels: 143 | :param mask: [N, N] 可见性mask。不可见的mask将不会被选择。若全部不可见,则对结果*0 144 | :param normalize_feature: 145 | :return: 146 | """ 147 | if normalize_feature: 148 | global_feat = normalize(global_feat, axis=-1) 149 | dist_mat = euclidean_dist(global_feat, global_feat) 150 | dist_ap, dist_an = hard_example_mining( 151 | dist_mat, labels, mask=mask) 152 | y = dist_an.new().resize_as_(dist_an).fill_(1) 153 | if self.margin is not None: 154 | loss = self.ranking_loss(dist_an, dist_ap, y) 155 | else: 156 | loss = self.ranking_loss(dist_an - dist_ap, y) 157 | return loss, dist_ap, dist_an 158 | 159 | 160 | class CrossEntropyLabelSmooth(nn.Module): 161 | """Cross entropy loss with label smoothing regularizer. 162 | 163 | Reference: 164 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 165 | Equation: y = (1 - epsilon) * y + epsilon / K. 166 | 167 | Args: 168 | num_classes (int): number of classes. 169 | epsilon (float): weight. 170 | """ 171 | 172 | def __init__(self, num_classes, epsilon=0.1, keep_dim=False): 173 | super(CrossEntropyLabelSmooth, self).__init__() 174 | self.num_classes = num_classes 175 | self.epsilon = epsilon 176 | self.logsoftmax = nn.LogSoftmax(dim=1) 177 | self.keep_dim = keep_dim 178 | 179 | def forward(self, inputs, targets): 180 | """ 181 | Args: 182 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 183 | targets: ground truth labels with shape (num_classes) 184 | """ 185 | log_probs = self.logsoftmax(inputs) 186 | targets = inputs.new_zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data, 1) 187 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 188 | if self.keep_dim: 189 | loss = (- targets * log_probs).sum(1) 190 | else: 191 | loss = (- targets * log_probs).mean(0).sum() 192 | return loss 193 | -------------------------------------------------------------------------------- /vehicle_reid_pytorch/metrics/R1_mAP.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import pickle 8 | import os 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional 12 | 13 | from .eval_reid import eval_func 14 | from .rerank import re_ranking 15 | 16 | 17 | def build_metric(cfg, num_query): 18 | if 'vehicleid' in cfg.datasets.names[0]: 19 | metric = CMC10Times(feat_norm=cfg.test.FEAT_NORM, output_path=cfg.OUTPUT_DIR, 20 | rerank=cfg.test.RE_RANKING == 'yes') 21 | else: 22 | metric = R1_mAP(num_query, max_rank=50, feat_norm=cfg.test.FEAT_NORM, output_path=cfg.OUTPUT_DIR, 23 | rerank=cfg.test.RE_RANKING == 'yes', remove_junk=cfg.test.REMOVE_JUNK == 'yes') 24 | 25 | return metric 26 | 27 | 28 | class R1_mAP: 29 | def __init__(self, num_query, *, max_rank=0, feat_norm=True, rerank=False, remove_junk=True, output_path=''): 30 | super(R1_mAP, self).__init__() 31 | self.num_query = num_query 32 | self.max_rank = max_rank 33 | self.feat_norm = feat_norm 34 | self.output_path = output_path 35 | self.rerank = rerank 36 | self.remove_junk = remove_junk 37 | self.reset() 38 | 39 | def reset(self): 40 | self.feats = [] 41 | self.pids = [] 42 | self.camids = [] 43 | self.paths = [] 44 | 45 | def update(self, output): 46 | feat, pid, camid, paths = output 47 | self.feats.append(feat) 48 | self.pids.extend(np.asarray(pid)) 49 | self.camids.extend(np.asarray(camid)) 50 | self.paths += paths 51 | 52 | def process_feat(self): 53 | self.feats = torch.cat(self.feats, dim=0) 54 | self.pids = np.asarray(self.pids) 55 | if self.feat_norm: 56 | self.feats = torch.nn.functional.normalize(self.feats, dim=1, p=2) 57 | 58 | 59 | def resplit(self, pids): 60 | # sorted_idxs = np.argsort(pids) 61 | # sorted_pid = pids[sorted_idxs] 62 | num_pid = len(set(pids)) 63 | query = [] 64 | gallery = [] 65 | for i in range(num_pid): 66 | idxs:np.ndarray = (pids==i).nonzero()[0] 67 | choose_idx = np.random.randint(len(idxs)) 68 | gallery.append(idxs[choose_idx]) 69 | query.extend(idxs[:choose_idx]) 70 | query.extend(idxs[choose_idx+1:]) 71 | return query + gallery 72 | 73 | def shuffle_eval(self): 74 | """打乱后再计算,用于VehicleID 75 | """ 76 | indexs = self.resplit(self.pids) 77 | self.feats = self.feats[indexs] 78 | self.pids = self.pids[indexs] 79 | 80 | return self._compute() 81 | 82 | def compute(self): 83 | self.process_feat() 84 | return self._compute() 85 | 86 | def _compute(self): 87 | feats = self.feats 88 | # query 89 | qf = feats[:self.num_query] 90 | q_pids = np.asarray(self.pids[:self.num_query]) 91 | q_camids = np.asarray(self.camids[:self.num_query]) 92 | # gallery 93 | gf = feats[self.num_query:] 94 | g_pids = np.asarray(self.pids[self.num_query:]) 95 | g_camids = np.asarray(self.camids[self.num_query:]) 96 | m, n = qf.shape[0], gf.shape[0] 97 | 98 | if self.rerank: 99 | distmat = re_ranking(qf, gf, k1=20, k2=6, lambda_value=0.3) 100 | 101 | else: 102 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 103 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 104 | distmat.addmm_(1, -2, qf, gf.t()) 105 | distmat = torch.sqrt(distmat).cpu().numpy() 106 | 107 | # 保存结果 108 | query_paths = self.paths[:self.num_query] 109 | gallery_paths = self.paths[self.num_query:] 110 | 111 | self.write_output(gallery_paths, query_paths, g_pids, q_pids, qf, gf, q_camids, g_camids, distmat) 112 | self.distmat = distmat 113 | 114 | cmc, mAP = eval_func(distmat, q_pids, g_pids, 115 | q_camids, g_camids, remove_junk=self.remove_junk) 116 | 117 | return cmc, mAP 118 | 119 | 120 | def write_output(self, gallery_paths, query_paths, g_pids, q_pids, qf, gf, q_camids, g_camids, distmat): 121 | if self.output_path != '': 122 | # gallery_indexs = np.argsort(distmat, axis=1)[:100] + 1 123 | # def int2str(id): 124 | # return f'{id:06d}' 125 | 126 | # with open('submit.txt', 'w') as f: 127 | # for gallerys in gallery_indexs: 128 | # output_str = ' '.join(map(int2str, gallerys)) + '\n' 129 | # f.write(output_str) 130 | 131 | try: 132 | with open(os.path.join(self.output_path, 'test_output.pkl'), 'wb') as f: 133 | torch.save({ 134 | 'gallery_paths': gallery_paths, 135 | 'query_paths': query_paths, 136 | 'gallery_ids': g_pids, 137 | 'query_ids': q_pids, 138 | 'query_features': qf, 139 | 'gallery_features': gf, 140 | 'query_cams': q_camids, 141 | 'gallery_cams': g_camids, 142 | 'distmat': distmat 143 | }, f) 144 | 145 | except OverflowError: 146 | print("Can't save results.") 147 | pass 148 | 149 | 150 | class CMC10Times: 151 | 152 | def __init__(self, feat_norm='yes', output_path='', rerank=False): 153 | """ 154 | VehicleID的评测算法。重复十次。每次各id随机取一张放入gallery中。 155 | 156 | :param num_query: 157 | :param max_rank: 158 | :param feat_norm: 159 | :param output_path: 160 | :param remove_junk: 161 | """ 162 | super(CMC10Times, self).__init__() 163 | self.feat_norm = feat_norm 164 | self.output_path = output_path 165 | self.rerank = rerank 166 | self.reset() 167 | 168 | def reset(self): 169 | self.feats = [] 170 | self.pids = [] 171 | self.camids = [] 172 | self.paths = [] 173 | 174 | def update(self, output): 175 | feat = output[0] 176 | pid = output[1] 177 | self.feats.append(feat) 178 | self.pids.extend(np.asarray(pid)) 179 | 180 | def compute(self): 181 | self.feats = torch.cat(self.feats, dim=0) 182 | 183 | if self.feat_norm == 'yes': 184 | print("The test feature is normalized") 185 | self.feats = torch.nn.functional.normalize(self.feats, dim=1, p=2) 186 | 187 | pids_np = np.array(self.pids) 188 | cmcs = [] 189 | mAPs = [] 190 | for i in range(10): 191 | # 采样 192 | gallery = [] 193 | pid_set = set(self.pids) 194 | for pid in pid_set: 195 | mask = (pids_np == pid) 196 | idxs = np.nonzero(mask)[0] 197 | sample_idx = np.random.choice(idxs) 198 | gallery.append(sample_idx) 199 | 200 | # 计算 201 | cmc, mAP = self.compute_once(gallery) 202 | cmcs.append(cmc) 203 | mAPs.append(mAP) 204 | # 求均值 205 | cmcs = np.array(cmcs) 206 | mean_cmc = cmcs.mean(axis=0) 207 | mAPs = np.array(mAPs) 208 | mean_mAP = mAPs.mean() 209 | return mean_cmc, mean_mAP 210 | 211 | def compute_once(self, gallery_idxs): 212 | gallery_mask = torch.zeros(len(self.feats)) 213 | gallery_mask[gallery_idxs] = 1 214 | query_mask = 1 - gallery_mask 215 | 216 | # query 217 | qf = self.feats[query_mask.type(torch.bool)] 218 | q_pids = np.asarray(self.pids)[query_mask.type(torch.bool)] 219 | # gallery 220 | gf = self.feats[gallery_mask.type(torch.bool)] 221 | g_pids = np.asarray(self.pids)[gallery_mask.type(torch.bool)] 222 | 223 | if self.rerank: 224 | distmat = re_ranking(qf, gf, k1=20, k2=6, lambda_value=0.3) 225 | else: 226 | m, n = qf.shape[0], gf.shape[0] 227 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 228 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 229 | distmat.addmm_(1, -2, qf, gf.t()) 230 | distmat = distmat.cpu().numpy() 231 | cmc, mAP = eval_func(distmat, q_pids, g_pids, 232 | None, None, remove_junk=False) 233 | 234 | return cmc, mAP 235 | -------------------------------------------------------------------------------- /vehicle_reid_pytorch/data/samplers/triplet_sampler.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | import torch 4 | from collections import defaultdict 5 | 6 | import numpy as np 7 | from torch.utils.data.sampler import Sampler 8 | 9 | 10 | class RandomIdentitySampler(Sampler): 11 | """ 12 | Randomly sample N identities, then for each identity, 13 | randomly sample K instances, therefore batch size is N*K. 14 | Args: 15 | - data_source (list): list of (img_path, pid, camid). 16 | - num_instances (int): number of instances per identity in a batch. 17 | - batch_size (int): number of examples in a batch. 18 | """ 19 | 20 | def __init__(self, data_source, batch_size, num_instances): 21 | super(RandomIdentitySampler, self).__init__(data_source) 22 | self.data_source = data_source 23 | self.batch_size = batch_size 24 | self.num_instances = num_instances 25 | self.num_pids_per_batch = self.batch_size // self.num_instances 26 | self.index_dic = defaultdict(list) 27 | for index, item in enumerate(self.data_source): 28 | pid = item['id'] 29 | self.index_dic[pid].append(index) 30 | self.pids = list(self.index_dic.keys()) 31 | 32 | # estimate number of examples in an epoch 33 | self.length = 0 34 | for pid in self.pids: 35 | idxs = self.index_dic[pid] 36 | num = len(idxs) 37 | if num < self.num_instances: 38 | num = self.num_instances 39 | self.length += num - num % self.num_instances 40 | 41 | def __iter__(self): 42 | batch_idxs_dict = defaultdict(list) 43 | 44 | for pid in self.pids: 45 | idxs = copy.deepcopy(self.index_dic[pid]) 46 | if len(idxs) < self.num_instances: 47 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 48 | random.shuffle(idxs) 49 | batch_idxs = [] 50 | for idx in idxs: 51 | batch_idxs.append(idx) 52 | if len(batch_idxs) == self.num_instances: 53 | batch_idxs_dict[pid].append(batch_idxs) 54 | batch_idxs = [] 55 | 56 | avai_pids = copy.deepcopy(self.pids) 57 | final_idxs = [] 58 | 59 | while len(avai_pids) >= self.num_pids_per_batch: 60 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 61 | for pid in selected_pids: 62 | batch_idxs = batch_idxs_dict[pid].pop(0) 63 | final_idxs.extend(batch_idxs) 64 | if len(batch_idxs_dict[pid]) == 0: 65 | avai_pids.remove(pid) 66 | 67 | self.length = len(final_idxs) 68 | return iter(final_idxs) 69 | 70 | def __len__(self): 71 | return self.length 72 | 73 | 74 | class KPSampler(Sampler): 75 | """ 76 | Randomly sample N identities, then for each identity, 77 | randomly sample K instances, therefore batch size is N*K. 78 | 79 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py. 80 | 81 | Args: 82 | data_source (Dataset): dataset to sample from. 83 | num_instances (int): number of instances per identity. 84 | """ 85 | 86 | def __init__(self, data_source, batch_size, num_instances): 87 | super(KPSampler, self).__init__(data_source) 88 | self.data_source = data_source 89 | self.num_instances = num_instances 90 | self.batch_size = batch_size 91 | self.index_dic = defaultdict(list) 92 | for index, item in enumerate(data_source): 93 | id = item["id"] 94 | self.index_dic[id].append(index) 95 | self.pids = list(self.index_dic.keys()) 96 | self.num_identities = len(self.pids) 97 | 98 | def __iter__(self): 99 | indices = torch.randperm(self.num_identities) 100 | ret = [] 101 | for i in indices: 102 | pid = self.pids[i] 103 | t = self.index_dic[pid] 104 | replace = False if len(t) >= self.num_instances else True 105 | t = np.random.choice(t, size=self.num_instances, replace=replace) 106 | ret.extend(t) 107 | print(ret) 108 | return iter(ret) 109 | 110 | def __len__(self): 111 | return self.num_identities * self.num_instances 112 | 113 | 114 | class SimilarIdentitySampler(Sampler): 115 | def __init__(self, data_source, batch_size, num_instances, similarity_matrix): 116 | """ 117 | 118 | :param list data_source: (path, pid, image_id) 119 | :param num_instances: 120 | :param np.ndarray similarity_matrix: 相似度矩阵。(i, j)代表两个id之间的相似度。 121 | """ 122 | 123 | super(SimilarIdentitySampler, self).__init__(data_source) 124 | self.data_source = data_source 125 | self.num_instances = num_instances 126 | num_ids = similarity_matrix.shape[0] 127 | similarity_matrix[np.eye(num_ids, dtype=bool)] = 0 128 | self.similarity_matrix = similarity_matrix 129 | 130 | self.index_dic = defaultdict(list) 131 | self.batch_size = batch_size 132 | self.num_pids_per_batch = self.batch_size // self.num_instances 133 | for i, (_, pid, _) in enumerate(data_source): 134 | self.index_dic[pid].append(i) 135 | self.pids = list(self.index_dic.keys()) 136 | # estimate number of examples in an epoch 137 | self.length = 0 138 | for pid in self.pids: 139 | idxs = self.index_dic[pid] 140 | num = len(idxs) 141 | if num < self.num_instances: 142 | num = self.num_instances 143 | self.length += num - num % self.num_instances 144 | 145 | def __iter__(self): 146 | batch_idxs_dict = defaultdict(list) 147 | 148 | for pid in self.pids: 149 | idxs = copy.deepcopy(self.index_dic[pid]) 150 | if len(idxs) < self.num_instances: 151 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 152 | random.shuffle(idxs) 153 | batch_idxs = [] 154 | for idx in idxs: 155 | batch_idxs.append(idx) 156 | if len(batch_idxs) == self.num_instances: 157 | batch_idxs_dict[pid].append(batch_idxs) 158 | batch_idxs = [] 159 | 160 | avai_pids = copy.deepcopy(self.pids) 161 | final_idxs = [] 162 | 163 | while len(avai_pids) >= self.num_pids_per_batch: 164 | selected_pid = random.sample(avai_pids, 1)[0] 165 | similarity = self.similarity_matrix[selected_pid, avai_pids] 166 | similarity /= np.sum(similarity) 167 | 168 | ################################################################################ 169 | # Sample only from top-k similarity. 170 | order = np.argsort(similarity) 171 | p = np.zeros_like(similarity) 172 | p[order[-(self.num_pids_per_batch - 1):]] = 1 / (self.num_pids_per_batch - 1) 173 | ################################################################################ 174 | 175 | selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch - 1, False, p=p) 176 | selected_pids = list(selected_pids) 177 | selected_pids.insert(0, selected_pid) 178 | assert len(selected_pids) == self.num_pids_per_batch 179 | for pid in selected_pids: 180 | batch_idxs = batch_idxs_dict[pid].pop(0) 181 | assert len(batch_idxs) == self.num_instances 182 | final_idxs.extend(batch_idxs) 183 | if len(batch_idxs_dict[pid]) == 0: 184 | avai_pids.remove(pid) 185 | 186 | self.length = len(final_idxs) 187 | return iter(final_idxs) 188 | 189 | def __len__(self): 190 | return self.length 191 | 192 | 193 | def test_similarity_sampler(): 194 | """ 195 | 只能保证运行时不出错,不验证正确性. 196 | 197 | :return: 198 | """ 199 | pids = list(range(100)) 200 | path = "aaa" 201 | data_source = [(path, pid, idx) for idx, pid in enumerate(np.random.choice(pids, 10000))] 202 | batch_size = 64 203 | num_instances = 16 204 | similarity_matrix = np.random.rand(100, 100) 205 | sampler = SimilarIdentitySampler(data_source, batch_size, num_instances, similarity_matrix) 206 | random_sampler = RandomIdentitySampler(data_source, batch_size, num_instances) 207 | print("bbb") 208 | print(len(sampler)) 209 | print(len(random_sampler)) 210 | print("aaa") 211 | for idx1, idx2 in zip(sampler, random_sampler): 212 | print(data_source[idx1]) 213 | -------------------------------------------------------------------------------- /vehicle_reid_pytorch/models/baseline.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | from torch import nn 9 | 10 | from resnest.torch import resnest50 11 | from torchvision.models import resnet50 12 | from torchvision.models.resnet import model_urls 13 | from torchvision.models.utils import load_state_dict_from_url 14 | from .backbones.resnet import ResNet, BasicBlock, Bottleneck 15 | from .backbones.senet import SENet, SEResNetBottleneck, SEBottleneck, SEResNeXtBottleneck 16 | from .backbones.resnet_ibn import resnet50_ibn_a 17 | 18 | 19 | def weights_init_kaiming(m): 20 | classname = m.__class__.__name__ 21 | if classname.find('Linear') != -1: 22 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 23 | nn.init.constant_(m.bias, 0.0) 24 | elif classname.find('Conv') != -1: 25 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 26 | if m.bias is not None: 27 | nn.init.constant_(m.bias, 0.0) 28 | elif classname.find('BatchNorm') != -1: 29 | if m.affine: 30 | nn.init.constant_(m.weight, 1.0) 31 | nn.init.constant_(m.bias, 0.0) 32 | 33 | 34 | def weights_init_classifier(m): 35 | classname = m.__class__.__name__ 36 | if classname.find('Linear') != -1: 37 | nn.init.normal_(m.weight, std=0.001) 38 | if m.bias: 39 | nn.init.constant_(m.bias, 0.0) 40 | 41 | 42 | class Baseline(nn.Module): 43 | 44 | def __init__(self, num_classes, last_stride, model_path, neck, neck_feat, model_name, pretrain_choice, **kwargs): 45 | super(Baseline, self).__init__() 46 | self.in_planes = 2048 47 | if model_name == 'resnet18': 48 | self.in_planes = 512 49 | self.base = ResNet(last_stride=last_stride, 50 | block=BasicBlock, 51 | layers=[2, 2, 2, 2]) 52 | elif model_name == 'resnet34': 53 | self.in_planes = 512 54 | self.base = ResNet(last_stride=last_stride, 55 | block=BasicBlock, 56 | layers=[3, 4, 6, 3]) 57 | elif model_name == 'resnet50': 58 | self.base = ResNet(last_stride=last_stride, 59 | block=Bottleneck, 60 | layers=[3, 4, 6, 3]) 61 | elif model_name == 'resnest50': 62 | self.base = resnest50(pretrained=True) 63 | 64 | elif model_name == 'resnet101': 65 | self.base = ResNet(last_stride=last_stride, 66 | block=Bottleneck, 67 | layers=[3, 4, 23, 3]) 68 | elif model_name == 'resnet152': 69 | self.base = ResNet(last_stride=last_stride, 70 | block=Bottleneck, 71 | layers=[3, 8, 36, 3]) 72 | 73 | elif model_name == 'se_resnet50': 74 | self.base = SENet(block=SEResNetBottleneck, 75 | layers=[3, 4, 6, 3], 76 | groups=1, 77 | reduction=16, 78 | dropout_p=None, 79 | inplanes=64, 80 | input_3x3=False, 81 | downsample_kernel_size=1, 82 | downsample_padding=0, 83 | last_stride=last_stride) 84 | elif model_name == 'se_resnet101': 85 | self.base = SENet(block=SEResNetBottleneck, 86 | layers=[3, 4, 23, 3], 87 | groups=1, 88 | reduction=16, 89 | dropout_p=None, 90 | inplanes=64, 91 | input_3x3=False, 92 | downsample_kernel_size=1, 93 | downsample_padding=0, 94 | last_stride=last_stride) 95 | elif model_name == 'se_resnet152': 96 | self.base = SENet(block=SEResNetBottleneck, 97 | layers=[3, 8, 36, 3], 98 | groups=1, 99 | reduction=16, 100 | dropout_p=None, 101 | inplanes=64, 102 | input_3x3=False, 103 | downsample_kernel_size=1, 104 | downsample_padding=0, 105 | last_stride=last_stride) 106 | elif model_name == 'se_resnext50': 107 | self.base = SENet(block=SEResNeXtBottleneck, 108 | layers=[3, 4, 6, 3], 109 | groups=32, 110 | reduction=16, 111 | dropout_p=None, 112 | inplanes=64, 113 | input_3x3=False, 114 | downsample_kernel_size=1, 115 | downsample_padding=0, 116 | last_stride=last_stride) 117 | elif model_name == 'se_resnext101': 118 | self.base = SENet(block=SEResNeXtBottleneck, 119 | layers=[3, 4, 23, 3], 120 | groups=32, 121 | reduction=16, 122 | dropout_p=None, 123 | inplanes=64, 124 | input_3x3=False, 125 | downsample_kernel_size=1, 126 | downsample_padding=0, 127 | last_stride=last_stride) 128 | elif model_name == 'senet154': 129 | self.base = SENet(block=SEBottleneck, 130 | layers=[3, 8, 36, 3], 131 | groups=64, 132 | reduction=16, 133 | dropout_p=0.2, 134 | last_stride=last_stride) 135 | elif model_name == 'resnet50_ibn': 136 | self.base = resnet50_ibn_a(last_stride) 137 | 138 | if pretrain_choice == 'imagenet': 139 | self.base.load_param(model_path) 140 | print('Loading pretrained ImageNet model......') 141 | 142 | self.gap = nn.AdaptiveAvgPool2d(1) 143 | # self.gap = nn.AdaptiveMaxPool2d(1) 144 | self.num_classes = num_classes 145 | self.neck = neck 146 | self.neck_feat = neck_feat 147 | 148 | if self.neck == 'no': 149 | self.classifier = nn.Linear(self.in_planes, self.num_classes) 150 | # self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False) # new add by luo 151 | # self.classifier.apply(weights_init_classifier) # new add by luo 152 | elif self.neck == 'bnneck': 153 | self.bottleneck = nn.BatchNorm1d(self.in_planes) 154 | self.bottleneck.bias.requires_grad_(False) # no shift 155 | self.classifier = nn.Linear( 156 | self.in_planes, self.num_classes, bias=False) 157 | 158 | self.bottleneck.apply(weights_init_kaiming) 159 | self.classifier.apply(weights_init_classifier) 160 | 161 | def forward(self, x, **kwargs): 162 | 163 | global_feat = self.gap(self.base(x)) # (b, 2048, 1, 1) 164 | global_feat = global_feat.view( 165 | global_feat.shape[0], -1) # flatten to (bs, 2048) 166 | 167 | if self.neck == 'no': 168 | feat = global_feat 169 | elif self.neck == 'bnneck': 170 | # normalize for angular softmax 171 | feat = self.bottleneck(global_feat) 172 | 173 | if self.training: 174 | cls_score = self.classifier(feat) 175 | return cls_score, global_feat # global feature for triplet loss 176 | else: 177 | if self.neck_feat == 'after': 178 | # print("Test with feature after BN") 179 | return feat 180 | else: 181 | # print("Test with feature before BN") 182 | return global_feat 183 | 184 | def load_param(self, trained_path): 185 | param_dict = torch.load(trained_path, map_location="cpu") 186 | # print(param_dict.keys()) 187 | for i in param_dict: 188 | if 'classifier' in i: 189 | continue 190 | if 'module' in i: 191 | self.state_dict()[i].copy_(param_dict[i]) 192 | else: 193 | self.state_dict()[i].copy_(param_dict[i]) 194 | -------------------------------------------------------------------------------- /vehicle_reid_pytorch/metrics/eval_reid.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import numpy as np 8 | from tqdm import tqdm 9 | from multiprocessing import Pool 10 | import time 11 | 12 | 13 | def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50, remove_junk=True): 14 | """Evaluation with veri776 metric 15 | Key: for each query identity, its gallery images from the same camera view are discarded. 16 | 17 | :param np.ndarray distmat: 18 | :param np.ndarray q_pids: 19 | :param np.ndarray g_pids: 20 | :param np.ndarray q_camids: 21 | :param np.ndarray g_camids: 22 | :param int max_rank: 23 | :param bool remove_junk: 24 | :return: 25 | """ 26 | # compute cmc curve for each query 27 | num_q, num_g = distmat.shape 28 | if num_g < max_rank: 29 | max_rank = num_g 30 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 31 | all_cmc = [] 32 | all_AP = [] 33 | num_valid_q = 0. # number of valid query 34 | for q_idx in tqdm(range(num_q), desc='Calc cmc and mAP'): 35 | # get query pid and camid 36 | q_pid = q_pids[q_idx] 37 | 38 | # remove gallery samples that have the same pid and camid with query 39 | order = np.argsort(distmat[q_idx]) 40 | if remove_junk: 41 | q_camid = q_camids[q_idx] 42 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 43 | else: 44 | remove = np.zeros_like(g_pids).astype(np.bool) 45 | keep = np.invert(remove) 46 | 47 | # compute cmc curve 48 | # binary vector, positions with value 1 are correct matches 49 | # orig_cmc = matches[q_idx][keep] 50 | orig_cmc = (g_pids[order] == q_pid).astype(np.int32)[keep] 51 | if not np.any(orig_cmc): 52 | # this condition is true when query identity does not appear in gallery 53 | continue 54 | 55 | cmc = orig_cmc.cumsum() 56 | cmc[cmc > 1] = 1 57 | 58 | all_cmc.append(cmc[:max_rank]) 59 | num_valid_q += 1. 60 | 61 | # compute average precision 62 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 63 | num_rel = orig_cmc.sum() 64 | tmp_cmc = orig_cmc.cumsum() 65 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 66 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 67 | AP = tmp_cmc.sum() / num_rel 68 | all_AP.append(AP) 69 | 70 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 71 | 72 | all_cmc = np.asarray(all_cmc).astype(np.float32) 73 | all_cmc = all_cmc.sum(0) / num_valid_q 74 | mAP = np.mean(all_AP) 75 | 76 | return all_cmc, mAP 77 | 78 | 79 | def eval_func_mp(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50, remove_junk=True): 80 | """ 81 | Multiprocess version for eval func 82 | """ 83 | 84 | num_q, num_g = distmat.shape 85 | if num_g < max_rank: 86 | max_rank = num_g 87 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 88 | all_cmc = [] 89 | all_AP = [] 90 | print('Generating worker pools') 91 | t1 = time.time() 92 | pool = Pool(30) 93 | res = pool.imap(worker, [ 94 | ( 95 | q_pids[q_idx], 96 | q_camids[q_idx], 97 | g_pids, 98 | g_camids, 99 | distmat[q_idx], 100 | max_rank, 101 | remove_junk 102 | ) for q_idx in range(num_q) 103 | ], chunksize=32) 104 | print(time.time() - t1) 105 | 106 | for r in tqdm(res, total=num_q): 107 | all_AP.append(r[0]) 108 | all_cmc.append(r[1]) 109 | 110 | 111 | # num_valid_q = 0. # number of valid query 112 | # for q_idx in tqdm(range(num_q), desc='Calc cmc and mAP'): 113 | # get query pid and camid 114 | # assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 115 | 116 | all_cmc = np.asarray(all_cmc).astype(np.float32) 117 | all_cmc = all_cmc.sum(0) / num_q 118 | mAP = np.mean(all_AP) 119 | 120 | return all_cmc, mAP, all_AP 121 | 122 | def worker(args): 123 | q_pid, q_camid, g_pids, g_camids, dist_vec, max_rank, remove_junk = args 124 | # remove gallery samples that have the same pid and camid with query 125 | order = np.argsort(dist_vec) 126 | if remove_junk: 127 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 128 | else: 129 | remove = np.zeros_like(g_pids).astype(np.bool) 130 | keep = np.invert(remove) 131 | 132 | # compute cmc curve 133 | # binary vector, positions with value 1 are correct matches 134 | # orig_cmc = matches[q_idx][keep] 135 | orig_cmc = (g_pids[order] == q_pid).astype(np.int32)[keep] 136 | AP, cmc = calc_AP(orig_cmc) 137 | return AP, cmc[:max_rank] 138 | 139 | def eval_func_th(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50, remove_junk=True): 140 | """Evaluation with veri776 metric 141 | Key: for each query identity, its gallery images from the same camera view are discarded. 142 | 143 | :param np.ndarray distmat: 144 | :param np.ndarray q_pids: 145 | :param np.ndarray g_pids: 146 | :param np.ndarray q_camids: 147 | :param np.ndarray g_camids: 148 | :param int max_rank: 149 | :param bool remove_junk: 150 | :return: 151 | """ 152 | # compute cmc curve for each query 153 | num_q, num_g = distmat.shape 154 | if num_g < max_rank: 155 | max_rank = num_g 156 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 157 | all_cmc = [] 158 | all_AP = [] 159 | num_valid_q = 0. # number of valid query 160 | for q_idx in tqdm(range(num_q), desc='Calc cmc and mAP'): 161 | # get query pid and camid 162 | q_pid = q_pids[q_idx] 163 | 164 | # remove gallery samples that have the same pid and camid with query 165 | order = np.argsort(distmat[q_idx]) 166 | if remove_junk: 167 | q_camid = q_camids[q_idx] 168 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 169 | else: 170 | remove = np.zeros_like(g_pids).astype(np.bool) 171 | keep = np.invert(remove) 172 | 173 | # compute cmc curve 174 | # binary vector, positions with value 1 are correct matches 175 | # orig_cmc = matches[q_idx][keep] 176 | orig_cmc = (g_pids[order] == q_pid).astype(np.int32)[keep] 177 | if not np.any(orig_cmc): 178 | # this condition is true when query identity does not appear in gallery 179 | continue 180 | 181 | cmc = orig_cmc.cumsum() 182 | cmc[cmc > 1] = 1 183 | 184 | all_cmc.append(cmc[:max_rank]) 185 | num_valid_q += 1. 186 | 187 | # compute average precision 188 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 189 | num_rel = orig_cmc.sum() 190 | tmp_cmc = orig_cmc.cumsum() 191 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 192 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 193 | AP = tmp_cmc.sum() / num_rel 194 | all_AP.append(AP) 195 | 196 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 197 | 198 | all_cmc = np.asarray(all_cmc).astype(np.float32) 199 | all_cmc = all_cmc.sum(0) / num_valid_q 200 | mAP = np.mean(all_AP) 201 | 202 | return all_cmc, mAP 203 | 204 | def calc_AP(orig_cmc): 205 | """Evaluation 206 | 207 | 计算一行的AP值 208 | """ 209 | # orig_cmc = (g_pids[order] == q_pid).astype(np.int32)[keep] 210 | if not np.any(orig_cmc): 211 | # this condition is true when query identity does not appear in gallery 212 | raise ValueError 213 | 214 | cmc = orig_cmc.cumsum() 215 | cmc[cmc > 1] = 1 216 | 217 | # compute average precision 218 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 219 | num_rel = orig_cmc.sum() 220 | tmp_cmc = orig_cmc.cumsum() 221 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] # Precision 222 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc # on Recall changed 223 | AP = tmp_cmc.sum() / num_rel 224 | return AP, cmc 225 | 226 | def get_expectation_of_AP(N=10, T=3): 227 | """ 228 | 蒙特卡洛方法获得AP期望值。 229 | 230 | N: 样本总数 231 | T: 正例数量 232 | """ 233 | for i in range(1000): 234 | idxs = np.random.choice(np.arange(N), T) 235 | cmc = np.zeros(N) 236 | cmc[idxs] = 1 237 | AP = calc_AP(cmc)[0] 238 | APs.append(AP) 239 | print(np.mean(APs)) -------------------------------------------------------------------------------- /examples/parsing_reid/math_tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ipdb 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | import numpy as np 7 | import pandas 8 | 9 | from vehicle_reid_pytorch.metrics import eval_func, eval_func_mp 10 | from vehicle_reid_pytorch.loss.triplet_loss import normalize, euclidean_dist 11 | from functools import reduce 12 | 13 | from vehicle_reid_pytorch.metrics.rerank import re_ranking 14 | 15 | 16 | # def calc_dist_split(qf, gf, split=0): 17 | # qf = qf 18 | # m = qf.shape[0] 19 | # n = gf.shape[0] 20 | # distmat = gf.new(m, n) 21 | 22 | # if split == 0: 23 | # distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 24 | # torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 25 | # distmat.addmm_(x, y.t(), beta=1, alpha=-2) 26 | 27 | # # 用于测试时控制显存 28 | # else: 29 | # start = 0 30 | # while start < n: 31 | # end = start + split if (start + split) < n else n 32 | # num = end - start 33 | 34 | # sub_distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, num) + \ 35 | # torch.pow(gf[start:end], 2).sum(dim=1, keepdim=True).expand(num, m).t() 36 | # # sub_distmat.addmm_(1, -2, qf, gf[start:end].t()) 37 | # sub_distmat.addmm_(qf, gf[start:end].t(), beta=1, alpha=-2) 38 | # distmat[:, start:end] = sub_distmat.cpu() 39 | # start += num 40 | 41 | # return distmat 42 | 43 | 44 | def clck_dist(feat1, feat2, vis_score1, vis_score2, split=0): 45 | """ 46 | 计算vpm论文中的clck距离 47 | 48 | :param torch.Tensor feat1: [B1, C, 3] 49 | :param torch.Tensor feat2: [B2, C, 3] 50 | :param torch.Tensor vis_score: [B, 3] 51 | :rtype torch.Tensor 52 | :return: clck distance. [B1, B2] 53 | """ 54 | 55 | B, C, N = feat1.shape 56 | dist_mat = 0 57 | ckcl = 0 58 | for i in range(N): 59 | parse_feat1 = feat1[:, :, i] 60 | parse_feat2 = feat2[:, :, i] 61 | ckcl_ = torch.mm(vis_score1[:, i].view(-1, 1), 62 | vis_score2[:, i].view(1, -1)) # [N, N] 63 | ckcl += ckcl_ 64 | dist_mat += euclidean_dist(parse_feat1, 65 | parse_feat2, split=split).sqrt() * ckcl_ 66 | 67 | return dist_mat / ckcl 68 | 69 | 70 | class Clck_R1_mAP: 71 | def __init__(self, num_query, *, max_rank=50, feat_norm=True, output_path='', rerank=False, remove_junk=True, 72 | lambda_=0.5): 73 | """ 74 | 计算VPM中的可见性距离并计算性能 75 | 76 | :param num_query: 77 | :param max_rank: 78 | :param feat_norm: 79 | :param output_path: 80 | :param rerank: 81 | :param remove_junk: 82 | :param lambda_: distmat = global_dist + lambda_ * local_dist, default 0.5 83 | """ 84 | super(Clck_R1_mAP, self).__init__() 85 | self.num_query = num_query 86 | self.max_rank = max_rank 87 | self.feat_norm = feat_norm 88 | self.output_path = output_path 89 | self.rerank = rerank 90 | self.remove_junk = remove_junk 91 | self.lambda_ = lambda_ 92 | self.reset() 93 | 94 | def reset(self): 95 | self.global_feats = [] 96 | self.local_feats = [] 97 | self.vis_scores = [] 98 | self.pids = [] 99 | self.camids = [] 100 | self.paths = [] 101 | 102 | def update(self, output): 103 | global_feat, local_feat, vis_score, pid, camid, paths = output 104 | self.global_feats += global_feat 105 | self.local_feats += local_feat 106 | self.vis_scores += vis_score 107 | self.pids.extend(np.asarray(pid)) 108 | self.camids.extend(np.asarray(camid)) 109 | self.paths += paths 110 | 111 | def save(self, path): 112 | output_dict = { 113 | "global_feats": self.global_feats, 114 | "local_feats": self.local_feats, 115 | "vis_scores": self.vis_scores, 116 | "pids": self.pids, 117 | "camids": self.camids, 118 | "paths": self.paths 119 | } 120 | torch.save(output_dict, path) 121 | 122 | def load(self, path): 123 | dict = torch.load(path) 124 | self.global_feats = dict["global_feats"] 125 | self.local_feats = dict["local_feats"] 126 | self.vis_scores = dict["vis_scores"] 127 | self.pids = dict["pids"] 128 | self.camids = dict["camids"] 129 | self.paths = dict["paths"] 130 | 131 | def resplit_for_vehicleid(self): 132 | """每个ID随机选择一辆车组成gallery,剩下的为query。 133 | """ 134 | 135 | # 采样 136 | indexes = range(len(self.pids)) 137 | df = pandas.DataFrame(dict(index=indexes, pid=self.pids)) 138 | query_idxs = [] 139 | gallery_idxs = [] 140 | for idx, group in df.groupby('pid'): 141 | gallery = group.sample(1)['index'].iloc[0] 142 | gallery_idxs.append(gallery) 143 | for index in group.index: 144 | if index != gallery: 145 | query_idxs.append(index) 146 | re_idxs = query_idxs + gallery_idxs 147 | 148 | self.num_query = len(query_idxs) 149 | # 重排序 150 | self.global_feats = [self.global_feats[i] for i in re_idxs] 151 | self.local_feats = [self.local_feats[i] for i in re_idxs] 152 | self.vis_scores = [self.vis_scores[i] for i in re_idxs] 153 | self.pids = [self.pids[i] for i in re_idxs] 154 | self.camids = [self.camids[i] for i in re_idxs] 155 | self.paths = [self.paths[i] for i in re_idxs] 156 | 157 | def compute(self, split=0): 158 | """ 159 | split: When the CUDA memory is not sufficient, we can split the dataset into different parts 160 | for the computing of distance. 161 | """ 162 | global_feats = torch.stack(self.global_feats, dim=0) 163 | local_feats = torch.stack(self.local_feats, dim=0) 164 | vis_scores = torch.stack(self.vis_scores) 165 | if self.feat_norm: 166 | print("The test feature is normalized") 167 | global_feats = F.normalize(global_feats, dim=1, p=2) 168 | local_feats = F.normalize(local_feats, dim=1, p=2) 169 | # 全局距离 170 | print('Calculate distance matrixs...') 171 | # query 172 | qf = global_feats[:self.num_query] 173 | q_pids = np.asarray(self.pids[:self.num_query]) 174 | q_camids = np.asarray(self.camids[:self.num_query]) 175 | # gallery 176 | gf = global_feats[self.num_query:] 177 | g_pids = np.asarray(self.pids[self.num_query:]) 178 | g_camids = np.asarray(self.camids[self.num_query:]) 179 | 180 | qf = qf 181 | m, n = qf.shape[0], gf.shape[0] 182 | 183 | if self.rerank: 184 | distmat = re_ranking(qf, gf, k1=20, k2=6, lambda_value=0.3) 185 | 186 | else: 187 | # qf: M, F 188 | # gf: N, F 189 | if split == 0: 190 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 191 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 192 | distmat.addmm_(qf, gf.t(), beta=1, alpha=-2) 193 | else: 194 | distmat = gf.new(m, n) 195 | start = 0 196 | while start < n: 197 | end = start + split if (start + split) < n else n 198 | num = end - start 199 | 200 | sub_distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, num) + \ 201 | torch.pow(gf[start:end], 2).sum( 202 | dim=1, keepdim=True).expand(num, m).t() 203 | # sub_distmat.addmm_(1, -2, qf, gf[start:end].t()) 204 | sub_distmat.addmm_(qf, gf[start:end].t(), beta=1, alpha=-2) 205 | distmat[:, start:end] = sub_distmat 206 | 207 | start += num 208 | 209 | distmat = distmat.detach().numpy() 210 | 211 | # 局部距离 212 | print('Calculate local distances...') 213 | local_distmat = clck_dist(local_feats[:self.num_query], local_feats[self.num_query:], 214 | vis_scores[:self.num_query], vis_scores[self.num_query:], split=split) 215 | 216 | local_feats = local_feats 217 | local_distmat = local_distmat.detach().cpu().numpy() 218 | 219 | if self.output_path: 220 | print('Saving results...') 221 | outputs = { 222 | "global_feats": global_feats, 223 | "vis_scores": vis_scores, 224 | "local_feats": local_feats, 225 | "pids": self.pids, 226 | "camids": self.camids, 227 | "paths": self.paths, 228 | "num_query": self.num_query, 229 | "distmat": distmat, 230 | "local_distmat": local_distmat, 231 | } 232 | torch.save(outputs, os.path.join(self.output_path, 233 | 'test_output.pkl'), pickle_protocol=4) 234 | 235 | print('Eval...') 236 | cmc, mAP, all_AP = eval_func_mp(distmat + self.lambda_ * (local_distmat ** 2), q_pids, g_pids, q_camids, g_camids, 237 | remove_junk=self.remove_junk) 238 | 239 | return { 240 | "cmc": cmc, 241 | "mAP": mAP, 242 | "distmat": distmat, 243 | "all_AP": all_AP 244 | } 245 | -------------------------------------------------------------------------------- /vehicle_reid_pytorch/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from matplotlib.patches import Rectangle 4 | from vehicle_reid_pytorch.metrics.eval_reid import calc_AP 5 | import cv2 6 | import albumentations as albu 7 | import math 8 | import os 9 | import time 10 | 11 | def time_it(func): 12 | def wrapper(*args, **kwargs): 13 | start = time.time() 14 | print(f'Start {func.__name__}') 15 | output = func(*args, **kwargs) 16 | end = time.time() 17 | print(f'End {func.__name__}. Elapsed {end-start} seconds') 18 | return output 19 | 20 | return wrapper 21 | 22 | 23 | COLOR_LIST = [ 24 | (0, 0, 0), 25 | (255, 0, 0), 26 | (0, 255, 0), 27 | (0, 0, 255), 28 | (255, 255, 0), 29 | (0, 255, 255), 30 | (255, 0, 255), 31 | (255, 255, 255), 32 | ] 33 | 34 | 35 | # helper function for data visualization 36 | def visualize_img(*no_title_images, cols=1, show=True, **images): 37 | """PLot images in one row.""" 38 | n = len(images) + len(no_title_images) 39 | rows = math.ceil(n / cols) 40 | plt.figure(figsize=(5 * cols, 5 * rows)) 41 | cols = np.ceil(n / rows) 42 | for i, image in enumerate(no_title_images): 43 | plt.subplot(rows, cols, i + 1) 44 | plt.xticks([]) 45 | plt.yticks([]) 46 | plt.imshow(image) 47 | 48 | for i, (name, image) in enumerate(images.items()): 49 | plt.subplot(rows, cols, len(no_title_images) + i + 1) 50 | plt.xticks([]) 51 | plt.yticks([]) 52 | plt.title(' '.join(name.split('_')).title()) 53 | plt.imshow(image) 54 | 55 | if show: 56 | plt.show() 57 | 58 | 59 | def get_heatmap(weights, featuremap, image): 60 | """ 61 | 绘制heatmap 62 | :param np.ndarray weights: 不同层的权重 C 63 | :param np.ndarray featuremap: featuremap C,H,W 64 | :param np.ndarray image: 原图 H,W,3 65 | :return: 66 | """ 67 | 68 | heatmap = np.sum(featuremap * weights.reshape([-1, 1, 1]), axis=0) # [B, H, W] 69 | 70 | heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap)) * 255 71 | heatmap = cv2.resize(heatmap, (image.shape[1], image.shape[0])) 72 | heatmap = heatmap.astype(np.uint8) 73 | heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) 74 | heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) 75 | return (heatmap * 0.2 + image * 0.7).astype(np.uint8) 76 | 77 | 78 | def visualize_reid(query, galleries, query_pid, gallery_pids): 79 | """可视化reid的结果 80 | 81 | Arguments: 82 | query {np.array} -- query image 83 | galleries {[np.array]} -- gallery images 84 | query_pid {int}} -- query的id 85 | gallery_pids {[int]} -- gallerys的id 86 | """ 87 | transforms = albu.Compose( 88 | [ 89 | # albu.SmallestMaxSize(256), 90 | albu.LongestMaxSize(256), 91 | # albu.CenterCrop(256, 256), 92 | albu.PadIfNeeded(min_height=256, min_width=256, border_mode=cv2.BORDER_CONSTANT, value=(150, 150, 150)) 93 | # albu.PadIfNeeded(min_height=256, min_width=256, border_mode=cv2.BORDER_REPLICATE) 94 | ] 95 | ) 96 | n = len(galleries) 97 | plt.figure(figsize=(4 * (n + 1), 5)) 98 | plt.subplot(1, n + 1, 1) 99 | plt.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0, wspace=0, hspace=0) 100 | plt.xticks([]) 101 | plt.yticks([]) 102 | plt.title(query_pid) 103 | # query = cv2.resize(query, (256, 256)) 104 | query = transforms(image=query)['image'] 105 | plt.imshow(query) 106 | # plt.gca().add_patch(Rectangle((0, 0), query.shape[1], query.shape[0], edgecolor='w', linewidth=10, fill=False)) 107 | for i in range(len(galleries)): 108 | g_img = galleries[i] 109 | # g_img = cv2.resize(g_img, (256,256)) 110 | g_img = transforms(image=g_img)['image'] 111 | g_pid = gallery_pids[i] 112 | plt.subplot(1, n + 1, i + 2) 113 | plt.xticks([]) 114 | plt.yticks([]) 115 | plt.title(g_pid) 116 | plt.imshow(g_img) 117 | if g_pid == query_pid: 118 | plt.gca().add_patch( 119 | Rectangle((0, 0), g_img.shape[1], g_img.shape[0], edgecolor='g', linewidth=10, fill=False)) 120 | else: 121 | plt.gca().add_patch( 122 | Rectangle((0, 0), g_img.shape[1], g_img.shape[0], edgecolor='r', linewidth=10, fill=False)) 123 | 124 | 125 | def render_mask_to_img(img, cls_map, num_classes): 126 | """ 127 | 128 | :param img: 129 | :param cls_map: 130 | :return: 131 | """ 132 | img = img.copy() 133 | for i in range(num_classes): 134 | if i == 0: 135 | continue 136 | img[cls_map == i] = img[cls_map == i] * 0.7 + np.array(COLOR_LIST[i]) * 0.3 137 | 138 | return img 139 | 140 | 141 | def render_keypoints_to_img(image, points, kp_vis=None, diameter=5): 142 | if kp_vis is not None: 143 | points = [point for vis, point in zip(kp_vis, points) if vis] 144 | im = image.copy() 145 | 146 | for (x, y) in points: 147 | cv2.circle(im, (int(x), int(y)), diameter, (0, 255, 0), -1) 148 | 149 | return im 150 | 151 | def render_bboxes_to_img(image, bboxes, color=(255, 0, 0), thickness=5): 152 | """将bbox画到图片上 153 | 154 | Arguments: 155 | image {[type]} -- [description] 156 | bboxes {[type]} -- bbox的列表。bbox格式为左上角xy和右下角xy 157 | 158 | Keyword Arguments: 159 | color {tuple} -- [description] (default: {(255, 0, 0)}) 160 | thickness {int} -- [description] (default: {10}) 161 | """ 162 | im = image.copy() 163 | for bbox in bboxes: 164 | pt1 = (int(bbox[0]), int(bbox[1])) 165 | pt2 = (int(bbox[2]), int(bbox[3])) 166 | cv2.rectangle(im, pt1, pt2, color, thickness) 167 | return im 168 | 169 | 170 | def generate_html_table(content_table, image_width='auto', image_height='auto', output_path=''): 171 | """Generate html table 172 | 173 | Args: 174 | content_table: 2D table 175 | width: image width 176 | height: image height 177 | output_path: output html path. 178 | """ 179 | html = '' 180 | html += '' 181 | 182 | html +=""" 183 | 184 | 185 | 186 | 187 | 188 | 189 | """ 190 | 191 | html += '' 192 | html += '' 193 | html += """ 194 | 213 | """ 214 | 215 | html += '' 216 | html += '' 217 | heads = content_table[0].keys() 218 | 219 | for i, h in enumerate(heads): 220 | html += f'' 221 | 222 | html += "" 223 | html += "" 224 | html += '
{h}
' 225 | 226 | html +=""" 227 | 228 | 229 | 230 | 231 | 232 | """ 233 | 234 | width = image_width 235 | height = image_height 236 | all_content_dict = [] 237 | for content_row in content_table: 238 | content_dict = {} 239 | for i, head in enumerate(heads): 240 | content = content_row[head] 241 | subhtml = '' 242 | 243 | if type(content) == dict: # 图片,支持更丰富的样式 244 | src = content['src'] 245 | alt = '' if not "alt" in content else content['alt'] 246 | title = '' if not "title" in content else content['title'] 247 | item_width = width if not "width" in content else content['width'] 248 | item_height = height if not "height" in content else content['height'] 249 | text = '' if not "text" in content else content['text'] 250 | style = '' if not "style" in content else content['style'] 251 | if text != '': 252 | subhtml += f"
{text}
" 253 | subhtml += f"\"{alt}\"" 254 | 255 | # 图片 256 | elif type(content) == str and os.path.splitext(content)[-1].lower() in ['.jpg', '.png', '.jpeg', '.gif']: 257 | src = content 258 | subhtml += f"\"{src}\"" 259 | 260 | # 视频 261 | elif type(content) == str and os.path.splitext(content)[-1].lower() in ['.mp4', '.webm']: 262 | src = content 263 | subhtml += f"