├── src ├── __init__.py ├── models │ ├── __init__.py │ ├── mvcnn.py │ ├── multiview_base.py │ ├── mvselect.py │ ├── mvdet.py │ ├── shufflenetv2.py │ └── resnet.py ├── utils │ ├── __init__.py │ ├── str2bool.py │ ├── meters.py │ ├── tensor_utils.py │ ├── draw_curve.py │ ├── logger.py │ ├── nms.py │ ├── projection.py │ ├── decode.py │ ├── image_utils.py │ └── mvrender.py ├── loss │ ├── __init__.py │ ├── gaussian_mse.py │ └── losses.py ├── datasets │ ├── __init__.py │ ├── modelnet40.py │ ├── multiviewx.py │ ├── wildtrack.py │ ├── scanobjectnn.py │ └── frameDataset.py ├── evaluation │ ├── README.md │ ├── evaluate.py │ ├── pyeval │ │ ├── evaluateDetection.py │ │ ├── README.md │ │ └── CLEAR_MOD_HUN.py │ └── test-demo.txt └── trainer_mvcnn.py ├── .gitattributes ├── requirements.txt ├── visualize_img.py ├── README.md ├── .gitignore ├── show_coverage.py ├── visualize_grid.py ├── mvcnn_speed_test.py ├── mvdet_speed_test.py └── main.py /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /src/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .gaussian_mse import GaussianMSE 2 | from .losses import focal_loss, entropy, regL1loss, regCEloss 3 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .wildtrack import Wildtrack 2 | from .multiviewx import MultiviewX 3 | from .frameDataset import frameDataset 4 | from .modelnet40 import ModelNet40 5 | from .scanobjectnn import ScanObjectNN 6 | -------------------------------------------------------------------------------- /src/utils/str2bool.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def str2bool(v): 5 | if isinstance(v, bool): 6 | return v 7 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 8 | return True 9 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 10 | return False 11 | else: 12 | raise argparse.ArgumentTypeError('Boolean value expected.') 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py==3.8.0 2 | kornia==0.6.12 3 | matlab==0.1 4 | matplotlib==3.7.1 5 | mvtorch==0.1.0 6 | numpy==1.24.3 7 | open3d==0.17.0 8 | opencv_python==4.7.0.72 9 | opencv_python_headless==4.9.0.80 10 | pandas==2.0.1 11 | Pillow==9.4.0 12 | Pillow==10.2.0 13 | pytorch3d==0.7.4 14 | scipy==1.12.0 15 | thop==0.1.1.post2209072238 16 | torch==1.13.0 17 | torchvision==0.14.0 18 | tqdm==4.65.0 19 | -------------------------------------------------------------------------------- /src/utils/meters.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | 4 | def __init__(self): 5 | self.val = 0 6 | self.avg = 0 7 | self.sum = 0 8 | self.count = 0 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count 21 | -------------------------------------------------------------------------------- /src/loss/gaussian_mse.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class GaussianMSE(nn.Module): 8 | 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, x, target, kernel): 13 | target = self._traget_transform(x, target, kernel) 14 | return F.mse_loss(x, target) 15 | 16 | def _traget_transform(self, x, target, kernel): 17 | target = F.adaptive_max_pool2d(target, x.shape[2:]) 18 | with torch.no_grad(): 19 | target = F.conv2d(target, kernel.float().to(target.device), padding=int((kernel.shape[-1] - 1) / 2)) 20 | return target 21 | -------------------------------------------------------------------------------- /src/utils/tensor_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def _sigmoid(x): 5 | y = torch.clamp(x.sigmoid(), min=1e-4, max=1 - 1e-4) 6 | return y 7 | 8 | 9 | def _gather_feat(feat, ind, mask=None): 10 | dim = feat.size(2) 11 | ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) 12 | feat = feat.gather(1, ind) 13 | if mask is not None: 14 | mask = mask.unsqueeze(2).expand_as(feat) 15 | feat = feat[mask] 16 | feat = feat.view(-1, dim) 17 | return feat 18 | 19 | 20 | def _transpose_and_gather_feat(feat, ind): 21 | feat = feat.permute(0, 2, 3, 1).contiguous() 22 | feat = feat.view(feat.size(0), -1, feat.size(3)) 23 | feat = _gather_feat(feat, ind) 24 | return feat 25 | -------------------------------------------------------------------------------- /src/utils/draw_curve.py: -------------------------------------------------------------------------------- 1 | # import matplotlib 2 | # 3 | # matplotlib.use('agg') 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def draw_curve(path, x_epoch, train_loss, test_loss, train_result, test_result): 8 | fig = plt.figure() 9 | ax1 = fig.add_subplot(121, title="loss") 10 | ax1.plot(x_epoch, train_loss, 'bo-', label='train' + ': {:.3f}'.format(train_loss[-1])) 11 | ax1.plot(x_epoch, test_loss, 'ro-', label='test' + ': {:.3f}'.format(test_loss[-1])) 12 | ax1.legend() 13 | if train_result is not None and None not in train_result: 14 | ax2 = fig.add_subplot(122, title="result") 15 | ax2.plot(x_epoch, train_result, 'bo-', label='train' + ': {:.1f}'.format(train_result[-1])) 16 | else: 17 | ax2 = fig.add_subplot(122, title="result") 18 | ax2.plot(x_epoch, test_result, 'ro-', label='test' + ': {:.1f}'.format(test_result[-1])) 19 | ax2.legend() 20 | fig.savefig(path) 21 | plt.close(fig) 22 | -------------------------------------------------------------------------------- /src/utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | class Logger(object): 6 | def __init__(self, fpath=None): 7 | self.console = sys.stdout 8 | self.file = None 9 | if fpath is not None: 10 | os.makedirs(os.path.dirname(fpath), exist_ok=True) 11 | self.file = open(fpath, 'w') 12 | 13 | def __del__(self): 14 | self.close() 15 | 16 | def __enter__(self): 17 | pass 18 | 19 | def __exit__(self, *args): 20 | self.close() 21 | 22 | def write(self, msg): 23 | self.console.write(msg) 24 | if self.file is not None: 25 | self.file.write(msg) 26 | 27 | def flush(self): 28 | self.console.flush() 29 | if self.file is not None: 30 | self.file.flush() 31 | os.fsync(self.file.fileno()) 32 | 33 | def close(self): 34 | self.console.close() 35 | if self.file is not None: 36 | self.file.close() 37 | -------------------------------------------------------------------------------- /src/evaluation/README.md: -------------------------------------------------------------------------------- 1 | # Evaluation for MultiView detection 2 | 3 | ## Preparation 4 | 5 | First, you have to install matlab. 6 | 7 | Then, please follow the installation guide for [matlab-engine](https://au.mathworks.com/help/matlab/matlab_external/install-matlab-engine-api-for-python-in-nondefault-locations.html). 8 | 9 | e.g. 10 | 11 | ```shell script 12 | cd /usr/local/MATLAB/R2019a/extern/engines/python 13 | python setup.py build --build-base="/home/houyz/matlab/" install --prefix="/home/houyz/miniconda3" 14 | ``` 15 | 16 | ## Demo 17 | 18 | First, after the installation of matlab, you can run the code ```motchallenge-devkit/eval_demo.m```. 19 | 20 | Then, once the set up of matlab-engine is finished, you can run the following 21 | ```shell script 22 | cd ../.. # this should bring you to the code root folder 23 | python 24 | ``` 25 | 26 | ## File format 27 | 28 | ground truth file: ```motchallenge-devkit/gt.txt``` 29 | 30 | detection result file: ```motchallenge-devkit/test.txt``` 31 | 32 | - The first column in ground truth / detection file should be frame number 33 | - The second and third column should be x and y coordinate 34 | 35 | ## Alternative Tools 36 | A Python version of the official MATLAB API is provided in ```src/evaluation/pyeval``` 37 | 38 | -------------------------------------------------------------------------------- /src/evaluation/evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from src.evaluation.pyeval.evaluateDetection import evaluateDetection_py 3 | 4 | 5 | def evaluate(res_fpath, gt_fpath, dataset='wildtrack', frames=None): 6 | try: 7 | import matlab.engine 8 | eng = matlab.engine.start_matlab() 9 | eng.cd('src/evaluation/motchallenge-devkit') 10 | res = eng.evaluateDetection(res_fpath, gt_fpath, dataset) 11 | recall, precision, moda, modp = np.array(res['detMets']).squeeze()[[0, 1, -2, -1]] 12 | except: 13 | recall, precision, moda, modp, stats = evaluateDetection_py(res_fpath, gt_fpath, frames) 14 | return recall, precision, moda, modp 15 | 16 | 17 | if __name__ == "__main__": 18 | import os 19 | 20 | res_fpath = os.path.abspath('test-demo.txt') 21 | gt_fpath = os.path.abspath('gt-demo.txt') 22 | os.chdir('../..') 23 | print(os.path.abspath('.')) 24 | 25 | # recall, precision, moda, modp = matlab_eval(res_fpath, gt_fpath, 'Wildtrack') 26 | # print(f'matlab eval: MODA {moda:.1f}, MODP {modp:.1f}, prec {precision:.1f}, rcll {recall:.1f}') 27 | # recall, precision, moda, modp = python_eval(res_fpath, gt_fpath, 'Wildtrack') 28 | # print(f'python eval: MODA {moda:.1f}, MODP {modp:.1f}, prec {precision:.1f}, rcll {recall:.1f}') 29 | 30 | recall, precision, moda, modp = evaluate(res_fpath, gt_fpath, dataset='Wildtrack') 31 | print(f'eval: MODA {moda:.1f}, MODP {modp:.1f}, prec {precision:.1f}, rcll {recall:.1f}') 32 | -------------------------------------------------------------------------------- /src/evaluation/pyeval/evaluateDetection.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from src.evaluation.pyeval.CLEAR_MOD_HUN import CLEAR_MOD_HUN 3 | 4 | 5 | def evaluateDetection_py(det, gt, frames=None): 6 | """ 7 | This is simply the python translation of a MATLAB Evaluation tool created by P. Dollar. 8 | Translated by Zicheng Duan 9 | Modified by Yunzhong Hou 10 | 11 | The purpose of this API: 12 | 1. To allow the project to run purely in Python without using MATLAB Engine. 13 | 14 | @param det: detection result file path 15 | @param gt: ground truth result file path 16 | @return: MODP, MODA, recall, precision 17 | """ 18 | 19 | if isinstance(gt, str): 20 | gt = np.loadtxt(gt) 21 | else: 22 | gt = np.array(gt) 23 | if isinstance(det, str): 24 | det = np.loadtxt(det) 25 | else: 26 | det = np.array(det) 27 | if det.shape == (3,): 28 | det = det[None, :] 29 | elif det.shape == (0,): 30 | det = det.reshape([0, 3]) 31 | 32 | if frames is None: 33 | frames = np.unique(det[:, 0]) if len(det) else np.unique(gt[:, 0]) 34 | gt = gt[np.isin(gt[:, 0], frames), :] 35 | 36 | MODA, MODP, precision, recall, (tp, fp, fn, gt, dist) = CLEAR_MOD_HUN(gt, det) 37 | return MODA, MODP, precision, recall, (tp, fp, fn, gt, dist) 38 | 39 | 40 | if __name__ == "__main__": 41 | res_fpath = "../test-demo.txt" 42 | gt_fpath = "../gt-demo.txt" 43 | moda, modp, precision, recall, stats = evaluateDetection_py(res_fpath, gt_fpath, np.arange(1800, 2000)) 44 | print(f'python eval: MODA {moda:.1f}, MODP {modp:.1f}, prec {precision:.1f}, rcll {recall:.1f}') 45 | -------------------------------------------------------------------------------- /src/evaluation/pyeval/README.md: -------------------------------------------------------------------------------- 1 | ## Python Evaluation Tool for MVDet 2 | 3 | This is simply the Python translation of a MATLAB Evaluation tool used to evaluate detection result created by P. Dollar. 4 | Translated by [Zicheng Duan](https://github.com/ZichengDuan). 5 | 6 | ### Purpose 7 | Allowing the project to run purely in Python without using MATLAB Engine. 8 | 9 | 10 | ### Critical information before usage 11 | 1. This API is only tested and deployed in this project: [hou-yz/MVDet](https://github.com/hou-yz/MVDet), might not be compatible with other projects. 12 | 2. The detection result using this API **is a little bit lower** (approximately 0~2% decrease in MODA, MODP) than that when using official MATLAB evaluation tool, the reason might be the Hungarian Algorithm implemented in sklearn is a little bit different from the one implemented by P. Dollar, hence resulting in different results. 13 | Therefore, **please use the official MATLAB API if you want to obtain the same evaluation result shown in the paper**. This Python API is only used for convenience. 14 | 3. The training process would not be affected by this API. 15 | 16 | ### Usage 17 | Please go to ```test()``` function in ```trainer.py``` for more details. 18 | 19 | ``` 20 | recall, precision, moda, modp = matlab_eval(os.path.abspath(res_fpath), os.path.abspath(gt_fpath), 21 | data_loader.dataset.base.__name__) 22 | 23 | # If you want to use the unofiicial python evaluation tool for convenient purposes. 24 | # recall, precision, modp, moda = python_eval(os.path.abspath(res_fpath), os.path.abspath(gt_fpath), 25 | # data_loader.dataset.base.__name__) 26 | ``` 27 | -------------------------------------------------------------------------------- /visualize_img.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from src.datasets import frameDataset, MultiviewX, Wildtrack, ModelNet40 5 | from src.utils.image_utils import img_color_denormalize 6 | import matplotlib.pyplot as plt 7 | from torchvision.utils import make_grid, save_image 8 | 9 | 10 | def set_border(img, width=5, fill=(0, 255, 0)): 11 | C, H, W = img.shape 12 | fill = torch.tensor(fill, dtype=img.dtype, device=img.device)[:, None, None] / 255 13 | img[:, :, :width] = fill 14 | img[:, :, -width:] = fill 15 | img[:, :width, :] = fill 16 | img[:, -width:, :] = fill 17 | return img 18 | 19 | 20 | if __name__ == '__main__': 21 | denorm = img_color_denormalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 22 | # dataset = frameDataset(Wildtrack(os.path.expanduser('~/Data/Wildtrack')), split='test', ) 23 | # imgs, world_gt, imgs_gt, affine_mats, frame, keep_cams = dataset[0] 24 | 25 | # dataset = imgDataset(os.path.expanduser('~/Data/modelnet/modelnet40v1png'), 12, split='test') 26 | dataset = ModelNet40(os.path.expanduser('~/Data/modelnet/modelnet40v2png_ori4'), 20, split='test') 27 | for i in range(30): 28 | # index=np.random.randint(len(dataset)) 29 | index = i 30 | imgs, tgt, keep_cams = dataset[index+100] 31 | imgs = denorm(imgs) 32 | # imgs[0] = set_border(imgs[0]) # , fill=(255, 192, 0) 33 | # imgs[9] = set_border(imgs[9]) # , fill=(0, 176, 80) 34 | # imgs[0] = set_border(imgs[0], width=20) # , fill=(255, 192, 0) 35 | # imgs[1] = set_border(imgs[1], width=20) # , fill=(0, 176, 80) 36 | # imgs[5] = set_border(imgs[5], width=20) # , fill=(0, 176, 80) 37 | 38 | imgs_grid = make_grid(imgs, nrow=5) 39 | # save_image(imgs_grid, 'imgs_grid.png') 40 | plt.imshow(imgs_grid.permute([1, 2, 0])) 41 | plt.show() 42 | pass 43 | -------------------------------------------------------------------------------- /src/utils/nms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | # Original author: Francisco Massa: 5 | # https://github.com/fmassa/object-detection.torch 6 | # Ported to PyTorch by Max deGroot (02/01/2017) 7 | def nms(points, scores, dist_thres=50 / 2.5, top_k=50): 8 | """Apply non-maximum suppression at test time to avoid detecting too many 9 | overlapping bounding boxes for a given object. 10 | Args: 11 | points: (tensor) The location preds for the img, Shape: [num_priors,2]. 12 | scores: (tensor) The class predscores for the img, Shape:[num_priors]. 13 | dist_thres: (float) The overlap thresh for suppressing unnecessary boxes. 14 | top_k: (int) The Maximum number of box preds to consider. 15 | Return: 16 | The indices of the kept boxes with respect to num_priors. 17 | """ 18 | 19 | assert points.shape[0] == scores.shape[0], 'make sure same points and scores have the same size' 20 | keep = torch.zeros_like(scores).long() 21 | if points.numel() == 0: 22 | return keep, 0 23 | _, indices = scores.sort(0) # sort in ascending order 24 | # I = I[v >= 0.01] 25 | top_k = min(top_k, len(indices)) 26 | indices = indices[-top_k:] # indices of the top-k largest vals 27 | 28 | # keep = torch.Tensor() 29 | count = 0 30 | while indices.numel() > 0: 31 | idx = indices[-1] # index of current largest val 32 | # keep.append(i) 33 | keep[count] = idx 34 | count += 1 35 | if indices.numel() == 1: 36 | break 37 | indices = indices[:-1] # remove kept element from view 38 | target_point = points[idx, :] 39 | # load bboxes of next highest vals 40 | remaining_points = points[indices, :] 41 | dists = torch.norm(target_point - remaining_points, dim=1) # store result in distances 42 | # keep only elements with an dists > dist_thres 43 | indices = indices[dists > dist_thres] 44 | return keep, count 45 | -------------------------------------------------------------------------------- /src/utils/projection.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def project_2d_points(project_mat, input_points): 5 | vertical_flag = 0 6 | if input_points.shape[1] == 2: 7 | vertical_flag = 1 8 | input_points = np.transpose(input_points) 9 | input_points = np.concatenate([input_points, np.ones([1, input_points.shape[1]])], axis=0) 10 | output_points = project_mat @ input_points 11 | output_points = output_points[:2, :] / output_points[2, :] 12 | if vertical_flag: 13 | output_points = np.transpose(output_points) 14 | return output_points 15 | 16 | 17 | def get_worldcoord_from_imagecoord(image_coord, intrinsic_mat, extrinsic_mat, z=0): 18 | project_mat = get_worldcoord_from_imgcoord_mat(intrinsic_mat, extrinsic_mat, z) 19 | return project_2d_points(project_mat, image_coord) 20 | 21 | 22 | def get_imagecoord_from_worldcoord(world_coord, intrinsic_mat, extrinsic_mat, z=0): 23 | project_mat = get_imgcoord_from_worldcoord_mat(intrinsic_mat, extrinsic_mat, z) 24 | return project_2d_points(project_mat, world_coord) 25 | 26 | 27 | def get_imgcoord_from_worldcoord_mat(intrinsic_mat, extrinsic_mat, z=0): 28 | """image of shape C,H,W (C,N_row,N_col); xy indexging; x,y (w,h) (n_col,n_row) 29 | world of shape N_row, N_col; indexed as specified in the dataset attribute (xy or ij) 30 | z in meters by default 31 | """ 32 | threeD2twoD = np.array([[1, 0, 0], [0, 1, 0], [0, 0, z], [0, 0, 1]]) 33 | project_mat = intrinsic_mat @ extrinsic_mat @ threeD2twoD 34 | return project_mat 35 | 36 | 37 | def get_worldcoord_from_imgcoord_mat(intrinsic_mat, extrinsic_mat, z=0): 38 | """image of shape C,H,W (C,N_row,N_col); xy indexging; x,y (w,h) (n_col,n_row) 39 | world of shape N_row, N_col; indexed as specified in the dataset attribute (xy or ij) 40 | z in meters by default 41 | """ 42 | project_mat = np.linalg.inv(get_imgcoord_from_worldcoord_mat(intrinsic_mat, extrinsic_mat, z)) 43 | return project_mat 44 | -------------------------------------------------------------------------------- /src/evaluation/pyeval/CLEAR_MOD_HUN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.optimize import linear_sum_assignment 3 | 4 | 5 | def CLEAR_MOD_HUN(gt, det, dist_thres=50 / 2.5): 6 | frames, num_gt_per_frame = np.unique(gt[:, 0], return_counts=True) 7 | matches = -np.ones([len(frames), np.max(num_gt_per_frame)]) # matching result for each GT target in each frame 8 | num_matches = np.zeros([len(frames)]) # c in original code 9 | fp = np.zeros([len(frames)]) 10 | fn = np.zeros([len(frames)]) # m in original code 11 | distances = np.inf * np.ones([len(frames), np.max(num_gt_per_frame)]) 12 | 13 | for frame_idx, t in enumerate(frames): 14 | gt_idx, = np.where(gt[:, 0] == t) 15 | det_idx, = np.where(det[:, 0] == t) 16 | 17 | if gt_idx is not None and det_idx is not None: 18 | dist = np.linalg.norm(gt[gt_idx, 1:][:, None, :] - det[det_idx, 1:][None, :, :], axis=2) 19 | 20 | # Please notice that the price/distance of are set to 100000 instead of np.inf, 21 | # since the Hungarian Algorithm implemented in sklearn will be slow if we use np.inf. 22 | dist[dist > dist_thres] = 1e6 23 | HUN_res = np.array(linear_sum_assignment(dist)) 24 | # filter out true matches 25 | HUN_res = HUN_res[:, dist[HUN_res[0], HUN_res[1]] < dist_thres] 26 | matches[frame_idx, HUN_res[0]] = HUN_res[1] 27 | distances[frame_idx, HUN_res[0]] = dist[HUN_res[0], HUN_res[1]] 28 | 29 | num_matches[frame_idx] = (matches[frame_idx, :] != -1).sum() 30 | fp[frame_idx] = len(det_idx) - num_matches[frame_idx] 31 | fn[frame_idx] = num_gt_per_frame[frame_idx] - num_matches[frame_idx] 32 | 33 | MODA = (1 - ((np.sum(fn) + np.sum(fp)) / np.sum(num_gt_per_frame))) * 100 34 | MODP = sum(1 - distances[distances < dist_thres] / dist_thres) / (np.sum(num_matches) + 1e-8) * 100 35 | precision = np.sum(num_matches) / (np.sum(fp) + np.sum(num_matches) + 1e-8) * 100 36 | recall = np.sum(num_matches) / np.sum(num_gt_per_frame) * 100 37 | 38 | return MODA, MODP, precision, recall, (num_matches, fp, fn, num_gt_per_frame, distances) 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning to Select Camera Views: Efficient Multiview Understanding at Few Glances 2 | 3 | ## Overview 4 | We release code for **MVSelect**, a view selection module for efficient multiview understanding. Parallel to reducing the image resolution or using lighter network backbones, the proposed approach reduces the computational cost for multiview understanding by limiting the number of views considered. 5 | 6 | 7 | 8 | ## Content 9 | - [Dependencies](#dependencies) 10 | - [Data Preparation](#data-preparation) 11 | - [Training](#training) 12 | 13 | 14 | ## Dependencies 15 | Please install dependencies with 16 | ``` 17 | pip install -r requirements.txt 18 | ``` 19 | 20 | ## Data Preparation 21 | 22 | For multiview classification, we use ModelNet40 dataset with the circular 12-view setup [[link](https://github.com/jongchyisu/mvcnn_pytorch)][[download](http://supermoe.cs.umass.edu/shape_recog/shaded_images.tar.gz)] and the dodecahedral 20-view setup [[link](https://github.com/kanezaki/pytorch-rotationnet)][[download](https://data.airc.aist.go.jp/kanezaki.asako/data/modelnet40v2png_ori4.tar)]. 23 | 24 | For multiview detection, we use MultiviewX [[link](https://github.com/hou-yz/MultiviewX)][[download](https://1drv.ms/u/s!AtzsQybTubHfgP9BJt2g7R_Ku4X3Pg?e=GFGeVn)] and Wildtrack [[link](https://www.epfl.ch/labs/cvlab/data/data-wildtrack/)][[download](http://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/Wildtrack/Wildtrack_dataset_full.zip)] in this project. 25 | 26 | Your `~/Data/` folder should look like this 27 | ``` 28 | Data/ 29 | ├── modelnet/ 30 | │ ├── modelnet40_images_new_12x/ 31 | │ │ └── ... 32 | │ └── modelnet40v2png_ori4/ 33 | | └── ... 34 | ├── MultiviewX/ 35 | │ └── ... 36 | └── Wildtrack/ 37 | └── ... 38 | ``` 39 | 40 | 41 | ## Training 42 | In order to train the task networks, please run the following 43 | ```shell script 44 | # task network 45 | python main.py 46 | ``` 47 | This should automatically return the full N-view results, as well as the oracle performances. 48 | 49 | To train the MVSelect, please run 50 | ```shell script 51 | # MVSelect only 52 | python main.py --step 2 --base_lr 0 --other_lr 0 53 | # joint training 54 | python main.py --step 2 55 | ``` 56 | The default dataset is Wildtrack. For other datasets, please specify with the `-d` argument. 57 | 58 | 59 | ## Pre-trained models 60 | You can download the checkpoints at this [link](https://1drv.ms/u/s!AtzsQybTubHfhNRCxKzkaOiLCKkIIA?e=fQxfhI). 61 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | .dmypy.json 122 | dmypy.json 123 | 124 | # Pyre type checker 125 | .pyre/ 126 | 127 | .idea 128 | imgs* 129 | .vscode/ 130 | logs/* 131 | -------------------------------------------------------------------------------- /src/loss/losses.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Portions of this code are from 3 | # CornerNet (https://github.com/princeton-vl/CornerNet) 4 | # Copyright (c) 2018, University of Michigan 5 | # Licensed under the BSD 3-Clause License 6 | # ------------------------------------------------------------------------------ 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import torch 12 | from src.utils.tensor_utils import _transpose_and_gather_feat, _sigmoid 13 | import torch.nn.functional as F 14 | 15 | 16 | def entropy(logits, dim=-1): 17 | return -torch.sum(F.softmax(logits, dim=dim) * F.log_softmax(logits, dim=dim), dim=dim) 18 | 19 | 20 | def focal_loss(output, target, mask=None, gamma=2, alpha=0.9, reduction='mean'): 21 | if mask is None: 22 | mask = torch.ones_like(target) 23 | output = _sigmoid(output) 24 | target = target.to(output.device) 25 | is_positive = target == 1 26 | mask = mask.to(output.device) 27 | 28 | neg_weights = torch.pow(1 - target, 4) 29 | 30 | pos_loss = torch.log(output) * torch.pow(1 - output, gamma) * is_positive * mask 31 | neg_loss = torch.log(1 - output) * torch.pow(output, gamma) * neg_weights * ~is_positive * mask 32 | 33 | if alpha >= 0: 34 | pos_loss *= alpha 35 | neg_loss *= (1 - alpha) 36 | 37 | loss = -(pos_loss + neg_loss).sum([1, 2, 3]) / is_positive.sum([1, 2, 3]).clamp(1) 38 | if reduction == 'none': 39 | return loss 40 | elif reduction == 'sum': 41 | return loss.sum() 42 | elif reduction == 'mean': 43 | return loss.mean() 44 | else: 45 | raise Exception 46 | 47 | 48 | def regL1loss(output, mask, ind, target): 49 | mask, ind, target = mask.to(output.device), ind.to(output.device), target.to(output.device) 50 | pred = _transpose_and_gather_feat(output, ind) 51 | mask = mask.unsqueeze(2).expand_as(pred).float() 52 | loss = F.l1_loss(pred * mask, target * mask, reduction='sum') 53 | loss = loss / (mask.sum() + 1e-4) 54 | return loss 55 | 56 | 57 | def regCEloss(output, mask, ind, target): 58 | mask, ind, target = mask.to(output.device), ind.to(output.device), target.to(output.device) 59 | pred = _transpose_and_gather_feat(output, ind) 60 | if len(target[mask]) != 0: 61 | loss = F.cross_entropy(pred[mask], target[mask], reduction='sum') 62 | loss = loss / (mask.sum() + 1e-4) 63 | else: 64 | loss = 0 65 | return loss 66 | -------------------------------------------------------------------------------- /show_coverage.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from src.datasets import frameDataset, MultiviewX, Wildtrack 5 | from src.evaluation.evaluate import evaluate 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | def show_coverage(base): 10 | dataset = frameDataset(base, split='test', ) 11 | N = dataset.num_cam 12 | gts = [np.loadtxt(f'{base.root}/gt_{cam}.txt') for cam in range(N)] 13 | cover_area = np.zeros([N, N]) 14 | performances = np.zeros([N, N, 4]) 15 | recall, precision, moda, modp = evaluate(f'{base.root}/gt.txt', f'{base.root}/gt.txt') 16 | for init_cam in range(N): 17 | for selected_cam in range(N): 18 | cover_area[init_cam, selected_cam] = (dataset.Rworld_coverage[init_cam] + 19 | dataset.Rworld_coverage[selected_cam]).bool().float().mean() 20 | gt_in_cam = np.unique(np.concatenate([gts[init_cam], gts[selected_cam]]), axis=0) 21 | gt_in_cam = gt_in_cam[gt_in_cam[:, 0] > dataset.num_frame * 0.9] 22 | np.savetxt('temp.txt', gt_in_cam) 23 | recall, precision, moda, modp = evaluate('temp.txt', f'{base.root}/gt.txt') 24 | performances[init_cam, selected_cam] = [moda, modp, precision, recall] 25 | 26 | pass 27 | 28 | plt.figure(figsize=(5, 5)) 29 | plt.imshow(cover_area, cmap='Blues') 30 | plt.xticks(np.arange(N), np.arange(N) + 1) 31 | plt.xlabel('second view') 32 | plt.yticks(np.arange(N), np.arange(N) + 1) 33 | plt.ylabel('initial view') 34 | # Loop over data dimensions and create text annotations. 35 | for i in range(N): 36 | for j in range(N): 37 | plt.text(j, i, f'{cover_area[i, j]:.2f}', ha="center", va="center", ) 38 | plt.tight_layout() 39 | plt.show() 40 | 41 | print(cover_area.max(axis=0).mean()) 42 | 43 | plt.figure(figsize=(5, 5)) 44 | plt.imshow(performances[:, :, 0], cmap='Blues') 45 | plt.xticks(np.arange(N), np.arange(N) + 1) 46 | plt.xlabel('second view') 47 | plt.yticks(np.arange(N), np.arange(N) + 1) 48 | plt.ylabel('initial view') 49 | # Loop over data dimensions and create text annotations. 50 | for i in range(N): 51 | for j in range(N): 52 | plt.text(j, i, f'{performances[i, j, 0]:.1f}', ha="center", va="center", ) 53 | plt.tight_layout() 54 | plt.show() 55 | 56 | print(performances[:, :, 0].max(axis=0).mean()) 57 | 58 | 59 | if __name__ == '__main__': 60 | show_coverage(Wildtrack(os.path.expanduser('~/Data/Wildtrack'))) 61 | show_coverage(MultiviewX(os.path.expanduser('~/Data/MultiviewX'))) 62 | -------------------------------------------------------------------------------- /src/models/mvcnn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.transforms as T 7 | import torchvision.models as models 8 | import matplotlib.pyplot as plt 9 | from src.models.multiview_base import MultiviewBase 10 | from src.models.mvselect import CamSelect 11 | 12 | 13 | class MVCNN(MultiviewBase): 14 | def __init__(self, dataset, arch='resnet18', aggregation='max'): 15 | super().__init__(dataset, aggregation) 16 | if arch == 'resnet18': 17 | self.base = nn.Sequential(*list(models.resnet18(pretrained=True).children())[:-2]) 18 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 19 | self.classifier = nn.Linear(512, dataset.num_class) 20 | base_dim = 512 21 | elif arch == 'vgg11': 22 | self.base = models.vgg11(pretrained=True).features 23 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 24 | self.classifier = models.vgg11(pretrained=True).classifier 25 | self.classifier[-1] = nn.Linear(4096, dataset.num_class) 26 | base_dim = 512 27 | else: 28 | raise Exception('architecture currently support [vgg11, resnet18]') 29 | 30 | # select camera based on initialization 31 | self.select_module = CamSelect(dataset.num_cam, base_dim, 1, aggregation) 32 | pass 33 | 34 | def get_feat(self, imgs, M=None, down=1, visualize=False): 35 | B, N, _, H, W = imgs.shape 36 | imgs = F.interpolate(imgs.flatten(0, 1), scale_factor=1 / down) 37 | imgs_feat = self.base(imgs) 38 | imgs_feat = self.avgpool(imgs_feat) 39 | _, C, H, W = imgs_feat.shape 40 | return imgs_feat.unflatten(0, [B, N]), None 41 | 42 | def get_output(self, overall_feat, visualize=False): 43 | overall_result = self.classifier(torch.flatten(overall_feat, 1)) 44 | return overall_result 45 | 46 | 47 | if __name__ == '__main__': 48 | from src.datasets import ModelNet40 49 | from torch.utils.data import DataLoader 50 | from thop import profile 51 | import itertools 52 | 53 | dataset = ModelNet40('/home/houyz/Data/modelnet/modelnet40_images_new_12x', 12) 54 | dataloader = DataLoader(dataset, 1, False, num_workers=0) 55 | imgs, tgt, keep_cams = next(iter(dataloader)) 56 | model = MVCNN(dataset).cuda() 57 | init_prob = F.one_hot(torch.tensor([0, 1]), num_classes=dataset.num_cam) 58 | keep_cams[0, 3] = 0 59 | model.train() 60 | res = model(imgs.cuda(), None, 2, init_prob, 3, keep_cams) 61 | # macs, params = profile(model, inputs=(imgs[:, :2].cuda(),)) 62 | # macs, params = profile(model.select_module, inputs=(torch.randn([1, 12, 512, 1, 1]).cuda(), 63 | # F.one_hot(torch.tensor([1]), num_classes=20).cuda())) 64 | # macs, params = profile(model, inputs=(torch.randn([1, 512, 1, 1]).cuda(),)) 65 | pass 66 | -------------------------------------------------------------------------------- /visualize_grid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | if __name__ == '__main__': 5 | precs = np.loadtxt( 6 | '/home/houyz/Code/MVselect/logs/modelnet40_12/resnet18_max_down1_lr5e-05_b8_e10_dropcam0.0_2023-02-26_04-03-42/prec_94.0_Lstrategy85.8_Rstrategy85.5_theory86.0_avg69.1.txt') 7 | precs = precs[len(precs) // 2:] 8 | N = len(precs) 9 | 10 | plt.figure(figsize=(5, 5)) 11 | plt.imshow(precs, cmap='Blues') 12 | plt.xticks(np.arange(N), np.arange(N) + 1) 13 | plt.xlabel('second view') 14 | plt.yticks(np.arange(N), np.arange(N) + 1) 15 | plt.ylabel('initial view') 16 | # Loop over data dimensions and create text annotations. 17 | for i in range(N): 18 | for j in range(N): 19 | text = plt.text(j, i, precs[i, j], ha="center", va="center", ) 20 | plt.tight_layout() 21 | plt.show() 22 | 23 | # only mvselect 24 | # prob = np.zeros([N, N]) 25 | # prob[0, 10] = 1 26 | # prob[1, 10] = 1 27 | # prob[2, 10] = 1 28 | # prob[3, 10] = 1 29 | # prob[4, [1, 9]] = [0.88, 0.12] 30 | # prob[5, 10] = 1 31 | # prob[6, 10] = 1 32 | # prob[7, 10] = 1 33 | # prob[8, 10] = 1 34 | # prob[9, 1] = 1 35 | # prob[10, 1] = 1 36 | # prob[11, 10] = 1 37 | 38 | # joint training 39 | # prob = np.zeros([N, N]) 40 | # prob[0, 10] = 1 41 | # prob[1, [9, 10]] = [0.98, 0.02] 42 | # prob[2, [9, 10]] = [0.31, 0.69] 43 | # prob[3, 1] = 1 44 | # prob[4, [0, 10]] = [0.45, 0.55] 45 | # prob[5, 10] = 1 46 | # prob[6, 10] = 1 47 | # prob[7, [9, 10]] = [0.18, 0.82] 48 | # prob[8, 0] = 1 49 | # prob[9, 1] = 1 50 | # prob[10, 0] = 1 51 | # prob[11, 9] = 1 52 | 53 | prob = np.zeros([N, N]) 54 | prob[0, [1, 4, 7, 10]] = [0.59, 0.13, 0.01, 0.26] 55 | prob[1, [4, 7, 9, 10]] = [0.19, 0.04, 0.01, 0.74] 56 | prob[2, [1, 4, 7, 9, 10]] = [0.51, 0.11, 0.01, 0.01, 0.36] 57 | prob[3, [1, 4, 7, 10]] = [0.64, 0.1, 0.03, 0.22] 58 | prob[4, [1, 3, 7, 9, 10]] = [0.64, 0.01, 0.12, 0.03, 0.22] 59 | prob[5, [1, 3, 7, 10]] = [0.53, 0.10, 0.03, 0.36] 60 | prob[6, [1, 4, 7, 10]] = [0.58, 0.12, 0.01, 0.29] 61 | prob[7, [1, 4, 9, 10]] = [0.49, 0.12, 0.01, 0.32] 62 | prob[8, [1, 4, 7, 9, 10]] = [0.53, 0.10, 0.01, 0.01, 0.35] 63 | prob[9, [1, 4, 7, 10]] = [0.62, 0.11, 0.04, 0.23] 64 | prob[10, [1, 3, 4, 7, 9, 11]] = [0.69, 0.01, 0.18, 0.10, 0.01, 0.01] 65 | prob[11, [1, 4, 7, 10]] = [0.53, 0.11, 0.01, 0.35] 66 | 67 | plt.figure(figsize=(5, 5)) 68 | plt.imshow(prob, cmap='Blues') 69 | plt.xticks(np.arange(N), np.arange(N) + 1) 70 | plt.xlabel('second view') 71 | plt.yticks(np.arange(N), np.arange(N) + 1) 72 | plt.ylabel('initial view') 73 | # Loop over data dimensions and create text annotations. 74 | for i in range(N): 75 | for j in range(N): 76 | text = plt.text(j, i, prob[i, j], ha="center", va="center", ) 77 | plt.tight_layout() 78 | plt.show() 79 | pass 80 | -------------------------------------------------------------------------------- /src/utils/decode.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from src.utils.tensor_utils import _gather_feat, _transpose_and_gather_feat 5 | 6 | 7 | def _nms(heatmap, kernel_size=3): 8 | # kernel_size = kernel_size * 2 + 1 9 | hmax = F.max_pool2d(heatmap, (kernel_size, kernel_size), stride=1, padding=(kernel_size - 1) // 2) 10 | keep = (hmax == heatmap).float() 11 | return heatmap * keep 12 | 13 | 14 | ''' 15 | # Slow for large number of categories 16 | def _topk(scores, K=40): 17 | batch, cat, height, width = scores.size() 18 | topk_scores, topk_inds = torch.topk(scores.view(batch, -1), K) 19 | 20 | topk_clses = (topk_inds / (height * width)).int() 21 | 22 | topk_inds = topk_inds % (height * width) 23 | topk_ys = (topk_inds / width).int().float() 24 | topk_xs = (topk_inds % width).int().float() 25 | return topk_scores, topk_inds, topk_clses, topk_ys, topk_xs 26 | ''' 27 | 28 | 29 | def _topk(scores, top_K): 30 | batch, cat, height, width = scores.size() 31 | 32 | topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), top_K) 33 | 34 | topk_inds = topk_inds % (height * width) 35 | topk_ys = (topk_inds / width).int().float() 36 | topk_xs = (topk_inds % width).int().float() 37 | 38 | topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), top_K) 39 | topk_clses = (topk_ind / top_K).int() 40 | topk_inds = _gather_feat(topk_inds.view(batch, -1, 1), topk_ind).view(batch, top_K) 41 | topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, top_K) 42 | topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, top_K) 43 | 44 | return topk_score, topk_inds, topk_clses, topk_ys, topk_xs 45 | 46 | 47 | def ctdet_decode(heatmap, offset=None, wh=None, id=None, top_K=100): 48 | B, C, H, W = heatmap.shape 49 | 50 | scoremap = torch.sigmoid(heatmap) 51 | # perform nms on heatmaps 52 | scoremap = _nms(scoremap) 53 | 54 | scores, inds, _, ys, xs = _topk(scoremap, top_K=top_K) 55 | xy = torch.stack([xs, ys], dim=2) 56 | if offset is not None: 57 | offset = _transpose_and_gather_feat(offset, inds) 58 | offset = offset.view(B, top_K, 2) 59 | xy = xy + offset 60 | else: 61 | xy = xy + 0.5 62 | scores = scores.view(B, top_K, 1) 63 | 64 | # xywh 65 | if wh is None: 66 | detections = torch.cat([xy, scores], dim=2) 67 | else: 68 | wh = _transpose_and_gather_feat(wh, inds) 69 | wh = wh.view(B, top_K, 2) 70 | detections = torch.cat([xy, wh, scores], dim=2) 71 | 72 | if id is not None: 73 | id = torch.argmax(id, dim=1) 74 | id = _transpose_and_gather_feat(id, inds) 75 | id = id.view(B, top_K, 2) 76 | detections = torch.cat([detections, id], dim=2) 77 | return detections 78 | 79 | 80 | def mvdet_decode(scoremap, offset=None, reduce=4): 81 | B, C, H, W = scoremap.shape 82 | # scoremap = _nms(scoremap) 83 | 84 | xy = torch.nonzero(torch.ones_like(scoremap.detach()[:, 0])).view([B, H * W, 3])[:, :, [2, 1]].float() 85 | if offset is not None: 86 | offset = offset.detach().permute(0, 2, 3, 1).reshape(B, H * W, 2) 87 | xy = xy + offset 88 | else: 89 | xy = xy + 0.5 90 | xy *= reduce 91 | scores = scoremap.detach().permute(0, 2, 3, 1).reshape(B, H * W, 1) 92 | 93 | return torch.cat([xy, scores], dim=2) 94 | -------------------------------------------------------------------------------- /src/datasets/modelnet40.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import glob 4 | import re 5 | from PIL import Image 6 | import matplotlib.pyplot as plt 7 | import torch 8 | import torchvision.transforms as T 9 | from torchvision.datasets import VisionDataset 10 | 11 | 12 | class ModelNet40(VisionDataset): 13 | classnames = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car', 'chair', 14 | 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower_pot', 'glass_box', 15 | 'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor', 'night_stand', 16 | 'person', 'piano', 'plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs', 17 | 'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox'] 18 | 19 | def __init__(self, root, num_cam, split='train', per_cls_instances=0, dropout=0.0): 20 | super().__init__(root) 21 | self.num_cam, self.num_class = num_cam, len(self.classnames) 22 | self.split = split 23 | self.transform = T.Compose([T.Resize([224, 224]), T.ToTensor(), 24 | T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) 25 | self.dropout = dropout 26 | 27 | self.img_fpaths = {cam: [] for cam in range(self.num_cam)} 28 | self.targets = [] 29 | self.class_idx_dict = {cls: [] for cls in self.classnames} 30 | for cls in self.classnames: 31 | for fname in sorted(glob.glob(f'{root}/{cls}/{split}/*.png')): 32 | fname = os.path.basename(fname) 33 | id, cam = map(int, re.findall(r'\d+', fname)) 34 | if len(self.class_idx_dict[cls]) >= (per_cls_instances if per_cls_instances else np.inf) and \ 35 | id not in self.class_idx_dict[cls]: 36 | break 37 | self.img_fpaths[cam - 1].append(f'{root}/{cls}/{split}/{fname}') 38 | if id not in self.class_idx_dict[cls]: 39 | self.targets.append(self.classnames.index(cls)) 40 | self.class_idx_dict[cls].append(id) 41 | assert np.prod([len(i) == len(self.targets) for i in self.img_fpaths.values()]), \ 42 | 'plz ensure all models appear {num_cam} times!' 43 | print(f'{split}: {self.num_class} classes, {num_cam} views, {len(self.targets)} instances') 44 | 45 | def __len__(self): 46 | return len(self.targets) 47 | 48 | def __getitem__(self, idx, visualize=False): 49 | imgs = [] 50 | for cam in range(self.num_cam): 51 | img = Image.open(self.img_fpaths[cam][idx]).convert('RGB') 52 | if visualize: 53 | plt.imshow(img) 54 | plt.show() 55 | imgs.append(self.transform(img)) 56 | imgs = torch.stack(imgs) 57 | # random_idx = np.random.randint(self.num_cam) 58 | # imgs = imgs[torch.cat([torch.arange(random_idx, self.num_cam), torch.arange(0, random_idx)])] 59 | tgt = self.targets[idx] 60 | # dropout 61 | drop, keep_cams = np.random.rand() < self.dropout, torch.ones(self.num_cam, dtype=torch.bool) 62 | if drop: 63 | num_drop = np.random.randint(self.num_cam - 1) 64 | drop_cams = np.random.choice(self.num_cam, num_drop, replace=False) 65 | for cam in drop_cams: 66 | keep_cams[cam] = 0 67 | # keep_cams = np.ones(self.num_cam, dtype=bool) 68 | # keep_cams[[0, 2, 5, 6, 8, 11]] = 0 69 | 70 | return imgs, tgt, keep_cams 71 | 72 | 73 | if __name__ == '__main__': 74 | dataset = ModelNet40('/home/houyz/Data/modelnet/modelnet40_images_new_12x', 12) 75 | dataset.__getitem__(0) 76 | dataset.__getitem__(len(dataset) - 1, visualize=True) 77 | -------------------------------------------------------------------------------- /src/models/multiview_base.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from src.models.mvselect import aggregate_feat, setup_args 7 | 8 | 9 | class MultiviewBase(nn.Module): 10 | def __init__(self, dataset, aggregation='max'): 11 | super().__init__() 12 | self.num_cam = dataset.num_cam 13 | self.aggregation = aggregation 14 | self.select_module = None 15 | 16 | def forward(self, imgs, M=None, down=1, init_prob=None, steps=0, keep_cams=None, visualize=False): 17 | feat, aux_res = self.get_feat(imgs, M, down, visualize) 18 | if self.select_module is None or init_prob is None or steps == 0: 19 | B, N, C, H, W = feat.shape 20 | if keep_cams is None: 21 | keep_cams = torch.ones([B, N], dtype=torch.bool) 22 | keep_cams = keep_cams.to(feat.device) 23 | overall_feat = aggregate_feat(feat, keep_cams, aggregation=self.aggregation) 24 | selection_res = (None, None, None, None) 25 | else: 26 | overall_feat, selection_res = self.do_steps(feat, init_prob, steps, keep_cams) 27 | overall_res = self.get_output(overall_feat, visualize) 28 | return overall_res, aux_res, selection_res 29 | 30 | def do_steps(self, feat, init_prob, steps, keep_cams): 31 | assert steps > 0 32 | init_prob, _, _ = setup_args(feat, init_prob) 33 | log_probs, values, actions, entropies = [], [], [], [] 34 | for _ in range(steps): 35 | overall_feat, (log_prob, state_value, action, entropy) = self.select_module(feat, init_prob, keep_cams) 36 | init_prob += action 37 | log_probs.append(log_prob) 38 | values.append(state_value) 39 | actions.append(action) 40 | entropies.append(entropy) 41 | selection_res = (log_probs, values, actions, entropies) 42 | return overall_feat, selection_res 43 | 44 | def get_feat(self, imgs, M, down=1, visualize=False): 45 | raise NotImplementedError 46 | 47 | def get_output(self, overall_feat, visualize=False): 48 | raise NotImplementedError 49 | 50 | def forward_combination(self, imgs, M, down, combinations, keep_cams): 51 | B, N, C, H, W = imgs.shape 52 | K, N = combinations.shape 53 | 54 | feat, aux_res = self.get_feat(imgs, M, down) 55 | 56 | # K, B, N 57 | combinations = torch.tensor(combinations).float().unsqueeze(1).repeat([1, B, 1]) * keep_cams[None, :] 58 | 59 | # K, B, N, C, H, W 60 | overall_feat_s = [aggregate_feat(feat, combinations[k], self.aggregation) for k in range(K)] 61 | overall_result_s = [self.get_output(overall_feat_s[k]) for k in range(K)] 62 | if isinstance(overall_result_s[0], tuple): 63 | overall_result_s = tuple([torch.stack([overall_result_s[k][i] for k in range(K)]) 64 | for i in range(len(overall_result_s[0]))]) 65 | else: 66 | overall_result_s = torch.stack(overall_result_s) 67 | return overall_result_s, aux_res 68 | 69 | 70 | if __name__ == '__main__': 71 | B, N, C, H, W = 2, 7, 512, 16, 16 72 | steps = 3 73 | feat = torch.randn([B, N, C, H, W], device='cuda') 74 | 75 | candidates = np.eye(N) 76 | combinations = np.array(list(itertools.combinations(candidates, steps + 1))).sum(1) 77 | K = len(combinations) 78 | combinations = torch.from_numpy(combinations).to(feat.device).unsqueeze(0).repeat([B, 1, 1]) 79 | feat = feat.unsqueeze(1).repeat([1, K, 1, 1, 1, 1]) 80 | 81 | overall_feat_s = [aggregate_feat(feat[:, k], combinations[:, k]) for k in range(K)] 82 | print((torch.stack(overall_feat_s, dim=1).flatten(0, 1) == 83 | aggregate_feat(feat.flatten(0, 1), combinations.flatten(0, 1))).prod().item()) 84 | -------------------------------------------------------------------------------- /mvcnn_speed_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['OMP_NUM_THREADS'] = '1' 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torchvision.transforms as T 9 | import torchvision.models as models 10 | import matplotlib.pyplot as plt 11 | from src.models.multiview_base import MultiviewBase 12 | from src.models.mvselect import CamSelect, setup_args 13 | 14 | 15 | class MVCNN(nn.Module): 16 | def __init__(self, dataset, arch='resnet18', aggregation='max'): 17 | super().__init__() 18 | if arch == 'resnet18': 19 | self.base = nn.Sequential(*list(models.resnet18(pretrained=True).children())[:-2]) 20 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 21 | self.classifier = nn.Linear(512, dataset.num_class) 22 | base_dim = 512 23 | elif arch == 'vgg11': 24 | self.base = models.vgg11(pretrained=True).features 25 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 26 | self.classifier = models.vgg11(pretrained=True).classifier 27 | self.classifier[-1] = nn.Linear(4096, dataset.num_class) 28 | base_dim = 512 29 | else: 30 | raise Exception('architecture currently support [vgg11, resnet18]') 31 | 32 | # select camera based on initialization 33 | self.select_module = CamSelect(dataset.num_cam, base_dim, 1, aggregation) 34 | pass 35 | 36 | def forward(self, imgs, init_prob=None, steps=0, ): 37 | B, N, _, _, _ = imgs.shape 38 | if self.select_module is None or init_prob is None or steps == 0: 39 | feat = [] 40 | for i in range(N): 41 | feat_i = self.get_feat(imgs, i) 42 | feat.append(feat_i) 43 | else: 44 | init_feat = self.get_feat(imgs, init_prob) 45 | feat = [init_feat] 46 | for _ in range(steps): 47 | init_prob, _, cam_candidate = setup_args(imgs, init_prob) 48 | cam_emb = init_prob.float() @ self.select_module.cam_emb 49 | cam_emb = self.select_module.emb_branch(cam_emb) 50 | cam_feat = self.select_module.feat_branch(init_feat).amax(dim=[2, 3]) 51 | action_value = self.select_module.value_head(cam_emb + cam_feat) 52 | action = torch.argmax(action_value + (cam_candidate.float() - 1) * 1e3, dim=-1) 53 | 54 | feat.append(self.get_feat(imgs, action)) 55 | init_feat, _ = torch.stack(feat, 1).max(1) 56 | init_prob += F.one_hot(action, num_classes=N).bool() 57 | overall_feat, _ = torch.stack(feat, 1).max(1) 58 | overall_res = self.get_output(overall_feat) 59 | return overall_res 60 | 61 | def get_feat(self, imgs, view_id): 62 | B, N, _, H, W = imgs.shape 63 | if isinstance(view_id, int): 64 | view_id = torch.tensor([view_id] * B) 65 | batch_id = torch.arange(B).to(imgs.device) 66 | view_id = view_id.to(imgs.device) 67 | 68 | imgs_feat = self.base(imgs[batch_id, view_id]) 69 | imgs_feat = self.avgpool(imgs_feat) 70 | _, C, H, W = imgs_feat.shape 71 | return imgs_feat 72 | 73 | def get_output(self, overall_feat): 74 | overall_result = self.classifier(torch.flatten(overall_feat, 1)) 75 | return overall_result 76 | 77 | 78 | if __name__ == '__main__': 79 | from src.datasets import ModelNet40 80 | from torch.utils.data import DataLoader 81 | from thop import profile 82 | import itertools 83 | import time 84 | import tqdm 85 | 86 | dataset = ModelNet40('/home/houyz/Data/modelnet/modelnet40v2png_ori4', 20) 87 | dataloader = DataLoader(dataset, 8, num_workers=0) 88 | imgs, tgt, keep_cams = next(iter(dataloader)) 89 | 90 | torch.backends.cudnn.benchmark = False 91 | model = MVCNN(dataset).cuda() 92 | model.eval() 93 | # init_cam, step = 0, 1 94 | init_cam, step = None, None 95 | t0 = time.time() 96 | # avoid bottleneck @ dataloader 97 | for _ in tqdm.tqdm(range(1000)): 98 | model(imgs.cuda(), init_cam, step) 99 | print(time.time() - t0) 100 | pass 101 | -------------------------------------------------------------------------------- /src/utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | import numpy as np 4 | import cv2 5 | from PIL import Image 6 | import torch 7 | 8 | 9 | def random_affine(img, bboxs, pids, hflip=0.5, degrees=(-0, 0), translate=(.2, .2), scale=(0.6, 1.4), shear=(-0, 0), 10 | borderValue=(128, 128, 128)): 11 | # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10)) 12 | # https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4 13 | 14 | border = 0 # width of added border (optional) 15 | height = img.shape[0] 16 | width = img.shape[1] 17 | 18 | # flipping 19 | F = np.eye(3) 20 | hflip = np.random.rand() < hflip 21 | if hflip: 22 | F[0, 0] = -1 23 | F[0, 2] = width 24 | 25 | # Rotation and Scale 26 | R = np.eye(3) 27 | a = random.random() * (degrees[1] - degrees[0]) + degrees[0] 28 | # a += random.choice([-180, -90, 0, 90]) # 90deg rotations added to small rotations 29 | s = random.random() * (scale[1] - scale[0]) + scale[0] 30 | R[:2] = cv2.getRotationMatrix2D(angle=a, center=(width / 2, height / 2), scale=s) 31 | 32 | # Translation 33 | T = np.eye(3) 34 | T[0, 2] = (random.random() * 2 - 1) * translate[0] * width + border # x translation (pixels) 35 | T[1, 2] = (random.random() * 2 - 1) * translate[1] * height + border # y translation (pixels) 36 | 37 | # Shear 38 | S = np.eye(3) 39 | S[0, 1] = math.tan((random.random() * (shear[1] - shear[0]) + shear[0]) * math.pi / 180) # x shear (deg) 40 | S[1, 0] = math.tan((random.random() * (shear[1] - shear[0]) + shear[0]) * math.pi / 180) # y shear (deg) 41 | 42 | M = S @ T @ R @ F # Combined rotation matrix. ORDER IS IMPORTANT HERE!! 43 | imw = cv2.warpPerspective(img, M, dsize=(width, height), flags=cv2.INTER_LINEAR, 44 | borderValue=borderValue) # BGR order borderValue 45 | 46 | # Return warped points also 47 | n = bboxs.shape[0] 48 | area0 = (bboxs[:, 2] - bboxs[:, 0]) * (bboxs[:, 3] - bboxs[:, 1]) 49 | 50 | # warp points 51 | xy = np.ones((n * 4, 3)) 52 | xy[:, :2] = bboxs[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1 53 | xy = (xy @ M.T)[:, :2].reshape(n, 8) 54 | 55 | # create new boxes 56 | x = xy[:, [0, 2, 4, 6]] 57 | y = xy[:, [1, 3, 5, 7]] 58 | xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T 59 | 60 | # apply angle-based reduction 61 | radians = a * math.pi / 180 62 | reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5 63 | x = (xy[:, 2] + xy[:, 0]) / 2 64 | y = (xy[:, 3] + xy[:, 1]) / 2 65 | w = (xy[:, 2] - xy[:, 0]) * reduction 66 | h = (xy[:, 3] - xy[:, 1]) * reduction 67 | xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T 68 | 69 | # reject warped points outside of image 70 | np.clip(xy[:, 0], 0, width - 1, out=xy[:, 0]) 71 | np.clip(xy[:, 2], 0, width - 1, out=xy[:, 2]) 72 | np.clip(xy[:, 1], 0, height - 1, out=xy[:, 1]) 73 | np.clip(xy[:, 3], 0, height - 1, out=xy[:, 3]) 74 | w = xy[:, 2] - xy[:, 0] 75 | h = xy[:, 3] - xy[:, 1] 76 | area = w * h 77 | ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16)) 78 | i = (w > 4) & (h > 4) & (area / (area0 + 1e-16) > 0.1) & (ar < 10) 79 | 80 | bboxs = xy[i] 81 | pids = pids[i] 82 | 83 | return imw, bboxs, pids, M 84 | 85 | 86 | def gaussian2D(shape, sigma=1): 87 | m, n = [(ss - 1.) / 2. for ss in shape] 88 | y, x = np.ogrid[-m:m + 1, -n:n + 1] 89 | 90 | h = np.exp(-(x ** 2 + y ** 2) / (2 * sigma ** 2)) 91 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 92 | return h 93 | 94 | 95 | def draw_umich_gaussian(heatmap, center, sigma, k=1): 96 | radius = int(3 * sigma) 97 | diameter = 2 * radius + 1 98 | gaussian = gaussian2D((diameter, diameter), sigma=sigma) 99 | 100 | x, y = int(center[0]), int(center[1]) 101 | 102 | H, W = heatmap.shape 103 | 104 | left, right = min(x, radius), min(W - x, radius + 1) 105 | top, bottom = min(y, radius), min(H - y, radius + 1) 106 | 107 | masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right] 108 | masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right] 109 | if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: 110 | np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap) 111 | return heatmap 112 | 113 | 114 | class img_color_denormalize(object): 115 | def __init__(self, mean, std): 116 | self.mean = torch.FloatTensor(mean).view([1, -1, 1, 1]) 117 | self.std = torch.FloatTensor(std).view([1, -1, 1, 1]) 118 | 119 | def __call__(self, tensor): 120 | return tensor * self.std.to(tensor.device) + self.mean.to(tensor.device) 121 | 122 | 123 | def add_heatmap_to_image(heatmap, image): 124 | heatmap = cv2.resize(np.array(array2heatmap(heatmap)), (image.size)) 125 | cam_result = np.uint8(heatmap * 0.3 + np.array(image) * 0.5) 126 | cam_result = Image.fromarray(cam_result) 127 | return cam_result 128 | 129 | 130 | def array2heatmap(heatmap): 131 | heatmap = heatmap - heatmap.min() 132 | heatmap = heatmap / (heatmap.max() + 1e-8) 133 | heatmap = np.uint8(255 * heatmap) 134 | heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_SUMMER) 135 | heatmap = Image.fromarray(cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)) 136 | return heatmap 137 | -------------------------------------------------------------------------------- /src/datasets/multiviewx.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import xml.etree.ElementTree as ET 5 | import re 6 | from torchvision.datasets import VisionDataset 7 | 8 | intrinsic_camera_matrix_filenames = ['intr_Camera1.xml', 'intr_Camera2.xml', 'intr_Camera3.xml', 'intr_Camera4.xml', 9 | 'intr_Camera5.xml', 'intr_Camera6.xml'] 10 | extrinsic_camera_matrix_filenames = ['extr_Camera1.xml', 'extr_Camera2.xml', 'extr_Camera3.xml', 'extr_Camera4.xml', 11 | 'extr_Camera5.xml', 'extr_Camera6.xml'] 12 | 13 | 14 | class MultiviewX(VisionDataset): 15 | def __init__(self, root): 16 | super().__init__(root) 17 | # image of shape C,H,W (C,N_row,N_col); xy indexging; x,y (w,h) (n_col,n_row) 18 | # MultiviewX has xy-indexing: H*W=640*1000, thus x is \in [0,1000), y \in [0,640) 19 | # MultiviewX has consistent unit: meter (m) for calibration & pos annotation 20 | self.__name__ = 'MultiviewX' 21 | self.img_shape, self.worldgrid_shape = [1080, 1920], [640, 1000] # H,W; N_row,N_col 22 | self.num_cam, self.num_frame = 6, 400 23 | # world x,y correspond to w,h 24 | self.indexing = 'xy' 25 | self.world_indexing_from_xy_mat = np.eye(3) 26 | self.world_indexing_from_ij_mat = np.array([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) 27 | # image is in xy indexing by default 28 | self.img_xy_from_xy_mat = np.eye(3) 29 | self.img_xy_from_ij_mat = np.array([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) 30 | # unit in meters 31 | self.worldcoord_unit = 1 32 | self.worldcoord_from_worldgrid_mat = np.array([[0.025, 0, 0], [0, 0.025, 0], [0, 0, 1]]) 33 | self.intrinsic_matrices, self.extrinsic_matrices = zip( 34 | *[self.get_intrinsic_extrinsic_matrix(cam) for cam in range(self.num_cam)]) 35 | 36 | def get_worldgrid_from_pos(self, pos): 37 | grid_x = pos % 1000 38 | grid_y = pos // 1000 39 | return np.array([[grid_x], [grid_y]], dtype=int).reshape([2, -1]) 40 | 41 | def get_pos_from_worldgrid(self, worldgrid): 42 | grid_x, grid_y = worldgrid[0, :], worldgrid[1, :] 43 | return grid_x + grid_y * 1000 44 | 45 | def get_worldgrid_from_worldcoord(self, world_coord): 46 | # datasets default unit: centimeter & origin: (-300,-900) 47 | coord_x, coord_y = world_coord[0, :], world_coord[1, :] 48 | grid_x = coord_x * 40 49 | grid_y = coord_y * 40 50 | return np.array([[grid_x], [grid_y]], dtype=int).reshape([2, -1]) 51 | 52 | def get_worldcoord_from_worldgrid(self, worldgrid): 53 | # datasets default unit: centimeter & origin: (-300,-900) 54 | grid_x, grid_y = worldgrid[0, :], worldgrid[1, :] 55 | coord_x = grid_x / 40 56 | coord_y = grid_y / 40 57 | return np.array([[coord_x], [coord_y]]).reshape([2, -1]) 58 | 59 | def get_worldcoord_from_pos(self, pos): 60 | grid = self.get_worldgrid_from_pos(pos) 61 | return self.get_worldcoord_from_worldgrid(grid) 62 | 63 | def get_pos_from_worldcoord(self, world_coord): 64 | grid = self.get_worldgrid_from_worldcoord(world_coord) 65 | return self.get_pos_from_worldgrid(grid) 66 | 67 | def get_intrinsic_extrinsic_matrix(self, camera_i): 68 | intrinsic_camera_path = os.path.join(self.root, 'calibrations', 'intrinsic') 69 | fp_calibration = cv2.FileStorage(os.path.join(intrinsic_camera_path, 70 | intrinsic_camera_matrix_filenames[camera_i]), 71 | flags=cv2.FILE_STORAGE_READ) 72 | intrinsic_matrix = fp_calibration.getNode('camera_matrix').mat() 73 | fp_calibration.release() 74 | 75 | extrinsic_camera_path = os.path.join(self.root, 'calibrations', 'extrinsic') 76 | fp_calibration = cv2.FileStorage(os.path.join(extrinsic_camera_path, 77 | extrinsic_camera_matrix_filenames[camera_i]), 78 | flags=cv2.FILE_STORAGE_READ) 79 | rvec, tvec = fp_calibration.getNode('rvec').mat().squeeze(), fp_calibration.getNode('tvec').mat().squeeze() 80 | fp_calibration.release() 81 | 82 | rotation_matrix, _ = cv2.Rodrigues(rvec) 83 | translation_matrix = np.array(tvec, dtype=float).reshape(3, 1) 84 | extrinsic_matrix = np.hstack((rotation_matrix, translation_matrix)) 85 | 86 | return intrinsic_matrix, extrinsic_matrix 87 | 88 | 89 | def test(): 90 | from src.utils.projection import get_worldcoord_from_imagecoord 91 | from src.datasets.frameDataset import read_pom 92 | dataset = MultiviewX(os.path.expanduser('~/Data/MultiviewX'), ) 93 | pom = read_pom(dataset.root) 94 | 95 | for cam in range(dataset.num_cam): 96 | head_errors, foot_errors = [], [] 97 | for pos in range(0, np.product(dataset.worldgrid_shape), 16): 98 | bbox = pom[pos][cam] 99 | foot_wc = dataset.get_worldcoord_from_pos(pos) 100 | if bbox is None: 101 | continue 102 | foot_ic = np.array([[(bbox[0] + bbox[2]) / 2, bbox[3]]]).T 103 | head_ic = np.array([[(bbox[0] + bbox[2]) / 2, bbox[1]]]).T 104 | p_foot_wc = get_worldcoord_from_imagecoord(foot_ic, dataset.intrinsic_matrices[cam], 105 | dataset.extrinsic_matrices[cam]) 106 | p_head_wc = get_worldcoord_from_imagecoord(head_ic, dataset.intrinsic_matrices[cam], 107 | dataset.extrinsic_matrices[cam], z=1.8 / dataset.worldcoord_unit) 108 | head_errors.append(np.linalg.norm(p_head_wc - foot_wc)) 109 | foot_errors.append(np.linalg.norm(p_foot_wc - foot_wc)) 110 | pass 111 | 112 | print(f'average head error: {np.average(head_errors) * dataset.worldcoord_unit}, ' 113 | f'average foot error: {np.average(foot_errors) * dataset.worldcoord_unit} (world meters)') 114 | pass 115 | pass 116 | 117 | 118 | if __name__ == '__main__': 119 | test() 120 | -------------------------------------------------------------------------------- /src/datasets/wildtrack.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import xml.etree.ElementTree as ET 5 | import re 6 | from torchvision.datasets import VisionDataset 7 | 8 | intrinsic_camera_matrix_filenames = ['intr_CVLab1.xml', 'intr_CVLab2.xml', 'intr_CVLab3.xml', 'intr_CVLab4.xml', 9 | 'intr_IDIAP1.xml', 'intr_IDIAP2.xml', 'intr_IDIAP3.xml'] 10 | extrinsic_camera_matrix_filenames = ['extr_CVLab1.xml', 'extr_CVLab2.xml', 'extr_CVLab3.xml', 'extr_CVLab4.xml', 11 | 'extr_IDIAP1.xml', 'extr_IDIAP2.xml', 'extr_IDIAP3.xml'] 12 | 13 | 14 | class Wildtrack(VisionDataset): 15 | def __init__(self, root): 16 | super().__init__(root) 17 | # image of shape C,H,W (C,N_row,N_col); xy indexging; x,y (w,h) (n_col,n_row) 18 | # WILDTRACK has ij-indexing: H*W=480*1440, thus x (i) is \in [0,480), y (j) is \in [0,1440) 19 | # WILDTRACK has in-consistent unit: centi-meter (cm) for calibration & pos annotation, 20 | self.__name__ = 'Wildtrack' 21 | self.img_shape, self.worldgrid_shape = [1080, 1920], [480, 1440] # H,W; N_row,N_col 22 | self.num_cam, self.num_frame = 7, 2000 23 | # world x,y actually means i,j in Wildtrack, which correspond to h,w 24 | self.indexing = 'ij' 25 | self.world_indexing_from_xy_mat = np.array([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) 26 | self.world_indexing_from_ij_mat = np.eye(3) 27 | # image is in xy indexing by default 28 | self.img_xy_from_xy_mat = np.eye(3) 29 | self.img_xy_from_ij_mat = np.array([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) 30 | # unit in meters 31 | self.worldcoord_unit = 0.01 32 | self.worldcoord_from_worldgrid_mat = np.array([[2.5, 0, -300], [0, 2.5, -900], [0, 0, 1]]) 33 | self.intrinsic_matrices, self.extrinsic_matrices = zip( 34 | *[self.get_intrinsic_extrinsic_matrix(cam) for cam in range(self.num_cam)]) 35 | 36 | def get_worldgrid_from_pos(self, pos): 37 | grid_x = pos % 480 38 | grid_y = pos // 480 39 | return np.array([[grid_x], [grid_y]], dtype=int).reshape([2, -1]) 40 | 41 | def get_pos_from_worldgrid(self, worldgrid): 42 | grid_x, grid_y = worldgrid[0, :], worldgrid[1, :] 43 | return grid_x + grid_y * 480 44 | 45 | def get_worldgrid_from_worldcoord(self, world_coord): 46 | # datasets default unit: centimeter & origin: (-300,-900) 47 | coord_x, coord_y = world_coord[0, :], world_coord[1, :] 48 | grid_x = (coord_x + 300) / 2.5 49 | grid_y = (coord_y + 900) / 2.5 50 | return np.array([[grid_x], [grid_y]], dtype=int).reshape([2, -1]) 51 | 52 | def get_worldcoord_from_worldgrid(self, worldgrid): 53 | # datasets default unit: centimeter & origin: (-300,-900) 54 | grid_x, grid_y = worldgrid[0, :], worldgrid[1, :] 55 | coord_x = -300 + 2.5 * grid_x 56 | coord_y = -900 + 2.5 * grid_y 57 | return np.array([[coord_x], [coord_y]]).reshape([2, -1]) 58 | 59 | def get_worldcoord_from_pos(self, pos): 60 | grid = self.get_worldgrid_from_pos(pos) 61 | return self.get_worldcoord_from_worldgrid(grid) 62 | 63 | def get_pos_from_worldcoord(self, world_coord): 64 | grid = self.get_worldgrid_from_worldcoord(world_coord) 65 | return self.get_pos_from_worldgrid(grid) 66 | 67 | def get_intrinsic_extrinsic_matrix(self, camera_i): 68 | intrinsic_camera_path = os.path.join(self.root, 'calibrations', 'intrinsic_zero') 69 | intrinsic_params_file = cv2.FileStorage(os.path.join(intrinsic_camera_path, 70 | intrinsic_camera_matrix_filenames[camera_i]), 71 | flags=cv2.FILE_STORAGE_READ) 72 | intrinsic_matrix = intrinsic_params_file.getNode('camera_matrix').mat() 73 | intrinsic_params_file.release() 74 | 75 | extrinsic_params_file_root = ET.parse(os.path.join(self.root, 'calibrations', 'extrinsic', 76 | extrinsic_camera_matrix_filenames[camera_i])).getroot() 77 | 78 | rvec = extrinsic_params_file_root.findall('rvec')[0].text.lstrip().rstrip().split(' ') 79 | rvec = np.array(list(map(lambda x: float(x), rvec)), dtype=np.float32) 80 | 81 | tvec = extrinsic_params_file_root.findall('tvec')[0].text.lstrip().rstrip().split(' ') 82 | tvec = np.array(list(map(lambda x: float(x), tvec)), dtype=np.float32) 83 | 84 | rotation_matrix, _ = cv2.Rodrigues(rvec) 85 | translation_matrix = np.array(tvec, dtype=float).reshape(3, 1) 86 | extrinsic_matrix = np.hstack((rotation_matrix, translation_matrix)) 87 | 88 | return intrinsic_matrix, extrinsic_matrix 89 | 90 | 91 | def test(): 92 | from src.utils.projection import get_worldcoord_from_imagecoord 93 | from src.datasets.frameDataset import read_pom 94 | dataset = Wildtrack(os.path.expanduser('~/Data/Wildtrack'), ) 95 | pom = read_pom(dataset.root) 96 | 97 | for cam in range(dataset.num_cam): 98 | head_errors, foot_errors = [], [] 99 | for pos in range(0, np.product(dataset.worldgrid_shape), 16): 100 | bbox = pom[pos][cam] 101 | foot_wc = dataset.get_worldcoord_from_pos(pos) 102 | if bbox is None: 103 | continue 104 | foot_ic = np.array([[(bbox[0] + bbox[2]) / 2, bbox[3]]]).T 105 | head_ic = np.array([[(bbox[0] + bbox[2]) / 2, bbox[1]]]).T 106 | p_foot_wc = get_worldcoord_from_imagecoord(foot_ic, dataset.intrinsic_matrices[cam], 107 | dataset.extrinsic_matrices[cam]) 108 | p_head_wc = get_worldcoord_from_imagecoord(head_ic, dataset.intrinsic_matrices[cam], 109 | dataset.extrinsic_matrices[cam], z=1.8 / dataset.worldcoord_unit) 110 | head_errors.append(np.linalg.norm(p_head_wc - foot_wc)) 111 | foot_errors.append(np.linalg.norm(p_foot_wc - foot_wc)) 112 | pass 113 | 114 | print(f'average head error: {np.average(head_errors) * dataset.worldcoord_unit}, ' 115 | f'average foot error: {np.average(foot_errors) * dataset.worldcoord_unit} (world meters)') 116 | pass 117 | pass 118 | 119 | 120 | if __name__ == '__main__': 121 | test() 122 | -------------------------------------------------------------------------------- /src/models/mvselect.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from torch.distributions import Categorical 7 | 8 | 9 | def setup_args(feat, init_prob, keep_cams=None): 10 | B, N, C, H, W = feat.shape 11 | # init_prob should be of shape [B, N] in binary form 12 | if isinstance(init_prob, int): 13 | init_prob = F.one_hot(torch.tensor(init_prob).repeat(B), num_classes=N) 14 | elif isinstance(init_prob, np.ndarray): 15 | init_prob = F.one_hot(torch.tensor(init_prob), num_classes=N) 16 | init_prob = init_prob.bool().to(feat.device) 17 | if keep_cams is None: 18 | keep_cams = torch.ones([B, N], dtype=torch.bool) 19 | keep_cams = keep_cams.to(feat.device) 20 | cam_candidate = ~init_prob & keep_cams 21 | return init_prob, keep_cams, cam_candidate 22 | 23 | 24 | def create_pos_embedding(L, hidden_dim=128, temperature=10000, ): 25 | position = torch.arange(L).unsqueeze(1) / L * 2 * np.pi 26 | div_term = temperature ** (torch.arange(0, hidden_dim, 2) / hidden_dim) 27 | pe = torch.zeros(L, hidden_dim) 28 | pe[:, 0::2] = torch.sin(position / div_term) 29 | pe[:, 1::2] = torch.cos(position / div_term) 30 | return pe 31 | 32 | 33 | def masked_softmax(input, dim=-1, mask=None, epsilon=1e-8): 34 | if mask is None: 35 | mask = torch.ones_like(input, dtype=torch.bool) 36 | masked_exp = torch.exp(input) * mask.float() 37 | masked_sum = masked_exp.sum(dim, keepdim=True) + epsilon 38 | softmax = masked_exp / masked_sum 39 | return softmax 40 | 41 | 42 | def gumbel_softmax(logits: torch.Tensor, tau: float = 1, dim: int = -1, mask: torch.Tensor = None) -> torch.Tensor: 43 | # ~Gumbel(0,1) 44 | gumbels = (-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()) 45 | # ~Gumbel(logits,tau) 46 | gumbels = (logits + gumbels) / tau 47 | y_soft = masked_softmax(gumbels, dim, mask) 48 | 49 | return y_soft 50 | 51 | 52 | def softmax_to_hard(y_soft, dim=-1): 53 | index = y_soft.max(dim, keepdim=True)[1] 54 | y_hard = torch.zeros_like(y_soft).scatter_(dim, index, 1.0) 55 | ret = y_hard - y_soft.detach() + y_soft 56 | return ret 57 | 58 | 59 | def aggregate_feat(feat, selection, aggregation='mean'): 60 | if selection is None: 61 | overall_feat = feat.mean(dim=1) if aggregation == 'mean' else feat.max(dim=1)[0] 62 | else: 63 | selection = selection.bool().to(feat.device) 64 | overall_feat = feat * selection[:, :, None, None, None] 65 | if aggregation == 'mean': 66 | overall_feat = overall_feat.sum(dim=1) / (selection.sum(dim=1).view(-1, 1, 1, 1) + 1e-8) 67 | elif aggregation == 'max': 68 | overall_feat = overall_feat.max(dim=1)[0] 69 | else: 70 | raise Exception 71 | return overall_feat 72 | 73 | 74 | class CamSelect(nn.Module): 75 | def __init__(self, num_cam, hidden_dim, kernel_size=1, aggregation='max'): 76 | super().__init__() 77 | self.aggregation = aggregation 78 | if kernel_size == 1: 79 | stride, padding = 1, 0 80 | elif kernel_size == 3: 81 | stride, padding = 2, 1 82 | else: 83 | raise Exception 84 | self.feat_branch = nn.Sequential(nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, padding), nn.ReLU(), ) 85 | self.cam_emb = nn.Parameter(F.normalize(torch.randn([num_cam, hidden_dim]), p=2, dim=1)) 86 | self.emb_branch = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), ) 87 | self.action_head = nn.Linear(hidden_dim, num_cam) 88 | self.action_head.weight.data.fill_(0) 89 | self.action_head.bias.data.fill_(0) 90 | self.value_head = nn.Linear(hidden_dim, num_cam) 91 | self.value_head.weight.data.fill_(0) 92 | self.value_head.bias.data.fill_(0) 93 | 94 | def forward(self, feat, init_prob, keep_cams=None, eps_thres=0.0): 95 | B, N, C, H, W = feat.shape 96 | init_prob, keep_cams, cam_candidate = setup_args(feat, init_prob, keep_cams) 97 | init_feat = aggregate_feat(feat, init_prob, self.aggregation) 98 | 99 | cam_emb = init_prob.float() @ self.cam_emb 100 | cam_emb = self.emb_branch(cam_emb) 101 | cam_feat = self.feat_branch(init_feat).amax(dim=[2, 3]) 102 | # action_logit = self.action_head(cam_emb + cam_feat) 103 | # action_prob = F.softmax(action_logit, dim=-1) * cam_candidate 104 | # entropy = -(F.log_softmax(action_logit, dim=-1) * action_prob).sum(1) 105 | # state_value = self.value_head(cam_emb + cam_feat) 106 | 107 | # if self.training: 108 | # m = Categorical(action_prob) 109 | # action = m.sample() 110 | # log_prob = m.log_prob(action) 111 | # else: 112 | # action = torch.argmax(action_prob, dim=-1) 113 | # log_prob = torch.zeros([B], device=feat.device) 114 | 115 | # DQN 116 | action_value = self.value_head(cam_emb + cam_feat) 117 | if random.random() > eps_thres: 118 | action = torch.argmax(action_value + (cam_candidate.float() - 1) * 1e3, dim=-1) 119 | else: 120 | # m = Categorical(F.normalize(cam_candidate.float(), p=1, dim=1)) 121 | # action = m.sample() 122 | action = torch.randint(N, [B], device=feat.device) 123 | 124 | action = F.one_hot(action, num_classes=N).bool() 125 | overall_feat = aggregate_feat(feat, init_prob + action, self.aggregation) 126 | # return overall_feat, (log_prob, state_value, action, entropy) 127 | return overall_feat, (torch.zeros([B], device=feat.device), action_value, 128 | action, torch.zeros([B], device=feat.device)) 129 | 130 | 131 | def update_ema_variables(model, ema_model, alpha=0.99): 132 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 133 | ema_param.data = ema_param.data * alpha + param.data * (1 - alpha) 134 | for ema_bn, bn in zip(ema_model.modules(), model.module.modules() if hasattr(model, 'module') else model.modules()): 135 | if isinstance(bn, nn.BatchNorm2d) or isinstance(bn, nn.SyncBatchNorm): 136 | ema_bn.running_mean = ema_bn.running_mean * alpha + bn.running_mean * (1 - alpha) 137 | ema_bn.running_var = ema_bn.running_var * alpha + bn.running_var * (1 - alpha) 138 | 139 | 140 | def get_eps_thres(epoch, total_epochs, eps_start=0.9, eps_end=0.05): 141 | eps_thres = eps_end + (eps_start - eps_end) * (np.cos(epoch / total_epochs * np.pi) + 1) / 2 142 | return eps_thres 143 | -------------------------------------------------------------------------------- /mvdet_speed_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['OMP_NUM_THREADS'] = '1' 4 | import time 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision.transforms as T 10 | from kornia.geometry import warp_perspective 11 | from src.models.resnet import resnet18 12 | from src.models.shufflenetv2 import shufflenet_v2_x0_5 13 | from src.models.mvselect import CamSelect, aggregate_feat, setup_args 14 | from src.models.multiview_base import MultiviewBase 15 | from src.utils.image_utils import img_color_denormalize, array2heatmap 16 | from src.utils.projection import get_worldcoord_from_imgcoord_mat, project_2d_points 17 | import matplotlib.pyplot as plt 18 | 19 | 20 | def fill_fc_weights(layers): 21 | for m in layers.modules(): 22 | if isinstance(m, nn.Conv2d): 23 | if m.bias is not None: 24 | nn.init.constant_(m.bias, 0) 25 | 26 | 27 | def output_head(in_dim, feat_dim, out_dim): 28 | if feat_dim: 29 | fc = nn.Sequential(nn.Conv2d(in_dim, feat_dim, 3, 1, 1), nn.ReLU(), 30 | nn.Conv2d(feat_dim, out_dim, 1)) 31 | else: 32 | fc = nn.Sequential(nn.Conv2d(in_dim, out_dim, 1)) 33 | return fc 34 | 35 | 36 | class MVDet(nn.Module): 37 | def __init__(self, dataset, arch='resnet18', aggregation='max', 38 | use_bottleneck=True, hidden_dim=128, outfeat_dim=0, z=0): 39 | super().__init__() 40 | self.Rimg_shape, self.Rworld_shape = np.array(dataset.Rimg_shape), np.array(dataset.Rworld_shape) 41 | self.img_reduce = dataset.img_reduce 42 | 43 | # world grid change to xy indexing 44 | world_zoom_mat = np.diag([dataset.world_reduce, dataset.world_reduce, 1]) 45 | Rworldgrid_from_worldcoord_mat = np.linalg.inv( 46 | dataset.base.worldcoord_from_worldgrid_mat @ world_zoom_mat @ dataset.base.world_indexing_from_xy_mat) 47 | 48 | # z in meters by default 49 | # projection matrices: img feat -> world feat 50 | worldcoord_from_imgcoord_mats = [get_worldcoord_from_imgcoord_mat(dataset.base.intrinsic_matrices[cam], 51 | dataset.base.extrinsic_matrices[cam], 52 | z / dataset.base.worldcoord_unit) 53 | for cam in range(dataset.num_cam)] 54 | # Rworldgrid(xy)_from_imgcoord(xy) 55 | self.proj_mats = torch.stack([torch.from_numpy(Rworldgrid_from_worldcoord_mat @ 56 | worldcoord_from_imgcoord_mats[cam]) 57 | for cam in range(dataset.num_cam)]).float() 58 | 59 | if arch == 'resnet18': 60 | self.base = nn.Sequential(*list(resnet18(pretrained=True, 61 | replace_stride_with_dilation=[False, True, True]).children())[:-2]) 62 | base_dim = 512 63 | elif arch == 'shufflenet0.5': 64 | self.base = nn.Sequential(*list(shufflenet_v2_x0_5(pretrained=True, 65 | replace_stride_with_dilation=[False, True, True] 66 | ).children())[:-2]) 67 | base_dim = 192 68 | else: 69 | raise Exception('architecture currently support [vgg11, resnet18]') 70 | 71 | if use_bottleneck: 72 | self.bottleneck = nn.Sequential(nn.Conv2d(base_dim, hidden_dim, 1), nn.ReLU()) 73 | base_dim = hidden_dim 74 | else: 75 | self.bottleneck = nn.Identity() 76 | 77 | # img heads 78 | self.img_heatmap = output_head(base_dim, outfeat_dim, 1) 79 | self.img_offset = output_head(base_dim, outfeat_dim, 2) 80 | self.img_wh = output_head(base_dim, outfeat_dim, 2) 81 | # self.img_id = output_head(base_dim, outfeat_dim, len(dataset.pid_dict)) 82 | 83 | # world feat 84 | self.world_feat = nn.Sequential(nn.Conv2d(base_dim, hidden_dim, 3, padding=1), nn.ReLU(), 85 | nn.Conv2d(hidden_dim, hidden_dim, 3, padding=2, dilation=2), nn.ReLU(), 86 | nn.Conv2d(hidden_dim, hidden_dim, 3, padding=4, dilation=4), nn.ReLU(), ) 87 | 88 | # select camera based on initialization 89 | self.select_module = CamSelect(dataset.num_cam, hidden_dim, 3, aggregation) 90 | 91 | # world heads 92 | self.world_heatmap = output_head(hidden_dim, outfeat_dim, 1) 93 | self.world_offset = output_head(hidden_dim, outfeat_dim, 2) 94 | # self.world_id = output_head(hidden_dim, outfeat_dim, len(dataset.pid_dict)) 95 | 96 | # init 97 | self.img_heatmap[-1].bias.data.fill_(-2.19) 98 | fill_fc_weights(self.img_offset) 99 | fill_fc_weights(self.img_wh) 100 | self.world_heatmap[-1].bias.data.fill_(-2.19) 101 | fill_fc_weights(self.world_offset) 102 | pass 103 | 104 | def forward(self, imgs, init_prob=None, steps=0, ): 105 | B, N, _, _, _ = imgs.shape 106 | if self.select_module is None or init_prob is None or steps == 0: 107 | feat = [] 108 | for i in range(N): 109 | feat_i = self.get_feat(imgs, i) 110 | feat.append(feat_i) 111 | else: 112 | init_feat = self.get_feat(imgs, init_prob) 113 | feat = [init_feat] 114 | for _ in range(steps): 115 | init_prob, _, cam_candidate = setup_args(imgs, init_prob) 116 | cam_emb = init_prob.float() @ self.select_module.cam_emb 117 | cam_emb = self.select_module.emb_branch(cam_emb) 118 | cam_feat = self.select_module.feat_branch(init_feat).amax(dim=[2, 3]) 119 | action_value = self.select_module.value_head(cam_emb + cam_feat) 120 | action = torch.argmax(action_value + (cam_candidate.float() - 1) * 1e3, dim=-1) 121 | 122 | feat.append(self.get_feat(imgs, action)) 123 | init_feat, _ = torch.stack(feat, 1).max(1) 124 | init_prob += F.one_hot(action, num_classes=N).bool() 125 | overall_feat, _ = torch.stack(feat, 1).max(1) 126 | overall_res = self.get_output(overall_feat) 127 | return overall_res 128 | 129 | def get_feat(self, imgs, view_id): 130 | B, N, _, H, W = imgs.shape 131 | if isinstance(view_id, int): 132 | view_id = torch.tensor([view_id] * B) 133 | batch_id = torch.arange(B).to(imgs.device) 134 | view_id = view_id.to(imgs.device) 135 | 136 | # image and world feature maps from xy indexing, change them into world indexing / xy indexing (img) 137 | imgcoord_from_Rimggrid_mat = torch.diag(torch.tensor([self.img_reduce, self.img_reduce, 1]) 138 | ).unsqueeze(0).repeat(B, 1, 1).float() 139 | proj_mats = self.proj_mats.unsqueeze(0).repeat(B, 1, 1, 1)[batch_id, view_id] @ imgcoord_from_Rimggrid_mat 140 | 141 | imgs_feat = self.base(imgs[batch_id, view_id]) 142 | imgs_feat = self.bottleneck(imgs_feat) 143 | 144 | # world feat 145 | world_feat = warp_perspective(imgs_feat, proj_mats.to(imgs.device), self.Rworld_shape) 146 | 147 | return world_feat 148 | 149 | def get_output(self, world_feat): 150 | 151 | # world heads 152 | world_feat = self.world_feat(world_feat) 153 | world_heatmap = self.world_heatmap(world_feat) 154 | world_offset = self.world_offset(world_feat) 155 | 156 | return world_heatmap, world_offset 157 | 158 | 159 | if __name__ == '__main__': 160 | from src.datasets.frameDataset import frameDataset 161 | from src.datasets.wildtrack import Wildtrack 162 | from src.datasets.multiviewx import MultiviewX 163 | import torchvision.transforms as T 164 | from torch.utils.data import DataLoader 165 | from src.utils.decode import ctdet_decode 166 | from thop import profile 167 | import tqdm 168 | 169 | dataset = frameDataset(MultiviewX(os.path.expanduser('~/Data/MultiviewX')), split='test') 170 | dataloader = DataLoader(dataset, 1, num_workers=0) 171 | imgs, world_gt, imgs_gt, affine_mats, frame, keep_cams = next(iter(dataloader)) 172 | 173 | torch.backends.cudnn.benchmark = False 174 | model = MVDet(dataset).cuda() 175 | model.eval() 176 | # init_cam, step = 0, 2 177 | init_cam, step = None, None 178 | t0 = time.time() 179 | # avoid bottleneck @ dataloader 180 | for _ in tqdm.tqdm(range(1000)): 181 | model(imgs.cuda(), init_cam, step) 182 | print(time.time() - t0) 183 | pass 184 | -------------------------------------------------------------------------------- /src/trainer_mvcnn.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import random 3 | import time 4 | import copy 5 | import os 6 | import numpy as np 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | import matplotlib.pyplot as plt 11 | from PIL import Image 12 | from src.utils.meters import AverageMeter 13 | from src.trainer import BaseTrainer, find_instance_lvl_strategy, find_dataset_lvl_strategy 14 | from src.utils.image_utils import add_heatmap_to_image, img_color_denormalize 15 | from src.models.mvselect import aggregate_feat, get_eps_thres 16 | 17 | 18 | class ClassifierTrainer(BaseTrainer): 19 | def __init__(self, model, logdir, args, ): 20 | super(ClassifierTrainer, self).__init__(model, logdir, args, ) 21 | 22 | def task_loss_reward(self, overall_feat, tgt, step): 23 | output = self.model.get_output(overall_feat) 24 | task_loss = F.cross_entropy(output, tgt, reduction='none') 25 | # reward = torch.zeros_like(task_loss) if step < self.args.steps - 1 else (output.argmax(1) == tgt).float() 26 | reward = (self.last_loss - task_loss).detach() 27 | self.last_loss = task_loss.detach() 28 | return task_loss, reward 29 | 30 | def train(self, epoch, dataloader, optimizer, scheduler=None, log_interval=200): 31 | self.model.train() 32 | if self.args.base_lr_ratio == 0: 33 | self.model.base.eval() 34 | losses, correct, miss = 0, 0, 1e-8 35 | t0 = time.time() 36 | action_sum = torch.zeros([dataloader.dataset.num_cam]).cuda() 37 | return_avg = None 38 | for batch_idx, (imgs, tgt, keep_cams) in enumerate(dataloader): 39 | B, N = imgs.shape[:2] 40 | imgs, tgt = imgs.cuda(), tgt.cuda() 41 | feat, _ = self.model.get_feat(imgs, None, self.args.down) 42 | if self.args.steps: 43 | eps_thres = get_eps_thres(epoch - 1 + batch_idx / len(dataloader), self.args.epochs) 44 | loss, (action_sum, return_avg, value_loss) = \ 45 | self.expand_episode(feat, keep_cams, tgt, eps_thres, (action_sum, return_avg)) 46 | else: 47 | overall_feat = aggregate_feat(feat, keep_cams, self.model.aggregation) 48 | output = self.model.get_output(overall_feat) 49 | loss = F.cross_entropy(output, tgt) 50 | 51 | pred = torch.argmax(output, 1) 52 | correct += (pred == tgt).sum().item() 53 | miss += B - (pred == tgt).sum().item() 54 | 55 | optimizer.zero_grad() 56 | loss.backward() 57 | optimizer.step() 58 | 59 | losses += loss.item() 60 | 61 | if scheduler is not None: 62 | if isinstance(scheduler, torch.optim.lr_scheduler.OneCycleLR): 63 | scheduler.step() 64 | elif isinstance(scheduler, torch.optim.lr_scheduler.CosineAnnealingWarmRestarts) or \ 65 | isinstance(scheduler, torch.optim.lr_scheduler.LambdaLR): 66 | scheduler.step(epoch - 1 + batch_idx / len(dataloader)) 67 | # logging 68 | if (batch_idx + 1) % log_interval == 0 or batch_idx + 1 == len(dataloader): 69 | # print(cyclic_scheduler.last_epoch, optimizer.param_groups[0]['lr']) 70 | t1 = time.time() 71 | t_epoch = t1 - t0 72 | print(f'Train epoch: {epoch}, batch:{(batch_idx + 1)}, ' 73 | f'loss: {losses / (batch_idx + 1):.3f}, time: {t_epoch:.1f}') 74 | if self.args.steps: 75 | print(f'value loss: {value_loss:.3f}, eps: {eps_thres:.3f}, return: {return_avg[-1]:.2f}') 76 | # print(f'value loss: {value_loss:.3f}, policy loss: {policy_loss:.3f}, ' 77 | # f'return: {return_avg[-1]:.2f}, entropy: {entropies.mean():.3f}') 78 | print(' '.join('cam {} {:.2f} |'.format(cam, freq) for cam, freq in 79 | zip(range(N), F.normalize(action_sum, p=1, dim=0).cpu()))) 80 | pass 81 | return losses / len(dataloader), None if self.args.steps else correct / (correct + miss) * 100.0 82 | 83 | def test(self, dataloader, init_cam=None): 84 | t0 = time.time() 85 | self.model.eval() 86 | K = len(init_cam) if init_cam is not None else 1 87 | losses, correct, miss = torch.zeros([K]), torch.zeros([K]), torch.zeros([K]) + 1e-8 88 | action_sum = torch.zeros([K, dataloader.dataset.num_cam]).cuda() 89 | for batch_idx, (imgs, tgt, keep_cams) in enumerate(dataloader): 90 | B, N = imgs.shape[:2] 91 | imgs, tgt = imgs.cuda(), tgt.cuda() 92 | outputs, actions = [], [] 93 | with torch.no_grad(): 94 | if self.args.steps == 0 or init_cam is None: 95 | output, _, (_, _, action, _) = self.model(imgs, None, self.args.down) 96 | outputs.append(output) 97 | actions.append(action) 98 | else: 99 | feat, _ = self.model.get_feat(imgs, None, self.args.down) 100 | # K, B, N 101 | for k in range(K): 102 | overall_feat, (_, _, action, _) = \ 103 | self.model.do_steps(feat, init_cam[k].repeat([B, 1]), self.args.steps, keep_cams) 104 | output = self.model.get_output(overall_feat) 105 | outputs.append(output) 106 | actions.append(action) 107 | 108 | for k in range(K): 109 | loss = F.cross_entropy(outputs[k], tgt) 110 | if init_cam is not None: 111 | # record actions 112 | action_sum[k] += torch.cat(actions[k]).sum(dim=0) 113 | losses[k] += loss.item() 114 | 115 | pred = torch.argmax(outputs[k], 1) 116 | correct[k] += (pred == tgt).sum().item() 117 | miss[k] += B - (pred == tgt).sum().item() 118 | 119 | for k in range(K): 120 | if init_cam is not None: 121 | print(f'init camera {init_cam[k].nonzero()[0].item()}: MVSelect') 122 | idx = action_sum[k].nonzero().cpu()[:, 0] 123 | print(' '.join('cam {} {:.2f} |'.format(cam, freq) for cam, freq in 124 | zip(idx, F.normalize(action_sum[k], p=1, dim=0).cpu()[idx]))) 125 | 126 | print(f'Test, loss: {losses[k] / len(dataloader):.3f}, prec: {correct[k] / (correct[k] + miss[k]):.2%}' + 127 | ('' if init_cam is not None else f', time: {time.time() - t0:.1f}s')) 128 | 129 | if init_cam is not None: 130 | prec = correct / (correct + miss) 131 | print('*************************************') 132 | print(f'MVSelect average prec {prec.mean()*100:.1f}±{prec.std()*100:.1f}%, time: {time.time() - t0:.1f}') 133 | print('*************************************') 134 | return losses.mean() / len(dataloader), [correct.sum() / (correct + miss).sum() * 100.0, ] 135 | 136 | def test_cam_combination(self, dataloader, step=0): 137 | self.model.eval() 138 | t0 = time.time() 139 | candidates = np.eye(dataloader.dataset.num_cam) 140 | combinations = np.array(list(itertools.combinations(candidates, step + 1))).sum(1) 141 | K, N = combinations.shape 142 | loss_s, pred_s, gt_s = [], [], [] 143 | for batch_idx, (imgs, tgt, keep_cams) in enumerate(dataloader): 144 | B, N = imgs.shape[:2] 145 | gt_s.append(tgt) 146 | tgt = tgt.unsqueeze(0).repeat([K, 1]) 147 | # K, B, N 148 | with torch.no_grad(): 149 | output, _ = self.model.forward_combination(imgs.cuda(), None, self.args.down, combinations, keep_cams) 150 | loss = F.cross_entropy(output.flatten(0, 1), tgt.flatten(0, 1).cuda(), reduction="none") 151 | pred = torch.argmax(output, -1) 152 | loss_s.append(loss.unflatten(0, [K, B]).cpu()) 153 | pred_s.append(pred.cpu()) 154 | loss_s, pred_s, gt_s = torch.cat(loss_s, 1), torch.cat(pred_s, 1), torch.cat(gt_s) 155 | # K, num_frames 156 | tp_s = (pred_s == gt_s[None, :]).float() 157 | # instance level selection 158 | instance_lvl_strategy = find_instance_lvl_strategy(tp_s, combinations) 159 | instance_lvl_oracle = np.take_along_axis(tp_s, instance_lvl_strategy, axis=0).mean(1).numpy()[:, None] 160 | # dataset level selection 161 | keep_cam_idx = combinations[:, keep_cams[0].bool().numpy()].sum(1).astype(bool) 162 | dataset_lvl_prec = tp_s.mean(1).numpy()[:, None] 163 | dataset_lvl_strategy = find_dataset_lvl_strategy(dataset_lvl_prec, combinations) 164 | dataset_lvl_best_prec = dataset_lvl_prec[dataset_lvl_strategy] 165 | oracle_info = f'{step} steps, averave acc {dataset_lvl_prec[keep_cam_idx].mean()*100:.1f}±{dataset_lvl_prec[keep_cam_idx].std()*100:.1f}%, ' \ 166 | f'dataset lvl best {dataset_lvl_best_prec.mean()*100:.1f}±{dataset_lvl_best_prec.std()*100:.1f}%, ' \ 167 | f'instance lvl oracle {instance_lvl_oracle.mean()*100:.1f}±{instance_lvl_oracle.std()*100:.1f}%, time: {time.time() - t0:.1f}s' 168 | print(oracle_info) 169 | return loss_s.mean(1).numpy(), dataset_lvl_prec * 100.0, instance_lvl_oracle * 100.0, oracle_info 170 | -------------------------------------------------------------------------------- /src/models/mvdet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.transforms as T 7 | from kornia.geometry import warp_perspective 8 | from src.models.resnet import resnet18 9 | from src.models.shufflenetv2 import shufflenet_v2_x0_5 10 | from src.models.mvselect import CamSelect 11 | from src.models.multiview_base import MultiviewBase 12 | from src.utils.image_utils import img_color_denormalize, array2heatmap 13 | from src.utils.projection import get_worldcoord_from_imgcoord_mat, project_2d_points 14 | import matplotlib.pyplot as plt 15 | 16 | 17 | def fill_fc_weights(layers): 18 | for m in layers.modules(): 19 | if isinstance(m, nn.Conv2d): 20 | if m.bias is not None: 21 | nn.init.constant_(m.bias, 0) 22 | 23 | 24 | def output_head(in_dim, feat_dim, out_dim): 25 | if feat_dim: 26 | fc = nn.Sequential(nn.Conv2d(in_dim, feat_dim, 3, 1, 1), nn.ReLU(), 27 | nn.Conv2d(feat_dim, out_dim, 1)) 28 | else: 29 | fc = nn.Sequential(nn.Conv2d(in_dim, out_dim, 1)) 30 | return fc 31 | 32 | 33 | class MVDet(MultiviewBase): 34 | def __init__(self, dataset, arch='resnet18', aggregation='max', 35 | use_bottleneck=True, hidden_dim=128, outfeat_dim=0, z=0): 36 | super().__init__(dataset, aggregation) 37 | self.Rimg_shape, self.Rworld_shape = np.array(dataset.Rimg_shape), np.array(dataset.Rworld_shape) 38 | self.img_reduce = dataset.img_reduce 39 | 40 | # world grid change to xy indexing 41 | world_zoom_mat = np.diag([dataset.world_reduce, dataset.world_reduce, 1]) 42 | Rworldgrid_from_worldcoord_mat = np.linalg.inv( 43 | dataset.base.worldcoord_from_worldgrid_mat @ world_zoom_mat @ dataset.base.world_indexing_from_xy_mat) 44 | 45 | # z in meters by default 46 | # projection matrices: img feat -> world feat 47 | worldcoord_from_imgcoord_mats = [get_worldcoord_from_imgcoord_mat(dataset.base.intrinsic_matrices[cam], 48 | dataset.base.extrinsic_matrices[cam], 49 | z / dataset.base.worldcoord_unit) 50 | for cam in range(dataset.num_cam)] 51 | # Rworldgrid(xy)_from_imgcoord(xy) 52 | self.proj_mats = torch.stack([torch.from_numpy(Rworldgrid_from_worldcoord_mat @ 53 | worldcoord_from_imgcoord_mats[cam]) 54 | for cam in range(dataset.num_cam)]).float() 55 | 56 | if arch == 'resnet18': 57 | self.base = nn.Sequential(*list(resnet18(pretrained=True, 58 | replace_stride_with_dilation=[False, True, True]).children())[:-2]) 59 | base_dim = 512 60 | elif arch == 'shufflenet0.5': 61 | self.base = nn.Sequential(*list(shufflenet_v2_x0_5(pretrained=True, 62 | replace_stride_with_dilation=[False, True, True] 63 | ).children())[:-2]) 64 | base_dim = 192 65 | else: 66 | raise Exception('architecture currently support [vgg11, resnet18]') 67 | 68 | if use_bottleneck: 69 | self.bottleneck = nn.Sequential(nn.Conv2d(base_dim, hidden_dim, 1), nn.ReLU()) 70 | base_dim = hidden_dim 71 | else: 72 | self.bottleneck = nn.Identity() 73 | 74 | # img heads 75 | self.img_heatmap = output_head(base_dim, outfeat_dim, 1) 76 | self.img_offset = output_head(base_dim, outfeat_dim, 2) 77 | self.img_wh = output_head(base_dim, outfeat_dim, 2) 78 | # self.img_id = output_head(base_dim, outfeat_dim, len(dataset.pid_dict)) 79 | 80 | # world feat 81 | self.world_feat = nn.Sequential(nn.Conv2d(base_dim, hidden_dim, 3, padding=1), nn.ReLU(), 82 | nn.Conv2d(hidden_dim, hidden_dim, 3, padding=2, dilation=2), nn.ReLU(), 83 | nn.Conv2d(hidden_dim, hidden_dim, 3, padding=4, dilation=4), nn.ReLU(), ) 84 | 85 | # select camera based on initialization 86 | self.select_module = CamSelect(dataset.num_cam, hidden_dim, 3, aggregation) 87 | 88 | # world heads 89 | self.world_heatmap = output_head(hidden_dim, outfeat_dim, 1) 90 | self.world_offset = output_head(hidden_dim, outfeat_dim, 2) 91 | # self.world_id = output_head(hidden_dim, outfeat_dim, len(dataset.pid_dict)) 92 | 93 | # init 94 | self.img_heatmap[-1].bias.data.fill_(-2.19) 95 | fill_fc_weights(self.img_offset) 96 | fill_fc_weights(self.img_wh) 97 | self.world_heatmap[-1].bias.data.fill_(-2.19) 98 | fill_fc_weights(self.world_offset) 99 | pass 100 | 101 | def get_feat(self, imgs, M, down=1, visualize=False): 102 | B, N, _, H, W = imgs.shape 103 | imgs = F.interpolate(imgs.flatten(0, 1), scale_factor=1 / down) 104 | 105 | inverse_affine_mats = torch.inverse(M.view([B * N, 3, 3])) 106 | # image and world feature maps from xy indexing, change them into world indexing / xy indexing (img) 107 | imgcoord_from_Rimggrid_mat = inverse_affine_mats @ \ 108 | torch.diag(torch.tensor([self.img_reduce * down, self.img_reduce * down, 1]) 109 | ).unsqueeze(0).repeat(B * N, 1, 1).float() 110 | # Rworldgrid(xy)_from_Rimggrid(xy) 111 | # proj_mats = torch.diag(torch.tensor([1 / down, 1 / down, 1])).unsqueeze(0).repeat(B * N, 1, 1).float() @ \ 112 | # self.proj_mats.unsqueeze(0).repeat(B, 1, 1, 1).flatten(0, 1) @ imgcoord_from_Rimggrid_mat 113 | proj_mats = self.proj_mats[:N].unsqueeze(0).repeat(B, 1, 1, 1).flatten(0, 1) @ imgcoord_from_Rimggrid_mat 114 | 115 | if visualize: 116 | denorm = img_color_denormalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 117 | proj_imgs = warp_perspective(F.interpolate(imgs, scale_factor=1 / 8), proj_mats.to(imgs.device), 118 | (self.Rworld_shape / down).astype(int)).unflatten(0, [B, N]) 119 | for cam in range(N): 120 | visualize_img = T.ToPILImage()(denorm(imgs.detach())[cam * B]) 121 | # visualize_img.save(f'../../imgs/augimg{cam + 1}.png') 122 | plt.imshow(visualize_img) 123 | plt.show() 124 | visualize_img = T.ToPILImage()(denorm(proj_imgs.detach())[0, cam]) 125 | plt.imshow(visualize_img) 126 | plt.show() 127 | 128 | imgs_feat = self.base(imgs) 129 | imgs_feat = self.bottleneck(imgs_feat) 130 | 131 | # img heads 132 | imgs_heatmap = self.img_heatmap(imgs_feat) 133 | imgs_offset = self.img_offset(imgs_feat) 134 | imgs_wh = self.img_wh(imgs_feat) 135 | 136 | if visualize: 137 | for cam in range(N): 138 | visualize_img = array2heatmap(torch.norm(imgs_feat[cam * B].detach(), dim=0).cpu()) 139 | # visualize_img.save(f'../../imgs/augimgfeat{cam + 1}.png') 140 | plt.imshow(visualize_img) 141 | plt.show() 142 | 143 | # world feat 144 | world_feat = warp_perspective(imgs_feat, proj_mats.to(imgs.device), self.Rworld_shape).unflatten(0, [B, N]) 145 | 146 | if visualize: 147 | for cam in range(N): 148 | visualize_img = array2heatmap(torch.norm(world_feat[0, cam].detach(), dim=0).cpu()) 149 | # visualize_img.save(f'../../imgs/projfeat{cam + 1}.png') 150 | plt.imshow(visualize_img) 151 | plt.show() 152 | 153 | # world_feat = self.world_feat_pre(world_feat) * keep_cams.view(B * N, 1, 1, 1).to(imgs.device) 154 | return world_feat, (F.interpolate(imgs_heatmap, tuple(self.Rimg_shape)), 155 | F.interpolate(imgs_offset, tuple(self.Rimg_shape)), 156 | F.interpolate(imgs_wh, tuple(self.Rimg_shape))) 157 | 158 | def get_output(self, world_feat, visualize=False): 159 | 160 | # world heads 161 | world_feat = self.world_feat(world_feat) 162 | world_heatmap = self.world_heatmap(world_feat) 163 | world_offset = self.world_offset(world_feat) 164 | # world_id = self.world_id(world_feat) 165 | 166 | if visualize: 167 | visualize_img = array2heatmap(torch.norm(world_feat[0].detach(), dim=0).cpu()) 168 | # visualize_img.save(f'../../imgs/worldfeatall.png') 169 | plt.imshow(visualize_img) 170 | plt.show() 171 | visualize_img = array2heatmap(torch.sigmoid(world_heatmap.detach())[0, 0].cpu()) 172 | # visualize_img.save(f'../../imgs/worldres.png') 173 | plt.imshow(visualize_img) 174 | plt.show() 175 | 176 | return world_heatmap, world_offset 177 | 178 | 179 | 180 | if __name__ == '__main__': 181 | from src.datasets.frameDataset import frameDataset 182 | from src.datasets.wildtrack import Wildtrack 183 | from src.datasets.multiviewx import MultiviewX 184 | import torchvision.transforms as T 185 | from torch.utils.data import DataLoader 186 | from src.utils.decode import ctdet_decode 187 | from thop import profile 188 | 189 | dataset = frameDataset(MultiviewX(os.path.expanduser('~/Data/MultiviewX')), split='train', augmentation=True) 190 | dataloader = DataLoader(dataset, 1, False, num_workers=0) 191 | 192 | model = MVDet(dataset).cuda() 193 | imgs, world_gt, imgs_gt, affine_mats, frame, keep_cams = next(iter(dataloader)) 194 | keep_cams[0, 3] = 0 195 | init_cam = 0 196 | model.train() 197 | (world_heatmap, world_offset), _, cam_train = model(imgs.cuda(), affine_mats, 2, init_cam, 3) 198 | xysc_train = ctdet_decode(world_heatmap, world_offset) 199 | # macs, params = profile(model, inputs=(imgs[:, :3].cuda(), affine_mats[:, :3].contiguous())) 200 | # macs, params = profile(model.select_module, inputs=(torch.randn([1, 128, 160, 250]).cuda(), 201 | # F.one_hot(torch.tensor([1]), num_classes=6).cuda())) 202 | # macs, params = profile(model, inputs=(torch.rand([1, 128, 160, 250]).cuda(),)) 203 | pass 204 | -------------------------------------------------------------------------------- /src/models/shufflenetv2.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Any, List 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | 7 | from torch.hub import load_state_dict_from_url 8 | from torchvision.utils import _log_api_usage_once 9 | 10 | __all__ = ["ShuffleNetV2", "shufflenet_v2_x0_5", "shufflenet_v2_x1_0", "shufflenet_v2_x1_5", "shufflenet_v2_x2_0"] 11 | 12 | model_urls = { 13 | "shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", 14 | "shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", 15 | "shufflenetv2_x1.5": None, 16 | "shufflenetv2_x2.0": None, 17 | } 18 | 19 | 20 | def channel_shuffle(x: Tensor, groups: int) -> Tensor: 21 | batchsize, num_channels, height, width = x.size() 22 | channels_per_group = num_channels // groups 23 | 24 | # reshape 25 | x = x.view(batchsize, groups, channels_per_group, height, width) 26 | 27 | x = torch.transpose(x, 1, 2).contiguous() 28 | 29 | # flatten 30 | x = x.view(batchsize, -1, height, width) 31 | 32 | return x 33 | 34 | 35 | class InvertedResidual(nn.Module): 36 | def __init__(self, inp: int, oup: int, stride: int, stride2dilation: bool = False, dilation: int = 1) -> None: 37 | super().__init__() 38 | 39 | if not (1 <= stride <= 3): 40 | raise ValueError("illegal stride value") 41 | self.stride = stride 42 | self.stride2dilation = stride2dilation 43 | self.dilation = dilation 44 | 45 | branch_features = oup // 2 46 | assert self.stride != 1 or self.stride2dilation or (inp == branch_features << 1) 47 | 48 | if self.stride > 1 or self.stride2dilation: 49 | self.branch1 = nn.Sequential( 50 | self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, dilation=self.dilation), 51 | nn.BatchNorm2d(inp), 52 | nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 53 | nn.BatchNorm2d(branch_features), 54 | nn.ReLU(inplace=True), 55 | ) 56 | else: 57 | self.branch1 = nn.Sequential() 58 | 59 | self.branch2 = nn.Sequential( 60 | nn.Conv2d( 61 | inp if self.stride > 1 or self.stride2dilation else branch_features, 62 | branch_features, 63 | kernel_size=1, 64 | stride=1, 65 | padding=0, 66 | bias=False, 67 | ), 68 | nn.BatchNorm2d(branch_features), 69 | nn.ReLU(inplace=True), 70 | self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, 71 | dilation=self.dilation), 72 | nn.BatchNorm2d(branch_features), 73 | nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 74 | nn.BatchNorm2d(branch_features), 75 | nn.ReLU(inplace=True), 76 | ) 77 | 78 | @staticmethod 79 | def depthwise_conv( 80 | i: int, o: int, kernel_size: int, stride: int = 1, bias: bool = False, dilation: int = 1 81 | ) -> nn.Conv2d: 82 | return nn.Conv2d(i, o, kernel_size, stride, dilation, bias=bias, groups=i, dilation=dilation) 83 | 84 | def forward(self, x: Tensor) -> Tensor: 85 | if self.stride == 1 and not self.stride2dilation: 86 | x1, x2 = x.chunk(2, dim=1) 87 | out = torch.cat((x1, self.branch2(x2)), dim=1) 88 | else: 89 | out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) 90 | 91 | out = channel_shuffle(out, 2) 92 | 93 | return out 94 | 95 | 96 | class ShuffleNetV2(nn.Module): 97 | def __init__( 98 | self, 99 | stages_repeats: List[int], 100 | stages_out_channels: List[int], 101 | num_classes: int = 1000, 102 | replace_stride_with_dilation=None, 103 | ) -> None: 104 | super().__init__() 105 | _log_api_usage_once(self) 106 | 107 | if len(stages_repeats) != 3: 108 | raise ValueError("expected stages_repeats as list of 3 positive ints") 109 | if len(stages_out_channels) != 5: 110 | raise ValueError("expected stages_out_channels as list of 5 positive ints") 111 | self._stage_out_channels = stages_out_channels 112 | 113 | self.dilation = 1 114 | if replace_stride_with_dilation is None: 115 | # each element in the tuple indicates if we should replace 116 | # the 2x2 stride with a dilated convolution instead 117 | replace_stride_with_dilation = [False, False, False] 118 | if len(replace_stride_with_dilation) != 3: 119 | raise ValueError("replace_stride_with_dilation should be None " 120 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 121 | 122 | input_channels = 3 123 | output_channels = self._stage_out_channels[0] 124 | self.conv1 = nn.Sequential( 125 | nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), 126 | nn.BatchNorm2d(output_channels), 127 | nn.ReLU(inplace=True), 128 | ) 129 | input_channels = output_channels 130 | 131 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 132 | 133 | # Static annotations for mypy 134 | self.stage2: nn.Sequential 135 | self.stage3: nn.Sequential 136 | self.stage4: nn.Sequential 137 | stage_names = [f"stage{i}" for i in [2, 3, 4]] 138 | for i, (name, repeats, output_channels) in enumerate( 139 | zip(stage_names, stages_repeats, self._stage_out_channels[1:])): 140 | seq = [self._make_layer(input_channels, output_channels, 2, replace_stride_with_dilation[i])] 141 | for _ in range(repeats - 1): 142 | seq.append(self._make_layer(output_channels, output_channels, 1)) 143 | setattr(self, name, nn.Sequential(*seq)) 144 | input_channels = output_channels 145 | 146 | output_channels = self._stage_out_channels[-1] 147 | self.conv5 = nn.Sequential( 148 | nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False), 149 | nn.BatchNorm2d(output_channels), 150 | nn.ReLU(inplace=True), 151 | ) 152 | 153 | self.fc = nn.Linear(output_channels, num_classes) 154 | 155 | def _make_layer(self, input_channels, output_channels, stride=1, dilate=False): 156 | if dilate: 157 | self.dilation *= stride 158 | stride = 1 159 | return InvertedResidual(input_channels, output_channels, stride, dilate, self.dilation) 160 | 161 | def _forward_impl(self, x: Tensor) -> Tensor: 162 | # See note [TorchScript super()] 163 | x = self.conv1(x) 164 | x = self.maxpool(x) 165 | x = self.stage2(x) 166 | x = self.stage3(x) 167 | x = self.stage4(x) 168 | x = self.conv5(x) 169 | x = x.mean([2, 3]) # globalpool 170 | x = self.fc(x) 171 | return x 172 | 173 | def forward(self, x: Tensor) -> Tensor: 174 | return self._forward_impl(x) 175 | 176 | 177 | def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwargs: Any) -> ShuffleNetV2: 178 | model = ShuffleNetV2(*args, **kwargs) 179 | 180 | if pretrained: 181 | model_url = model_urls[arch] 182 | if model_url is None: 183 | raise ValueError(f"No checkpoint is available for model type {arch}") 184 | else: 185 | state_dict = load_state_dict_from_url(model_url, progress=progress) 186 | model.load_state_dict(state_dict) 187 | 188 | return model 189 | 190 | 191 | def shufflenet_v2_x0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: 192 | """ 193 | Constructs a ShuffleNetV2 with 0.5x output channels, as described in 194 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 195 | `_. 196 | 197 | Args: 198 | pretrained (bool): If True, returns a model pre-trained on ImageNet 199 | progress (bool): If True, displays a progress bar of the download to stderr 200 | """ 201 | return _shufflenetv2("shufflenetv2_x0.5", pretrained, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) 202 | 203 | 204 | def shufflenet_v2_x1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: 205 | """ 206 | Constructs a ShuffleNetV2 with 1.0x output channels, as described in 207 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 208 | `_. 209 | 210 | Args: 211 | pretrained (bool): If True, returns a model pre-trained on ImageNet 212 | progress (bool): If True, displays a progress bar of the download to stderr 213 | """ 214 | return _shufflenetv2("shufflenetv2_x1.0", pretrained, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) 215 | 216 | 217 | def shufflenet_v2_x1_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: 218 | """ 219 | Constructs a ShuffleNetV2 with 1.5x output channels, as described in 220 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 221 | `_. 222 | 223 | Args: 224 | pretrained (bool): If True, returns a model pre-trained on ImageNet 225 | progress (bool): If True, displays a progress bar of the download to stderr 226 | """ 227 | return _shufflenetv2("shufflenetv2_x1.5", pretrained, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) 228 | 229 | 230 | def shufflenet_v2_x2_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: 231 | """ 232 | Constructs a ShuffleNetV2 with 2.0x output channels, as described in 233 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 234 | `_. 235 | 236 | Args: 237 | pretrained (bool): If True, returns a model pre-trained on ImageNet 238 | progress (bool): If True, displays a progress bar of the download to stderr 239 | """ 240 | return _shufflenetv2("shufflenetv2_x2.0", pretrained, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) 241 | 242 | 243 | if __name__ == '__main__': 244 | img = torch.zeros([14, 3, 720, 1280]) 245 | model = shufflenet_v2_x0_5(pretrained=True, replace_stride_with_dilation=[False, True, True]) 246 | output = model(img) 247 | -------------------------------------------------------------------------------- /src/utils/mvrender.py: -------------------------------------------------------------------------------- 1 | # this is borrowed from https://github.com/ajhamdi/mvtorch/blob/main/mvtorch/mvrenderer.py 2 | 3 | from mvtorch.utils import * 4 | from mvtorch.ops import check_and_correct_rotation_matrix 5 | 6 | import torch 7 | from torch.autograd import Variable 8 | from torch import nn 9 | import numpy as np 10 | from pytorch3d.structures import Meshes, Pointclouds 11 | from pytorch3d.renderer.mesh import Textures 12 | from pytorch3d.renderer import (OpenGLPerspectiveCameras, look_at_view_transform, 13 | RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams, 14 | HardPhongShader, 15 | OpenGLOrthographicCameras, 16 | PointsRasterizationSettings, 17 | PointsRasterizer, 18 | NormWeightedCompositor, DirectionalLights) 19 | from pytorch3d.renderer.cameras import camera_position_from_spherical_angles 20 | 21 | ORTHOGONAL_THRESHOLD = 1e-6 22 | EXAHSTION_LIMIT = 20 23 | 24 | 25 | class MVRenderer(nn.Module): 26 | """ 27 | The Multi-view differntiable renderer main class that render multiple views differntiably from some given viewpoints. It can render meshes and point clouds as well 28 | Args: 29 | `nb_views` int , The number of views used in the multi-view setup 30 | `image_size` int , The image sizes of the rendered views. 31 | `pc_rendering` : bool , flag to use point cloud rendering instead of mesh rendering 32 | `object_color` : str , The color setup of the objects/points rendered. Choices: ["white", "random","black","red","green","blue", "custom"] 33 | `background_color` : str , The color setup of the rendering background. Choices: ["white", "random","black","red","green","blue", "custom"] 34 | `faces_per_pixel` int , The number of faces rendered per pixel when mesh rendering is used (`pc_rendering` == `False`) . 35 | `points_radius`: float , the radius of the points rendered. The more points in a specific `image_size`, the less radius required for proper rendering. 36 | `points_per_pixel` int , The number of points rendered per pixel when point cloud rendering is used (`pc_rendering` == `True`) . 37 | `light_direction` : str , The setup of the light used in rendering when mesh rendering is available. Choices: ["fixed", "random", "relative"] 38 | `cull_backfaces` : bool , Allow backface-culling when rendering meshes (`pc_rendering` == `False`). 39 | 40 | Returns: 41 | an MVTN object that can render multiple views according to predefined setup 42 | """ 43 | 44 | def __init__(self, nb_views, image_size=224, pc_rendering=True, object_color="white", background_color="black", 45 | faces_per_pixel=1, points_radius=0.006, points_per_pixel=1, light_direction="random", 46 | cull_backfaces=False, return_mapping=True): 47 | super().__init__() 48 | self.nb_views = nb_views 49 | self.image_size = image_size 50 | self.pc_rendering = pc_rendering 51 | self.object_color = object_color 52 | self.background_color = background_color 53 | self.faces_per_pixel = faces_per_pixel 54 | self.points_radius = points_radius 55 | self.points_per_pixel = points_per_pixel 56 | self.light_direction_type = light_direction 57 | self.cull_backfaces = cull_backfaces 58 | self.return_mapping = return_mapping 59 | 60 | def render_meshes(self, meshes, color, azim, elev, dist, lights, background_color=(1.0, 1.0, 1.0), ): 61 | c_batch_size = len(meshes) 62 | verts = [msh.verts_list()[0].cuda() for msh in meshes] 63 | faces = [msh.faces_list()[0].cuda() for msh in meshes] 64 | new_meshes = Meshes(verts=verts, 65 | faces=faces, 66 | textures=None) 67 | max_vert = new_meshes.verts_padded().shape[1] 68 | 69 | new_meshes.textures = Textures(verts_rgb=color.cuda() * torch.ones((c_batch_size, max_vert, 3)).cuda()) 70 | 71 | # Create a Meshes object for the teapot. Here we have only one mesh in the batch. 72 | R, T = look_at_view_transform(dist=batch_tensor(dist.T, dim=1, squeeze=True), 73 | elev=batch_tensor(elev.T, dim=1, squeeze=True), 74 | azim=batch_tensor(azim.T, dim=1, squeeze=True)) 75 | R, T = check_and_correct_rotation_matrix(R, T, EXAHSTION_LIMIT, azim, elev, dist) 76 | 77 | cameras = OpenGLPerspectiveCameras(device="cuda:{}".format(torch.cuda.current_device()), R=R, T=T) 78 | camera = OpenGLPerspectiveCameras(device="cuda:{}".format(torch.cuda.current_device()), 79 | R=R[None, 0, ...], T=T[None, 0, ...]) 80 | 81 | raster_settings = RasterizationSettings(image_size=self.image_size, 82 | blur_radius=0.0, 83 | faces_per_pixel=self.faces_per_pixel, 84 | cull_backfaces=self.cull_backfaces, 85 | ) 86 | renderer = MeshRenderer(rasterizer=MeshRasterizer(cameras=camera, raster_settings=raster_settings), 87 | shader=HardPhongShader(blend_params=BlendParams(background_color=background_color), 88 | device=lights.device, cameras=camera, lights=lights)) 89 | new_meshes = new_meshes.extend(self.nb_views) 90 | 91 | rendered_images = renderer(new_meshes, cameras=cameras, lights=lights) 92 | 93 | rendered_images = unbatch_tensor(rendered_images, batch_size=self.nb_views, 94 | dim=1, unsqueeze=True).transpose(0, 1) 95 | 96 | rendered_images = rendered_images[..., 0:3].transpose(2, 4).transpose(3, 4) 97 | return rendered_images, cameras 98 | 99 | def render_points(self, points, color, azim, elev, dist, background_color=(0.0, 0.0, 0.0), ): 100 | point_cloud = Pointclouds(points=points.float(), 101 | features=color.to(points.device) * torch.ones_like(points, dtype=torch.float)).cuda() 102 | 103 | # Create a Meshes object for the teapot. Here we have only one mesh in the batch. 104 | R, T = look_at_view_transform(dist=batch_tensor(dist.T, dim=1, squeeze=True), 105 | elev=batch_tensor(elev.T, dim=1, squeeze=True), 106 | azim=batch_tensor(azim.T, dim=1, squeeze=True)) 107 | R, T = check_and_correct_rotation_matrix(R, T, EXAHSTION_LIMIT, azim, elev, dist) 108 | 109 | cameras = OpenGLOrthographicCameras(device="cuda:{}".format(torch.cuda.current_device()), R=R, T=T, znear=0.01) 110 | raster_settings = PointsRasterizationSettings(image_size=self.image_size, 111 | radius=self.points_radius, 112 | points_per_pixel=self.points_per_pixel) 113 | rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings) 114 | compositor = NormWeightedCompositor(background_color=background_color) 115 | 116 | point_cloud = point_cloud.extend(self.nb_views) 117 | point_cloud.scale_(batch_tensor(1.0 / dist.T, dim=1, squeeze=True)[..., None][..., None]) 118 | 119 | fragments = rasterizer(point_cloud, ) 120 | 121 | # Construct weights based on the distance of a point to the true point. 122 | # However, this could be done differently: e.g. predicted as opposed 123 | # to a function of the weights. 124 | r = rasterizer.raster_settings.radius 125 | dists2 = fragments.dists.permute(0, 3, 1, 2) 126 | weights = 1 - dists2 / (r * r) 127 | rendered_images = compositor(fragments.idx.long().permute(0, 3, 1, 2), 128 | weights, point_cloud.features_packed().permute(1, 0), ) 129 | # permute so image comes at the end 130 | rendered_images = rendered_images.permute(0, 2, 3, 1) 131 | 132 | rendered_images = unbatch_tensor(rendered_images, batch_size=self.nb_views, 133 | dim=1, unsqueeze=True).transpose(0, 1) 134 | weights = unbatch_tensor(weights, batch_size=self.nb_views, dim=1, unsqueeze=True).transpose(0, 1) 135 | indxs = unbatch_tensor(fragments.idx.long().permute(0, 3, 1, 2), batch_size=self.nb_views, 136 | dim=1, unsqueeze=True).transpose(0, 1) 137 | 138 | rendered_images = rendered_images[..., 0:3].transpose(2, 4).transpose(3, 4) 139 | return rendered_images, indxs, weights, cameras 140 | 141 | def rendering_color(self, custom_color=(1.0, 0, 0)): 142 | if self.object_color == "custom": 143 | color = custom_color 144 | elif self.object_color == "random" and not self.training: 145 | color = torch_color("white") 146 | else: 147 | color = torch_color(self.object_color, max_lightness=True, ) 148 | return color 149 | 150 | def light_direction(self, azim, elev, dist): 151 | if self.light_direction_type == "fixed": 152 | return ((0, 1.0, 0),) 153 | elif self.light_direction_type == "random" and self.training: 154 | return (tuple(1.0 - 2 * np.random.rand(3)),) 155 | else: 156 | relative_view = Variable( 157 | camera_position_from_spherical_angles(distance=batch_tensor(dist.T, dim=1, squeeze=True), 158 | elevation=batch_tensor(elev.T, dim=1, squeeze=True), 159 | azimuth=batch_tensor(azim.T, dim=1, squeeze=True))).float() 160 | return relative_view 161 | 162 | def forward(self, meshes, points, azim, elev, dist, color=None): 163 | """ 164 | The main rendering function of the MVRenderer class. It can render meshes (if `self.pc_rendering` == `False`) or 3D point clouds(if `self.pc_rendering` == `True`). 165 | Arge: 166 | `meshes`: a list of B `Pytorch3D.Mesh` to be rendered , B batch size. In case not available, just pass `None`. 167 | `points`: B * N * 3 tensor, a batch of B point clouds to be rendered where each point cloud has N points and each point has X,Y,Z property. In case not available, just pass `None` . 168 | `azim`: B * M tensor, a B batch of M azimth angles that represent the azimth angles of the M view-points to render the points or meshes from. 169 | `elev`: B * M tensor, a B batch of M elevation angles that represent the elevation angles of the M view-points to render the points or meshes from. 170 | `dist`: B * M tensor, a B batch of M unit distances that represent the distances of the M view-points to render the points or meshes from. 171 | `color`: B * N * 3 tensor, The RGB colors of batch of point clouds/meshes with N is the number of points/vertices and B batch size. Only if `self.object_color` == `custom`, otherwise this option not used 172 | 173 | """ 174 | background_color = torch_color(self.background_color, self.background_color, max_lightness=True, ).cuda() 175 | color = self.rendering_color(color) 176 | 177 | if not self.pc_rendering: 178 | lights = DirectionalLights(device=background_color.device, direction=self.light_direction(azim, elev, dist)) 179 | 180 | rendered_images, cameras = self.render_meshes(meshes=meshes, color=color, azim=azim, elev=elev, dist=dist, 181 | lights=lights, background_color=background_color) 182 | else: 183 | rendered_images, indxs, weights, cameras = self.render_points(points=points, color=color, azim=azim, 184 | elev=elev, dist=dist, 185 | background_color=background_color) 186 | if self.return_mapping: 187 | return rendered_images, indxs, weights, cameras 188 | else: 189 | return rendered_images, cameras 190 | 191 | def render_and_save(self, meshes, points, azim, elev, dist, images_path, cameras_path, color=None): 192 | with torch.no_grad(): 193 | rendered_images, cameras = self.forward(meshes, points, azim, elev, dist, color) 194 | save_grid(image_batch=rendered_images[0, ...], 195 | save_path=images_path, nrow=self.nb_views) 196 | save_cameras(cameras, save_path=cameras_path, scale=0.22, dpi=200) 197 | -------------------------------------------------------------------------------- /src/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.hub import load_state_dict_from_url 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 7 | 'wide_resnet50_2', 'wide_resnet101_2'] 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 16 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 17 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 18 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 19 | } 20 | 21 | 22 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 23 | """3x3 convolution with padding""" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 25 | padding=dilation, groups=groups, bias=False, dilation=dilation) 26 | 27 | 28 | def conv1x1(in_planes, out_planes, stride=1): 29 | """1x1 convolution""" 30 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 31 | 32 | 33 | class BasicBlock(nn.Module): 34 | expansion = 1 35 | 36 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 37 | base_width=64, dilation=1, norm_layer=None): 38 | super(BasicBlock, self).__init__() 39 | if norm_layer is None: 40 | norm_layer = nn.BatchNorm2d 41 | if groups != 1 or base_width != 64: 42 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 43 | # if dilation > 1: 44 | # raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 45 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 46 | self.conv1 = conv3x3(inplanes, planes, stride, dilation=dilation) 47 | self.bn1 = norm_layer(planes) 48 | self.relu = nn.ReLU(inplace=True) 49 | self.conv2 = conv3x3(planes, planes) 50 | self.bn2 = norm_layer(planes) 51 | self.downsample = downsample 52 | self.stride = stride 53 | 54 | def forward(self, x): 55 | identity = x 56 | 57 | out = self.conv1(x) 58 | out = self.bn1(out) 59 | out = self.relu(out) 60 | 61 | out = self.conv2(out) 62 | out = self.bn2(out) 63 | 64 | if self.downsample is not None: 65 | identity = self.downsample(x) 66 | 67 | out += identity 68 | out = self.relu(out) 69 | 70 | return out 71 | 72 | 73 | class Bottleneck(nn.Module): 74 | expansion = 4 75 | 76 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 77 | base_width=64, dilation=1, norm_layer=None): 78 | super(Bottleneck, self).__init__() 79 | if norm_layer is None: 80 | norm_layer = nn.BatchNorm2d 81 | width = int(planes * (base_width / 64.)) * groups 82 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 83 | self.conv1 = conv1x1(inplanes, width) 84 | self.bn1 = norm_layer(width) 85 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 86 | self.bn2 = norm_layer(width) 87 | self.conv3 = conv1x1(width, planes * self.expansion) 88 | self.bn3 = norm_layer(planes * self.expansion) 89 | self.relu = nn.ReLU(inplace=True) 90 | self.downsample = downsample 91 | self.stride = stride 92 | 93 | def forward(self, x): 94 | identity = x 95 | 96 | out = self.conv1(x) 97 | out = self.bn1(out) 98 | out = self.relu(out) 99 | 100 | out = self.conv2(out) 101 | out = self.bn2(out) 102 | out = self.relu(out) 103 | 104 | out = self.conv3(out) 105 | out = self.bn3(out) 106 | 107 | if self.downsample is not None: 108 | identity = self.downsample(x) 109 | 110 | out += identity 111 | out = self.relu(out) 112 | 113 | return out 114 | 115 | 116 | class ResNet(nn.Module): 117 | 118 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 119 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 120 | norm_layer=None, in_channels=3): 121 | super(ResNet, self).__init__() 122 | if norm_layer is None: 123 | norm_layer = nn.BatchNorm2d 124 | self._norm_layer = norm_layer 125 | 126 | self.inplanes = 64 127 | self.dilation = 1 128 | if replace_stride_with_dilation is None: 129 | # each element in the tuple indicates if we should replace 130 | # the 2x2 stride with a dilated convolution instead 131 | replace_stride_with_dilation = [False, False, False] 132 | if len(replace_stride_with_dilation) != 3: 133 | raise ValueError("replace_stride_with_dilation should be None " 134 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 135 | self.groups = groups 136 | self.base_width = width_per_group 137 | self.conv1 = nn.Conv2d(in_channels, self.inplanes, kernel_size=7, stride=2, padding=3, 138 | bias=False) 139 | self.bn1 = norm_layer(self.inplanes) 140 | self.relu = nn.ReLU(inplace=True) 141 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 142 | self.layer1 = self._make_layer(block, 64, layers[0]) 143 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 144 | dilate=replace_stride_with_dilation[0]) 145 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 146 | dilate=replace_stride_with_dilation[1]) 147 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 148 | dilate=replace_stride_with_dilation[2]) 149 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 150 | self.fc = nn.Linear(512 * block.expansion, num_classes) 151 | 152 | for m in self.modules(): 153 | if isinstance(m, nn.Conv2d): 154 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 155 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 156 | nn.init.constant_(m.weight, 1) 157 | nn.init.constant_(m.bias, 0) 158 | 159 | # Zero-initialize the last BN in each residual branch, 160 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 161 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 162 | if zero_init_residual: 163 | for m in self.modules(): 164 | if isinstance(m, Bottleneck): 165 | nn.init.constant_(m.bn3.weight, 0) 166 | elif isinstance(m, BasicBlock): 167 | nn.init.constant_(m.bn2.weight, 0) 168 | 169 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 170 | norm_layer = self._norm_layer 171 | downsample = None 172 | previous_dilation = self.dilation 173 | if dilate: 174 | self.dilation *= stride 175 | stride = 1 176 | if stride != 1 or self.inplanes != planes * block.expansion: 177 | downsample = nn.Sequential( 178 | conv1x1(self.inplanes, planes * block.expansion, stride), 179 | norm_layer(planes * block.expansion), 180 | ) 181 | 182 | layers = [] 183 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 184 | self.base_width, previous_dilation, norm_layer)) 185 | self.inplanes = planes * block.expansion 186 | for _ in range(1, blocks): 187 | layers.append(block(self.inplanes, planes, groups=self.groups, 188 | base_width=self.base_width, dilation=self.dilation, 189 | norm_layer=norm_layer)) 190 | 191 | return nn.Sequential(*layers) 192 | 193 | def forward(self, x): 194 | x = self.conv1(x) 195 | x = self.bn1(x) 196 | x = self.relu(x) 197 | x = self.maxpool(x) 198 | 199 | x = self.layer1(x) 200 | x = self.layer2(x) 201 | x = self.layer3(x) 202 | x = self.layer4(x) 203 | 204 | x = self.avgpool(x) 205 | x = torch.flatten(x, 1) 206 | x = self.fc(x) 207 | 208 | return x 209 | 210 | 211 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 212 | model = ResNet(block, layers, **kwargs) 213 | if pretrained: 214 | state_dict = load_state_dict_from_url(model_urls[arch], 215 | progress=progress) 216 | model.load_state_dict(state_dict) 217 | return model 218 | 219 | 220 | def resnet18(pretrained=False, progress=True, **kwargs): 221 | r"""ResNet-18 model from 222 | `"Deep Residual Learning for Image Recognition" `_ 223 | 224 | Args: 225 | pretrained (bool): If True, returns a model pre-trained on ImageNet 226 | progress (bool): If True, displays a progress bar of the download to stderr 227 | """ 228 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 229 | **kwargs) 230 | 231 | 232 | def resnet34(pretrained=False, progress=True, **kwargs): 233 | r"""ResNet-34 model from 234 | `"Deep Residual Learning for Image Recognition" `_ 235 | 236 | Args: 237 | pretrained (bool): If True, returns a model pre-trained on ImageNet 238 | progress (bool): If True, displays a progress bar of the download to stderr 239 | """ 240 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 241 | **kwargs) 242 | 243 | 244 | def resnet50(pretrained=False, progress=True, **kwargs): 245 | r"""ResNet-50 model from 246 | `"Deep Residual Learning for Image Recognition" `_ 247 | 248 | Args: 249 | pretrained (bool): If True, returns a model pre-trained on ImageNet 250 | progress (bool): If True, displays a progress bar of the download to stderr 251 | """ 252 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 253 | **kwargs) 254 | 255 | 256 | def resnet101(pretrained=False, progress=True, **kwargs): 257 | r"""ResNet-101 model from 258 | `"Deep Residual Learning for Image Recognition" `_ 259 | 260 | Args: 261 | pretrained (bool): If True, returns a model pre-trained on ImageNet 262 | progress (bool): If True, displays a progress bar of the download to stderr 263 | """ 264 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 265 | **kwargs) 266 | 267 | 268 | def resnet152(pretrained=False, progress=True, **kwargs): 269 | r"""ResNet-152 model from 270 | `"Deep Residual Learning for Image Recognition" `_ 271 | 272 | Args: 273 | pretrained (bool): If True, returns a model pre-trained on ImageNet 274 | progress (bool): If True, displays a progress bar of the download to stderr 275 | """ 276 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 277 | **kwargs) 278 | 279 | 280 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 281 | r"""ResNeXt-50 32x4d model from 282 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 283 | 284 | Args: 285 | pretrained (bool): If True, returns a model pre-trained on ImageNet 286 | progress (bool): If True, displays a progress bar of the download to stderr 287 | """ 288 | kwargs['groups'] = 32 289 | kwargs['width_per_group'] = 4 290 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 291 | pretrained, progress, **kwargs) 292 | 293 | 294 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 295 | r"""ResNeXt-101 32x8d model from 296 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 297 | 298 | Args: 299 | pretrained (bool): If True, returns a model pre-trained on ImageNet 300 | progress (bool): If True, displays a progress bar of the download to stderr 301 | """ 302 | kwargs['groups'] = 32 303 | kwargs['width_per_group'] = 8 304 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 305 | pretrained, progress, **kwargs) 306 | 307 | 308 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 309 | r"""Wide ResNet-50-2 model from 310 | `"Wide Residual Networks" `_ 311 | 312 | The model is the same as ResNet except for the bottleneck number of channels 313 | which is twice larger in every block. The number of channels in outer 1x1 314 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 315 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 316 | 317 | Args: 318 | pretrained (bool): If True, returns a model pre-trained on ImageNet 319 | progress (bool): If True, displays a progress bar of the download to stderr 320 | """ 321 | kwargs['width_per_group'] = 64 * 2 322 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 323 | pretrained, progress, **kwargs) 324 | 325 | 326 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 327 | r"""Wide ResNet-101-2 model from 328 | `"Wide Residual Networks" `_ 329 | 330 | The model is the same as ResNet except for the bottleneck number of channels 331 | which is twice larger in every block. The number of channels in outer 1x1 332 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 333 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 334 | 335 | Args: 336 | pretrained (bool): If True, returns a model pre-trained on ImageNet 337 | progress (bool): If True, displays a progress bar of the download to stderr 338 | """ 339 | kwargs['width_per_group'] = 64 * 2 340 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 341 | pretrained, progress, **kwargs) 342 | 343 | 344 | if __name__ == '__main__': 345 | img = torch.zeros([14, 3, 720, 1280]) 346 | model = resnet18(pretrained=True, replace_stride_with_dilation=[False, True, True]) 347 | output = model(img) 348 | -------------------------------------------------------------------------------- /src/datasets/scanobjectnn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import shutil 4 | from tqdm import tqdm 5 | from typing import Dict, List, Optional, Tuple 6 | import glob 7 | import h5py 8 | import pandas as pd 9 | import numpy as np 10 | import torch 11 | from torch.utils.data import Dataset 12 | import torchvision.transforms as T 13 | import matplotlib.pyplot as plt 14 | from src.datasets import ModelNet40 15 | 16 | from PIL import Image 17 | import open3d as o3d 18 | # Data structures and functions for rendering 19 | from pytorch3d.structures import Meshes, Pointclouds 20 | from pytorch3d.vis.plotly_vis import AxisArgs, plot_batch_individually, plot_scene 21 | from pytorch3d.vis.texture_vis import texturesuv_image_matplotlib 22 | from pytorch3d.renderer import ( 23 | look_at_view_transform, 24 | FoVPerspectiveCameras, 25 | PointLights, 26 | DirectionalLights, 27 | Materials, 28 | RasterizationSettings, 29 | MeshRenderer, 30 | MeshRasterizer, 31 | SoftPhongShader, 32 | HardPhongShader, 33 | Textures, 34 | FoVOrthographicCameras, 35 | PointsRasterizationSettings, 36 | PointsRenderer, 37 | PulsarPointsRenderer, 38 | PointsRasterizer, 39 | AlphaCompositor, 40 | NormWeightedCompositor 41 | ) 42 | 43 | 44 | # this code is borrowed from https://github.com/ajhamdi/mvtorch/blob/main/mvtorch/data.py 45 | class ScanObjectNNPointCloud(Dataset): 46 | """ 47 | This class loads ScanObjectNN from a given directory into a Dataset object. 48 | ScanObjjectNN is a point cloud dataset of realistic shapes of from the ScanNet dataset and can be downloaded from 49 | https://github.com/hkust-vgd/scanobjectnn . 50 | """ 51 | 52 | def __init__( 53 | self, 54 | data_dir, 55 | split, 56 | nb_points=100000, 57 | normals: bool = False, 58 | suncg: bool = False, 59 | variant: str = "obj_only", 60 | dset_norm: str = "inf", 61 | 62 | ): 63 | """ 64 | Store each object's synset id and models id from data_dir. 65 | Args: 66 | data_dir: Path to ShapeNetCore data. 67 | synsets: List of synset categories to load from ShapeNetCore in the form of 68 | synset offsets or labels. A combination of both is also accepted. 69 | When no category is specified, all categories in data_dir are loaded. 70 | version: (int) version of ShapeNetCore data in data_dir, 1 or 2. 71 | Default is set to be 1. Version 1 has 57 categories and verions 2 has 55 72 | categories. 73 | Note: version 1 has two categories 02858304(boat) and 02992529(cellphone) 74 | that are hyponyms of categories 04530566(watercraft) and 04401088(telephone) 75 | respectively. You can combine the categories manually if needed. 76 | Version 2 doesn't have 02858304(boat) or 02834778(bicycle) compared to 77 | version 1. 78 | load_textures: Boolean indicating whether textures should loaded for the model. 79 | Textures will be of type TexturesAtlas i.e. a texture map per face. 80 | texture_resolution: Int specifying the resolution of the texture map per face 81 | created using the textures in the obj file. A 82 | (texture_resolution, texture_resolution, 3) map is created per face. 83 | """ 84 | super().__init__() 85 | self.data_dir = data_dir 86 | self.nb_points = nb_points 87 | self.normals = normals 88 | self.suncg = suncg 89 | self.variant = variant 90 | self.dset_norm = dset_norm 91 | self.split = split 92 | self.classes = {0: 'bag', 10: 'bed', 1: 'bin', 2: 'box', 3: 'cabinet', 4: 'chair', 5: 'desk', 6: 'display', 93 | 7: 'door', 11: 'pillow', 8: 'shelf', 12: 'sink', 13: 'sofa', 9: 'table', 14: 'toilet'} 94 | 95 | self.labels_dict = {"train": {}, "test": {}} 96 | self.objects_paths = {"train": [], "test": []} 97 | 98 | if self.variant != "hardest": 99 | pcdataset = pd.read_csv(os.path.join( 100 | data_dir, "split_new.txt"), sep="\t", names=['obj_id', 'label', "split"]) 101 | for ii in range(len(pcdataset)): 102 | if pcdataset["split"][ii] != "t": 103 | self.labels_dict["train"][pcdataset["obj_id"] 104 | [ii]] = pcdataset["label"][ii] 105 | else: 106 | self.labels_dict["test"][pcdataset["obj_id"] 107 | [ii]] = pcdataset["label"][ii] 108 | 109 | all_obj_ids = glob.glob(os.path.join(self.data_dir, "*/*.bin")) 110 | filtered_ids = list(filter(lambda x: "part" not in os.path.split( 111 | x)[-1] and "indices" not in os.path.split(x)[-1], all_obj_ids)) 112 | 113 | self.objects_paths["train"] = sorted( 114 | [x for x in filtered_ids if os.path.split(x)[-1] in self.labels_dict["train"].keys()]) 115 | self.objects_paths["test"] = sorted( 116 | [x for x in filtered_ids if os.path.split(x)[-1] in self.labels_dict["test"].keys()]) 117 | else: 118 | filename = os.path.join(data_dir, "{}_objectdataset_augmentedrot_scale75.h5".format(self.split)) 119 | with h5py.File(filename, "r") as f: 120 | self.labels_dict[self.split] = np.array(f["label"]) 121 | self.objects_paths[self.split] = np.array(f["data"]) 122 | # print("1############", len(self.labels_dict[self.split])) 123 | # print("2############", len(self.labels_dict[self.split])) 124 | 125 | def __getitem__(self, idx: int) -> Dict: 126 | """ 127 | Read a model by the given index. no mesh is availble in this dataset so retrun None and correction factor of 1.0 128 | 129 | """ 130 | if self.variant != "hardest": 131 | obj_path = self.objects_paths[self.split][idx] 132 | # obj_path,label 133 | points = self.load_pc_file(obj_path) 134 | # sample the required number of points randomly 135 | if len(points) > self.nb_points: 136 | # points = points[np.random.randint(points.shape[0], size=self.nb_points), :] 137 | points = points[np.linspace(0, points.shape[0] - 1, self.nb_points).astype(int), :] 138 | # print(pc.min(),classes[label],obj_path) 139 | label = self.labels_dict[self.split][os.path.split(obj_path)[-1]] 140 | else: 141 | 142 | points = self.objects_paths[self.split][idx] 143 | label = self.labels_dict[self.split][idx] 144 | 145 | points = points 146 | points[:, :3] = np_center_and_normalize(points[:, :3], p=self.dset_norm) 147 | return label, None, points 148 | 149 | def __len__(self): 150 | return len(self.objects_paths[self.split]) 151 | 152 | def load_pc_file(self, filename): 153 | # load bin file 154 | # pc=np.fromfile(filename, dtype=np.float32) 155 | pc = np.fromfile(filename, dtype=np.float32) 156 | 157 | # first entry is the number of points 158 | # then x, y, z, nx, ny, nz, r, g, b, label, nyu_label 159 | if (self.suncg): 160 | pc = pc[1:].reshape((-1, 3)) 161 | else: 162 | pc = pc[1:].reshape((-1, 11)) 163 | 164 | # return pc 165 | 166 | # only use x, y, z for now 167 | if self.variant == "with_bg": 168 | # pc = np.array(pc[:, 0:3]) 169 | return pc 170 | 171 | else: 172 | ##To remove backgorund points 173 | ##filter unwanted class 174 | filtered_idx = np.intersect1d(np.intersect1d(np.where( 175 | pc[:, -1] != 0)[0], np.where(pc[:, -1] != 1)[0]), np.where(pc[:, -1] != 2)[0]) 176 | (values, counts) = np.unique(pc[filtered_idx, -1], return_counts=True) 177 | max_ind = np.argmax(counts) 178 | idx = np.where(pc[:, -1] == values[max_ind])[0] 179 | # pc = np.array(pc[idx, 0:3]) 180 | pc = np.array(pc[idx]) 181 | return pc 182 | 183 | 184 | def np_center_and_normalize(points, p="inf"): 185 | """ 186 | a helper pytorch function that normalize and center 3D points clouds 187 | """ 188 | N = points.shape[0] 189 | center = points.mean(0) 190 | if p != "fro" and p != "no": 191 | scale = np.max(np.linalg.norm(points - center, ord=float(p), axis=1)) 192 | elif p == "fro": 193 | scale = np.linalg.norm(points - center, ord=p) 194 | elif p == "no": 195 | scale = 1.0 196 | points = points - center[None] 197 | points = points * (1.0 / float(scale)) 198 | return points 199 | 200 | 201 | def render_pointcloud_mv_img(points, num_cam, renderer, cameras, lights, device='cuda'): 202 | # x, y, z, nx, ny, nz, r, g, b, label, nyu_label 203 | point_cloud = Pointclouds(points=[torch.tensor(points[:, :3], dtype=torch.float32, device=device)], 204 | normals=[torch.tensor(points[:, 3:6], dtype=torch.float32, device=device)], 205 | features=[torch.tensor(points[:, 6:9], dtype=torch.float32, device=device) / 256]).extend(num_cam) 206 | images = renderer(point_cloud) 207 | return images 208 | 209 | 210 | def render_mesh_mv_img(points, num_cam, renderer, cameras, lights, device='cuda'): 211 | # x, y, z, nx, ny, nz, r, g, b, label, nyu_label 212 | # Create an open3d.geometry.PointCloud object from the NumPy array 213 | point_cloud = o3d.geometry.PointCloud() 214 | point_cloud.points = o3d.utility.Vector3dVector(points[:, :3]) 215 | point_cloud.normals = o3d.utility.Vector3dVector(points[:, 3:6]) 216 | point_cloud.colors = o3d.utility.Vector3dVector(points[:, 6:9]) 217 | 218 | # Perform poisson surface reconstruction 219 | mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(point_cloud, depth=8) 220 | # mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(point_cloud, 221 | # o3d.utility.DoubleVector([0.01, 0.1])) 222 | # mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_alpha_shape(point_cloud, alpha=0.2) 223 | 224 | # Smooth the mesh 225 | # mesh = mesh.filter_smooth_simple() 226 | 227 | # Create a batch of meshes by repeating the cow mesh and associated textures. 228 | # Meshes has a useful `extend` method which allows us do this very easily. 229 | # This also extends the textures. 230 | meshes = Meshes(verts=[torch.tensor(mesh.vertices, dtype=torch.float32, device=device)], 231 | faces=[torch.tensor(mesh.triangles, dtype=torch.int64, device=device)], 232 | textures=Textures(verts_rgb=torch.tensor(mesh.vertex_colors, dtype=torch.float32, device=device)[ 233 | None] / 255)).extend(num_cam) 234 | 235 | # We can pass arbitrary keyword arguments to the rasterizer/shader via the renderer 236 | # so the renderer does not need to be reinitialized if any of the settings change. 237 | images = renderer(meshes, cameras=cameras, lights=lights) 238 | return images 239 | 240 | 241 | def save_mv_img_dataset(base, root, num_cam, split, visualize=False): 242 | point_cnt_avg = 0 243 | device = 'cuda' 244 | 245 | # # Define the settings for rasterization and shading. Here we set the output image to be of size 246 | # # 512x512. As we are rendering images for visualization purposes only we will set faces_per_pixel=1 247 | # # and blur_radius=0.0. We also set bin_size and max_faces_per_bin to None which ensure that 248 | # # the faster coarse-to-fine rasterization method is used. Refer to rasterize_meshes.py for 249 | # # explanations of these parameters. Refer to docs/notes/renderer.md for an explanation of 250 | # # the difference between naive and coarse-to-fine rasterization. 251 | # raster_settings = RasterizationSettings(image_size=512, blur_radius=0.0, faces_per_pixel=1, ) 252 | 253 | # # Get a batch of viewing angles. 254 | # elev = 30 255 | # azim = torch.linspace(-180, 180, num_cam + 1)[:num_cam] 256 | 257 | # # All the cameras helper methods support mixed type inputs and broadcasting. So we can 258 | # # view the camera from the same distance and specify dist=2.7 as a float, 259 | # # and then specify elevation and azimuth angles for each viewpoint as tensors. 260 | # R, T = look_at_view_transform(dist=2, elev=elev, azim=azim) 261 | # cameras = FoVPerspectiveCameras(device=device, R=R, T=T) 262 | 263 | # lights = PointLights(device=device, location=[[0.0, 3.0, 0.0]]) 264 | 265 | # # Create a Phong renderer by composing a rasterizer and a shader. The textured Phong shader will 266 | # # interpolate the texture uv coordinates for each vertex, sample from a texture image and 267 | # # apply the Phong lighting model 268 | # renderer = MeshRenderer(rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings), 269 | # shader=HardPhongShader(device=device, cameras=cameras, lights=lights)) 270 | 271 | # Initialize a camera. 272 | 273 | # Get a batch of viewing angles. 274 | elev = 30 275 | azim = torch.linspace(-180, 180, num_cam + 1)[:num_cam] 276 | 277 | # Define the settings for rasterization and shading. Here we set the output image to be of size 278 | # 512x512. As we are rendering images for visualization purposes only we will set faces_per_pixel=1 279 | # and blur_radius=0.0. Refer to raster_points.py for explanations of these parameters. 280 | raster_settings = PointsRasterizationSettings( 281 | image_size=512, 282 | radius=0.003, 283 | points_per_pixel=10 284 | ) 285 | 286 | # All the cameras helper methods support mixed type inputs and broadcasting. So we can 287 | # view the camera from the same distance and specify dist=2.7 as a float, 288 | # and then specify elevation and azimuth angles for each viewpoint as tensors. 289 | R, T = look_at_view_transform(dist=2, elev=elev, azim=azim) 290 | cameras = FoVPerspectiveCameras(device=device, R=R, T=T) 291 | 292 | lights = PointLights(device=device, location=[[0.0, 3.0, 0.0]]) 293 | 294 | # Create a points renderer by compositing points using an alpha compositor (nearer points 295 | # are weighted more heavily). See [1] for an explanation. 296 | rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings) 297 | renderer = PointsRenderer( 298 | rasterizer=rasterizer, 299 | compositor=AlphaCompositor() 300 | ) 301 | for class_name in base.classes.values(): 302 | if os.path.exists(f'{root}/{class_name}/{split}/'): 303 | shutil.rmtree(f'{root}/{class_name}/{split}/') 304 | 305 | for idx in tqdm(range(len(base))): 306 | target, mesh, points = base.__getitem__(idx) 307 | imgs = render_pointcloud_mv_img(points, num_cam, renderer, cameras, lights) 308 | # imgs = self.mvrenderer(mesh, points[None].cuda(), 309 | # azim=torch.linspace(0, 360, self.num_cam + 1)[None, :self.num_cam].cuda(), 310 | # elev=30 * torch.ones([1, self.num_cam]).cuda(), 311 | # dist=2 * torch.ones([1, self.num_cam]).cuda())[0][0].permute(0, 2, 3, 1) 312 | # imgs = render_mesh_mv_img(points, num_cam, renderer, cameras, lights) 313 | for cam, img in enumerate(imgs): 314 | if visualize: 315 | plt.imshow(img.cpu().numpy()) 316 | plt.show() 317 | os.makedirs(f'{root}/{base.classes[target]}/{split}/', exist_ok=True) 318 | img = Image.fromarray((img.cpu().numpy() * 255).astype('uint8')) 319 | img.save(f'{root}/{base.classes[target]}/{split}/{idx:04d}_{cam + 1:02d}.png') 320 | print(f'{root}/{base.classes[target]}/{split}/{idx:04d}') 321 | point_cnt_avg += len(points) 322 | pass 323 | print(point_cnt_avg / len(base)) 324 | 325 | 326 | class ScanObjectNN(ModelNet40): 327 | classnames = ['bag', 'bin', 'box', 'cabinet', 'chair', 'desk', 'display', 328 | 'door', 'shelf', 'table', 'bed', 'pillow', 'sink', 'sofa', 'toilet'] 329 | 330 | def __init__(self, root, split='train', per_cls_instances=0, dropout=0.0): 331 | super().__init__(root, 12, split, per_cls_instances, dropout) 332 | 333 | 334 | if __name__ == '__main__': 335 | split = 'train' 336 | dataset = ScanObjectNNPointCloud('/home/houyz/Data/ScanObjectNN', split) 337 | save_mv_img_dataset(dataset, '/home/houyz/Data/ScanObjectNN_pc', 12, split) 338 | split = 'test' 339 | dataset = ScanObjectNNPointCloud('/home/houyz/Data/ScanObjectNN', split) 340 | save_mv_img_dataset(dataset, '/home/houyz/Data/ScanObjectNN_pc', 12, split) 341 | # dataset = ScanObjectNN('/home/houyz/Data/ScanObjectNN') 342 | # dataset.__getitem__(0) 343 | # dataset.__getitem__(len(dataset) - 1, visualize=True) 344 | pass 345 | -------------------------------------------------------------------------------- /src/evaluation/test-demo.txt: -------------------------------------------------------------------------------- 1 | 1800 308 660 2 | 1800 392 824 3 | 1800 284 528 4 | 1800 300 568 5 | 1800 216 624 6 | 1800 272 664 7 | 1800 260 500 8 | 1800 264 568 9 | 1800 316 808 10 | 1800 368 764 11 | 1800 76 488 12 | 1800 240 532 13 | 1800 444 604 14 | 1800 440 824 15 | 1800 288 1112 16 | 1800 468 632 17 | 1800 236 60 18 | 1800 288 1200 19 | 1800 388 804 20 | 1800 412 828 21 | 1805 428 608 22 | 1805 216 624 23 | 1805 300 568 24 | 1805 260 500 25 | 1805 428 832 26 | 1805 308 660 27 | 1805 264 568 28 | 1805 304 824 29 | 1805 284 528 30 | 1805 396 828 31 | 1805 276 668 32 | 1805 76 460 33 | 1805 236 532 34 | 1805 364 740 35 | 1805 268 1144 36 | 1805 376 820 37 | 1805 272 1224 38 | 1810 216 624 39 | 1810 260 500 40 | 1810 300 568 41 | 1810 304 664 42 | 1810 360 720 43 | 1810 252 1172 44 | 1810 284 528 45 | 1810 264 568 46 | 1810 292 836 47 | 1810 420 616 48 | 1810 76 432 49 | 1810 364 828 50 | 1810 236 532 51 | 1810 420 836 52 | 1810 276 668 53 | 1810 464 632 54 | 1810 392 840 55 | 1810 256 1232 56 | 1810 468 524 57 | 1810 464 432 58 | 1815 216 624 59 | 1815 300 568 60 | 1815 284 844 61 | 1815 420 836 62 | 1815 284 528 63 | 1815 304 664 64 | 1815 360 828 65 | 1815 260 500 66 | 1815 416 628 67 | 1815 384 844 68 | 1815 276 668 69 | 1815 264 568 70 | 1815 356 692 71 | 1815 236 1200 72 | 1815 236 532 73 | 1815 84 420 74 | 1815 444 512 75 | 1815 244 1248 76 | 1815 440 404 77 | 1815 464 628 78 | 1815 248 1268 79 | 1815 360 712 80 | 1820 216 624 81 | 1820 352 672 82 | 1820 260 500 83 | 1820 304 664 84 | 1820 284 528 85 | 1820 300 568 86 | 1820 276 664 87 | 1820 88 420 88 | 1820 424 508 89 | 1820 276 852 90 | 1820 420 836 91 | 1820 412 632 92 | 1820 264 568 93 | 1820 356 836 94 | 1820 236 532 95 | 1820 380 848 96 | 1820 424 384 97 | 1820 452 632 98 | 1825 216 624 99 | 1825 276 668 100 | 1825 352 660 101 | 1825 408 640 102 | 1825 304 664 103 | 1825 260 500 104 | 1825 356 840 105 | 1825 240 532 106 | 1825 400 512 107 | 1825 296 568 108 | 1825 284 528 109 | 1825 92 420 110 | 1825 264 568 111 | 1825 380 852 112 | 1825 416 836 113 | 1825 444 636 114 | 1825 280 856 115 | 1825 408 364 116 | 1825 224 1304 117 | 1825 212 1276 118 | 1830 216 624 119 | 1830 404 648 120 | 1830 240 532 121 | 1830 260 500 122 | 1830 352 644 123 | 1830 308 664 124 | 1830 260 568 125 | 1830 288 868 126 | 1830 276 664 127 | 1830 284 528 128 | 1830 352 840 129 | 1830 300 568 130 | 1830 388 512 131 | 1830 380 852 132 | 1830 92 428 133 | 1830 416 832 134 | 1830 436 636 135 | 1830 392 340 136 | 1835 308 664 137 | 1835 304 864 138 | 1835 216 624 139 | 1835 300 568 140 | 1835 260 496 141 | 1835 276 664 142 | 1835 84 444 143 | 1835 240 536 144 | 1835 380 852 145 | 1835 352 628 146 | 1835 264 568 147 | 1835 352 828 148 | 1835 284 528 149 | 1835 372 504 150 | 1835 408 832 151 | 1835 400 652 152 | 1835 428 640 153 | 1835 380 312 154 | 1835 300 884 155 | 1840 216 624 156 | 1840 380 852 157 | 1840 360 612 158 | 1840 264 500 159 | 1840 300 568 160 | 1840 304 668 161 | 1840 80 456 162 | 1840 316 868 163 | 1840 260 568 164 | 1840 240 532 165 | 1840 276 664 166 | 1840 360 496 167 | 1840 288 528 168 | 1840 396 652 169 | 1840 404 832 170 | 1840 368 288 171 | 1840 352 824 172 | 1840 420 644 173 | 1845 216 624 174 | 1845 364 596 175 | 1845 384 856 176 | 1845 264 500 177 | 1845 304 668 178 | 1845 76 460 179 | 1845 284 528 180 | 1845 332 860 181 | 1845 300 568 182 | 1845 392 656 183 | 1845 240 532 184 | 1845 264 568 185 | 1845 352 492 186 | 1845 412 648 187 | 1845 404 828 188 | 1845 276 664 189 | 1845 356 260 190 | 1845 352 828 191 | 1845 300 648 192 | 1850 216 624 193 | 1850 304 664 194 | 1850 364 584 195 | 1850 300 568 196 | 1850 284 528 197 | 1850 264 568 198 | 1850 260 500 199 | 1850 76 460 200 | 1850 404 828 201 | 1850 344 492 202 | 1850 384 664 203 | 1850 276 664 204 | 1850 352 824 205 | 1850 236 532 206 | 1850 408 664 207 | 1850 340 848 208 | 1850 340 236 209 | 1850 388 852 210 | 1855 304 664 211 | 1855 216 624 212 | 1855 300 568 213 | 1855 352 824 214 | 1855 260 500 215 | 1855 284 528 216 | 1855 76 456 217 | 1855 324 208 218 | 1855 264 568 219 | 1855 404 828 220 | 1855 340 492 221 | 1855 388 852 222 | 1855 280 664 223 | 1855 368 576 224 | 1855 240 532 225 | 1855 372 676 226 | 1855 392 680 227 | 1855 340 848 228 | 1855 372 556 229 | 1860 304 664 230 | 1860 260 500 231 | 1860 300 568 232 | 1860 216 624 233 | 1860 76 460 234 | 1860 276 664 235 | 1860 288 524 236 | 1860 372 556 237 | 1860 340 496 238 | 1860 404 824 239 | 1860 240 536 240 | 1860 264 568 241 | 1860 364 688 242 | 1860 388 852 243 | 1860 312 180 244 | 1860 352 828 245 | 1860 340 848 246 | 1860 388 692 247 | 1860 476 520 248 | 1865 376 712 249 | 1865 344 504 250 | 1865 348 704 251 | 1865 300 664 252 | 1865 216 624 253 | 1865 376 544 254 | 1865 404 828 255 | 1865 240 532 256 | 1865 264 500 257 | 1865 276 664 258 | 1865 76 468 259 | 1865 300 568 260 | 1865 468 516 261 | 1865 392 848 262 | 1865 352 828 263 | 1865 264 568 264 | 1865 288 528 265 | 1865 312 152 266 | 1865 344 848 267 | 1870 344 504 268 | 1870 336 716 269 | 1870 364 728 270 | 1870 272 496 271 | 1870 456 516 272 | 1870 404 824 273 | 1870 216 624 274 | 1870 304 668 275 | 1870 276 664 276 | 1870 72 488 277 | 1870 376 532 278 | 1870 352 824 279 | 1870 240 532 280 | 1870 388 852 281 | 1870 300 568 282 | 1870 344 852 283 | 1870 264 568 284 | 1870 288 528 285 | 1870 300 116 286 | 1875 352 744 287 | 1875 444 508 288 | 1875 276 664 289 | 1875 216 624 290 | 1875 348 512 291 | 1875 68 512 292 | 1875 240 532 293 | 1875 304 668 294 | 1875 376 520 295 | 1875 300 568 296 | 1875 288 524 297 | 1875 272 496 298 | 1875 384 856 299 | 1875 328 724 300 | 1875 404 824 301 | 1875 352 828 302 | 1875 264 568 303 | 1875 348 848 304 | 1875 468 444 305 | 1880 272 492 306 | 1880 216 624 307 | 1880 404 824 308 | 1880 344 508 309 | 1880 300 564 310 | 1880 276 664 311 | 1880 432 504 312 | 1880 308 736 313 | 1880 64 536 314 | 1880 240 532 315 | 1880 304 668 316 | 1880 332 752 317 | 1880 380 512 318 | 1880 348 848 319 | 1880 388 852 320 | 1880 264 568 321 | 1880 352 828 322 | 1880 284 524 323 | 1880 468 432 324 | 1885 420 500 325 | 1885 220 624 326 | 1885 372 492 327 | 1885 300 568 328 | 1885 308 760 329 | 1885 280 664 330 | 1885 264 496 331 | 1885 304 668 332 | 1885 404 824 333 | 1885 68 560 334 | 1885 236 532 335 | 1885 264 568 336 | 1885 344 512 337 | 1885 388 852 338 | 1885 284 528 339 | 1885 460 420 340 | 1885 292 732 341 | 1885 356 824 342 | 1885 352 848 343 | 1885 376 512 344 | 1890 216 624 345 | 1890 368 484 346 | 1890 408 496 347 | 1890 404 824 348 | 1890 284 528 349 | 1890 300 568 350 | 1890 256 496 351 | 1890 304 660 352 | 1890 64 584 353 | 1890 280 756 354 | 1890 388 852 355 | 1890 352 848 356 | 1890 276 668 357 | 1890 344 516 358 | 1890 356 824 359 | 1890 240 532 360 | 1890 452 408 361 | 1890 264 568 362 | 1890 268 724 363 | 1890 460 472 364 | 1895 300 564 365 | 1895 64 620 366 | 1895 216 624 367 | 1895 396 496 368 | 1895 240 532 369 | 1895 360 480 370 | 1895 256 496 371 | 1895 304 668 372 | 1895 264 568 373 | 1895 388 852 374 | 1895 288 528 375 | 1895 400 832 376 | 1895 276 664 377 | 1895 76 776 378 | 1895 356 828 379 | 1895 256 748 380 | 1895 444 392 381 | 1895 476 820 382 | 1895 344 520 383 | 1895 352 848 384 | 1895 244 708 385 | 1895 56 780 386 | 1895 464 420 387 | 1895 448 456 388 | 1900 384 492 389 | 1900 220 624 390 | 1900 264 500 391 | 1900 64 644 392 | 1900 404 828 393 | 1900 240 532 394 | 1900 300 564 395 | 1900 352 468 396 | 1900 460 832 397 | 1900 264 568 398 | 1900 304 664 399 | 1900 232 732 400 | 1900 288 528 401 | 1900 432 380 402 | 1900 232 696 403 | 1900 276 664 404 | 1900 352 844 405 | 1900 60 768 406 | 1900 344 516 407 | 1900 388 852 408 | 1900 88 764 409 | 1900 356 824 410 | 1900 436 448 411 | 1905 404 824 412 | 1905 212 676 413 | 1905 264 500 414 | 1905 216 624 415 | 1905 60 672 416 | 1905 304 664 417 | 1905 340 460 418 | 1905 368 484 419 | 1905 240 532 420 | 1905 352 836 421 | 1905 280 664 422 | 1905 300 568 423 | 1905 212 708 424 | 1905 448 844 425 | 1905 284 524 426 | 1905 460 732 427 | 1905 352 520 428 | 1905 388 852 429 | 1905 428 436 430 | 1905 428 376 431 | 1905 268 572 432 | 1905 96 748 433 | 1905 356 816 434 | 1905 64 752 435 | 1910 240 532 436 | 1910 268 504 437 | 1910 284 528 438 | 1910 356 828 439 | 1910 384 856 440 | 1910 424 372 441 | 1910 404 820 442 | 1910 328 448 443 | 1910 304 668 444 | 1910 216 624 445 | 1910 432 864 446 | 1910 460 880 447 | 1910 48 700 448 | 1910 276 668 449 | 1910 352 476 450 | 1910 352 848 451 | 1910 192 696 452 | 1910 300 572 453 | 1910 72 732 454 | 1910 416 420 455 | 1910 368 532 456 | 1910 188 676 457 | 1910 96 720 458 | 1910 200 648 459 | 1910 440 732 460 | 1910 264 572 461 | 1915 412 360 462 | 1915 316 440 463 | 1915 264 500 464 | 1915 288 524 465 | 1915 440 900 466 | 1915 240 532 467 | 1915 216 624 468 | 1915 336 472 469 | 1915 304 668 470 | 1915 424 884 471 | 1915 404 824 472 | 1915 352 828 473 | 1915 280 668 474 | 1915 380 860 475 | 1915 264 568 476 | 1915 172 664 477 | 1915 348 848 478 | 1915 88 700 479 | 1915 300 564 480 | 1915 416 408 481 | 1915 80 720 482 | 1915 36 720 483 | 1915 416 736 484 | 1915 184 628 485 | 1915 380 532 486 | 1915 32 248 487 | 1915 108 692 488 | 1920 264 504 489 | 1920 296 436 490 | 1920 300 668 491 | 1920 320 464 492 | 1920 388 732 493 | 1920 276 664 494 | 1920 100 672 495 | 1920 396 544 496 | 1920 288 524 497 | 1920 80 676 498 | 1920 424 920 499 | 1920 404 344 500 | 1920 216 624 501 | 1920 240 532 502 | 1920 304 568 503 | 1920 380 860 504 | 1920 156 632 505 | 1920 412 900 506 | 1920 452 712 507 | 1920 352 820 508 | 1920 264 564 509 | 1920 36 252 510 | 1920 176 604 511 | 1920 400 832 512 | 1920 456 732 513 | 1920 352 844 514 | 1920 160 652 515 | 1920 408 372 516 | 1925 360 736 517 | 1925 264 500 518 | 1925 300 668 519 | 1925 400 336 520 | 1925 280 436 521 | 1925 216 624 522 | 1925 280 664 523 | 1925 80 656 524 | 1925 240 532 525 | 1925 380 860 526 | 1925 264 568 527 | 1925 408 928 528 | 1925 40 256 529 | 1925 308 464 530 | 1925 404 552 531 | 1925 404 824 532 | 1925 140 608 533 | 1925 100 648 534 | 1925 300 564 535 | 1925 168 580 536 | 1925 428 720 537 | 1925 356 824 538 | 1925 436 392 539 | 1925 288 524 540 | 1925 352 844 541 | 1930 404 824 542 | 1930 388 944 543 | 1930 300 568 544 | 1930 416 560 545 | 1930 164 548 546 | 1930 336 740 547 | 1930 216 624 548 | 1930 304 668 549 | 1930 240 536 550 | 1930 280 664 551 | 1930 264 568 552 | 1930 80 636 553 | 1930 264 500 554 | 1930 268 432 555 | 1930 396 312 556 | 1930 460 376 557 | 1930 136 580 558 | 1930 284 528 559 | 1930 44 260 560 | 1930 100 632 561 | 1930 288 464 562 | 1930 432 360 563 | 1930 380 864 564 | 1930 356 828 565 | 1930 132 560 566 | 1930 400 728 567 | 1930 352 848 568 | 1930 392 964 569 | 1930 16 612 570 | 1935 248 432 571 | 1935 308 740 572 | 1935 372 964 573 | 1935 216 624 574 | 1935 300 568 575 | 1935 424 564 576 | 1935 284 524 577 | 1935 240 532 578 | 1935 308 668 579 | 1935 276 664 580 | 1935 72 608 581 | 1935 260 500 582 | 1935 264 568 583 | 1935 280 468 584 | 1935 428 348 585 | 1935 468 848 586 | 1935 400 824 587 | 1935 392 292 588 | 1935 392 972 589 | 1935 372 308 590 | 1935 128 540 591 | 1935 44 260 592 | 1935 356 820 593 | 1935 448 344 594 | 1935 380 856 595 | 1935 372 728 596 | 1935 156 536 597 | 1935 12 588 598 | 1935 156 512 599 | 1935 104 608 600 | 1940 216 624 601 | 1940 240 532 602 | 1940 276 668 603 | 1940 284 524 604 | 1940 300 568 605 | 1940 232 432 606 | 1940 308 668 607 | 1940 460 868 608 | 1940 360 984 609 | 1940 424 568 610 | 1940 400 824 611 | 1940 284 744 612 | 1940 144 492 613 | 1940 264 568 614 | 1940 264 464 615 | 1940 424 320 616 | 1940 352 848 617 | 1940 264 500 618 | 1940 68 592 619 | 1940 388 272 620 | 1940 352 732 621 | 1940 44 260 622 | 1940 444 324 623 | 1940 356 824 624 | 1940 468 288 625 | 1940 380 856 626 | 1940 380 988 627 | 1940 128 512 628 | 1940 376 292 629 | 1940 468 888 630 | 1940 96 584 631 | 1945 348 1004 632 | 1945 216 624 633 | 1945 284 524 634 | 1945 300 568 635 | 1945 264 568 636 | 1945 212 432 637 | 1945 256 744 638 | 1945 68 568 639 | 1945 304 668 640 | 1945 240 536 641 | 1945 424 568 642 | 1945 280 664 643 | 1945 252 468 644 | 1945 140 464 645 | 1945 348 848 646 | 1945 400 824 647 | 1945 376 860 648 | 1945 264 500 649 | 1945 444 300 650 | 1945 368 276 651 | 1945 384 256 652 | 1945 352 828 653 | 1945 460 256 654 | 1945 328 732 655 | 1945 444 888 656 | 1945 120 484 657 | 1945 44 260 658 | 1945 416 296 659 | 1945 348 264 660 | 1945 456 908 661 | 1945 468 856 662 | 1945 88 564 663 | 1945 356 1024 664 | 1950 336 1024 665 | 1950 216 624 666 | 1950 300 568 667 | 1950 420 572 668 | 1950 284 528 669 | 1950 300 664 670 | 1950 232 748 671 | 1950 264 568 672 | 1950 352 824 673 | 1950 132 440 674 | 1950 240 472 675 | 1950 352 848 676 | 1950 276 664 677 | 1950 264 500 678 | 1950 196 440 679 | 1950 376 864 680 | 1950 240 532 681 | 1950 400 824 682 | 1950 468 560 683 | 1950 356 248 684 | 1950 380 232 685 | 1950 300 728 686 | 1950 344 1044 687 | 1950 448 940 688 | 1950 112 456 689 | 1950 68 552 690 | 1950 428 904 691 | 1950 44 260 692 | 1950 404 248 693 | 1950 96 544 694 | 1950 440 276 695 | 1950 452 228 696 | 1950 452 884 697 | 1955 264 500 698 | 1955 300 568 699 | 1955 424 572 700 | 1955 276 664 701 | 1955 444 556 702 | 1955 216 624 703 | 1955 284 528 704 | 1955 300 664 705 | 1955 260 568 706 | 1955 128 404 707 | 1955 180 436 708 | 1955 224 472 709 | 1955 356 820 710 | 1955 368 208 711 | 1955 332 1056 712 | 1955 312 1048 713 | 1955 72 528 714 | 1955 212 744 715 | 1955 356 232 716 | 1955 384 852 717 | 1955 396 828 718 | 1955 348 848 719 | 1955 92 524 720 | 1955 240 532 721 | 1955 276 724 722 | 1955 404 236 723 | 1955 44 264 724 | 1955 432 252 725 | 1955 100 432 726 | 1955 436 928 727 | 1955 336 224 728 | 1955 432 972 729 | 1955 448 208 730 | 1960 264 500 731 | 1960 316 1080 732 | 1960 216 624 733 | 1960 300 568 734 | 1960 296 1068 735 | 1960 288 524 736 | 1960 420 564 737 | 1960 160 440 738 | 1960 212 476 739 | 1960 300 664 740 | 1960 264 568 741 | 1960 428 544 742 | 1960 240 532 743 | 1960 120 380 744 | 1960 92 404 745 | 1960 400 828 746 | 1960 352 824 747 | 1960 348 852 748 | 1960 276 664 749 | 1960 64 508 750 | 1960 180 744 751 | 1960 344 208 752 | 1960 384 856 753 | 1960 44 260 754 | 1960 368 180 755 | 1960 88 500 756 | 1960 248 724 757 | 1960 420 968 758 | 1960 392 212 759 | 1960 432 192 760 | 1960 444 560 761 | 1965 284 1088 762 | 1965 264 500 763 | 1965 300 568 764 | 1965 216 624 765 | 1965 148 444 766 | 1965 288 528 767 | 1965 300 664 768 | 1965 412 548 769 | 1965 156 748 770 | 1965 280 668 771 | 1965 200 480 772 | 1965 400 828 773 | 1965 240 532 774 | 1965 304 1100 775 | 1965 352 820 776 | 1965 264 568 777 | 1965 120 352 778 | 1965 220 724 779 | 1965 420 188 780 | 1965 348 848 781 | 1965 404 984 782 | 1965 88 480 783 | 1965 428 964 784 | 1965 380 856 785 | 1965 92 372 786 | 1965 44 264 787 | 1965 360 168 788 | 1965 60 496 789 | 1965 388 180 790 | 1970 264 500 791 | 1970 216 624 792 | 1970 300 664 793 | 1970 288 528 794 | 1970 292 1120 795 | 1970 192 476 796 | 1970 380 856 797 | 1970 300 568 798 | 1970 280 668 799 | 1970 404 520 800 | 1970 136 748 801 | 1970 132 444 802 | 1970 272 1112 803 | 1970 400 824 804 | 1970 388 536 805 | 1970 112 324 806 | 1970 88 456 807 | 1970 264 564 808 | 1970 84 344 809 | 1970 348 848 810 | 1970 240 532 811 | 1970 60 468 812 | 1970 44 260 813 | 1970 384 160 814 | 1970 352 824 815 | 1970 196 720 816 | 1970 392 1012 817 | 1970 420 1004 818 | 1970 332 172 819 | 1975 260 1132 820 | 1975 264 500 821 | 1975 216 624 822 | 1975 116 456 823 | 1975 300 664 824 | 1975 300 564 825 | 1975 464 860 826 | 1975 88 432 827 | 1975 288 528 828 | 1975 384 852 829 | 1975 276 664 830 | 1975 352 824 831 | 1975 60 444 832 | 1975 184 480 833 | 1975 376 512 834 | 1975 352 848 835 | 1975 264 568 836 | 1975 108 296 837 | 1975 400 824 838 | 1975 384 1032 839 | 1975 280 1144 840 | 1975 168 720 841 | 1975 392 496 842 | 1975 240 532 843 | 1975 88 308 844 | 1975 44 260 845 | 1975 396 804 846 | 1975 84 328 847 | 1975 412 1036 848 | 1975 108 748 849 | 1975 468 880 850 | 1980 248 1152 851 | 1980 80 416 852 | 1980 216 624 853 | 1980 284 528 854 | 1980 300 568 855 | 1980 268 1164 856 | 1980 384 852 857 | 1980 264 500 858 | 1980 280 664 859 | 1980 360 484 860 | 1980 176 484 861 | 1980 300 668 862 | 1980 240 532 863 | 1980 264 568 864 | 1980 112 460 865 | 1980 352 844 866 | 1980 448 888 867 | 1980 84 280 868 | 1980 108 268 869 | 1980 352 820 870 | 1980 400 820 871 | 1980 372 1052 872 | 1980 380 480 873 | 1980 396 800 874 | 1980 144 720 875 | 1980 64 432 876 | 1980 380 1072 877 | 1980 404 1040 878 | 1980 44 260 879 | 1980 84 752 880 | 1985 216 624 881 | 1985 264 500 882 | 1985 232 1168 883 | 1985 436 908 884 | 1985 264 568 885 | 1985 300 664 886 | 1985 172 484 887 | 1985 276 664 888 | 1985 288 524 889 | 1985 240 532 890 | 1985 352 832 891 | 1985 384 852 892 | 1985 76 396 893 | 1985 300 568 894 | 1985 120 724 895 | 1985 400 832 896 | 1985 256 1184 897 | 1985 340 460 898 | 1985 104 236 899 | 1985 64 756 900 | 1985 364 444 901 | 1985 84 244 902 | 1985 96 392 903 | 1985 392 1072 904 | 1985 48 260 905 | 1985 388 1092 906 | 1985 468 912 907 | 1985 396 812 908 | 1985 96 464 909 | 1985 80 264 910 | 1985 348 852 911 | 1990 216 624 912 | 1990 172 484 913 | 1990 220 1192 914 | 1990 264 500 915 | 1990 300 664 916 | 1990 76 476 917 | 1990 276 664 918 | 1990 352 416 919 | 1990 240 532 920 | 1990 300 568 921 | 1990 244 1208 922 | 1990 264 568 923 | 1990 400 828 924 | 1990 284 528 925 | 1990 420 928 926 | 1990 352 824 927 | 1990 384 852 928 | 1990 88 372 929 | 1990 328 436 930 | 1990 72 216 931 | 1990 76 236 932 | 1990 100 212 933 | 1990 68 384 934 | 1990 352 848 935 | 1990 456 940 936 | 1990 380 1096 937 | 1990 44 264 938 | 1990 396 808 939 | 1995 404 948 940 | 1995 172 484 941 | 1995 264 500 942 | 1995 300 668 943 | 1995 300 564 944 | 1995 216 624 945 | 1995 384 852 946 | 1995 344 396 947 | 1995 400 828 948 | 1995 260 564 949 | 1995 60 476 950 | 1995 352 824 951 | 1995 288 528 952 | 1995 276 664 953 | 1995 80 352 954 | 1995 240 532 955 | 1995 228 1224 956 | 1995 100 348 957 | 1995 84 172 958 | 1995 440 964 959 | 1995 208 1216 960 | 1995 316 412 961 | 1995 348 844 962 | 1995 80 732 963 | 1995 44 260 964 | 1995 64 180 965 | 1995 396 808 966 | 1995 64 336 967 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['OMP_NUM_THREADS'] = '1' 4 | import time 5 | import itertools 6 | import argparse 7 | import sys 8 | import shutil 9 | from distutils.dir_util import copy_tree 10 | import datetime 11 | import tqdm 12 | import random 13 | import numpy as np 14 | import torch 15 | from torch import optim 16 | from torch.utils.data import DataLoader 17 | from src.datasets import * 18 | from src.models.mvdet import MVDet 19 | from src.models.mvcnn import MVCNN 20 | from src.utils.logger import Logger 21 | from src.utils.draw_curve import draw_curve 22 | from src.utils.str2bool import str2bool 23 | from src.trainer import PerspectiveTrainer, find_dataset_lvl_strategy 24 | from src.trainer_mvcnn import ClassifierTrainer 25 | 26 | 27 | def main(args): 28 | # check if in debug mode 29 | gettrace = getattr(sys, 'gettrace', None) 30 | if gettrace(): 31 | print('Hmm, Big Debugger is watching me') 32 | is_debug = True 33 | torch.autograd.set_detect_anomaly(True) 34 | else: 35 | print('No sys.gettrace') 36 | is_debug = False 37 | 38 | # seed 39 | if args.seed is not None: 40 | random.seed(args.seed) 41 | np.random.seed(args.seed) 42 | torch.manual_seed(args.seed) 43 | torch.cuda.manual_seed(args.seed) 44 | torch.cuda.manual_seed_all(args.seed) 45 | 46 | # deterministic 47 | if args.deterministic: 48 | torch.backends.cudnn.deterministic = True 49 | torch.backends.cudnn.benchmark = False 50 | torch.autograd.set_detect_anomaly(True) 51 | else: 52 | torch.backends.cudnn.benchmark = True 53 | 54 | # dataset 55 | if 'modelnet' in args.dataset: 56 | if args.dataset == 'modelnet40_12': 57 | fpath = os.path.expanduser('~/Data/modelnet/modelnet40_images_new_12x') 58 | num_cam = 12 59 | elif args.dataset == 'modelnet40_20': 60 | fpath = os.path.expanduser('~/Data/modelnet/modelnet40v2png_ori4') 61 | num_cam = 20 62 | else: 63 | raise Exception 64 | 65 | args.task = 'mvcnn' 66 | result_type = ['prec'] 67 | args.lr = 5e-5 if args.lr is None else args.lr 68 | args.select_lr = 1e-4 if args.select_lr is None else args.select_lr 69 | args.batch_size = 8 if args.batch_size is None else args.batch_size 70 | 71 | train_set = ModelNet40(fpath, num_cam, split='train', ) 72 | val_set = ModelNet40(fpath, num_cam, split='train', per_cls_instances=25) 73 | test_set = ModelNet40(fpath, num_cam, split='test', ) 74 | elif args.dataset=='scanobjectnn': 75 | fpath = os.path.expanduser('~/Data/ScanObjectNN') 76 | 77 | args.task = 'mvcnn' 78 | result_type = ['prec'] 79 | args.lr = 5e-5 if args.lr is None else args.lr 80 | args.select_lr = 1e-4 if args.select_lr is None else args.select_lr 81 | args.batch_size = 8 if args.batch_size is None else args.batch_size 82 | 83 | train_set = ScanObjectNN(fpath, split='train', ) 84 | val_set = ScanObjectNN(fpath, split='train', per_cls_instances=25) 85 | test_set = ScanObjectNN(fpath, split='test', ) 86 | else: 87 | if args.dataset == 'wildtrack': 88 | base = Wildtrack(os.path.expanduser('~/Data/Wildtrack')) 89 | elif args.dataset == 'multiviewx': 90 | base = MultiviewX(os.path.expanduser('~/Data/MultiviewX')) 91 | else: 92 | raise Exception('must choose from [wildtrack, multiviewx]') 93 | 94 | args.task = 'mvdet' 95 | result_type = ['moda', 'modp', 'prec', 'recall'] 96 | args.lr = 5e-4 if args.lr is None else args.lr 97 | args.select_lr = 1e-4 if args.select_lr is None else args.select_lr 98 | args.batch_size = 1 if args.batch_size is None else args.batch_size 99 | 100 | train_set = frameDataset(base, split='trainval', world_reduce=args.world_reduce, 101 | img_reduce=args.img_reduce, world_kernel_size=args.world_kernel_size, 102 | img_kernel_size=args.img_kernel_size, 103 | dropout=args.dropcam, augmentation=args.augmentation) 104 | val_set = frameDataset(base, split='val', world_reduce=args.world_reduce, 105 | img_reduce=args.img_reduce, world_kernel_size=args.world_kernel_size, 106 | img_kernel_size=args.img_kernel_size) 107 | test_set = frameDataset(base, split='test', world_reduce=args.world_reduce, 108 | img_reduce=args.img_reduce, world_kernel_size=args.world_kernel_size, 109 | img_kernel_size=args.img_kernel_size) 110 | 111 | if args.steps: 112 | args.lr /= 5 113 | # args.epochs *= 2 114 | 115 | def seed_worker(worker_id): 116 | worker_seed = torch.initial_seed() % 2 ** 32 117 | np.random.seed(worker_seed) 118 | random.seed(worker_seed) 119 | 120 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, 121 | pin_memory=True, worker_init_fn=seed_worker) 122 | val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, 123 | pin_memory=True, worker_init_fn=seed_worker) 124 | test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, 125 | pin_memory=True, worker_init_fn=seed_worker) 126 | N = train_set.num_cam 127 | 128 | # logging 129 | select_settings = f'steps{args.steps}_' 130 | lr_settings = f'base{args.base_lr_ratio}other{args.other_lr_ratio}' + \ 131 | f'select{args.select_lr}' if args.steps else '' 132 | logdir = f'logs/{args.dataset}/{"DEBUG_" if is_debug else ""}{args.arch}_{args.aggregation}_down{args.down}_' \ 133 | f'{select_settings if args.steps else ""}' \ 134 | f'lr{args.lr}{lr_settings}_b{args.batch_size}_e{args.epochs}_dropcam{args.dropcam}_' \ 135 | f'{datetime.datetime.today():%Y-%m-%d_%H-%M-%S}' if not args.eval \ 136 | else f'logs/{args.dataset}/EVAL_{args.resume}' 137 | os.makedirs(logdir, exist_ok=True) 138 | copy_tree('src', logdir + '/scripts/src') 139 | for script in os.listdir('.'): 140 | if script.split('.')[-1] == 'py': 141 | dst_file = os.path.join(logdir, 'scripts', os.path.basename(script)) 142 | shutil.copyfile(script, dst_file) 143 | sys.stdout = Logger(os.path.join(logdir, 'log.txt'), ) 144 | print(logdir) 145 | print('Settings:') 146 | print(vars(args)) 147 | 148 | # model 149 | if args.task == 'mvcnn': 150 | model = MVCNN(train_set, args.arch, args.aggregation).cuda() 151 | else: 152 | model = MVDet(train_set, args.arch, args.aggregation, 153 | args.use_bottleneck, args.hidden_dim, args.outfeat_dim).cuda() 154 | 155 | # load checkpoint 156 | if args.steps: 157 | with open(f'logs/{args.dataset}/{args.arch}_performance.txt', 'r') as fp: 158 | result_str = fp.read() 159 | print(result_str) 160 | load_dir = result_str.split('\n')[1].replace('# ', '') 161 | pretrained_dict = torch.load(f'{load_dir}/model.pth') 162 | model_dict = model.state_dict() 163 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'select' not in k} 164 | model_dict.update(pretrained_dict) 165 | model.load_state_dict(model_dict) 166 | 167 | if args.resume: 168 | print(f'loading checkpoint: logs/{args.dataset}/{args.resume}') 169 | pretrained_dict = torch.load(f'logs/{args.dataset}/{args.resume}/model.pth') 170 | model_dict = model.state_dict() 171 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 172 | model_dict.update(pretrained_dict) 173 | model.load_state_dict(model_dict) 174 | 175 | param_dicts = [{"params": [p for n, p in model.named_parameters() 176 | if 'base' not in n and 'select' not in n and p.requires_grad], 177 | "lr": args.lr * args.other_lr_ratio, }, 178 | {"params": [p for n, p in model.named_parameters() if 'base' in n and p.requires_grad], 179 | "lr": args.lr * args.base_lr_ratio, }, 180 | {"params": [p for n, p in model.named_parameters() if 'select' in n and p.requires_grad], 181 | "lr": args.select_lr, }, ] 182 | optimizer = optim.Adam(param_dicts, lr=args.lr, weight_decay=args.weight_decay) 183 | 184 | def warmup_lr_scheduler(epoch, warmup_epochs=0.1 * args.epochs): 185 | if epoch < warmup_epochs: 186 | return epoch / warmup_epochs 187 | else: 188 | return (np.cos((epoch - warmup_epochs) / (args.epochs - warmup_epochs) * np.pi) + 1) / 2 189 | 190 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, warmup_lr_scheduler) 191 | 192 | if args.task == 'mvcnn': 193 | trainer = ClassifierTrainer(model, logdir, args) 194 | else: 195 | trainer = PerspectiveTrainer(model, logdir, args) 196 | 197 | # draw curve 198 | x_epoch = [] 199 | train_loss_s = [] 200 | train_prec_s = [] 201 | test_loss_s = [] 202 | test_prec_s = [] 203 | 204 | # learn 205 | if not args.eval: 206 | # trainer.test(test_loader) 207 | for epoch in tqdm.tqdm(range(1, args.epochs + 1)): 208 | print('Training...') 209 | train_loss, train_prec = trainer.train(epoch, train_loader, optimizer, scheduler) 210 | if epoch % max(args.epochs // 10, 1) == 0: 211 | print('Testing...') 212 | test_loss, test_prec = trainer.test(test_loader, torch.eye(N) if args.steps else None) 213 | 214 | # draw & save 215 | x_epoch.append(epoch) 216 | train_loss_s.append(train_loss) 217 | train_prec_s.append(train_prec) 218 | test_loss_s.append(test_loss) 219 | test_prec_s.append(test_prec[0]) 220 | draw_curve(os.path.join(logdir, 'learning_curve.jpg'), x_epoch, train_loss_s, test_loss_s, 221 | train_prec_s, test_prec_s) 222 | torch.save(model.state_dict(), os.path.join(logdir, 'model.pth')) 223 | 224 | def log_best2cam_strategy(result_type=('prec',), max_steps=4): 225 | candidates = np.eye(N) 226 | combinations = np.array(list(itertools.combinations(candidates, 2))).sum(1) 227 | combination_indices = np.array(list(itertools.combinations(list(range(N)), 2))) 228 | info_str = {} 229 | # diagonal: step == 0 230 | val_loss_diag, val_prec_diag, _, _ = trainer.test_cam_combination(val_loader, 0) 231 | test_loss_diag, test_prec_diag, _, info_str[0] = trainer.test_cam_combination(test_loader, 0) 232 | # non-diagonal: step == 1 233 | val_loss_s, val_prec_s, val_oracle_s, _ = trainer.test_cam_combination(val_loader, 1) 234 | test_loss_s, test_prec_s, test_oracle_s, info_str[1] = trainer.test_cam_combination(test_loader, 1) 235 | for i in range(2, max_steps + 1): 236 | _, _, _, info_str[i] = trainer.test_cam_combination(test_loader, i) 237 | info_str = '\n'.join(info_str.values()) 238 | 239 | def combine2mat(diag_terms, non_diag_terms): 240 | combined_mat = np.zeros([len(diag_terms), len(diag_terms)] + list(diag_terms.shape[1:])) 241 | combined_mat[np.eye(len(diag_terms), dtype=bool)] = diag_terms 242 | non_diag_indices = list(itertools.combinations(list(range(len(diag_terms))), 2)) 243 | for i in range(len(non_diag_indices)): 244 | idx = non_diag_indices[i] 245 | combined_mat[idx[0], idx[1]] = combined_mat[idx[1], idx[0]] = non_diag_terms[i] 246 | return combined_mat 247 | 248 | def find_cam(init_cam, combination_id): 249 | cam_tuple = list(combination_indices[combination_id]) 250 | cam_tuple.remove(init_cam) 251 | return cam_tuple[0] 252 | 253 | val_loss_strategy = find_dataset_lvl_strategy(-val_loss_s, combinations) 254 | val_metric_strategy = find_dataset_lvl_strategy(val_prec_s[:, 0], combinations) 255 | test_metric_strategy = find_dataset_lvl_strategy(test_prec_s[:, 0], combinations) 256 | 257 | _, prec = trainer.test(test_loader) 258 | np.savetxt(f'{logdir}/losses_val_test.txt', np.concatenate([combine2mat(val_loss_diag, val_loss_s), 259 | combine2mat(test_loss_diag, test_loss_s)]), '%.2f') 260 | for i in range(len(result_type)): 261 | fname = f'{result_type[i]}_{prec[i]:.1f}_' \ 262 | f'Lstrategy{test_prec_s[val_loss_strategy].mean(0)[i]:.1f}_' \ 263 | f'Rstrategy{test_prec_s[val_metric_strategy].mean(0)[i]:.1f}_' \ 264 | f'theory{test_prec_s[test_metric_strategy].mean(0)[i]:.1f}_' \ 265 | f'avg{test_prec_s.mean(0)[i]:.1f}.txt' 266 | np.savetxt(f'{logdir}/{fname}', 267 | np.concatenate([combine2mat(val_prec_diag, val_prec_s)[:, :, i], 268 | combine2mat(test_prec_diag, test_prec_s)[:, :, i]]), '%.1f', 269 | header=f'loading checkpoint...\n' 270 | f'{logdir}\n' 271 | f'val / test', 272 | footer=(f'\n{info_str}\n\n' if i == 0 else '') + f'\tdataset level: loss strategy\n' + 273 | ' '.join(f'cam {find_cam(cam, val_loss_strategy[cam])} |' for cam in range(N)) + '\n' + 274 | ' '.join(f'{test_prec_s[val_loss_strategy][cam, i]:.1f}% |' 275 | for cam in range(N)) + '\n' + 276 | f'\tdataset level: result strategy\n' + 277 | ' '.join(f'cam {find_cam(cam, val_metric_strategy[cam])} |' for cam in range(N)) + '\n' + 278 | ' '.join(f'{test_prec_s[val_metric_strategy][cam, i]:.1f}% |' 279 | for cam in range(N)) + '\n' + 280 | f'\tdataset level: theory\n' + 281 | ' '.join(f'cam {find_cam(cam, test_metric_strategy[cam])} |' for cam in range(N)) + '\n' + 282 | ' '.join(f'{test_prec_s[test_metric_strategy][cam, i]:.1f}% |' 283 | for cam in range(N)) + '\n' + 284 | f'\tinstance level: oracle\n' + 285 | ' '.join(f'----- |' for cam in range(N)) + '\n' + 286 | ' '.join(f'{test_oracle_s[cam, i]:.1f}% |' 287 | for cam in range(N)) + '\n' + 288 | f'2 best cam: loss_strategy {test_prec_s[val_loss_strategy].mean(0)[i]:.1f}, ' 289 | f'result_strategy {test_prec_s[val_metric_strategy].mean(0)[i]:.1f}, ' 290 | f'theory {test_prec_s[test_metric_strategy].mean(0)[i]:.1f}, ' 291 | f'oracle {test_oracle_s.mean(0)[i]:.1f}, average {test_prec_s.mean(0)[i]:.1f}\n' 292 | f'all cam: {prec[i]:.1f}') 293 | with open(f'{logdir}/{fname}', 'r') as fp: 294 | if i == 0: 295 | print(fp.read()) 296 | if not args.eval and i == 0: 297 | shutil.copyfile(f'{logdir}/{fname}', f'logs/{args.dataset}/{args.arch}_performance.txt') 298 | 299 | print('Test loaded model...') 300 | print(logdir) 301 | if args.steps == 0: 302 | log_best2cam_strategy(result_type) 303 | else: 304 | if args.eval: 305 | trainer.test(test_loader, torch.eye(N)) 306 | trainer.test(test_loader) 307 | 308 | 309 | if __name__ == '__main__': 310 | # common settings 311 | parser = argparse.ArgumentParser(description='view selection for multiview classification & detection') 312 | parser.add_argument('--eval', action='store_true', help='evaluation only') 313 | parser.add_argument('--arch', type=str, default='resnet18') 314 | parser.add_argument('--aggregation', type=str, default='max', choices=['mean', 'max']) 315 | parser.add_argument('-d', '--dataset', type=str, default='wildtrack', 316 | choices=['wildtrack', 'multiviewx', 'modelnet40_12', 'modelnet40_20', 'scanobjectnn']) 317 | parser.add_argument('-j', '--num_workers', type=int, default=4) 318 | parser.add_argument('-b', '--batch_size', type=int, default=None, help='input batch size for training') 319 | parser.add_argument('--dropcam', type=float, default=0.0) 320 | parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train') 321 | parser.add_argument('--lr', type=float, default=None, help='learning rate for task network') 322 | parser.add_argument('--select_lr', type=float, default=None, help='learning rate for MVselect') 323 | parser.add_argument('--base_lr_ratio', type=float, default=1.0) 324 | parser.add_argument('--other_lr_ratio', type=float, default=1.0) 325 | parser.add_argument('--weight_decay', type=float, default=1e-4) 326 | parser.add_argument('--resume', type=str, default=None) 327 | parser.add_argument('--visualize', action='store_true') 328 | parser.add_argument('--seed', type=int, default=None, help='random seed') 329 | parser.add_argument('--deterministic', type=str2bool, default=False) 330 | # MVSelect settings 331 | parser.add_argument('--steps', type=int, default=0, 332 | help='number of camera views to choose. if 0, then no selection') 333 | parser.add_argument('--gamma', type=float, default=0.99, help='reward discount factor (default: 0.99)') 334 | parser.add_argument('--down', type=int, default=1, help='down sample the image to 1/N size') 335 | # parser.add_argument('--beta_entropy', type=float, default=0.01) 336 | # multiview detection specific settings 337 | parser.add_argument('--eval_init_cam', type=str2bool, default=False, 338 | help='only consider pedestrians covered by the initial camera') 339 | parser.add_argument('--reID', action='store_true') 340 | parser.add_argument('--augmentation', type=str2bool, default=True) 341 | parser.add_argument('--id_ratio', type=float, default=0) 342 | parser.add_argument('--cls_thres', type=float, default=0.6) 343 | parser.add_argument('--alpha', type=float, default=1.0, help='ratio for per view loss') 344 | parser.add_argument('--use_mse', type=str2bool, default=False) 345 | parser.add_argument('--use_bottleneck', type=str2bool, default=True) 346 | parser.add_argument('--hidden_dim', type=int, default=128) 347 | parser.add_argument('--outfeat_dim', type=int, default=0) 348 | parser.add_argument('--world_reduce', type=int, default=4) 349 | parser.add_argument('--world_kernel_size', type=int, default=10) 350 | parser.add_argument('--img_reduce', type=int, default=12) 351 | parser.add_argument('--img_kernel_size', type=int, default=10) 352 | 353 | args = parser.parse_args() 354 | 355 | main(args) 356 | -------------------------------------------------------------------------------- /src/datasets/frameDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import time 5 | from operator import itemgetter 6 | from PIL import Image 7 | from kornia.geometry import warp_perspective 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | import torchvision.transforms as T 12 | from torchvision.datasets import VisionDataset 13 | from src.utils.projection import * 14 | from src.utils.image_utils import draw_umich_gaussian, random_affine 15 | import matplotlib.pyplot as plt 16 | 17 | 18 | def get_gt(Rshape, x_s, y_s, w_s=None, h_s=None, v_s=None, reduce=4, top_k=100, kernel_size=4): 19 | H, W = Rshape 20 | heatmap = np.zeros([1, H, W], dtype=np.float32) 21 | reg_mask = np.zeros([top_k], dtype=bool) 22 | idx = np.zeros([top_k], dtype=np.int64) 23 | pid = np.zeros([top_k], dtype=np.int64) 24 | offset = np.zeros([top_k, 2], dtype=np.float32) 25 | wh = np.zeros([top_k, 2], dtype=np.float32) 26 | 27 | for k in range(len(v_s)): 28 | ct = np.array([x_s[k] / reduce, y_s[k] / reduce], dtype=np.float32) 29 | if 0 <= ct[0] < W and 0 <= ct[1] < H: 30 | ct_int = ct.astype(np.int32) 31 | draw_umich_gaussian(heatmap[0], ct_int, kernel_size / reduce) 32 | reg_mask[k] = 1 33 | idx[k] = ct_int[1] * W + ct_int[0] 34 | pid[k] = v_s[k] 35 | offset[k] = ct - ct_int 36 | if w_s is not None and h_s is not None: 37 | wh[k] = [w_s[k] / reduce, h_s[k] / reduce] 38 | # plt.imshow(heatmap[0]) 39 | # plt.show() 40 | 41 | ret = {'heatmap': torch.from_numpy(heatmap), 'reg_mask': torch.from_numpy(reg_mask), 'idx': torch.from_numpy(idx), 42 | 'pid': torch.from_numpy(pid), 'offset': torch.from_numpy(offset)} 43 | if w_s is not None and h_s is not None: 44 | ret.update({'wh': torch.from_numpy(wh)}) 45 | return ret 46 | 47 | 48 | def read_pom(root): 49 | bbox_by_pos_cam = {} 50 | cam_pos_pattern = re.compile(r'(\d+) (\d+)') 51 | cam_pos_bbox_pattern = re.compile(r'(\d+) (\d+) ([-\d]+) ([-\d]+) (\d+) (\d+)') 52 | with open(os.path.join(root, 'rectangles.pom'), 'r') as fp: 53 | for line in fp: 54 | if 'RECTANGLE' in line: 55 | cam, pos = map(int, cam_pos_pattern.search(line).groups()) 56 | if pos not in bbox_by_pos_cam: 57 | bbox_by_pos_cam[pos] = {} 58 | if 'notvisible' in line: 59 | bbox_by_pos_cam[pos][cam] = None 60 | else: 61 | cam, pos, left, top, right, bottom = map(int, cam_pos_bbox_pattern.search(line).groups()) 62 | bbox_by_pos_cam[pos][cam] = [max(left, 0), max(top, 0), 63 | min(right, 1920 - 1), min(bottom, 1080 - 1)] 64 | return bbox_by_pos_cam 65 | 66 | 67 | class frameDataset(VisionDataset): 68 | def __init__(self, base, split='train', reID=False, world_reduce=4, img_reduce=12, 69 | world_kernel_size=10, img_kernel_size=10, 70 | split_ratio=(0.8, 0.1, 0.1), top_k=100, force_download=True, dropout=0.0, augmentation=False): 71 | super().__init__(base.root) 72 | 73 | self.base = base 74 | self.num_cam, self.num_frame = base.num_cam, base.num_frame 75 | # world (grid) reduce: on top of the 2.5cm grid 76 | self.reID, self.top_k = reID, top_k 77 | # reduce = input/output 78 | self.world_reduce, self.img_reduce = world_reduce, img_reduce 79 | self.img_shape, self.worldgrid_shape = base.img_shape, base.worldgrid_shape # H,W; N_row,N_col 80 | self.world_kernel_size, self.img_kernel_size = world_kernel_size, img_kernel_size 81 | self.dropout = dropout 82 | self.transform = T.Compose([T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 83 | T.Resize((np.array(self.img_shape) * 8 // self.img_reduce).tolist())]) 84 | self.augmentation = augmentation 85 | 86 | self.Rworld_shape = list(map(lambda x: x // self.world_reduce, self.worldgrid_shape)) 87 | self.Rimg_shape = np.ceil(np.array(self.img_shape) / self.img_reduce).astype(int).tolist() 88 | 89 | # split = ('train', 'val', 'test'), split_ratio=(0.8, 0.1, 0.1) 90 | split_ratio = tuple(sum(split_ratio[:i + 1]) for i in range(len(split_ratio))) 91 | assert split_ratio[-1] == 1 92 | self.split = split 93 | if split == 'train': 94 | frame_range = range(0, int(self.num_frame * split_ratio[0])) 95 | elif split == 'val': 96 | frame_range = range(int(self.num_frame * split_ratio[0]), int(self.num_frame * split_ratio[1])) 97 | elif split == 'trainval': 98 | frame_range = range(0, int(self.num_frame * split_ratio[1])) 99 | elif split == 'test': 100 | frame_range = range(int(self.num_frame * split_ratio[1]), self.num_frame) 101 | else: 102 | raise Exception 103 | 104 | self.world_from_img, self.img_from_world = self.get_world_imgs_trans() 105 | world_masks = torch.ones([self.num_cam, 1] + self.worldgrid_shape) 106 | self.imgs_region = warp_perspective(world_masks, self.img_from_world, self.img_shape, 'nearest') 107 | self.Rworld_coverage = self.get_world_coverage().bool() 108 | 109 | self.img_fpaths = self.get_image_fpaths(frame_range) 110 | self.world_gt, self.imgs_gt, self.pid_dict, self.frames = self.get_gt_targets( 111 | split if split == 'trainval' else f'{split} \t', frame_range) 112 | # gt in mot format for evaluation 113 | self.gt_fname = f'{self.root}/gt' 114 | if not os.path.exists(f'{self.gt_fname}.txt') or force_download: 115 | self.prepare_gt() 116 | pass 117 | 118 | def get_image_fpaths(self, frame_range): 119 | img_fpaths = {cam: {} for cam in range(self.num_cam)} 120 | for camera_folder in sorted(os.listdir(os.path.join(self.root, 'Image_subsets'))): 121 | cam = int(camera_folder[-1]) - 1 122 | if cam >= self.num_cam: 123 | continue 124 | for fname in sorted(os.listdir(os.path.join(self.root, 'Image_subsets', camera_folder))): 125 | frame = int(fname.split('.')[0]) 126 | if frame in frame_range: 127 | img_fpaths[cam][frame] = os.path.join(self.root, 'Image_subsets', camera_folder, fname) 128 | return img_fpaths 129 | 130 | def get_gt_targets(self, split, frame_range): 131 | num_world_bbox, num_imgs_bbox = 0, 0 132 | world_gt = {} 133 | imgs_gt = {} 134 | pid_dict = {} 135 | frames = [] 136 | for fname in sorted(os.listdir(os.path.join(self.root, 'annotations_positions'))): 137 | frame = int(fname.split('.')[0]) 138 | if frame in frame_range: 139 | frames.append(frame) 140 | with open(os.path.join(self.root, 'annotations_positions', fname)) as json_file: 141 | all_pedestrians = json.load(json_file) 142 | world_pts, world_pids = [], [] 143 | img_bboxs, img_pids = [[] for _ in range(self.num_cam)], [[] for _ in range(self.num_cam)] 144 | for pedestrian in all_pedestrians: 145 | grid_x, grid_y = self.base.get_worldgrid_from_pos(pedestrian['positionID']).squeeze() 146 | if pedestrian['personID'] not in pid_dict: 147 | pid_dict[pedestrian['personID']] = len(pid_dict) 148 | num_world_bbox += 1 149 | if self.base.indexing == 'xy': 150 | world_pts.append((grid_x, grid_y)) 151 | else: 152 | world_pts.append((grid_y, grid_x)) 153 | world_pids.append(pid_dict[pedestrian['personID']]) 154 | for cam in range(self.num_cam): 155 | if itemgetter('xmin', 'ymin', 'xmax', 'ymax')(pedestrian['views'][cam]) != (-1, -1, -1, -1): 156 | img_bboxs[cam].append(itemgetter('xmin', 'ymin', 'xmax', 'ymax') 157 | (pedestrian['views'][cam])) 158 | img_pids[cam].append(pid_dict[pedestrian['personID']]) 159 | num_imgs_bbox += 1 160 | world_gt[frame] = (np.array(world_pts), np.array(world_pids)) 161 | imgs_gt[frame] = {} 162 | for cam in range(self.num_cam): 163 | # x1y1x2y2 164 | imgs_gt[frame][cam] = (np.array(img_bboxs[cam]), np.array(img_pids[cam])) 165 | 166 | print(f'{split}:\t pid: {len(pid_dict)}, frame: {len(frames)}, ' 167 | f'world bbox: {num_world_bbox / len(frames):.1f}, ' 168 | f'imgs bbox per cam: {num_imgs_bbox / len(frames) / self.num_cam:.1f}') 169 | return world_gt, imgs_gt, pid_dict, frames 170 | 171 | def get_world_coverage(self): 172 | # world grid change to xy indexing 173 | world_zoom_mat = np.diag([self.world_reduce, self.world_reduce, 1]) 174 | Rworldgrid_from_worldcoord_mat = np.linalg.inv( 175 | self.base.worldcoord_from_worldgrid_mat @ world_zoom_mat @ self.base.world_indexing_from_xy_mat) 176 | 177 | # z in meters by default 178 | # projection matrices: img feat -> world feat 179 | worldcoord_from_imgcoord_mats = [get_worldcoord_from_imgcoord_mat(self.base.intrinsic_matrices[cam], 180 | self.base.extrinsic_matrices[cam], ) 181 | for cam in range(self.num_cam)] 182 | # Rworldgrid(xy)_from_imgcoord(xy) 183 | proj_mats = torch.stack([torch.from_numpy(Rworldgrid_from_worldcoord_mat @ 184 | worldcoord_from_imgcoord_mats[cam]) 185 | for cam in range(self.num_cam)]).float() 186 | 187 | imgs = torch.ones([self.num_cam, 1, self.base.img_shape[0], self.base.img_shape[1]]) 188 | coverage = warp_perspective(imgs, proj_mats, self.Rworld_shape) 189 | return coverage 190 | 191 | def get_world_imgs_trans(self, z=0): 192 | # image and world feature maps from xy indexing, change them into world indexing / xy indexing (img) 193 | # world grid change to xy indexing 194 | Rworldgrid_from_worldcoord_mat = np.linalg.inv(self.base.worldcoord_from_worldgrid_mat @ 195 | self.base.world_indexing_from_xy_mat) 196 | 197 | # z in meters by default 198 | # projection matrices: img feat -> world feat 199 | worldcoord_from_imgcoord_mats = [get_worldcoord_from_imgcoord_mat(self.base.intrinsic_matrices[cam], 200 | self.base.extrinsic_matrices[cam], 201 | z / self.base.worldcoord_unit) 202 | for cam in range(self.num_cam)] 203 | # worldgrid(xy)_from_img(xy) 204 | proj_mats = [Rworldgrid_from_worldcoord_mat @ worldcoord_from_imgcoord_mats[cam] @ self.base.img_xy_from_xy_mat 205 | for cam in range(self.num_cam)] 206 | world_from_img = torch.tensor(np.stack(proj_mats)) 207 | # img(xy)_from_worldgrid(xy) 208 | img_from_world = torch.tensor(np.stack([np.linalg.inv(proj_mat) for proj_mat in proj_mats])) 209 | return world_from_img.float(), img_from_world.float() 210 | 211 | def prepare_gt(self): 212 | og_gt = [[] for _ in range(self.num_cam)] 213 | for fname in sorted(os.listdir(os.path.join(self.root, 'annotations_positions'))): 214 | frame = int(fname.split('.')[0]) 215 | with open(os.path.join(self.root, 'annotations_positions', fname)) as json_file: 216 | all_pedestrians = json.load(json_file) 217 | for single_pedestrian in all_pedestrians: 218 | def is_in_cam(cam, grid_x, grid_y): 219 | visible = not (single_pedestrian['views'][cam]['xmin'] == -1 and 220 | single_pedestrian['views'][cam]['xmax'] == -1 and 221 | single_pedestrian['views'][cam]['ymin'] == -1 and 222 | single_pedestrian['views'][cam]['ymax'] == -1) 223 | in_view = (single_pedestrian['views'][cam]['xmin'] > 0 and 224 | single_pedestrian['views'][cam]['xmax'] < 1920 and 225 | single_pedestrian['views'][cam]['ymin'] > 0 and 226 | single_pedestrian['views'][cam]['ymax'] < 1080) 227 | 228 | # Rgrid_x, Rgrid_y = grid_x // self.world_reduce, grid_y // self.world_reduce 229 | # in_map = Rgrid_x < self.Rworld_shape[0] and Rgrid_y < self.Rworld_shape[1] 230 | return visible and in_view 231 | 232 | grid_x, grid_y = self.base.get_worldgrid_from_pos(single_pedestrian['positionID']).squeeze() 233 | for cam in range(self.num_cam): 234 | if is_in_cam(cam, grid_x, grid_y): 235 | og_gt[cam].append(np.array([frame, grid_x, grid_y])) 236 | og_gt = [np.stack(og_gt[cam], axis=0) for cam in range(self.num_cam)] 237 | np.savetxt(f'{self.gt_fname}.txt', np.unique(np.concatenate(og_gt, axis=0), axis=0), '%d') 238 | for cam in range(self.num_cam): 239 | np.savetxt(f'{self.gt_fname}_{cam}.txt', og_gt[cam], '%d') 240 | 241 | def __getitem__(self, index, visualize=False): 242 | def plt_visualize(): 243 | import cv2 244 | from matplotlib.patches import Circle 245 | fig, ax = plt.subplots(1) 246 | ax.imshow(img) 247 | for i in range(len(img_x_s)): 248 | x, y = img_x_s[i], img_y_s[i] 249 | if x > 0 and y > 0: 250 | ax.add_patch(Circle((x, y), 10)) 251 | plt.show() 252 | img0 = img.copy() 253 | for bbox in img_bboxs: 254 | bbox = tuple(int(pt) for pt in bbox) 255 | cv2.rectangle(img0, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2) 256 | plt.imshow(img0) 257 | plt.show() 258 | 259 | frame = list(self.world_gt.keys())[index] 260 | # imgs 261 | imgs, imgs_gt, affine_mats, masks = [], [], [], [] 262 | for cam in range(self.num_cam): 263 | img = np.array(Image.open(self.img_fpaths[cam][frame]).convert('RGB')) 264 | img_bboxs, img_pids = self.imgs_gt[frame][cam] 265 | if self.augmentation: 266 | img, img_bboxs, img_pids, M = random_affine(img, img_bboxs, img_pids) 267 | else: 268 | M = np.eye(3) 269 | imgs.append(self.transform(img)) 270 | affine_mats.append(torch.from_numpy(M).float()) 271 | img_x_s, img_y_s = (img_bboxs[:, 0] + img_bboxs[:, 2]) / 2, img_bboxs[:, 3] 272 | img_w_s, img_h_s = (img_bboxs[:, 2] - img_bboxs[:, 0]), (img_bboxs[:, 3] - img_bboxs[:, 1]) 273 | 274 | img_gt = get_gt(self.Rimg_shape, img_x_s, img_y_s, img_w_s, img_h_s, v_s=img_pids, 275 | reduce=self.img_reduce, top_k=self.top_k, kernel_size=self.img_kernel_size) 276 | imgs_gt.append(img_gt) 277 | if visualize: 278 | plt_visualize() 279 | 280 | imgs = torch.stack(imgs) 281 | affine_mats = torch.stack(affine_mats) 282 | # inverse_M = torch.inverse( 283 | # torch.cat([affine_mats, torch.tensor([0, 0, 1]).view(1, 1, 3).repeat(self.num_cam, 1, 1)], dim=1))[:, :2] 284 | imgs_gt = {key: torch.stack([img_gt[key] for img_gt in imgs_gt]) for key in imgs_gt[0]} 285 | drop, keep_cams = np.random.rand() < self.dropout, torch.ones(self.num_cam, dtype=torch.bool) 286 | if drop: 287 | num_drop = np.random.randint(self.num_cam - 1) 288 | drop_cams = np.random.choice(self.num_cam, num_drop, replace=False) 289 | for cam in drop_cams: 290 | keep_cams[cam] = 0 291 | for key in imgs_gt: 292 | imgs_gt[key][cam] = 0 293 | # world gt 294 | world_pt_s, world_pid_s = self.world_gt[frame] 295 | world_gt = get_gt(self.Rworld_shape, world_pt_s[:, 0], world_pt_s[:, 1], v_s=world_pid_s, 296 | reduce=self.world_reduce, top_k=self.top_k, kernel_size=self.world_kernel_size) 297 | return imgs, world_gt, imgs_gt, affine_mats, frame, keep_cams 298 | 299 | def __len__(self): 300 | return len(self.world_gt.keys()) 301 | 302 | 303 | def test(test_projection=False): 304 | from torch.utils.data import DataLoader 305 | from src.datasets.wildtrack import Wildtrack 306 | from src.datasets.multiviewx import MultiviewX 307 | 308 | dataset = frameDataset(Wildtrack(os.path.expanduser('~/Data/Wildtrack')), force_download=True) 309 | dataset = frameDataset(MultiviewX(os.path.expanduser('~/Data/MultiviewX')), force_download=True) 310 | # dataset = frameDataset(Wildtrack(os.path.expanduser('~/Data/Wildtrack')), split='train', semi_supervised=.1) 311 | # dataset = frameDataset(MultiviewX(os.path.expanduser('~/Data/MultiviewX')), split='train', semi_supervised=.1) 312 | # dataset = frameDataset(Wildtrack(os.path.expanduser('~/Data/Wildtrack')), split='train', semi_supervised=0.5) 313 | # dataset = frameDataset(MultiviewX(os.path.expanduser('~/Data/MultiviewX')), split='train', semi_supervised=0.5) 314 | min_dist = np.inf 315 | for world_gt in dataset.world_gt.values(): 316 | x, y = world_gt[0][:, 0], world_gt[0][:, 1] 317 | if x.size and y.size: 318 | xy_dists = ((x - x[:, None]) ** 2 + (y - y[:, None]) ** 2) ** 0.5 319 | np.fill_diagonal(xy_dists, np.inf) 320 | min_dist = min(min_dist, np.min(xy_dists)) 321 | pass 322 | dataloader = DataLoader(dataset, 2, True, num_workers=0) 323 | # imgs, world_gt, imgs_gt, M, frame, keep_cams = next(iter(dataloader)) 324 | t0 = time.time() 325 | for i in range(10): 326 | imgs, world_gt, imgs_gt, M, frame, keep_cams = dataset.__getitem__(i, visualize=False) 327 | print(time.time() - t0) 328 | 329 | pass 330 | if test_projection: 331 | import matplotlib.pyplot as plt 332 | from src.utils.projection import get_worldcoord_from_imagecoord 333 | world_grid_maps = [] 334 | xx, yy = np.meshgrid(np.arange(0, 1920, 20), np.arange(0, 1080, 20)) 335 | H, W = xx.shape 336 | image_coords = np.stack([xx, yy], axis=2).reshape([-1, 2]) 337 | for cam in range(dataset.num_cam): 338 | world_coords = get_worldcoord_from_imagecoord(image_coords.transpose(), 339 | dataset.base.intrinsic_matrices[cam], 340 | dataset.base.extrinsic_matrices[cam]) 341 | world_grids = dataset.base.get_worldgrid_from_worldcoord(world_coords).transpose().reshape([H, W, 2]) 342 | world_grid_map = np.zeros(dataset.worldgrid_shape) 343 | for i in range(H): 344 | for j in range(W): 345 | x, y = world_grids[i, j] 346 | if dataset.base.indexing == 'xy': 347 | if x in range(dataset.worldgrid_shape[1]) and y in range(dataset.worldgrid_shape[0]): 348 | world_grid_map[int(y), int(x)] += 1 349 | else: 350 | if x in range(dataset.worldgrid_shape[0]) and y in range(dataset.worldgrid_shape[1]): 351 | world_grid_map[int(x), int(y)] += 1 352 | world_grid_map = world_grid_map != 0 353 | plt.imshow(world_grid_map) 354 | plt.show() 355 | world_grid_maps.append(world_grid_map) 356 | pass 357 | plt.imshow(np.sum(np.stack(world_grid_maps), axis=0)) 358 | plt.show() 359 | pass 360 | 361 | 362 | if __name__ == '__main__': 363 | test(False) 364 | --------------------------------------------------------------------------------