├── libs ├── __init__.py ├── .DS_Store ├── process_result.py ├── utils.py └── functions.py ├── ycbv_qua.jpg ├── dataset ├── .DS_Store ├── fps │ ├── .DS_Store │ ├── lm │ │ ├── obj01_fps128.mat │ │ ├── obj02_fps128.mat │ │ ├── obj03_fps128.mat │ │ ├── obj04_fps128.mat │ │ ├── obj05_fps128.mat │ │ ├── obj06_fps128.mat │ │ ├── obj07_fps128.mat │ │ ├── obj08_fps128.mat │ │ ├── obj09_fps128.mat │ │ ├── obj10_fps128.mat │ │ ├── obj11_fps128.mat │ │ ├── obj12_fps128.mat │ │ ├── obj13_fps128.mat │ │ ├── obj14_fps128.mat │ │ └── obj15_fps128.mat │ └── ycbv │ │ ├── obj01_fps128.mat │ │ ├── obj02_fps128.mat │ │ ├── obj03_fps128.mat │ │ ├── obj04_fps128.mat │ │ ├── obj05_fps128.mat │ │ ├── obj06_fps128.mat │ │ ├── obj07_fps128.mat │ │ ├── obj08_fps128.mat │ │ ├── obj09_fps128.mat │ │ ├── obj10_fps128.mat │ │ ├── obj11_fps128.mat │ │ ├── obj12_fps128.mat │ │ ├── obj13_fps128.mat │ │ ├── obj14_fps128.mat │ │ ├── obj15_fps128.mat │ │ ├── obj16_fps128.mat │ │ ├── obj17_fps128.mat │ │ ├── obj18_fps128.mat │ │ ├── obj19_fps128.mat │ │ ├── obj20_fps128.mat │ │ └── obj21_fps128.mat ├── lmo.py └── lm.py ├── model ├── .DS_Store └── hrnet_backbone.py ├── detection ├── .DS_Store ├── __init__.py ├── image_list.py ├── generalized_rcnn.py ├── transform.py ├── _utils.py ├── keypoint_rcnn.py └── faster_rcnn.py ├── reference ├── .DS_Store ├── transforms.py ├── engine.py ├── group_by_aspect_ratio.py ├── train.py └── utils.py ├── analysis.py ├── cfg.yaml ├── README.md ├── .gitignore ├── LICENSE ├── environment.yaml ├── main_lmo.py ├── main_lm.py └── main_ycbv.py /libs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ycbv_qua.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/ycbv_qua.jpg -------------------------------------------------------------------------------- /libs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/libs/.DS_Store -------------------------------------------------------------------------------- /dataset/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/.DS_Store -------------------------------------------------------------------------------- /model/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/model/.DS_Store -------------------------------------------------------------------------------- /detection/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/detection/.DS_Store -------------------------------------------------------------------------------- /detection/__init__.py: -------------------------------------------------------------------------------- 1 | from .faster_rcnn import * 2 | from .keypoint_rcnn import * 3 | -------------------------------------------------------------------------------- /reference/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/reference/.DS_Store -------------------------------------------------------------------------------- /dataset/fps/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/.DS_Store -------------------------------------------------------------------------------- /dataset/fps/lm/obj01_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/lm/obj01_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/lm/obj02_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/lm/obj02_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/lm/obj03_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/lm/obj03_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/lm/obj04_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/lm/obj04_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/lm/obj05_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/lm/obj05_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/lm/obj06_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/lm/obj06_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/lm/obj07_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/lm/obj07_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/lm/obj08_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/lm/obj08_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/lm/obj09_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/lm/obj09_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/lm/obj10_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/lm/obj10_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/lm/obj11_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/lm/obj11_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/lm/obj12_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/lm/obj12_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/lm/obj13_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/lm/obj13_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/lm/obj14_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/lm/obj14_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/lm/obj15_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/lm/obj15_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/ycbv/obj01_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/ycbv/obj01_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/ycbv/obj02_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/ycbv/obj02_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/ycbv/obj03_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/ycbv/obj03_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/ycbv/obj04_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/ycbv/obj04_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/ycbv/obj05_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/ycbv/obj05_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/ycbv/obj06_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/ycbv/obj06_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/ycbv/obj07_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/ycbv/obj07_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/ycbv/obj08_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/ycbv/obj08_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/ycbv/obj09_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/ycbv/obj09_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/ycbv/obj10_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/ycbv/obj10_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/ycbv/obj11_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/ycbv/obj11_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/ycbv/obj12_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/ycbv/obj12_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/ycbv/obj13_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/ycbv/obj13_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/ycbv/obj14_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/ycbv/obj14_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/ycbv/obj15_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/ycbv/obj15_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/ycbv/obj16_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/ycbv/obj16_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/ycbv/obj17_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/ycbv/obj17_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/ycbv/obj18_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/ycbv/obj18_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/ycbv/obj19_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/ycbv/obj19_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/ycbv/obj20_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/ycbv/obj20_fps128.mat -------------------------------------------------------------------------------- /dataset/fps/ycbv/obj21_fps128.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoChenYS/ROPE/HEAD/dataset/fps/ycbv/obj21_fps128.mat -------------------------------------------------------------------------------- /analysis.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | from libs.process_result import Result_processor 3 | from libs.utils import get_logger 4 | 5 | def main(cfg): 6 | logger = get_logger(cfg) 7 | processor = Result_processor(cfg, mat_file=cfg.OUTPUT_DIR+'/'+cfg.obj+'/'+cfg.log_name+'_result.mat') 8 | processor.ycbv_auc(logger=logger) 9 | 10 | 11 | if __name__ == "__main__": 12 | import argparse 13 | parser = argparse.ArgumentParser( 14 | description=__doc__) 15 | 16 | parser.add_argument('--cfg', help='experiment configure file name', required=True, type=str) 17 | parser.add_argument('--obj', required=True, type=str) 18 | parser.add_argument('--log_name', required=True, type=str) 19 | args = parser.parse_args() 20 | cfg = CN(new_allowed=True) 21 | cfg.defrost() 22 | cfg.merge_from_file(args.cfg) 23 | cfg.obj = args.obj 24 | cfg.log_name = args.log_name 25 | 26 | main(cfg) 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /detection/image_list.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | from torch.jit.annotations import List, Tuple 4 | from torch import Tensor 5 | 6 | 7 | class ImageList(object): 8 | """ 9 | Structure that holds a list of images (of possibly 10 | varying sizes) as a single tensor. 11 | This works by padding the images to the same size, 12 | and storing in a field the original sizes of each image 13 | """ 14 | 15 | def __init__(self, tensors, image_sizes): 16 | # type: (Tensor, List[Tuple[int, int]]) -> None 17 | """ 18 | Arguments: 19 | tensors (tensor) 20 | image_sizes (list[tuple[int, int]]) 21 | """ 22 | self.tensors = tensors 23 | self.image_sizes = image_sizes 24 | 25 | def to(self, device): 26 | # type: (Device) -> ImageList # noqa 27 | cast_tensor = self.tensors.to(device) 28 | return ImageList(cast_tensor, self.image_sizes) 29 | -------------------------------------------------------------------------------- /libs/process_result.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import utils 3 | from scipy.io import loadmat, savemat 4 | 5 | 6 | class Result_processor(object): 7 | def __init__(self, cfg, mat_file): 8 | temp = loadmat(mat_file) 9 | self.test_idx = temp['test_idx'].squeeze() 10 | self.cfg = cfg 11 | self.result_dict = temp 12 | self.mat_file = mat_file 13 | 14 | def ycbv_auc(self, logger=None): 15 | poses = torch.tensor(self.result_dict['pose_recordc']).float() 16 | pts_model_h, is_sym = utils.get_ycbv_3dmodel(self.cfg,homo=True) 17 | annos = loadmat(self.cfg.YCBV_DIR+'/test_annos/obj{:02d}.mat'.format(utils.get_ycbv_objid(self.cfg.obj))) 18 | PMs = torch.tensor(annos['PMs']).float() 19 | Xs = [] 20 | Ys = [] 21 | for i in range(51): 22 | x = 0.02*i 23 | Xs.append(x) 24 | if is_sym: 25 | n, N = utils.ADDS_accuracy(poses, pts_model_h, PMs, 100*x, P_is_matrix=False) 26 | else: 27 | n, N = utils.ADD_accuracy(poses, pts_model_h, PMs, 100*x, P_is_matrix=False) 28 | assert N==len(self.test_idx) 29 | y = n/N 30 | Ys.append(y) 31 | logger.info('Computing ADD(-S) for AUC, x:{:.4f}, y:{:.4f}'.format(x,y)) 32 | import sklearn.metrics as M 33 | auc = M.auc(Xs, Ys) 34 | if logger is not None: 35 | logger.info('Obj: {}, AUC: {:1.4f}'.format(self.cfg.obj, auc)) 36 | self.result_dict.update({'Xs':Xs, 'Ys':Ys, 'auc':auc}) 37 | savemat(self.mat_file, self.result_dict) 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /reference/transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | from torchvision.transforms import functional as F 5 | 6 | 7 | def _flip_coco_person_keypoints(kps, width): 8 | flip_inds = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] 9 | flipped_data = kps[:, flip_inds] 10 | flipped_data[..., 0] = width - flipped_data[..., 0] 11 | # Maintain COCO convention that if visibility == 0, then x, y = 0 12 | inds = flipped_data[..., 2] == 0 13 | flipped_data[inds] = 0 14 | return flipped_data 15 | 16 | 17 | class Compose(object): 18 | def __init__(self, transforms): 19 | self.transforms = transforms 20 | 21 | def __call__(self, image, target): 22 | for t in self.transforms: 23 | image, target = t(image, target) 24 | return image, target 25 | 26 | 27 | class RandomHorizontalFlip(object): 28 | def __init__(self, prob): 29 | self.prob = prob 30 | 31 | def __call__(self, image, target): 32 | if random.random() < self.prob: 33 | height, width = image.shape[-2:] 34 | image = image.flip(-1) 35 | bbox = target["boxes"] 36 | bbox[:, [0, 2]] = width - bbox[:, [2, 0]] 37 | target["boxes"] = bbox 38 | if "masks" in target: 39 | target["masks"] = target["masks"].flip(-1) 40 | if "keypoints" in target: 41 | keypoints = target["keypoints"] 42 | keypoints = _flip_coco_person_keypoints(keypoints, width) 43 | target["keypoints"] = keypoints 44 | return image, target 45 | 46 | 47 | class ToTensor(object): 48 | def __call__(self, image, target): 49 | image = F.to_tensor(image) 50 | return image, target 51 | -------------------------------------------------------------------------------- /cfg.yaml: -------------------------------------------------------------------------------- 1 | N_PTS: 11 2 | DEVICE: 'cuda' 3 | BATCH_SIZE: 3 4 | TEST_BATCH_SIZE: 1 5 | LM_DIR: '' 6 | LMO_DIR: '' 7 | LM_SYNT_DIR: '' 8 | YCBV_DIR: '' 9 | OUTPUT_DIR: '' 10 | WORKERS: 4 11 | END_EPOCH: 200 12 | PRINT_FREQ: 100 13 | LR: 0.0002 14 | LR_DECAY: 0.2 15 | LR_STEPS: 16 | - 60 17 | - 120 18 | - 170 19 | 20 | MODEL: 21 | INIT_WEIGHTS: true 22 | PRETRAINED: '' 23 | TARGET_TYPE: gaussian 24 | IMAGE_SIZE: 25 | - 640 26 | - 480 27 | EXTRA: 28 | PRETRAINED_LAYERS: 29 | - 'conv1' 30 | - 'bn1' 31 | - 'conv2' 32 | - 'bn2' 33 | - 'layer1' 34 | - 'transition1' 35 | - 'stage2' 36 | - 'transition2' 37 | - 'stage3' 38 | - 'transition3' 39 | - 'stage4' 40 | FINAL_CONV_KERNEL: 1 41 | STAGE1: 42 | NUM_MODULES: 1 43 | NUM_RANCHES: 1 44 | BLOCK: BOTTLENECK 45 | NUM_BLOCKS: 46 | - 4 47 | NUM_CHANNELS: 48 | - 64 49 | FUSE_METHOD: SUM 50 | STAGE2: 51 | NUM_MODULES: 1 52 | NUM_BRANCHES: 2 53 | BLOCK: BASIC 54 | NUM_BLOCKS: 55 | - 4 56 | - 4 57 | NUM_CHANNELS: 58 | - 32 59 | - 64 60 | FUSE_METHOD: SUM 61 | STAGE3: 62 | NUM_MODULES: 4 63 | NUM_BRANCHES: 3 64 | BLOCK: BASIC 65 | NUM_BLOCKS: 66 | - 4 67 | - 4 68 | - 4 69 | NUM_CHANNELS: 70 | - 32 71 | - 64 72 | - 128 73 | FUSE_METHOD: SUM 74 | STAGE4: 75 | NUM_MODULES: 3 76 | NUM_BRANCHES: 4 77 | BLOCK: BASIC 78 | NUM_BLOCKS: 79 | - 4 80 | - 4 81 | - 4 82 | - 4 83 | NUM_CHANNELS: 84 | - 32 85 | - 64 86 | - 128 87 | - 256 88 | FUSE_METHOD: SUM 89 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Robust Object Pose Estimation (ROPE) 2 | This repo stores code used in the paper 3 | ## [Occlusion-Robust Object Pose Estimation with Holistic Representation](https://openaccess.thecvf.com/content/WACV2022/papers/Chen_Occlusion-Robust_Object_Pose_Estimation_With_Holistic_Representation_WACV_2022_paper.pdf) 4 | 5 | ![](ycbv_qua.jpg) 6 | 7 | ### Environment 8 | Our system environment is provided in environment.yaml for consideration. 9 | 10 | ### Datasets 11 | The Linemod (lm), Linemod-Occluded (lmo) and YCB-Video (ycbv) datasets can be downloaded from the [BOP](https://bop.felk.cvut.cz/datasets/) website. The paths to the datasets should then be specified in the cfg.yaml file. 12 | 13 | For better initialisation, the pretrained hrnet backbone file can be downloaded from [here](https://github.com/leoxiaobin/deep-high-resolution-net.pytorch). 14 | 15 | ### Usage examples 16 | To train for the lm test set in distrubted mode 17 | ````bash 18 | python -m torch.distributed.launch --nproc_per_node= --use_env main_lm.py --cfg cfg.yaml --obj duck --log_name 19 | ```` 20 | 21 | To train for the lmo test set in single GPU mode 22 | ````bash 23 | CUDA_VISIBLE_DEVICES= python main_lmo.py --cfg cfg.yaml --obj ape --log_name 24 | ```` 25 | 26 | To load trained model and test on the lmo dataset 27 | ````bash 28 | CUDA_VISIBLE_DEVICES= python main_lmo.py --cfg cfg.yaml --obj cat --log_name --resume --test-only 29 | ```` 30 | 31 | 32 | To train for the ycbv test set for object 01 33 | ````bash 34 | python -m torch.distributed.launch --nproc_per_node= --use_env main_ycbv.py --cfg cfg.yaml --obj 01 --log_name 35 | ```` 36 | 37 | To compute AUC for a ycbv test result for object 20 38 | ````bash 39 | python analysis.py --cfg cfg.yaml --log_name --obj 20 40 | ```` 41 | 42 | ### Cite this work 43 | ```` 44 | @inproceedings{chen2022occlusion, 45 | Author = {Chen, Bo and Chin, Tat-Jun and Klimavicius, Marius}, 46 | Title = {Occlusion-Robust Object Pose Estimation with Holistic Representation}, 47 | Booktitle = {WACV}, 48 | Year = {2022} 49 | } 50 | ```` 51 | 52 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .DS_Store 132 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Bo Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | 24 | BSD 3-Clause License 25 | 26 | Copyright (c) Soumith Chintala 2016, 27 | All rights reserved. 28 | 29 | Redistribution and use in source and binary forms, with or without 30 | modification, are permitted provided that the following conditions are met: 31 | 32 | * Redistributions of source code must retain the above copyright notice, this 33 | list of conditions and the following disclaimer. 34 | 35 | * Redistributions in binary form must reproduce the above copyright notice, 36 | this list of conditions and the following disclaimer in the documentation 37 | and/or other materials provided with the distribution. 38 | 39 | * Neither the name of the copyright holder nor the names of its 40 | contributors may be used to endorse or promote products derived from 41 | this software without specific prior written permission. 42 | 43 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 44 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 45 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 46 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 47 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 48 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 49 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 50 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 51 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 52 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /dataset/lmo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import os 4 | import numpy as np 5 | import fnmatch 6 | from PIL import Image 7 | from libs.utils import batch_project 8 | from scipy.io import loadmat, savemat 9 | from torch.utils.data import Dataset 10 | import argparse 11 | from yacs.config import CfgNode as CN 12 | import torchvision 13 | from torchvision.transforms import functional as F 14 | import json 15 | 16 | def get_objid(obj): 17 | obj_dict = {'ape':1, 'benchvise':2, 'cam':4, 'can':5, 'cat':6, 'driller':8, 'duck':9, 'eggbox':10, 'glue':11, 'holepuncher':12, 18 | 'iron':13, 'lamp':14, 'phone':15} 19 | return obj_dict[obj] 20 | 21 | def get_lmo_PM_gt_img_list(root, objid): 22 | PM_file = root + '/{:06d}/scene_gt.json'.format(2) 23 | with open(PM_file) as f: 24 | PMs = json.load(f) 25 | len_i = len(PMs) 26 | PM = torch.zeros(len_i, 3, 4) 27 | img_list = [] 28 | for idx in range(len_i): 29 | list_idx = PMs[str(idx)] 30 | objid_list = [temp['obj_id'] for temp in list_idx] 31 | if objid in objid_list: 32 | ttt = [ temp for temp in list_idx if temp['obj_id']==objid] 33 | R = torch.tensor(ttt[0]['cam_R_m2c']).view(1,3,3) 34 | T = 0.1*torch.tensor(ttt[0]['cam_t_m2c']).view(1,3,1) 35 | PM[idx,:,:] = torch.cat((R,T),dim=-1) 36 | img_list.append(idx) 37 | return PM, img_list 38 | 39 | def get_K(): 40 | fx = 572.41140 41 | fy = 573.57043 42 | u = 325.26110 43 | v = 242.04899 44 | K = torch.tensor( 45 | [[fx, 0, u], 46 | [0, fy, v], 47 | [0, 0, 1]], 48 | dtype=torch.float) 49 | return K 50 | 51 | class lmo(Dataset): 52 | def __init__(self, cfg): 53 | self.img_path = os.path.join(cfg.LMO_DIR,'000002/rgb') 54 | self.objid = get_objid(cfg.obj) 55 | self.pts3d = torch.tensor(loadmat('dataset/fps/lm/obj{:02d}_fps128.mat'.format(self.objid))['fps'])[:cfg.N_PTS,:] 56 | self.npts = cfg.N_PTS 57 | self.PMs, self.img_list = get_lmo_PM_gt_img_list(cfg.LMO_DIR, self.objid) 58 | self.cfg = cfg 59 | self.K = get_K() 60 | 61 | 62 | def __len__(self,): 63 | return 1214 # return the full set, then must create subset using self.img_list 64 | 65 | def __getitem__(self, idx): 66 | img = Image.open(os.path.join(self.img_path, '{:06d}.png'.format(idx))) 67 | PM = self.PMs[idx].view(1,3,4) 68 | pts2d = batch_project(PM,self.pts3d,self.K,angle_axis=False).squeeze() 69 | 70 | W,H = self.cfg.MODEL.IMAGE_SIZE 71 | xmin = pts2d[:,0].min()-5 72 | xmax = pts2d[:,0].max()+5 73 | ymin = pts2d[:,1].min()-5 74 | ymax = pts2d[:,1].max()+5 75 | 76 | num_objs = 1 77 | boxes = [xmin, ymin, xmax, ymax] 78 | boxes = torch.as_tensor(boxes, dtype=torch.float32).view(1,4) 79 | labels = torch.ones((num_objs,), dtype=torch.int64) 80 | vis = torch.ones(self.npts, 1) 81 | vis[pts2d[:,0]<0, 0] = 0 82 | vis[pts2d[:,0]>W, 0] = 0 83 | vis[pts2d[:,1]<0, 0] = 0 84 | vis[pts2d[:,1]>H, 0] = 0 85 | keypoints = torch.cat((pts2d, vis),dim=-1).view(self.npts, -1, 3) 86 | 87 | image_id = torch.tensor([idx]) 88 | area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) 89 | iscrowd = torch.zeros((num_objs,), dtype=torch.int64) 90 | 91 | target = {} 92 | target["boxes"] = boxes 93 | target["labels"] = labels 94 | target["image_id"] = image_id 95 | target["area"] = area 96 | target["iscrowd"] = iscrowd 97 | target["keypoints"] = keypoints 98 | target["PM"] = PM.squeeze() 99 | 100 | img = F.to_tensor(img) 101 | 102 | return img, target 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /reference/engine.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | import time 4 | import torch 5 | 6 | from libs.functions import Evaluator 7 | from . import utils 8 | 9 | 10 | def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, obj, logger=None): 11 | model.train() 12 | metric_logger = utils.MetricLogger(delimiter=" ") 13 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 14 | header = 'Obj: {} Epoch: [{}]'.format(obj, epoch) 15 | 16 | lr_scheduler = None 17 | if epoch == 0: 18 | warmup_factor = 1. / 1000 19 | warmup_iters = min(1000, len(data_loader) - 1) 20 | 21 | lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor) 22 | 23 | for images1, images2, targets1, targets2 in metric_logger.log_every(data_loader, print_freq, header, logger): 24 | 25 | images = list(image.to(device) for image in images1) 26 | if images2[0] is not None: ########################### temp ################################################ 27 | images.extend(list(image.to(device) for image in images2)) 28 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets1] 29 | if targets2[0] is not None: ########################## temp ################################################# 30 | targets.extend([{k: v.to(device) for k, v in t.items()} for t in targets2]) 31 | 32 | loss_dict = model(images, targets) 33 | losses = sum(loss for loss in loss_dict.values()) 34 | 35 | # reduce losses over all GPUs for logging purposes 36 | loss_dict_reduced = utils.reduce_dict(loss_dict) 37 | losses_reduced = sum(loss for loss in loss_dict_reduced.values()) 38 | 39 | loss_value = losses_reduced.item() 40 | 41 | if not math.isfinite(loss_value): 42 | print("Loss is {}, stopping training".format(loss_value)) 43 | print(loss_dict_reduced) 44 | sys.exit(1) 45 | 46 | optimizer.zero_grad() 47 | losses.backward() 48 | optimizer.step() 49 | 50 | if lr_scheduler is not None: 51 | lr_scheduler.step() 52 | 53 | metric_logger.update(loss=losses_reduced, **loss_dict_reduced) 54 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 55 | 56 | return metric_logger 57 | 58 | 59 | def _get_iou_types(model): 60 | model_without_ddp = model 61 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 62 | model_without_ddp = model.module 63 | iou_types = ["bbox"] 64 | if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN): 65 | iou_types.append("segm") 66 | if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN): 67 | iou_types.append("keypoints") 68 | return iou_types 69 | 70 | 71 | @torch.no_grad() 72 | def evaluate(model, data_loader, device, logger=None): 73 | n_threads = torch.get_num_threads() 74 | # FIXME remove this and make paste_masks_in_image run on the GPU 75 | torch.set_num_threads(1) 76 | cpu_device = torch.device("cpu") 77 | model.eval() 78 | metric_logger = utils.MetricLogger(delimiter=" ") 79 | evaluator = Evaluator() 80 | header = 'Test:' 81 | 82 | for images, targets in metric_logger.log_every(data_loader, 100, header, logger): 83 | images = list(img.to(device) for img in images) 84 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets] 85 | 86 | torch.cuda.synchronize() 87 | model_time = time.time() 88 | outputs = model(images) 89 | 90 | outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs] 91 | model_time = time.time() - model_time 92 | 93 | res = {target["image_id"].item(): output for target, output in zip(targets, outputs)} 94 | evaluator_time = time.time() 95 | evaluator.update(res) 96 | evaluator_time = time.time() - evaluator_time 97 | metric_logger.update(model_time=model_time, evaluator_time=evaluator_time) 98 | 99 | # gather the stats from all processes 100 | metric_logger.synchronize_between_processes() 101 | print("Averaged stats:", metric_logger) 102 | evaluator.gather_all() 103 | return evaluator 104 | 105 | -------------------------------------------------------------------------------- /detection/generalized_rcnn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | """ 3 | Implements the Generalized R-CNN framework 4 | """ 5 | 6 | from collections import OrderedDict 7 | import torch 8 | from torch import nn 9 | import warnings 10 | from torch.jit.annotations import Tuple, List, Dict, Optional 11 | from torch import Tensor 12 | import dsntnn 13 | 14 | 15 | class GeneralizedRCNN(nn.Module): 16 | """ 17 | Main class for Generalized R-CNN. 18 | 19 | Arguments: 20 | backbone (nn.Module): 21 | rpn (nn.Module): 22 | roi_heads (nn.Module): takes the features + the proposals from the RPN and computes 23 | detections / masks from it. 24 | transform (nn.Module): performs the data transformation from the inputs to feed into 25 | the model 26 | """ 27 | 28 | def __init__(self, backbone, rpn, roi_heads, transform): 29 | super(GeneralizedRCNN, self).__init__() 30 | self.transform = transform 31 | self.backbone = backbone 32 | self.rpn = rpn 33 | self.roi_heads = roi_heads 34 | # used only on torchscript mode 35 | self._has_warned = False 36 | 37 | @torch.jit.unused 38 | def eager_outputs(self, losses, detections): 39 | # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] 40 | if self.training: 41 | return losses 42 | 43 | return detections 44 | 45 | def forward(self, images, targets=None): 46 | # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] 47 | """ 48 | Arguments: 49 | images (list[Tensor]): images to be processed 50 | targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional) 51 | 52 | Returns: 53 | result (list[BoxList] or dict[Tensor]): the output from the model. 54 | During training, it returns a dict[Tensor] which contains the losses. 55 | During testing, it returns list[BoxList] contains additional fields 56 | like `scores`, `labels` and `mask` (for Mask R-CNN models). 57 | 58 | """ 59 | if self.training and targets is None: 60 | raise ValueError("In training mode, targets should be passed") 61 | if self.training: 62 | assert targets is not None 63 | for target in targets: 64 | boxes = target["boxes"] 65 | if isinstance(boxes, torch.Tensor): 66 | if len(boxes.shape) != 2 or boxes.shape[-1] != 4: 67 | raise ValueError("Expected target boxes to be a tensor" 68 | "of shape [N, 4], got {:}.".format( 69 | boxes.shape)) 70 | else: 71 | raise ValueError("Expected target boxes to be of type " 72 | "Tensor, got {:}.".format(type(boxes))) 73 | 74 | original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], []) 75 | for img in images: 76 | val = img.shape[-2:] 77 | assert len(val) == 2 78 | original_image_sizes.append((val[0], val[1])) 79 | 80 | images, targets = self.transform(images, targets) 81 | 82 | # Check for degenerate boxes 83 | # TODO: Move this to a function 84 | if targets is not None: 85 | for target_idx, target in enumerate(targets): 86 | boxes = target["boxes"] 87 | degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] 88 | if degenerate_boxes.any(): 89 | # print the first degenrate box 90 | bb_idx = degenerate_boxes.any(dim=1).nonzero().view(-1)[0] 91 | degen_bb: List[float] = boxes[bb_idx].tolist() 92 | raise ValueError("All bounding boxes should have positive height and width." 93 | " Found invaid box {} for target at index {}." 94 | .format(degen_bb, target_idx)) 95 | 96 | features = self.backbone(images.tensors) 97 | 98 | losses = {} 99 | if isinstance(features, torch.Tensor): 100 | features = OrderedDict([('0', features)]) 101 | proposals, proposal_losses = self.rpn(images, features, targets) 102 | detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets) 103 | detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) 104 | 105 | 106 | losses.update(detector_losses) 107 | losses.update(proposal_losses) 108 | 109 | if torch.jit.is_scripting(): 110 | if not self._has_warned: 111 | warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting") 112 | self._has_warned = True 113 | return (losses, detections) 114 | else: 115 | return self.eager_outputs(losses, detections) 116 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: rope 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=1_llvm 10 | - absl-py=0.13.0=pyhd8ed1ab_0 11 | - av=8.0.2=py38h497d03b_4 12 | - blas=1.0=mkl 13 | - blosc=1.21.0=h9c3ff4c_0 14 | - brotli=1.0.9=h7f98852_5 15 | - brotli-bin=1.0.9=h7f98852_5 16 | - brunsli=0.1=h9c3ff4c_0 17 | - bzip2=1.0.8=h7f98852_4 18 | - c-ares=1.17.2=h7f98852_0 19 | - ca-certificates=2021.5.30=ha878542_0 20 | - cached-property=1.5.2=hd8ed1ab_1 21 | - cached_property=1.5.2=pyha770c72_1 22 | - cairo=1.16.0=hcf35c78_1003 23 | - ceres-solver=2.0.0=h5605472_1 24 | - certifi=2021.5.30=py38h578d9bd_0 25 | - cfitsio=3.470=hb418390_7 26 | - charls=2.2.0=h9c3ff4c_0 27 | - cloudpickle=1.6.0=py_0 28 | - cudatoolkit=10.1.243=h6bb024c_0 29 | - cycler=0.10.0=py_2 30 | - cython=0.29.24=py38h709712a_0 31 | - cytoolz=0.11.0=py38h497a2fe_3 32 | - dask-core=2021.8.1=pyhd8ed1ab_0 33 | - dataclasses=0.8=pyhc8e2a94_3 34 | - dbus=1.13.6=he372182_0 35 | - decorator=4.4.2=py_0 36 | - easydict=1.9=py_0 37 | - eigen=3.4.0=h4bd325d_0 38 | - expat=2.2.10=h9c3ff4c_0 39 | - ffmpeg=4.3.2=hca11adc_0 40 | - fontconfig=2.13.1=hba837de_1005 41 | - freetype=2.10.4=h0708190_1 42 | - fsspec=2021.7.0=pyhd8ed1ab_0 43 | - geos=3.9.1=h9c3ff4c_2 44 | - gettext=0.19.8.1=hf34092f_1004 45 | - gflags=2.2.2=he1b5a44_1004 46 | - giflib=5.2.1=h36c2ea0_2 47 | - glib=2.65.0=h6f030ca_0 48 | - glog=0.4.0=h49b9bf7_3 49 | - gmp=6.2.1=h58526e2_0 50 | - gnutls=3.6.13=h85f3911_1 51 | - graphite2=1.3.13=h58526e2_1001 52 | - grpcio=1.38.1=py38hdd6454d_0 53 | - gst-plugins-base=1.14.5=h0935bb2_2 54 | - gstreamer=1.14.5=h36ae1b5_2 55 | - h5py=3.1.0=nompi_py38hafa665b_100 56 | - harfbuzz=2.4.0=h9f30f68_3 57 | - hdf5=1.10.6=nompi_h6a2412b_1114 58 | - icu=64.2=he1b5a44_1 59 | - ilmbase=2.5.2=h8b12597_0 60 | - imagecodecs=2021.7.30=py38hb5ce8f7_0 61 | - imageio=2.9.0=py_0 62 | - imgaug=0.4.0=py_1 63 | - importlib-metadata=4.8.1=py38h578d9bd_0 64 | - jasper=1.900.1=h07fcdf6_1006 65 | - joblib=1.0.1=pyhd8ed1ab_0 66 | - jpeg=9d=h36c2ea0_0 67 | - json_tricks=3.15.5=pyhd8ed1ab_0 68 | - jxrlib=1.1=h7f98852_2 69 | - kiwisolver=1.3.2=py38h1fd1430_0 70 | - kornia=0.5.8=pyhd8ed1ab_0 71 | - krb5=1.19.2=hcc1bbae_0 72 | - lame=3.100=h7f98852_1001 73 | - lcms2=2.12=hddcbb42_0 74 | - ld_impl_linux-64=2.36.1=hea4e1c9_2 75 | - lerc=2.2.1=h9c3ff4c_0 76 | - libaec=1.0.5=h9c3ff4c_0 77 | - libblas=3.8.0=16_mkl 78 | - libbrotlicommon=1.0.9=h7f98852_5 79 | - libbrotlidec=1.0.9=h7f98852_5 80 | - libbrotlienc=1.0.9=h7f98852_5 81 | - libcblas=3.8.0=16_mkl 82 | - libclang=9.0.1=default_hde54327_0 83 | - libcurl=7.78.0=h2574ce0_0 84 | - libdeflate=1.8=h7f98852_0 85 | - libedit=3.1.20191231=he28a2e2_2 86 | - libev=4.33=h516909a_1 87 | - libevent=2.1.10=hcdb4288_3 88 | - libffi=3.2.1=he1b5a44_1007 89 | - libgcc-ng=11.1.0=hc902ee8_8 90 | - libgfortran-ng=11.1.0=h69a702a_8 91 | - libgfortran5=11.1.0=h6c583b3_8 92 | - libiconv=1.15=h516909a_1006 93 | - liblapack=3.8.0=16_mkl 94 | - liblapacke=3.8.0=16_mkl 95 | - libllvm9=9.0.1=hf817b99_2 96 | - libnghttp2=1.43.0=h812cca2_0 97 | - libopencv=4.4.0=py38_2 98 | - libpng=1.6.37=h21135ba_2 99 | - libprotobuf=3.17.2=h780b84a_1 100 | - libssh2=1.10.0=ha56f1ee_0 101 | - libstdcxx-ng=11.1.0=h56837e0_8 102 | - libtiff=4.3.0=hf544144_0 103 | - libuuid=2.32.1=h7f98852_1000 104 | - libwebp-base=1.2.1=h7f98852_0 105 | - libxcb=1.13=h7f98852_1003 106 | - libxkbcommon=0.10.0=he1b5a44_0 107 | - libxml2=2.9.10=hee79883_0 108 | - libzopfli=1.0.3=h9c3ff4c_0 109 | - llvm-openmp=12.0.1=h4bd325d_1 110 | - lmdb=0.9.24=h516909a_0 111 | - locket=0.2.0=py_2 112 | - lz4-c=1.9.3=h9c3ff4c_1 113 | - markdown=3.3.4=pyhd8ed1ab_0 114 | - matplotlib-base=3.4.3=py38hf4fb855_0 115 | - metis=5.1.0=h58526e2_1006 116 | - mkl=2020.4=h726a3e6_304 117 | - mpfr=4.1.0=h9202a9a_1 118 | - ncurses=6.2=h58526e2_4 119 | - nettle=3.6=he412f7d_0 120 | - networkx=2.6.2=pyhd8ed1ab_0 121 | - ninja=1.10.2=h4bd325d_0 122 | - nspr=4.30=h9c3ff4c_0 123 | - nss=3.69=hb5efdd6_0 124 | - numpy=1.21.2=py38he2449b9_0 125 | - olefile=0.46=pyh9f0ad1d_1 126 | - opencv=4.4.0=py38_2 127 | - openexr=2.5.2=he513fc3_0 128 | - openh264=2.1.1=h780b84a_0 129 | - openjpeg=2.4.0=hb52868f_1 130 | - openssl=1.1.1k=h7f98852_1 131 | - packaging=21.0=pyhd8ed1ab_0 132 | - pandas=1.3.2=py38h43a58ef_0 133 | - partd=1.2.0=pyhd8ed1ab_0 134 | - pcre=8.45=h9c3ff4c_0 135 | - pillow=8.3.1=py38h8e6f84c_0 136 | - pip=21.2.4=pyhd8ed1ab_0 137 | - pixman=0.38.0=h516909a_1003 138 | - plyfile=0.7.2=pyh9f0ad1d_0 139 | - pthread-stubs=0.4=h36c2ea0_1001 140 | - py-opencv=4.4.0=py38h23f93f0_2 141 | - pycocotools=2.0.1=py38h1e0a361_1 142 | - pyparsing=2.4.7=pyh9f0ad1d_0 143 | - python=3.8.5=h4d41432_2_cpython 144 | - python-dateutil=2.8.2=pyhd8ed1ab_0 145 | - python-lmdb=0.96=py38he1b5a44_0 146 | - python_abi=3.8=2_cp38 147 | - pytorch=1.6.0=py3.8_cuda10.1.243_cudnn7.6.3_0 148 | - pywavelets=1.1.1=py38h5c078b8_3 149 | - pyyaml=5.4.1=py38h497a2fe_1 150 | - qt=5.12.5=hd8c4c69_1 151 | - readline=8.1=h46c0cb4_0 152 | - rowan=1.3.0.post1=pyh9f0ad1d_0 153 | - scikit-image=0.17.2=py38h51da96c_4 154 | - scikit-learn=0.24.2=py38h1561384_1 155 | - scipy=1.7.1=py38h56a6a73_0 156 | - setuptools=57.4.0=py38h578d9bd_0 157 | - shapely=1.7.1=py38hb7fe4a8_5 158 | - six=1.16.0=pyh6c4a22f_0 159 | - sklearn-contrib-lightning=0.5.0=py38hb3f55d8_1 160 | - snappy=1.1.8=he1b5a44_3 161 | - sqlite=3.36.0=h9cd32fc_0 162 | - suitesparse=5.10.1=hd8046ac_0 163 | - tbb=2020.2=h4bd325d_4 164 | - tensorboard=1.15.0=py38_0 165 | - tensorboardx=2.1=py_0 166 | - threadpoolctl=2.2.0=pyh8a188c0_0 167 | - tifffile=2021.8.8=pyhd8ed1ab_0 168 | - tk=8.6.11=h21135ba_0 169 | - tmux=3.1=ha1ba12b_0 170 | - toolz=0.11.1=py_0 171 | - torchvision=0.7.0=py38_cu101 172 | - tornado=6.1=py38h497a2fe_1 173 | - tqdm=4.48.2=pyh9f0ad1d_0 174 | - transforms3d=0.3.1=py_0 175 | - werkzeug=2.0.1=pyhd8ed1ab_0 176 | - wheel=0.37.0=pyhd8ed1ab_1 177 | - x264=1!161.3030=h7f98852_1 178 | - xorg-kbproto=1.0.7=h7f98852_1002 179 | - xorg-libice=1.0.10=h7f98852_0 180 | - xorg-libsm=1.2.3=hd9c2040_1000 181 | - xorg-libx11=1.7.2=h7f98852_0 182 | - xorg-libxau=1.0.9=h7f98852_0 183 | - xorg-libxdmcp=1.1.3=h7f98852_0 184 | - xorg-libxext=1.3.4=h7f98852_1 185 | - xorg-libxrender=0.9.10=h7f98852_1003 186 | - xorg-renderproto=0.11.1=h7f98852_1002 187 | - xorg-xextproto=7.3.0=h7f98852_1002 188 | - xorg-xproto=7.0.31=h7f98852_1007 189 | - xz=5.2.5=h516909a_1 190 | - yacs=0.1.6=py_0 191 | - yaml=0.2.5=h516909a_0 192 | - zfp=0.5.5=h9c3ff4c_5 193 | - zipp=3.5.0=pyhd8ed1ab_0 194 | - zlib=1.2.11=h516909a_1010 195 | - zstd=1.5.0=ha95c52a_0 196 | - pip: 197 | - cachetools==4.1.1 198 | - chardet==3.0.4 199 | - charset-normalizer==2.0.4 200 | - defusedxml==0.7.1 201 | - dgx-authenticator==0.0.3+feature.qol.fixes.255c11ed 202 | - dsntnn==0.5.3 203 | - flatbuffers==2.0 204 | - future==0.18.2 205 | - google-api-core==1.22.4 206 | - google-api-python-client==1.8.2 207 | - google-auth==1.18.0 208 | - google-auth-httplib2==0.0.3 209 | - google-auth-oauthlib==0.4.1 210 | - googleapis-common-protos==1.52.0 211 | - httplib2==0.18.1 212 | - idna==2.10 213 | - imageio-ffmpeg==0.4.2 214 | - imath==0.0.1 215 | - importlib-resources==2.0.1 216 | - kaggle==1.5.12 217 | - oauthlib==3.1.0 218 | - onnx==1.10.1 219 | - onnxruntime-gpu==1.8.1 220 | - protobuf==3.13.0 221 | - pyasn1==0.4.8 222 | - pyasn1-modules==0.2.8 223 | - python-slugify==5.0.2 224 | - pytz==2020.1 225 | - requests==2.26.0 226 | - requests-oauthlib==1.3.0 227 | - rsa==4.6 228 | - text-unidecode==1.3 229 | - torchdiffeq==0.2.1 230 | - torchgeometry==0.1.2 231 | - typing-extensions==3.10.0.0 232 | - uritemplate==3.0.1 233 | - urllib3==1.25.10 234 | prefix: 235 | -------------------------------------------------------------------------------- /reference/group_by_aspect_ratio.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | from collections import defaultdict 3 | import copy 4 | from itertools import repeat, chain 5 | import math 6 | import numpy as np 7 | 8 | import torch 9 | import torch.utils.data 10 | from torch.utils.data.sampler import BatchSampler, Sampler 11 | from torch.utils.model_zoo import tqdm 12 | import torchvision 13 | 14 | from PIL import Image 15 | 16 | 17 | def _repeat_to_at_least(iterable, n): 18 | repeat_times = math.ceil(n / len(iterable)) 19 | repeated = chain.from_iterable(repeat(iterable, repeat_times)) 20 | return list(repeated) 21 | 22 | 23 | class GroupedBatchSampler(BatchSampler): 24 | """ 25 | Wraps another sampler to yield a mini-batch of indices. 26 | It enforces that the batch only contain elements from the same group. 27 | It also tries to provide mini-batches which follows an ordering which is 28 | as close as possible to the ordering from the original sampler. 29 | Arguments: 30 | sampler (Sampler): Base sampler. 31 | group_ids (list[int]): If the sampler produces indices in range [0, N), 32 | `group_ids` must be a list of `N` ints which contains the group id of each sample. 33 | The group ids must be a continuous set of integers starting from 34 | 0, i.e. they must be in the range [0, num_groups). 35 | batch_size (int): Size of mini-batch. 36 | """ 37 | def __init__(self, sampler, group_ids, batch_size): 38 | if not isinstance(sampler, Sampler): 39 | raise ValueError( 40 | "sampler should be an instance of " 41 | "torch.utils.data.Sampler, but got sampler={}".format(sampler) 42 | ) 43 | self.sampler = sampler 44 | self.group_ids = group_ids 45 | self.batch_size = batch_size 46 | 47 | def __iter__(self): 48 | buffer_per_group = defaultdict(list) 49 | samples_per_group = defaultdict(list) 50 | 51 | num_batches = 0 52 | for idx in self.sampler: 53 | group_id = self.group_ids[idx] 54 | buffer_per_group[group_id].append(idx) 55 | samples_per_group[group_id].append(idx) 56 | if len(buffer_per_group[group_id]) == self.batch_size: 57 | yield buffer_per_group[group_id] 58 | num_batches += 1 59 | del buffer_per_group[group_id] 60 | assert len(buffer_per_group[group_id]) < self.batch_size 61 | 62 | # now we have run out of elements that satisfy 63 | # the group criteria, let's return the remaining 64 | # elements so that the size of the sampler is 65 | # deterministic 66 | expected_num_batches = len(self) 67 | num_remaining = expected_num_batches - num_batches 68 | if num_remaining > 0: 69 | # for the remaining batches, take first the buffers with largest number 70 | # of elements 71 | for group_id, _ in sorted(buffer_per_group.items(), 72 | key=lambda x: len(x[1]), reverse=True): 73 | remaining = self.batch_size - len(buffer_per_group[group_id]) 74 | samples_from_group_id = _repeat_to_at_least(samples_per_group[group_id], remaining) 75 | buffer_per_group[group_id].extend(samples_from_group_id[:remaining]) 76 | assert len(buffer_per_group[group_id]) == self.batch_size 77 | yield buffer_per_group[group_id] 78 | num_remaining -= 1 79 | if num_remaining == 0: 80 | break 81 | assert num_remaining == 0 82 | 83 | def __len__(self): 84 | return len(self.sampler) // self.batch_size 85 | 86 | 87 | def _compute_aspect_ratios_slow(dataset, indices=None): 88 | print("Your dataset doesn't support the fast path for " 89 | "computing the aspect ratios, so will iterate over " 90 | "the full dataset and load every image instead. " 91 | "This might take some time...") 92 | if indices is None: 93 | indices = range(len(dataset)) 94 | 95 | class SubsetSampler(Sampler): 96 | def __init__(self, indices): 97 | self.indices = indices 98 | 99 | def __iter__(self): 100 | return iter(self.indices) 101 | 102 | def __len__(self): 103 | return len(self.indices) 104 | 105 | sampler = SubsetSampler(indices) 106 | data_loader = torch.utils.data.DataLoader( 107 | dataset, batch_size=1, sampler=sampler, 108 | num_workers=14, # you might want to increase it for faster processing 109 | collate_fn=lambda x: x[0]) 110 | aspect_ratios = [] 111 | with tqdm(total=len(dataset)) as pbar: 112 | # for _i, (img, _) in enumerate(data_loader): 113 | for _i, (img, _, _, _) in enumerate(data_loader): 114 | pbar.update(1) 115 | height, width = img.shape[-2:] 116 | aspect_ratio = float(width) / float(height) 117 | aspect_ratios.append(aspect_ratio) 118 | return aspect_ratios 119 | 120 | 121 | def _compute_aspect_ratios_custom_dataset(dataset, indices=None): 122 | if indices is None: 123 | indices = range(len(dataset)) 124 | aspect_ratios = [] 125 | for i in indices: 126 | height, width = dataset.get_height_and_width(i) 127 | aspect_ratio = float(width) / float(height) 128 | aspect_ratios.append(aspect_ratio) 129 | return aspect_ratios 130 | 131 | 132 | def _compute_aspect_ratios_coco_dataset(dataset, indices=None): 133 | if indices is None: 134 | indices = range(len(dataset)) 135 | aspect_ratios = [] 136 | for i in indices: 137 | img_info = dataset.coco.imgs[dataset.ids[i]] 138 | aspect_ratio = float(img_info["width"]) / float(img_info["height"]) 139 | aspect_ratios.append(aspect_ratio) 140 | return aspect_ratios 141 | 142 | 143 | def _compute_aspect_ratios_voc_dataset(dataset, indices=None): 144 | if indices is None: 145 | indices = range(len(dataset)) 146 | aspect_ratios = [] 147 | for i in indices: 148 | # this doesn't load the data into memory, because PIL loads it lazily 149 | width, height = Image.open(dataset.images[i]).size 150 | aspect_ratio = float(width) / float(height) 151 | aspect_ratios.append(aspect_ratio) 152 | return aspect_ratios 153 | 154 | 155 | def _compute_aspect_ratios_subset_dataset(dataset, indices=None): 156 | if indices is None: 157 | indices = range(len(dataset)) 158 | 159 | ds_indices = [dataset.indices[i] for i in indices] 160 | return compute_aspect_ratios(dataset.dataset, ds_indices) 161 | 162 | 163 | def compute_aspect_ratios(dataset, indices=None): 164 | if hasattr(dataset, "get_height_and_width"): 165 | return _compute_aspect_ratios_custom_dataset(dataset, indices) 166 | 167 | if isinstance(dataset, torchvision.datasets.CocoDetection): 168 | return _compute_aspect_ratios_coco_dataset(dataset, indices) 169 | 170 | if isinstance(dataset, torchvision.datasets.VOCDetection): 171 | return _compute_aspect_ratios_voc_dataset(dataset, indices) 172 | 173 | if isinstance(dataset, torch.utils.data.Subset): 174 | return _compute_aspect_ratios_subset_dataset(dataset, indices) 175 | 176 | # slow path 177 | return _compute_aspect_ratios_slow(dataset, indices) 178 | 179 | 180 | def _quantize(x, bins): 181 | bins = copy.deepcopy(bins) 182 | bins = sorted(bins) 183 | quantized = list(map(lambda y: bisect.bisect_right(bins, y), x)) 184 | return quantized 185 | 186 | 187 | def create_aspect_ratio_groups(dataset, k=0): 188 | aspect_ratios = compute_aspect_ratios(dataset) 189 | bins = (2 ** np.linspace(-1, 1, 2 * k + 1)).tolist() if k > 0 else [1.0] 190 | groups = _quantize(aspect_ratios, bins) 191 | # count number of elements per group 192 | counts = np.unique(groups, return_counts=True)[1] 193 | fbins = [0] + bins + [np.inf] 194 | print("Using {} as bins for aspect ratio quantization".format(fbins)) 195 | print("Count of instances per bin: {}".format(counts)) 196 | return groups 197 | -------------------------------------------------------------------------------- /main_lmo.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import time 4 | from yacs.config import CfgNode as CN 5 | from scipy.io import savemat 6 | 7 | import torch 8 | import torch.utils.data 9 | from torch import nn 10 | import torchvision 11 | import torchvision.models.detection 12 | import torchvision.models.detection.mask_rcnn 13 | 14 | from reference.group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups 15 | from reference.engine import train_one_epoch, evaluate 16 | from reference import utils 17 | from dataset.lm import lm_with_synt 18 | from dataset.lmo import lmo 19 | from detection.keypoint_rcnn import keypointrcnn_hrnet 20 | from libs.utils import get_logger 21 | 22 | def main(args, cfg): 23 | utils.init_distributed_mode(args) 24 | logger = get_logger(cfg) 25 | device = torch.device(cfg.DEVICE) 26 | 27 | # Data loading code 28 | print("Loading data") 29 | 30 | dataset = lm_with_synt(cfg) 31 | dataset_test_full = lmo(cfg) 32 | valid_list = dataset_test_full.img_list 33 | dataset_test = torch.utils.data.Subset(dataset_test_full, valid_list) 34 | 35 | print("Creating data loaders. Is distributed? ", args.distributed) 36 | if args.distributed: 37 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) 38 | test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False) 39 | else: 40 | train_sampler = torch.utils.data.RandomSampler(dataset) 41 | test_sampler = torch.utils.data.SequentialSampler(dataset_test) 42 | 43 | if args.aspect_ratio_group_factor >= 0: 44 | group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor) 45 | train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, cfg.BATCH_SIZE) 46 | else: 47 | train_batch_sampler = torch.utils.data.BatchSampler( 48 | train_sampler, cfg.BATCH_SIZE, drop_last=True) 49 | 50 | data_loader = torch.utils.data.DataLoader( 51 | dataset, batch_sampler=train_batch_sampler, num_workers=cfg.WORKERS, 52 | collate_fn=utils.collate_fn) 53 | 54 | data_loader_test = torch.utils.data.DataLoader( 55 | dataset_test, batch_size=cfg.TEST_BATCH_SIZE, 56 | sampler=test_sampler, num_workers=cfg.WORKERS, 57 | collate_fn=utils.collate_fn) 58 | 59 | print("Creating model") 60 | model = keypointrcnn_hrnet(cfg, resume=args.resume, min_size=480, max_size=640) 61 | model.to(device) 62 | 63 | model_without_ddp = model 64 | if args.distributed: 65 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 66 | model_without_ddp = model.module 67 | 68 | params = [p for p in model.parameters() if p.requires_grad] 69 | optimizer = torch.optim.Adam(params, lr=cfg.LR) 70 | 71 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg.LR_STEPS, gamma=cfg.LR_DECAY) 72 | 73 | if args.resume: 74 | checkpoint = torch.load(os.path.join(cfg.OUTPUT_DIR,cfg.obj,'{}.pth'.format(cfg.log_name)), map_location='cpu') 75 | model_without_ddp.load_state_dict(checkpoint['model']) 76 | optimizer.load_state_dict(checkpoint['optimizer']) 77 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 78 | args.start_epoch = checkpoint['epoch'] + 1 79 | 80 | if args.test_only: 81 | evaluator = evaluate(model, data_loader_test, device=device, logger=logger) 82 | 83 | boxes, pose_record1, pose_record2, pose_record3, pose_recordc, pts2d_record1, pts2d_record2, pts2d_record3, \ 84 | corrects1, corrects2, corrects3, correctsc \ 85 | = evaluator.get_accuracy(cfg, args.start_epoch-1, n_test=len(valid_list), testset_name='lmo', n_min=4, thres=1, logger=logger) 86 | savemat(os.path.join(cfg.OUTPUT_DIR, cfg.obj, '{}_result.mat'.format(cfg.log_name)), 87 | {'boxes':boxes.cpu().numpy(), 'pose_record1': pose_record1.detach().cpu().numpy(), 88 | 'pose_record2': pose_record2.detach().cpu().numpy(), 'pose_record3': pose_record3.detach().cpu().numpy(), 89 | 'pose_recordc': pose_recordc.detach().cpu().numpy(), 'pts2d_record1': pts2d_record1.detach().cpu().numpy(), 90 | 'pts2d_record2': pts2d_record2.detach().cpu().numpy(), 'pts2d_record3': pts2d_record3.detach().cpu().numpy(), 91 | 'corrects1':corrects1.detach().cpu().numpy(), 'corrects2':corrects2.detach().cpu().numpy(), 92 | 'corrects3':corrects3.detach().cpu().numpy(), 'correctsc':correctsc.detach().cpu().numpy(), 'test_idx': valid_list}) 93 | return 94 | 95 | print("Start training") 96 | start_time = time.time() 97 | for epoch in range(args.start_epoch, cfg.END_EPOCH): 98 | if args.distributed: 99 | train_sampler.set_epoch(epoch) 100 | 101 | train_one_epoch(model, optimizer, data_loader, device, epoch, cfg.PRINT_FREQ, cfg.obj, logger) 102 | lr_scheduler.step() 103 | if cfg.OUTPUT_DIR: 104 | utils.save_on_master({ 105 | 'model': model_without_ddp.state_dict(), 106 | 'optimizer': optimizer.state_dict(), 107 | 'lr_scheduler': lr_scheduler.state_dict(), 108 | 'args': args, 109 | 'cfg': cfg, 110 | 'epoch': epoch}, 111 | os.path.join(cfg.OUTPUT_DIR, cfg.obj, '{}.pth'.format(cfg.log_name))) 112 | 113 | if epoch==cfg.END_EPOCH-1: 114 | evaluator = evaluate(model, data_loader_test, device=device, logger=logger) 115 | boxes, pose_record1, pose_record2, pose_record3, pose_recordc, pts2d_record1, pts2d_record2, pts2d_record3, \ 116 | corrects1, corrects2, corrects3, correctsc \ 117 | = evaluator.get_accuracy(cfg, epoch, n_test=len(valid_list), testset_name='lmo', n_min=4, thres=1, logger=logger) 118 | savemat(os.path.join(cfg.OUTPUT_DIR, cfg.obj, '{}_result.mat'.format(cfg.log_name)), 119 | {'boxes':boxes.cpu().numpy(), 'pose_record1': pose_record1.detach().cpu().numpy(), 120 | 'pose_record2': pose_record2.detach().cpu().numpy(), 'pose_record3': pose_record3.detach().cpu().numpy(), 121 | 'pose_recordc': pose_recordc.detach().cpu().numpy(), 'pts2d_record1': pts2d_record1.detach().cpu().numpy(), 122 | 'pts2d_record2': pts2d_record2.detach().cpu().numpy(), 'pts2d_record3': pts2d_record3.detach().cpu().numpy(), 123 | 'corrects1':corrects1.detach().cpu().numpy(), 'corrects2':corrects2.detach().cpu().numpy(), 124 | 'corrects3':corrects3.detach().cpu().numpy(), 'correctsc':correctsc.detach().cpu().numpy(), 'test_idx': valid_list}) 125 | 126 | total_time = time.time() - start_time 127 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 128 | print('Training time {}'.format(total_time_str)) 129 | 130 | if __name__ == "__main__": 131 | import argparse 132 | parser = argparse.ArgumentParser( 133 | description=__doc__) 134 | 135 | parser.add_argument('--resume', dest="resume",action="store_true") 136 | parser.add_argument('--start_epoch', default=0, type=int, help='start epoch') 137 | parser.add_argument('--aspect-ratio-group-factor', default=-1, type=int) 138 | parser.add_argument("--test-only",dest="test_only",help="Only test the model",action="store_true",) 139 | parser.add_argument('--world-size', default=1, type=int, help='number of distributed processes') 140 | parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') 141 | parser.add_argument('--cfg', help='experiment configure file name', required=True, type=str) 142 | parser.add_argument('--obj', required=True, type=str) 143 | parser.add_argument('--sigma1', default=1.5, required=False, type=float) 144 | parser.add_argument('--sigma2', default=3, required=False, type=float) 145 | parser.add_argument('--sigma3', default=8, required=False, type=float) 146 | parser.add_argument('--log_name', required=True, type=str) 147 | parser.add_argument('--distrib', default=1, type=int) 148 | args = parser.parse_args() 149 | cfg = CN(new_allowed=True) 150 | cfg.defrost() 151 | cfg.merge_from_file(args.cfg) 152 | cfg.obj = args.obj 153 | cfg.log_name = args.log_name 154 | cfg.sigma1 = args.sigma1 155 | cfg.sigma2 = args.sigma2 156 | cfg.sigma3 = args.sigma3 157 | cfg.freeze() 158 | 159 | main(args, cfg) 160 | 161 | 162 | 163 | 164 | 165 | -------------------------------------------------------------------------------- /main_lm.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import time 4 | from yacs.config import CfgNode as CN 5 | from scipy.io import savemat 6 | 7 | import torch 8 | import torch.utils.data 9 | from torch import nn 10 | import torchvision 11 | import torchvision.models.detection 12 | import torchvision.models.detection.mask_rcnn 13 | 14 | from reference.group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups 15 | from reference.engine import train_one_epoch, evaluate 16 | from reference import utils 17 | from dataset.lm import lm_with_synt, lm 18 | from detection.keypoint_rcnn import keypointrcnn_hrnet 19 | from libs.utils import get_logger, get_lm_img_idx 20 | 21 | def main(args, cfg): 22 | utils.init_distributed_mode(args) 23 | logger = get_logger(cfg) 24 | device = torch.device(cfg.DEVICE) 25 | 26 | print("Loading data") 27 | 28 | dataset_train_full = lm_with_synt(cfg) 29 | dataset_test_full = lm(cfg) 30 | n_lm = len(dataset_test_full) 31 | n_lm_synt = len(dataset_train_full) - n_lm 32 | train_idx, test_idx = get_lm_img_idx(cfg, n_lm, n_lm_synt) 33 | dataset = torch.utils.data.Subset(dataset_train_full, train_idx) 34 | dataset_test = torch.utils.data.Subset(dataset_test_full, test_idx) 35 | 36 | print("Creating data loaders") 37 | if args.distributed: 38 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) 39 | test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False) 40 | else: 41 | train_sampler = torch.utils.data.RandomSampler(dataset) 42 | test_sampler = torch.utils.data.SequentialSampler(dataset_test) 43 | 44 | if args.aspect_ratio_group_factor >= 0: 45 | group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor) 46 | train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, cfg.BATCH_SIZE) 47 | else: 48 | train_batch_sampler = torch.utils.data.BatchSampler( 49 | train_sampler, cfg.BATCH_SIZE, drop_last=True) 50 | 51 | data_loader = torch.utils.data.DataLoader( 52 | dataset, batch_sampler=train_batch_sampler, num_workers=cfg.WORKERS, 53 | collate_fn=utils.collate_fn) 54 | 55 | data_loader_test = torch.utils.data.DataLoader( 56 | dataset_test, batch_size=cfg.TEST_BATCH_SIZE, 57 | sampler=test_sampler, num_workers=cfg.WORKERS, 58 | collate_fn=utils.collate_fn) 59 | 60 | print("Creating model") 61 | model = keypointrcnn_hrnet(cfg, resume=args.resume, min_size=480, max_size=640) 62 | model.to(device) 63 | 64 | model_without_ddp = model 65 | if args.distributed: 66 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 67 | model_without_ddp = model.module 68 | 69 | params = [p for p in model.parameters() if p.requires_grad] 70 | optimizer = torch.optim.Adam(params, lr=cfg.LR) 71 | 72 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg.LR_STEPS, gamma=cfg.LR_DECAY) 73 | 74 | if args.resume: 75 | checkpoint = torch.load(os.path.join(cfg.OUTPUT_DIR,cfg.obj,'{}.pth'.format(cfg.log_name)), map_location='cpu') 76 | model_without_ddp.load_state_dict(checkpoint['model']) 77 | optimizer.load_state_dict(checkpoint['optimizer']) 78 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 79 | args.start_epoch = checkpoint['epoch'] + 1 80 | 81 | if args.test_only: 82 | evaluator = evaluate(model, data_loader_test, device=device, logger=logger) 83 | 84 | boxes, pose_record1, pose_record2, pose_record3, pose_recordc, pts2d_record1, pts2d_record2, pts2d_record3, \ 85 | corrects1, corrects2, corrects3, correctsc \ 86 | = evaluator.get_accuracy(cfg, args.start_epoch-1, n_test=len(test_idx), testset_name='lm', n_min=4, thres=1, logger=logger) 87 | savemat(os.path.join(cfg.OUTPUT_DIR, cfg.obj, '{}_result.mat'.format(cfg.log_name)), 88 | {'boxes':boxes.cpu().numpy(), 'pose_record1': pose_record1.detach().cpu().numpy(), 89 | 'pose_record2': pose_record2.detach().cpu().numpy(), 'pose_record3': pose_record3.detach().cpu().numpy(), 90 | 'pose_recordc': pose_recordc.detach().cpu().numpy(), 'pts2d_record1': pts2d_record1.detach().cpu().numpy(), 91 | 'pts2d_record2': pts2d_record2.detach().cpu().numpy(), 'pts2d_record3': pts2d_record3.detach().cpu().numpy(), 92 | 'corrects1':corrects1.detach().cpu().numpy(), 'corrects2':corrects2.detach().cpu().numpy(), 93 | 'corrects3':corrects3.detach().cpu().numpy(), 'correctsc':correctsc.detach().cpu().numpy(), 'test_idx': test_idx}) 94 | 95 | return 96 | 97 | print("Start training") 98 | start_time = time.time() 99 | for epoch in range(args.start_epoch, cfg.END_EPOCH): 100 | if args.distributed: 101 | train_sampler.set_epoch(epoch) 102 | train_one_epoch(model, optimizer, data_loader, device, epoch, cfg.PRINT_FREQ, cfg.obj, logger) 103 | lr_scheduler.step() 104 | if cfg.OUTPUT_DIR: 105 | utils.save_on_master({ 106 | 'model': model_without_ddp.state_dict(), 107 | 'optimizer': optimizer.state_dict(), 108 | 'lr_scheduler': lr_scheduler.state_dict(), 109 | 'args': args, 110 | 'cfg': cfg, 111 | 'epoch': epoch}, 112 | os.path.join(cfg.OUTPUT_DIR, cfg.obj, '{}.pth'.format(cfg.log_name))) 113 | 114 | if epoch==cfg.END_EPOCH-1: 115 | evaluator = evaluate(model, data_loader_test, device=device, logger=logger) 116 | 117 | boxes, pose_record1, pose_record2, pose_record3, pose_recordc, pts2d_record1, pts2d_record2, pts2d_record3, \ 118 | corrects1, corrects2, corrects3, correctsc \ 119 | = evaluator.get_accuracy(cfg, epoch, n_test=len(test_idx), testset_name='lm', n_min=4, thres=1, logger=logger) 120 | savemat(os.path.join(cfg.OUTPUT_DIR, cfg.obj, '{}_result.mat'.format(cfg.log_name)), 121 | {'boxes':boxes.cpu().numpy(), 'pose_record1': pose_record1.detach().cpu().numpy(), 122 | 'pose_record2': pose_record2.detach().cpu().numpy(), 'pose_record3': pose_record3.detach().cpu().numpy(), 123 | 'pose_recordc': pose_recordc.detach().cpu().numpy(), 'pts2d_record1': pts2d_record1.detach().cpu().numpy(), 124 | 'pts2d_record2': pts2d_record2.detach().cpu().numpy(), 'pts2d_record3': pts2d_record3.detach().cpu().numpy(), 125 | 'corrects1':corrects1.detach().cpu().numpy(), 'corrects2':corrects2.detach().cpu().numpy(), 126 | 'corrects3':corrects3.detach().cpu().numpy(), 'correctsc':correctsc.detach().cpu().numpy(), 'test_idx': test_idx}) 127 | 128 | total_time = time.time() - start_time 129 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 130 | print('Training time {}'.format(total_time_str)) 131 | 132 | 133 | if __name__ == "__main__": 134 | import argparse 135 | parser = argparse.ArgumentParser( 136 | description=__doc__) 137 | 138 | parser.add_argument('--resume', dest="resume",action="store_true") 139 | parser.add_argument('--start_epoch', default=0, type=int, help='start epoch') 140 | parser.add_argument('--aspect-ratio-group-factor', default=-1, type=int) 141 | parser.add_argument("--test-only",dest="test_only",help="Only test the model",action="store_true",) 142 | parser.add_argument('--world-size', default=1, type=int, help='number of distributed processes') 143 | parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') 144 | parser.add_argument('--cfg', help='experiment configure file name', required=True, type=str) 145 | parser.add_argument('--obj', required=True, type=str) 146 | parser.add_argument('--sigma1', default=1.5, required=False, type=float) 147 | parser.add_argument('--sigma2', default=3, required=False, type=float) 148 | parser.add_argument('--sigma3', default=8, required=False, type=float) 149 | parser.add_argument('--log_name', required=True, type=str) 150 | args = parser.parse_args() 151 | cfg = CN(new_allowed=True) 152 | cfg.defrost() 153 | cfg.merge_from_file(args.cfg) 154 | cfg.obj = args.obj 155 | cfg.log_name = args.log_name 156 | cfg.sigma1 = args.sigma1 157 | cfg.sigma2 = args.sigma2 158 | cfg.sigma3 = args.sigma3 159 | cfg.freeze() 160 | 161 | main(args, cfg) 162 | 163 | 164 | 165 | 166 | 167 | -------------------------------------------------------------------------------- /main_ycbv.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import time 4 | from yacs.config import CfgNode as CN 5 | from scipy.io import savemat 6 | 7 | import torch 8 | import torch.utils.data 9 | from torch import nn 10 | import torchvision 11 | import torchvision.models.detection 12 | import torchvision.models.detection.mask_rcnn 13 | 14 | from reference.group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups 15 | from reference.engine import train_one_epoch, evaluate 16 | from reference import utils 17 | from dataset.ycbv import ycbv_train_w_synt, ycbv_test 18 | from detection.keypoint_rcnn import keypointrcnn_hrnet 19 | 20 | from libs.utils import get_logger 21 | 22 | 23 | def main(args, cfg): 24 | utils.init_distributed_mode(args) 25 | logger = get_logger(cfg) 26 | device = torch.device(cfg.DEVICE) 27 | 28 | # Data loading code 29 | print("Loading data") 30 | 31 | dataset = ycbv_train_w_synt(cfg) 32 | dataset_test = ycbv_test(cfg) 33 | valid_list = list(range(len(dataset_test))) 34 | 35 | print("Creating data loaders. Is distributed? ", args.distributed) 36 | if args.distributed: 37 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) 38 | test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False) 39 | else: 40 | train_sampler = torch.utils.data.RandomSampler(dataset) 41 | test_sampler = torch.utils.data.SequentialSampler(dataset_test) 42 | 43 | if args.aspect_ratio_group_factor >= 0: 44 | group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor) 45 | train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, cfg.BATCH_SIZE) 46 | else: 47 | train_batch_sampler = torch.utils.data.BatchSampler( 48 | train_sampler, cfg.BATCH_SIZE, drop_last=True) 49 | 50 | data_loader = torch.utils.data.DataLoader( 51 | dataset, batch_sampler=train_batch_sampler, num_workers=cfg.WORKERS, 52 | collate_fn=utils.collate_fn) 53 | 54 | data_loader_test = torch.utils.data.DataLoader( 55 | dataset_test, batch_size=cfg.TEST_BATCH_SIZE, 56 | sampler=test_sampler, num_workers=cfg.WORKERS, 57 | collate_fn=utils.collate_fn) 58 | 59 | print("Creating model") 60 | model = keypointrcnn_hrnet(cfg, resume=args.resume, min_size=480, max_size=640) 61 | model.to(device) 62 | 63 | model_without_ddp = model 64 | if args.distributed: 65 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 66 | model_without_ddp = model.module 67 | 68 | params = [p for p in model.parameters() if p.requires_grad] 69 | optimizer = torch.optim.Adam(params, lr=cfg.LR) 70 | 71 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg.LR_STEPS, gamma=cfg.LR_DECAY) 72 | 73 | if args.resume: 74 | checkpoint = torch.load(os.path.join(cfg.OUTPUT_DIR, cfg.obj,'{}.pth'.format(cfg.log_name)), map_location='cpu') 75 | model_without_ddp.load_state_dict(checkpoint['model']) 76 | optimizer.load_state_dict(checkpoint['optimizer']) 77 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 78 | args.start_epoch = checkpoint['epoch'] + 1 79 | 80 | if args.test_only: 81 | evaluator = evaluate(model, data_loader_test, device=device, logger=logger) 82 | 83 | boxes, pose_record1, pose_record2, pose_record3, pose_recordc, pts2d_record1, pts2d_record2, pts2d_record3, \ 84 | corrects1, corrects2, corrects3, correctsc, seq_ids, img_ids \ 85 | = evaluator.get_ycbv_accuracy(cfg, args.start_epoch-1, n_test=len(valid_list), testset_name='ycbv', n_min=4, thres=1, logger=logger) 86 | savemat(os.path.join(cfg.OUTPUT_DIR, cfg.obj, '{}_result.mat'.format(cfg.log_name)), 87 | {'boxes':boxes.cpu().numpy(), 'pose_record1': pose_record1.detach().cpu().numpy(), 88 | 'pose_record2': pose_record2.detach().cpu().numpy(), 'pose_record3': pose_record3.detach().cpu().numpy(), 89 | 'pose_recordc': pose_recordc.detach().cpu().numpy(), 'pts2d_record1': pts2d_record1.detach().cpu().numpy(), 90 | 'pts2d_record2': pts2d_record2.detach().cpu().numpy(), 'pts2d_record3': pts2d_record3.detach().cpu().numpy(), 91 | 'corrects1':corrects1.detach().cpu().numpy(), 'corrects2':corrects2.detach().cpu().numpy(), 92 | 'corrects3':corrects3.detach().cpu().numpy(), 'correctsc':correctsc.detach().cpu().numpy(), 'test_idx': valid_list, 93 | 'seq_ids':seq_ids, 'img_ids':img_ids}) 94 | return 95 | 96 | print("Start training") 97 | start_time = time.time() 98 | for epoch in range(args.start_epoch, cfg.END_EPOCH): 99 | if args.distributed: 100 | train_sampler.set_epoch(epoch) 101 | 102 | train_one_epoch(model, optimizer, data_loader, device, epoch, cfg.PRINT_FREQ, cfg.obj, logger) 103 | lr_scheduler.step() 104 | if cfg.OUTPUT_DIR: 105 | utils.save_on_master({ 106 | 'model': model_without_ddp.state_dict(), 107 | 'optimizer': optimizer.state_dict(), 108 | 'lr_scheduler': lr_scheduler.state_dict(), 109 | 'args': args, 110 | 'cfg': cfg, 111 | 'epoch': epoch}, 112 | os.path.join(cfg.OUTPUT_DIR, cfg.obj, '{}.pth'.format(cfg.log_name))) 113 | 114 | if epoch==cfg.END_EPOCH-1: 115 | evaluator = evaluate(model, data_loader_test, device=device, logger=logger) 116 | 117 | boxes, pose_record1, pose_record2, pose_record3, pose_recordc, pts2d_record1, pts2d_record2, pts2d_record3, \ 118 | corrects1, corrects2, corrects3, correctsc, seq_ids, img_ids \ 119 | = evaluator.get_ycbv_accuracy(cfg, epoch, n_test=len(valid_list), testset_name='ycbv', n_min=4, thres=1, logger=logger) 120 | savemat(os.path.join(cfg.OUTPUT_DIR, cfg.obj, '{}_result.mat'.format(cfg.log_name)), 121 | {'boxes':boxes.cpu().numpy(), 'pose_record1': pose_record1.detach().cpu().numpy(), 122 | 'pose_record2': pose_record2.detach().cpu().numpy(), 'pose_record3': pose_record3.detach().cpu().numpy(), 123 | 'pose_recordc': pose_recordc.detach().cpu().numpy(), 'pts2d_record1': pts2d_record1.detach().cpu().numpy(), 124 | 'pts2d_record2': pts2d_record2.detach().cpu().numpy(), 'pts2d_record3': pts2d_record3.detach().cpu().numpy(), 125 | 'corrects1':corrects1.detach().cpu().numpy(), 'corrects2':corrects2.detach().cpu().numpy(), 126 | 'corrects3':corrects3.detach().cpu().numpy(), 'correctsc':correctsc.detach().cpu().numpy(), 'test_idx': valid_list, 127 | 'seq_ids':seq_ids, 'img_ids':img_ids}) 128 | 129 | total_time = time.time() - start_time 130 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 131 | print('Training time {}'.format(total_time_str)) 132 | 133 | 134 | if __name__ == "__main__": 135 | import argparse 136 | parser = argparse.ArgumentParser( 137 | description=__doc__) 138 | 139 | parser.add_argument('--resume', dest="resume",action="store_true") 140 | parser.add_argument('--start_epoch', default=0, type=int, help='start epoch') 141 | parser.add_argument('--aspect-ratio-group-factor', default=-1, type=int) 142 | parser.add_argument("--test-only",dest="test_only",help="Only test the model",action="store_true",) 143 | parser.add_argument('--world-size', default=1, type=int, help='number of distributed processes') 144 | parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') 145 | parser.add_argument('--cfg', help='experiment configure file name', required=True, type=str) 146 | parser.add_argument('--obj', required=True, type=str) 147 | parser.add_argument('--sigma1', default=1.5, required=False, type=float) 148 | parser.add_argument('--sigma2', default=3, required=False, type=float) 149 | parser.add_argument('--sigma3', default=8, required=False, type=float) 150 | parser.add_argument('--log_name', required=True, type=str) 151 | parser.add_argument('--distrib', default=1, type=int) 152 | args = parser.parse_args() 153 | cfg = CN(new_allowed=True) 154 | cfg.defrost() 155 | cfg.merge_from_file(args.cfg) 156 | cfg.obj = args.obj 157 | cfg.log_name = args.log_name 158 | cfg.sigma1 = args.sigma1 159 | cfg.sigma2 = args.sigma2 160 | cfg.sigma3 = args.sigma3 161 | cfg.freeze() 162 | 163 | main(args, cfg) 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | -------------------------------------------------------------------------------- /reference/train.py: -------------------------------------------------------------------------------- 1 | r"""PyTorch Detection Training. 2 | 3 | To run in a multi-gpu environment, use the distributed launcher:: 4 | 5 | python -m torch.distributed.launch --nproc_per_node=$NGPU --use_env \ 6 | train.py ... --world-size $NGPU 7 | 8 | The default hyperparameters are tuned for training on 8 gpus and 2 images per gpu. 9 | --lr 0.02 --batch-size 2 --world-size 8 10 | If you use different number of gpus, the learning rate should be changed to 0.02/8*$NGPU. 11 | 12 | On top of that, for training Faster/Mask R-CNN, the default hyperparameters are 13 | --epochs 26 --lr-steps 16 22 --aspect-ratio-group-factor 3 14 | 15 | Also, if you train Keypoint R-CNN, the default hyperparameters are 16 | --epochs 46 --lr-steps 36 43 --aspect-ratio-group-factor 3 17 | Because the number of images is smaller in the person keypoint subset of COCO, 18 | the number of epochs should be adapted so that we have the same number of iterations. 19 | """ 20 | import datetime 21 | import os 22 | import time 23 | 24 | import torch 25 | import torch.utils.data 26 | from torch import nn 27 | import torchvision 28 | import torchvision.models.detection 29 | import torchvision.models.detection.mask_rcnn 30 | 31 | from coco_utils import get_coco, get_coco_kp 32 | 33 | from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups 34 | from engine import train_one_epoch, evaluate 35 | 36 | from . import utils 37 | from . import transforms as T 38 | 39 | 40 | def get_dataset(name, image_set, transform, data_path): 41 | paths = { 42 | "coco": (data_path, get_coco, 91), 43 | "coco_kp": (data_path, get_coco_kp, 2) 44 | } 45 | p, ds_fn, num_classes = paths[name] 46 | 47 | ds = ds_fn(p, image_set=image_set, transforms=transform) 48 | return ds, num_classes 49 | 50 | 51 | def get_transform(train): 52 | transforms = [] 53 | transforms.append(T.ToTensor()) 54 | if train: 55 | transforms.append(T.RandomHorizontalFlip(0.5)) 56 | return T.Compose(transforms) 57 | 58 | 59 | def main(args): 60 | utils.init_distributed_mode(args) 61 | print(args) 62 | 63 | device = torch.device(args.device) 64 | 65 | # Data loading code 66 | print("Loading data") 67 | 68 | dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True), args.data_path) 69 | dataset_test, _ = get_dataset(args.dataset, "val", get_transform(train=False), args.data_path) 70 | 71 | print("Creating data loaders") 72 | if args.distributed: 73 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) 74 | test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) 75 | else: 76 | train_sampler = torch.utils.data.RandomSampler(dataset) 77 | test_sampler = torch.utils.data.SequentialSampler(dataset_test) 78 | 79 | if args.aspect_ratio_group_factor >= 0: 80 | group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor) 81 | train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size) 82 | else: 83 | train_batch_sampler = torch.utils.data.BatchSampler( 84 | train_sampler, args.batch_size, drop_last=True) 85 | 86 | data_loader = torch.utils.data.DataLoader( 87 | dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, 88 | collate_fn=utils.collate_fn) 89 | 90 | data_loader_test = torch.utils.data.DataLoader( 91 | dataset_test, batch_size=1, 92 | sampler=test_sampler, num_workers=args.workers, 93 | collate_fn=utils.collate_fn) 94 | 95 | print("Creating model") 96 | model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, 97 | pretrained=args.pretrained) 98 | model.to(device) 99 | 100 | model_without_ddp = model 101 | if args.distributed: 102 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 103 | model_without_ddp = model.module 104 | 105 | params = [p for p in model.parameters() if p.requires_grad] 106 | optimizer = torch.optim.SGD( 107 | params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 108 | 109 | # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) 110 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma) 111 | 112 | if args.resume: 113 | checkpoint = torch.load(args.resume, map_location='cpu') 114 | model_without_ddp.load_state_dict(checkpoint['model']) 115 | optimizer.load_state_dict(checkpoint['optimizer']) 116 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 117 | args.start_epoch = checkpoint['epoch'] + 1 118 | 119 | if args.test_only: 120 | evaluate(model, data_loader_test, device=device) 121 | return 122 | 123 | print("Start training") 124 | start_time = time.time() 125 | for epoch in range(args.start_epoch, args.epochs): 126 | if args.distributed: 127 | train_sampler.set_epoch(epoch) 128 | train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq) 129 | lr_scheduler.step() 130 | if args.output_dir: 131 | utils.save_on_master({ 132 | 'model': model_without_ddp.state_dict(), 133 | 'optimizer': optimizer.state_dict(), 134 | 'lr_scheduler': lr_scheduler.state_dict(), 135 | 'args': args, 136 | 'epoch': epoch}, 137 | os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) 138 | 139 | # evaluate after every epoch 140 | evaluate(model, data_loader_test, device=device) 141 | 142 | total_time = time.time() - start_time 143 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 144 | print('Training time {}'.format(total_time_str)) 145 | 146 | 147 | if __name__ == "__main__": 148 | import argparse 149 | parser = argparse.ArgumentParser( 150 | description=__doc__) 151 | 152 | parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset') 153 | parser.add_argument('--dataset', default='coco', help='dataset') 154 | parser.add_argument('--model', default='maskrcnn_resnet50_fpn', help='model') 155 | parser.add_argument('--device', default='cuda', help='device') 156 | parser.add_argument('-b', '--batch-size', default=2, type=int, 157 | help='images per gpu, the total batch size is $NGPU x batch_size') 158 | parser.add_argument('--epochs', default=26, type=int, metavar='N', 159 | help='number of total epochs to run') 160 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 161 | help='number of data loading workers (default: 4)') 162 | parser.add_argument('--lr', default=0.02, type=float, 163 | help='initial learning rate, 0.02 is the default value for training ' 164 | 'on 8 gpus and 2 images_per_gpu') 165 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 166 | help='momentum') 167 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 168 | metavar='W', help='weight decay (default: 1e-4)', 169 | dest='weight_decay') 170 | parser.add_argument('--lr-step-size', default=8, type=int, help='decrease lr every step-size epochs') 171 | parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int, help='decrease lr every step-size epochs') 172 | parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') 173 | parser.add_argument('--print-freq', default=20, type=int, help='print frequency') 174 | parser.add_argument('--output-dir', default='.', help='path where to save') 175 | parser.add_argument('--resume', default='', help='resume from checkpoint') 176 | parser.add_argument('--start_epoch', default=0, type=int, help='start epoch') 177 | parser.add_argument('--aspect-ratio-group-factor', default=3, type=int) 178 | parser.add_argument( 179 | "--test-only", 180 | dest="test_only", 181 | help="Only test the model", 182 | action="store_true", 183 | ) 184 | parser.add_argument( 185 | "--pretrained", 186 | dest="pretrained", 187 | help="Use pre-trained models from the modelzoo", 188 | action="store_true", 189 | ) 190 | 191 | # distributed training parameters 192 | parser.add_argument('--world-size', default=1, type=int, 193 | help='number of distributed processes') 194 | parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') 195 | 196 | args = parser.parse_args() 197 | 198 | if args.output_dir: 199 | utils.mkdir(args.output_dir) 200 | 201 | main(args) 202 | -------------------------------------------------------------------------------- /libs/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from scipy.io import loadmat, savemat 5 | from sklearn.neighbors import KDTree 6 | import kornia as kn 7 | import logging 8 | import json 9 | 10 | def get_logger(cfg): 11 | if not os.path.exists(cfg.OUTPUT_DIR+'/'+cfg.obj+'/'): 12 | os.mkdir(cfg.OUTPUT_DIR+'/'+cfg.obj+'/') 13 | logging.basicConfig(filename=cfg.OUTPUT_DIR+'/'+cfg.obj+'/'+cfg.log_name+'.out', level=logging.INFO, 14 | format='%(asctime)s %(message)s', datefmt='%Y/%m/%d %I:%M:%S %p') 15 | logger = logging.getLogger(__name__) 16 | logger.setLevel(logging.INFO) 17 | ch = logging.StreamHandler() 18 | ch.setLevel(logging.INFO) 19 | logger.addHandler(ch) 20 | return logger 21 | 22 | def get_objid(obj): 23 | obj_dict = {'ape':1, 'benchvise':2, 'cam':4, 'can':5, 'cat':6, 'driller':8, 'duck':9, 'eggbox':10, 'glue':11, 'holepuncher':12, 24 | 'iron':13, 'lamp':14, 'phone':15} 25 | return obj_dict[obj] 26 | 27 | def get_ycbv_objid(obj): 28 | # obj_dict = {'master_chef_can':1, 'cracker_box':2, 'sugar_box':3, 'tomato_soup_can':4, 'mustard_bottle':5, 'tuna_fish_can':6, 'pudding_box':7, 'gelatin_box':8, 29 | # 'potted_meat_can':9, 'banana':10, 'pitcher_base':11, 'bleach_cleanser':12, 'bowl':13, 'mug':14, 'power_drill':15, 'wood_block':16, 'scissors':17, 'large_marker':18, 30 | # 'large_clamp':19, 'extra_large_clamp':20, 'foam_brick':21} 31 | obj_dict = {'01':1, '02':2, '03':3, '04':4, '05':5, '06':6, '07':7, '08':8, '09':9, '10':10, '11':11, '12':12, '13':13, '14':14, '15':15, '16':16, '17':17, '18':18, 32 | '19':19, '20':20, '21':21} 33 | return obj_dict[obj] 34 | 35 | def get_lm_img_idx(cfg, n_lm, n_lm_synt=None): 36 | objid = get_objid(cfg.obj) 37 | f = open(os.path.join(cfg.LM_DIR, '{:06d}/training_range.txt'.format(objid))) 38 | train_idx = [int(x) for x in f] 39 | whole_idx = list(range(n_lm)) 40 | test_idx = [item for item in whole_idx if item not in train_idx] 41 | if n_lm_synt is None: 42 | return train_idx, test_idx 43 | else: 44 | synt_idx = list(range(n_lm, n_lm+n_lm_synt)) 45 | train_idx.extend(synt_idx) 46 | return train_idx, test_idx 47 | 48 | def get_lm_dataset_size(cfg): 49 | objid = get_objid(cfg.obj) 50 | PM_file = cfg.LM_DIR + '/{:06d}/scene_gt.json'.format(objid) 51 | with open(PM_file) as f: 52 | PM = json.load(f) 53 | return len(PM) 54 | 55 | def get_lm_o_fps(cfg): 56 | # get the Farthest Point Sample of the 3D mesh 57 | objid = get_objid(cfg.obj) 58 | fps_file = cfg.LM_DIR+'/lm_models/lm_fps/obj{:02d}_fps128.mat'.format(objid) 59 | fpsvis = loadmat(fps_file) 60 | pts3d_fps = torch.tensor(fpsvis['fps'][0:cfg.N_PTS, :]) # size [N_PTS, 3] 61 | return pts3d_fps 62 | 63 | def get_lm_o_3dmodel(cfg, homo=False): 64 | objid = get_objid(cfg.obj) 65 | model_file = cfg.LM_DIR+'/lm_models/lm_meshes_cm/obj{:02d}.mat'.format(objid) 66 | pts3d = loadmat(model_file)['pts3d'] 67 | pts3d = torch.tensor(pts3d, dtype=torch.float) 68 | if homo: 69 | pts3d = torch.cat((pts3d, torch.ones(pts3d.size(0), 1)), dim=-1) 70 | return pts3d 71 | 72 | def get_ycbv_fps(cfg): 73 | # get the Farthest Point Sample of the 3D mesh 74 | objid = get_ycbv_objid(cfg.obj) 75 | fps_file = cfg.YCBV_DIR+'/models/obj{:02d}_fps128.mat'.format(objid) 76 | fps = loadmat(fps_file) 77 | pts3d_fps = torch.tensor(fps['fps'][0:cfg.N_PTS, :]) # size [N_PTS, 3] 78 | return pts3d_fps 79 | 80 | def get_ycbv_3dmodel(cfg, homo=False): 81 | objid = get_ycbv_objid(cfg.obj) 82 | model_file = cfg.YCBV_DIR+'/models/obj{:02d}.mat'.format(objid) 83 | model = loadmat(model_file) 84 | pts3d = torch.tensor(model['pts3d']) 85 | is_sym = model['sym'] 86 | if homo: 87 | pts3d = torch.cat((pts3d, torch.ones(pts3d.size(0), 1)), dim=-1) 88 | return pts3d, is_sym 89 | 90 | def get_lm_pose(cfg, idx): 91 | objid = get_objid(cfg.obj) 92 | PM_file = cfg.LM_DIR + '/{:06d}/scene_gt.json'.format(objid) 93 | with open(PM_file) as f: 94 | PMs = json.load(f) 95 | R = torch.tensor(PMs[str(idx)][0]['cam_R_m2c']).view(1,3,3) 96 | T = 0.1*torch.tensor(PMs[str(idx)][0]['cam_t_m2c']).view(1,3,1) 97 | return torch.cat((R,T),dim=-1) 98 | 99 | def get_lmo_pose(cfg, idx): 100 | objid = get_objid(cfg.obj) 101 | PM_file = cfg.LMO_DIR + '/000002/scene_gt.json' 102 | with open(PM_file) as f: 103 | PMs = json.load(f) 104 | list_idx = PMs[str(idx)] 105 | objid_list = [temp['obj_id'] for temp in list_idx] 106 | assert objid in objid_list, 'Image id {} doesn\'t have object {} in sight.'.format(idx, objid) 107 | ttt = [ temp for temp in list_idx if temp['obj_id']==objid] 108 | R = torch.tensor(ttt[0]['cam_R_m2c']).view(1,3,3) 109 | T = 0.1*torch.tensor(ttt[0]['cam_t_m2c']).view(1,3,1) 110 | return torch.cat((R,T),dim=-1) 111 | 112 | def get_PM_gt(cfg, dataset_name, idx): 113 | if dataset_name == 'lm': 114 | return get_lm_pose(cfg, idx) 115 | if dataset_name == 'lmo': 116 | return get_lmo_pose(cfg, idx) 117 | 118 | def get_distance(cfg): 119 | # get the diameter of the object (in cm) 120 | objid = get_objid(cfg.obj) 121 | with open(os.path.join(cfg.LM_DIR,'lm_models/models/models_info.json')) as f: 122 | models_info = json.load(f) 123 | diameter = 0.1*torch.tensor(models_info[str(objid)]['diameter']).view(1) 124 | assert diameter.size()[0] == 1 125 | return diameter 126 | 127 | def get_ycbv_distance(cfg): 128 | # get the diameter of the object (in cm) 129 | objid = get_ycbv_objid(cfg.obj) 130 | with open(os.path.join(cfg.YCBV_DIR,'models/models_info.json')) as f: 131 | models_info = json.load(f) 132 | diameter = 0.1*torch.tensor(models_info[str(objid)]['diameter']).view(1) 133 | assert diameter.size()[0] == 1 134 | return diameter 135 | 136 | def get_K(): 137 | fx = 572.41140 138 | fy = 573.57043 139 | u = 325.26110 140 | v = 242.04899 141 | K = torch.tensor( 142 | [[fx, 0, u], 143 | [0, fy, v], 144 | [0, 0, 1]], 145 | dtype=torch.float) 146 | return K 147 | 148 | def get_kp_consensus_aslist(pts2d1, pts2d2, pts2d3, thres, n_min): 149 | bs = pts2d1.size(0) 150 | npts = pts2d1.size(1) 151 | pts2dc = [] 152 | ids = [] 153 | for i in range(bs): 154 | p1 = pts2d1[i] 155 | p2 = pts2d2[i] 156 | p3 = pts2d3[i] 157 | 158 | # dist = (p1-p2).norm(dim=-1) + (p1-p3).norm(dim=-1) 159 | dist = (p1-p2).norm(dim=-1) 160 | # dist = (p1-p3).norm(dim=-1) 161 | 162 | ids_i = np.where((dist= warmup_iters: 260 | return 1 261 | alpha = float(x) / warmup_iters 262 | return warmup_factor * (1 - alpha) + alpha 263 | 264 | return torch.optim.lr_scheduler.LambdaLR(optimizer, f) 265 | 266 | 267 | def mkdir(path): 268 | try: 269 | os.makedirs(path) 270 | except OSError as e: 271 | if e.errno != errno.EEXIST: 272 | raise 273 | 274 | 275 | def setup_for_distributed(is_master): 276 | """ 277 | This function disables printing when not in master process 278 | """ 279 | import builtins as __builtin__ 280 | builtin_print = __builtin__.print 281 | 282 | def print(*args, **kwargs): 283 | force = kwargs.pop('force', False) 284 | if is_master or force: 285 | builtin_print(*args, **kwargs) 286 | 287 | __builtin__.print = print 288 | 289 | 290 | def is_dist_avail_and_initialized(): 291 | if not dist.is_available(): 292 | return False 293 | if not dist.is_initialized(): 294 | return False 295 | return True 296 | 297 | 298 | def get_world_size(): 299 | if not is_dist_avail_and_initialized(): 300 | return 1 301 | return dist.get_world_size() 302 | 303 | 304 | def get_rank(): 305 | if not is_dist_avail_and_initialized(): 306 | return 0 307 | return dist.get_rank() 308 | 309 | 310 | def is_main_process(): 311 | return get_rank() == 0 312 | 313 | 314 | def save_on_master(*args, **kwargs): 315 | if is_main_process(): 316 | torch.save(*args, **kwargs) 317 | 318 | 319 | def init_distributed_mode(args): 320 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 321 | args.rank = int(os.environ["RANK"]) 322 | args.world_size = int(os.environ['WORLD_SIZE']) 323 | args.gpu = int(os.environ['LOCAL_RANK']) 324 | elif 'SLURM_PROCID' in os.environ: 325 | args.rank = int(os.environ['SLURM_PROCID']) 326 | args.gpu = args.rank % torch.cuda.device_count() 327 | else: 328 | print('Not using distributed mode') 329 | args.distributed = False 330 | return 331 | 332 | args.distributed = True 333 | 334 | torch.cuda.set_device(args.gpu) 335 | args.dist_backend = 'nccl' 336 | print('| distributed init (rank {}): {}'.format( 337 | args.rank, args.dist_url), flush=True) 338 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 339 | world_size=args.world_size, rank=args.rank) 340 | torch.distributed.barrier() 341 | setup_for_distributed(args.rank == 0) 342 | 343 | 344 | 345 | 346 | 347 | 348 | -------------------------------------------------------------------------------- /detection/transform.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | import torch 4 | from torch import nn, Tensor 5 | from torch.nn import functional as F 6 | import torchvision 7 | from torch.jit.annotations import List, Tuple, Dict, Optional 8 | 9 | from .image_list import ImageList 10 | from .roi_heads import paste_masks_in_image 11 | 12 | 13 | @torch.jit.unused 14 | def _resize_image_and_masks_onnx(image, self_min_size, self_max_size, target): 15 | # type: (Tensor, float, float, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]] 16 | from torch.onnx import operators 17 | im_shape = operators.shape_as_tensor(image)[-2:] 18 | min_size = torch.min(im_shape).to(dtype=torch.float32) 19 | max_size = torch.max(im_shape).to(dtype=torch.float32) 20 | scale_factor = torch.min(self_min_size / min_size, self_max_size / max_size) 21 | 22 | image = torch.nn.functional.interpolate( 23 | image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True, 24 | align_corners=False)[0] 25 | 26 | if target is None: 27 | return image, target 28 | 29 | if "masks" in target: 30 | mask = target["masks"] 31 | mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor)[:, 0].byte() 32 | target["masks"] = mask 33 | return image, target 34 | 35 | 36 | def _resize_image_and_masks(image, self_min_size, self_max_size, target): 37 | # type: (Tensor, float, float, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]] 38 | im_shape = torch.tensor(image.shape[-2:]) 39 | min_size = float(torch.min(im_shape)) 40 | max_size = float(torch.max(im_shape)) 41 | scale_factor = self_min_size / min_size 42 | if max_size * scale_factor > self_max_size: 43 | scale_factor = self_max_size / max_size 44 | image = torch.nn.functional.interpolate( 45 | image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True, 46 | align_corners=False)[0] 47 | 48 | if target is None: 49 | return image, target 50 | 51 | if "masks" in target: 52 | mask = target["masks"] 53 | mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor)[:, 0].byte() 54 | target["masks"] = mask 55 | return image, target 56 | 57 | 58 | class GeneralizedRCNNTransform(nn.Module): 59 | """ 60 | Performs input / target transformation before feeding the data to a GeneralizedRCNN 61 | model. 62 | 63 | The transformations it perform are: 64 | - input normalization (mean subtraction and std division) 65 | - input / target resizing to match min_size / max_size 66 | 67 | It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets 68 | """ 69 | 70 | def __init__(self, min_size, max_size, image_mean, image_std): 71 | super(GeneralizedRCNNTransform, self).__init__() 72 | if not isinstance(min_size, (list, tuple)): 73 | min_size = (min_size,) 74 | self.min_size = min_size 75 | self.max_size = max_size 76 | self.image_mean = image_mean 77 | self.image_std = image_std 78 | 79 | def forward(self, 80 | images, # type: List[Tensor] 81 | targets=None # type: Optional[List[Dict[str, Tensor]]] 82 | ): 83 | # type: (...) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]] 84 | images = [img for img in images] 85 | if targets is not None: 86 | # make a copy of targets to avoid modifying it in-place 87 | # once torchscript supports dict comprehension 88 | # this can be simplified as as follows 89 | # targets = [{k: v for k,v in t.items()} for t in targets] 90 | targets_copy: List[Dict[str, Tensor]] = [] 91 | for t in targets: 92 | data: Dict[str, Tensor] = {} 93 | for k, v in t.items(): 94 | data[k] = v 95 | targets_copy.append(data) 96 | targets = targets_copy 97 | for i in range(len(images)): 98 | image = images[i] 99 | target_index = targets[i] if targets is not None else None 100 | 101 | if image.dim() != 3: 102 | raise ValueError("images is expected to be a list of 3d tensors " 103 | "of shape [C, H, W], got {}".format(image.shape)) 104 | image = self.normalize(image) 105 | image, target_index = self.resize(image, target_index) 106 | images[i] = image 107 | if targets is not None and target_index is not None: 108 | targets[i] = target_index 109 | 110 | image_sizes = [img.shape[-2:] for img in images] 111 | images = self.batch_images(images) 112 | image_sizes_list = torch.jit.annotate(List[Tuple[int, int]], []) 113 | for image_size in image_sizes: 114 | assert len(image_size) == 2 115 | image_sizes_list.append((image_size[0], image_size[1])) 116 | 117 | image_list = ImageList(images, image_sizes_list) 118 | return image_list, targets 119 | 120 | def normalize(self, image): 121 | dtype, device = image.dtype, image.device 122 | mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device) 123 | std = torch.as_tensor(self.image_std, dtype=dtype, device=device) 124 | return (image - mean[:, None, None]) / std[:, None, None] 125 | 126 | def torch_choice(self, k): 127 | # type: (List[int]) -> int 128 | """ 129 | Implements `random.choice` via torch ops so it can be compiled with 130 | TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803 131 | is fixed. 132 | """ 133 | index = int(torch.empty(1).uniform_(0., float(len(k))).item()) 134 | return k[index] 135 | 136 | def resize(self, image, target): 137 | # type: (Tensor, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]] 138 | h, w = image.shape[-2:] 139 | if self.training: 140 | size = float(self.torch_choice(self.min_size)) 141 | else: 142 | # FIXME assume for now that testing uses the largest scale 143 | size = float(self.min_size[-1]) 144 | if torchvision._is_tracing(): 145 | image, target = _resize_image_and_masks_onnx(image, size, float(self.max_size), target) 146 | else: 147 | image, target = _resize_image_and_masks(image, size, float(self.max_size), target) 148 | 149 | if target is None: 150 | return image, target 151 | 152 | bbox = target["boxes"] 153 | bbox = resize_boxes(bbox, (h, w), image.shape[-2:]) 154 | target["boxes"] = bbox 155 | 156 | if "keypoints" in target: 157 | keypoints = target["keypoints"] 158 | keypoints = resize_keypoints(keypoints, (h, w), image.shape[-2:]) 159 | target["keypoints"] = keypoints 160 | return image, target 161 | 162 | # _onnx_batch_images() is an implementation of 163 | # batch_images() that is supported by ONNX tracing. 164 | @torch.jit.unused 165 | def _onnx_batch_images(self, images, size_divisible=32): 166 | # type: (List[Tensor], int) -> Tensor 167 | max_size = [] 168 | for i in range(images[0].dim()): 169 | max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64) 170 | max_size.append(max_size_i) 171 | stride = size_divisible 172 | max_size[1] = (torch.ceil((max_size[1].to(torch.float32)) / stride) * stride).to(torch.int64) 173 | max_size[2] = (torch.ceil((max_size[2].to(torch.float32)) / stride) * stride).to(torch.int64) 174 | max_size = tuple(max_size) 175 | 176 | # work around for 177 | # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 178 | # which is not yet supported in onnx 179 | padded_imgs = [] 180 | for img in images: 181 | padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] 182 | padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) 183 | padded_imgs.append(padded_img) 184 | 185 | return torch.stack(padded_imgs) 186 | 187 | def max_by_axis(self, the_list): 188 | # type: (List[List[int]]) -> List[int] 189 | maxes = the_list[0] 190 | for sublist in the_list[1:]: 191 | for index, item in enumerate(sublist): 192 | maxes[index] = max(maxes[index], item) 193 | return maxes 194 | 195 | def batch_images(self, images, size_divisible=32): 196 | # type: (List[Tensor], int) -> Tensor 197 | if torchvision._is_tracing(): 198 | # batch_images() does not export well to ONNX 199 | # call _onnx_batch_images() instead 200 | return self._onnx_batch_images(images, size_divisible) 201 | 202 | max_size = self.max_by_axis([list(img.shape) for img in images]) 203 | stride = float(size_divisible) 204 | max_size = list(max_size) 205 | max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride) 206 | max_size[2] = int(math.ceil(float(max_size[2]) / stride) * stride) 207 | 208 | batch_shape = [len(images)] + max_size 209 | batched_imgs = images[0].new_full(batch_shape, 0) 210 | for img, pad_img in zip(images, batched_imgs): 211 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 212 | 213 | return batched_imgs 214 | 215 | def postprocess(self, 216 | result, # type: List[Dict[str, Tensor]] 217 | image_shapes, # type: List[Tuple[int, int]] 218 | original_image_sizes # type: List[Tuple[int, int]] 219 | ): 220 | # type: (...) -> List[Dict[str, Tensor]] 221 | if self.training: 222 | return result 223 | for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)): 224 | boxes = pred["boxes"] 225 | boxes = resize_boxes(boxes, im_s, o_im_s) 226 | result[i]["boxes"] = boxes 227 | if "masks" in pred: 228 | masks = pred["masks"] 229 | masks = paste_masks_in_image(masks, boxes, o_im_s) 230 | result[i]["masks"] = masks 231 | if "keypoints" in pred: 232 | keypoints = pred["keypoints"] 233 | keypoints = resize_keypoints(keypoints, im_s, o_im_s) 234 | result[i]["keypoints"] = keypoints 235 | if "keypoints1" in pred: 236 | keypoints1 = pred["keypoints1"] 237 | keypoints1 = resize_keypoints(keypoints1, im_s, o_im_s) 238 | result[i]["keypoints1"] = keypoints1 239 | if "keypoints2" in pred: 240 | keypoints2 = pred["keypoints2"] 241 | keypoints2 = resize_keypoints(keypoints2, im_s, o_im_s) 242 | result[i]["keypoints2"] = keypoints2 243 | if "keypoints3" in pred: 244 | keypoints3 = pred["keypoints3"] 245 | keypoints3 = resize_keypoints(keypoints3, im_s, o_im_s) 246 | result[i]["keypoints3"] = keypoints3 247 | return result 248 | 249 | def __repr__(self): 250 | format_string = self.__class__.__name__ + '(' 251 | _indent = '\n ' 252 | format_string += "{0}Normalize(mean={1}, std={2})".format(_indent, self.image_mean, self.image_std) 253 | format_string += "{0}Resize(min_size={1}, max_size={2}, mode='bilinear')".format(_indent, self.min_size, 254 | self.max_size) 255 | format_string += '\n)' 256 | return format_string 257 | 258 | 259 | def resize_keypoints(keypoints, original_size, new_size): 260 | # type: (Tensor, List[int], List[int]) -> Tensor 261 | ratios = [ 262 | torch.tensor(s, dtype=torch.float32, device=keypoints.device) / 263 | torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device) 264 | for s, s_orig in zip(new_size, original_size) 265 | ] 266 | ratio_h, ratio_w = ratios 267 | resized_data = keypoints.clone() 268 | if torch._C._get_tracing_state(): 269 | resized_data_0 = resized_data[:, :, 0] * ratio_w 270 | resized_data_1 = resized_data[:, :, 1] * ratio_h 271 | resized_data = torch.stack((resized_data_0, resized_data_1, resized_data[:, :, 2]), dim=2) 272 | else: 273 | resized_data[..., 0] *= ratio_w 274 | resized_data[..., 1] *= ratio_h 275 | return resized_data 276 | 277 | 278 | def resize_boxes(boxes, original_size, new_size): 279 | # type: (Tensor, List[int], List[int]) -> Tensor 280 | ratios = [ 281 | torch.tensor(s, dtype=torch.float32, device=boxes.device) / 282 | torch.tensor(s_orig, dtype=torch.float32, device=boxes.device) 283 | for s, s_orig in zip(new_size, original_size) 284 | ] 285 | ratio_height, ratio_width = ratios 286 | xmin, ymin, xmax, ymax = boxes.unbind(1) 287 | 288 | xmin = xmin * ratio_width 289 | xmax = xmax * ratio_width 290 | ymin = ymin * ratio_height 291 | ymax = ymax * ratio_height 292 | return torch.stack((xmin, ymin, xmax, ymax), dim=1) 293 | -------------------------------------------------------------------------------- /dataset/lm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import os 4 | import numpy as np 5 | import fnmatch 6 | from PIL import Image 7 | from libs.utils import batch_project 8 | from scipy.io import loadmat, savemat 9 | from torch.utils.data import Dataset 10 | import imgaug.augmenters as iaa 11 | import imgaug as ia 12 | from imgaug.augmentables import Keypoint, KeypointsOnImage 13 | import argparse 14 | from yacs.config import CfgNode as CN 15 | import torchvision 16 | import imageio 17 | import json 18 | 19 | def get_objid(obj): 20 | obj_dict = {'ape':1, 'benchvise':2, 'cam':4, 'can':5, 'cat':6, 'driller':8, 'duck':9, 'eggbox':10, 'glue':11, 'holepuncher':12, 21 | 'iron':13, 'lamp':14, 'phone':15} 22 | return obj_dict[obj] 23 | 24 | def get_lm_PM_gt(root, objid): 25 | PM_file = root + '/{:06d}/scene_gt.json'.format(objid) 26 | with open(PM_file) as f: 27 | PM = json.load(f) 28 | return PM 29 | 30 | def get_lm_synt_PM_gt(root, objid): 31 | PM_file = root + '/{:06d}/scene_gt.json'.format(objid) 32 | with open(PM_file) as f: 33 | PM = json.load(f) 34 | return PM 35 | 36 | def get_K(device): 37 | fx = 572.41140 38 | fy = 573.57043 39 | u = 325.26110 40 | v = 242.04899 41 | K = torch.tensor( 42 | [[fx, 0, u], 43 | [0, fy, v], 44 | [0, 0, 1]], 45 | device=device, dtype=torch.float) 46 | return K 47 | 48 | def divide_box(bbox, n_range=(3,6), p_range=(0.25, 0.7), img_w=640, img_h=480): 49 | # bbox: size [4], format [x,y,w,h] 50 | n = torch.randint(n_range[0], n_range[1], (1,)).item() 51 | p = (p_range[1]-p_range[0])*torch.rand(1).item()+p_range[0] 52 | cells = torch.zeros(n, n, 2) 53 | occlude = torch.rand(n,n)<=p 54 | X = bbox[0] 55 | Y = bbox[1] 56 | W = bbox[2] 57 | H = bbox[3] 58 | if W%n != 0: 59 | W = W - W%n 60 | if H%n != 0: 61 | H = H - H%n 62 | assert W%n == 0 63 | assert H%n == 0 64 | assert X+W <= img_w, 'X: {}, W: {}, img_w: {}'.format(X, W, img_w) 65 | assert Y+H <= img_h, 'Y: {}, H: {}, img_h: {}'.format(Y, H, img_h) 66 | w = int(W/n) 67 | h = int(H/n) 68 | for i in range(n): 69 | for j in range(n): 70 | cells[i,j,0] = X + i*w 71 | cells[i,j,1] = Y + j*h 72 | return cells.view(-1,2).long(), occlude.view(-1), w, h 73 | 74 | def get_patch_xy(num_patches, img_w, img_h, obj_bbox, cell_w, cell_h): 75 | patch_xy = torch.zeros(num_patches, 2) 76 | max_w = img_w - cell_w 77 | max_h = img_h - cell_h 78 | X = obj_bbox[0] 79 | Y = obj_bbox[1] 80 | XX = X + obj_bbox[2] 81 | YY = Y + obj_bbox[3] 82 | assert XX>X and X>=0 and XX<=img_w, 'X {}, XX {}, Y {}, YY {}, cell_w {}, cell_h {}, img_w {}, img_h {}.'.format(X, XX, Y, YY, cell_w, cell_h, img_w, img_h) 83 | assert YY>Y and Y>=0 and YY<=img_h, 'X {}, XX {}, Y {}, YY {}, cell_w {}, cell_h {}, img_w {}, img_h {}.'.format(X, XX, Y, YY, cell_w, cell_h, img_w, img_h) 84 | for i in range(num_patches): 85 | x = torch.randint(0, max_w-1, (1,)) 86 | y = torch.randint(0, max_h-1, (1,)) 87 | trial = 0 88 | while x>=X and x=Y and y 1000: 93 | print('Can find patch! X {}, XX {}, Y {}, YY {}, cell_w {}, cell_h {}, img_w {}, img_h {}.' 94 | .format(X, XX, Y, YY, cell_w, cell_h, img_w, img_h)) 95 | patch_xy[i,0] = x 96 | patch_xy[i,1] = y 97 | return patch_xy 98 | 99 | def get_bbox(pts2d, img_size, coco_format=False): 100 | W = img_size[-2] 101 | H = img_size[-1] 102 | xmin = int(max(pts2d[:,0].min().round().item()-15, 0)) 103 | xmax = int(min(pts2d[:,0].max().round().item()+15, W)) 104 | assert xmax>xmin 105 | ymin = int(max(pts2d[:,1].min().round().item()-15, 0)) 106 | ymax = int(min(pts2d[:,1].max().round().item()+15, H)) 107 | assert ymax>ymin 108 | if coco_format: 109 | return [xmin, ymin, xmax, ymax] 110 | else: 111 | return [xmin, ymin, xmax-xmin, ymax-ymin] 112 | 113 | def check_if_inside(pts2d, x1, x2, y1, y2): 114 | r1 = pts2d[:, 0]-0.5 >= x1 -0.5 115 | r2 = pts2d[:, 0]-0.5 <= x2 -1 + 0.5 116 | r3 = pts2d[:, 1]-0.5 >= y1 -0.5 117 | r4 = pts2d[:, 1]-0.5 <= y2 -1 + 0.5 118 | return r1*r2*r3*r4 119 | 120 | def obj_out_of_view(W, H, pts2d): 121 | xmin = pts2d[:,0].min().item() 122 | xmax = pts2d[:,0].max().item() 123 | ymin = pts2d[:,1].min().item() 124 | ymax = pts2d[:,1].max().item() 125 | if xmin>W or xmax<0 or ymin>H or ymax<0: 126 | return True 127 | else: 128 | return False 129 | 130 | def occlude_obj(img, pts2d, vis=None, p_white_noise=0.1, p_occlude=(0.25, 0.7)): 131 | # img: image tensor of size [3, h, w] 132 | _, img_h, img_w = img.size() 133 | 134 | if obj_out_of_view(img_w, img_h, pts2d): 135 | return img, None 136 | 137 | bbox = get_bbox(pts2d, [img_w, img_h]) 138 | cells, occ_cell, cell_w, cell_h = divide_box(bbox, p_range=p_occlude, img_w=img_w, img_h=img_h) 139 | num_cells = cells.size(0) 140 | noise_occ_id = torch.rand(num_cells) <= p_white_noise 141 | actual_noise_occ = noise_occ_id * occ_cell 142 | num_patch_occ = occ_cell.sum() - actual_noise_occ.sum() 143 | patches_xy = get_patch_xy(num_patch_occ, img_w, img_h, bbox, cell_w, cell_h) 144 | j = 0 145 | for i in range(num_cells): 146 | if occ_cell[i]: 147 | x1 = cells[i,0].item() 148 | x2 = x1 + cell_w 149 | y1 = cells[i,1].item() 150 | y2 = y1 + cell_h 151 | 152 | if vis is not None: 153 | vis = vis*(~check_if_inside(pts2d, x1, x2, y1, y2)) 154 | 155 | if noise_occ_id[i]: # white_noise occlude 156 | img[:, y1:y2, x1:x2] = torch.rand(3, cell_h, cell_w) 157 | else: # patch occlude 158 | xx1 = patches_xy[j, 0].long().item() 159 | xx2 = xx1 + cell_w 160 | yy1 = patches_xy[j, 1].long().item() 161 | yy2 = yy1 + cell_h 162 | img[:, y1:y2, x1:x2] = img[:, yy1:yy2, xx1:xx2].clone() 163 | j += 1 164 | assert num_patch_occ == j 165 | return img, vis 166 | 167 | 168 | def kps2tensor(kps): 169 | n = len(kps.keypoints) 170 | pts2d = np.array([kps.keypoints[i].coords for i in range(n)]) 171 | return torch.tensor(pts2d, dtype=torch.float).squeeze() 172 | 173 | 174 | def augment_lm(img, pts2d, device, is_synt=False): 175 | assert len(img.size()) == 3 176 | 177 | H, W = img.size()[-2:] 178 | bbox = get_bbox(pts2d, [W, H]) 179 | min_x_shift = int(-bbox[0]+10) 180 | max_x_shift = int(W-bbox[0]-bbox[2]-10) 181 | min_y_shift = int(-bbox[1]+10) 182 | max_y_shift = int(H-bbox[1]-bbox[3]-10) 183 | assert max_x_shift > min_x_shift, 'H: {}, W: {}, bbox: {}, {}, {}, {}'.format(H, W, bbox[0], bbox[1], bbox[2], bbox[3]) 184 | assert max_y_shift > min_y_shift, 'H: {}, W: {}, bbox: {}, {}, {}, {}'.format(H, W, bbox[0], bbox[1], bbox[2], bbox[3]) 185 | 186 | img = img.permute(1,2,0).numpy() 187 | nkp = pts2d.size(0) 188 | kp_list = [Keypoint(x=pts2d[i][0].item(), y=pts2d[i][1].item()) for i in range(nkp)] 189 | kps = KeypointsOnImage(kp_list, shape=img.shape) 190 | 191 | if is_synt: 192 | step0 = iaa.Affine(scale=(0.35,0.6)) 193 | img, kps = step0(image=img, keypoints=kps) 194 | 195 | rotate = iaa.Affine(rotate=(-30, 30)) 196 | scale = iaa.Affine(scale=(0.8, 1.2)) 197 | trans = iaa.Affine(translate_px={"x": (min_x_shift, max_x_shift), "y": (min_y_shift, max_y_shift)}) 198 | bright = iaa.MultiplyAndAddToBrightness(mul=(0.7, 1.3)) 199 | hue_satu = iaa.MultiplyHueAndSaturation(mul_hue=(0.95,1.05), mul_saturation=(0.5,1.5)) 200 | contrast = iaa.GammaContrast((0.8, 1.2)) 201 | random_aug = iaa.SomeOf((3, 6), [rotate, trans, scale, bright, hue_satu, contrast]) 202 | img1, kps1 = random_aug(image=img, keypoints=kps) 203 | 204 | img1 = torch.tensor(img1).permute(2,0,1).to(device) 205 | pts2d1 = kps2tensor(kps1).to(device) 206 | 207 | if pts2d1[:,0].min()>W or pts2d1[:,0].max()<0 or pts2d1[:,1].min()>H or pts2d1[:,1].max()<0: 208 | img1 = torch.tensor(img).permute(2,0,1).to(device) 209 | pts2d1 = kps2tensor(kps).to(device) 210 | 211 | return img1, img1.clone(), pts2d1 212 | 213 | 214 | def blackout(img, pts2d): 215 | assert len(img.size()) == 3 216 | H, W = img.size()[-2:] 217 | x, y, w, h = get_bbox(pts2d, [W, H]) 218 | img2 = torch.zeros_like(img) 219 | img2[:, y:y+h, x:x+w] = img[:, y:y+h, x:x+w].clone() 220 | return img2 221 | 222 | 223 | class lm(Dataset): 224 | def __init__(self, cfg): 225 | self.objid = get_objid(cfg.obj) 226 | self.lm_root = cfg.LM_DIR 227 | self.data_path = os.path.join(self.lm_root,'{:06d}/rgb'.format(self.objid)) 228 | self.PMs = get_lm_PM_gt(self.lm_root, self.objid) 229 | self.pts3d = torch.tensor(loadmat('dataset/fps/lm/obj{:02d}_fps128.mat'.format(self.objid))['fps'])[:cfg.N_PTS,:] 230 | self.npts = cfg.N_PTS 231 | self.cfg = cfg 232 | self.K = get_K('cpu') 233 | self.n_lm = len(self.PMs) 234 | 235 | def __len__(self,): 236 | return self.n_lm 237 | 238 | def __getitem__(self, idx): 239 | img = imageio.imread(os.path.join(self.data_path, '{:06d}.png'.format(idx))) 240 | img = torch.tensor(img).permute(2,0,1) 241 | R = torch.tensor(self.PMs[str(idx)][0]['cam_R_m2c']).view(1,3,3) 242 | T = 0.1*torch.tensor(self.PMs[str(idx)][0]['cam_t_m2c']).view(1,3,1) 243 | PM = torch.cat((R,T),dim=-1) 244 | pts2d = batch_project(PM, self.pts3d, self.K, angle_axis=False).squeeze() 245 | 246 | tru = torch.ones(1, dtype=torch.bool) 247 | fal = torch.zeros(1, dtype=torch.bool) 248 | num_objs = 1 249 | W, H = self.cfg.MODEL.IMAGE_SIZE 250 | 251 | boxes = get_bbox(pts2d, self.cfg.MODEL.IMAGE_SIZE, coco_format=True) 252 | boxes = torch.as_tensor(boxes, dtype=torch.float32).view(1,4) 253 | labels = torch.ones((num_objs,), dtype=torch.int64) 254 | vis = torch.ones(self.npts, 1) 255 | vis[pts2d[:,0]<0, 0] = 0 256 | vis[pts2d[:,0]>W, 0] = 0 257 | vis[pts2d[:,1]<0, 0] = 0 258 | vis[pts2d[:,1]>H, 0] = 0 259 | keypoints = torch.cat((pts2d, vis),dim=-1).view(1, self.npts, 3) 260 | 261 | image_id = torch.tensor([idx]) 262 | area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) 263 | iscrowd = torch.zeros((num_objs,), dtype=torch.int64) 264 | 265 | target = {} 266 | target["boxes"] = boxes.clone() 267 | target["labels"] = labels.clone() 268 | target["image_id"] = image_id.clone() 269 | target["area"] = area.clone() 270 | target["iscrowd"] = iscrowd.clone() 271 | target["keypoints"] = keypoints.clone() 272 | target['fmco1'] = tru.clone() 273 | target['fmco2'] = fal.clone() 274 | 275 | return img.float()/255, target 276 | 277 | 278 | class lm_with_synt(Dataset): 279 | def __init__(self, cfg): 280 | self.objid = get_objid(cfg.obj) 281 | self.synt_root = cfg.LM_SYNT_DIR 282 | self.lm_root = cfg.LM_DIR 283 | self.data_path = os.path.join(self.lm_root,'{:06d}/rgb'.format(self.objid)) 284 | self.PMs = get_lm_PM_gt(self.lm_root, self.objid) 285 | self.PMs_synt = get_lm_synt_PM_gt(self.synt_root, self.objid) 286 | self.pts3d = torch.tensor(loadmat('dataset/fps/lm/obj{:02d}_fps128.mat'.format(self.objid))['fps'])[:cfg.N_PTS,:] 287 | self.npts = cfg.N_PTS 288 | self.cfg = cfg 289 | self.K = get_K('cpu') 290 | self.n_lm = len(self.PMs) 291 | self.n_synt = len(self.PMs_synt) 292 | 293 | def __len__(self,): 294 | return self.n_lm + self.n_synt 295 | 296 | def __getitem__(self, idx): 297 | idx0 = idx 298 | if idx < self.n_lm: 299 | img = imageio.imread(os.path.join(self.data_path, '{:06d}.png'.format(idx))) 300 | img = torch.tensor(img).permute(2,0,1) 301 | R = torch.tensor(self.PMs[str(idx)][0]['cam_R_m2c']).view(1,3,3) 302 | T = 0.1*torch.tensor(self.PMs[str(idx)][0]['cam_t_m2c']).view(1,3,1) 303 | PM = torch.cat((R,T),dim=-1) 304 | pts2d = batch_project(PM, self.pts3d, self.K, angle_axis=False).squeeze() 305 | is_synt = False 306 | else: 307 | idx = idx-self.n_lm 308 | img = imageio.imread(os.path.join(self.synt_root, '{:06d}/rgb/{:06d}.png'.format(self.objid, idx))) 309 | img = torch.tensor(img).permute(2,0,1) 310 | R = torch.tensor(self.PMs_synt[str(idx)][0]['cam_R_m2c']).view(1,3,3) 311 | T = 0.1*torch.tensor(self.PMs_synt[str(idx)][0]['cam_t_m2c']).view(1,3,1) 312 | PM = torch.cat((R,T),dim=-1) 313 | pts2d = batch_project(PM, self.pts3d, self.K, angle_axis=False).squeeze() 314 | is_synt = True 315 | 316 | img1, img2, pts2d = augment_lm(img, pts2d, 'cpu', is_synt) 317 | if torch.rand(1) < 0.95: 318 | img2, _ = occlude_obj(img2.clone(), pts2d.clone(), p_occlude=(0.15, 0.7)) 319 | img2 = blackout(img2, pts2d.clone()) 320 | 321 | num_objs = 1 322 | W, H = self.cfg.MODEL.IMAGE_SIZE 323 | 324 | boxes = get_bbox(pts2d, self.cfg.MODEL.IMAGE_SIZE, coco_format=True) 325 | boxes = torch.as_tensor(boxes, dtype=torch.float32).view(1,4) 326 | labels = torch.ones((num_objs,), dtype=torch.int64) 327 | vis = torch.ones(self.npts, 1) 328 | vis[pts2d[:,0]<0, 0] = 0 329 | vis[pts2d[:,0]>W, 0] = 0 330 | vis[pts2d[:,1]<0, 0] = 0 331 | vis[pts2d[:,1]>H, 0] = 0 332 | keypoints = torch.cat((pts2d, vis),dim=-1).view(1, self.npts, 3) 333 | 334 | image_id = torch.tensor([idx0]) 335 | area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) 336 | iscrowd = torch.zeros((num_objs,), dtype=torch.int64) 337 | 338 | target1 = {} 339 | target1["boxes"] = boxes.clone() 340 | target1["labels"] = labels.clone() 341 | target1["image_id"] = image_id.clone() 342 | target1["area"] = area.clone() 343 | target1["iscrowd"] = iscrowd.clone() 344 | target1["keypoints"] = keypoints.clone() 345 | 346 | target2 = {} 347 | target2["boxes"] = boxes.clone() 348 | target2["labels"] = labels.clone() 349 | target2["image_id"] = image_id.clone() 350 | target2["area"] = area.clone() 351 | target2["iscrowd"] = iscrowd.clone() 352 | target2["keypoints"] = keypoints.clone() 353 | 354 | return img1.float()/255, img2.float()/255, target1, target2 355 | 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | 366 | 367 | 368 | 369 | 370 | 371 | -------------------------------------------------------------------------------- /libs/functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from reference.utils import all_gather 3 | from . import utils 4 | import numpy as np 5 | import cv2 as cv 6 | import kornia as kn 7 | import rowan 8 | from tqdm import tqdm 9 | from scipy.io import loadmat 10 | import json 11 | 12 | class Evaluator(object): 13 | def __init__(self): 14 | self.img_ids = [] 15 | self.outputs = [] 16 | self.img_ids_all = [] 17 | self.outputs_all = [] 18 | self.has_gathered = False 19 | 20 | def update(self, res): 21 | for img_id, output in res.items(): 22 | self.img_ids.append(img_id) 23 | self.outputs.append(output) 24 | 25 | def gather_all(self): 26 | ids_list = all_gather(self.img_ids) 27 | outputs_list = all_gather(self.outputs) 28 | assert len(ids_list)==len(outputs_list) 29 | for ids, outs in zip(ids_list, outputs_list): 30 | self.img_ids_all.extend(ids) 31 | self.outputs_all.extend(outs) 32 | self.has_gathered = True 33 | 34 | def get_accuracy(self, cfg, epoch, n_test, testset_name, n_min=4, thres=1, logger=None): 35 | assert self.has_gathered 36 | n_correct_sum1 = 0 37 | n_correct_sum2 = 0 38 | n_correct_sum3 = 0 39 | n_correct_sumc = 0 40 | n_correct_sum1s = 0 41 | n_correct_sum2s = 0 42 | n_correct_sum3s = 0 43 | n_correct_sumcs = 0 44 | size_vec = torch.tensor([cfg.MODEL.IMAGE_SIZE],dtype=torch.float) 45 | if testset_name=='lm': 46 | pts3d_fps = utils.get_lm_o_fps(cfg) 47 | pts_model_h = utils.get_lm_o_3dmodel(cfg,homo=True) 48 | n_dataset = utils.get_lm_dataset_size(cfg) 49 | if testset_name=='lmo': 50 | pts3d_fps = utils.get_lm_o_fps(cfg) 51 | pts_model_h = utils.get_lm_o_3dmodel(cfg,homo=True) 52 | n_dataset = 1214 53 | pts3d_h = torch.cat((pts3d_fps, torch.ones(cfg.N_PTS, 1)), dim=-1) 54 | K = utils.get_K() 55 | diameter = utils.get_distance(cfg) 56 | 57 | pts2d_record1 = torch.zeros(n_dataset, cfg.N_PTS, 2) 58 | pts2d_record2 = torch.zeros(n_dataset, cfg.N_PTS, 2) 59 | pts2d_record3 = torch.zeros(n_dataset, cfg.N_PTS, 2) 60 | boxes = torch.zeros(n_dataset, 4) 61 | pose_record1 = torch.zeros(n_dataset, 6) 62 | pose_record2 = torch.zeros(n_dataset, 6) 63 | pose_record3 = torch.zeros(n_dataset, 6) 64 | pose_recordc = torch.zeros(n_dataset, 6) 65 | corrects1 = torch.zeros(n_dataset, 1) 66 | corrects2 = torch.zeros(n_dataset, 1) 67 | corrects3 = torch.zeros(n_dataset, 1) 68 | correctsc = torch.zeros(n_dataset, 1) 69 | for idx, out_dict in tqdm(zip(self.img_ids_all, self.outputs_all)): 70 | if len(out_dict['scores'])==0: 71 | continue 72 | PM_gt = utils.get_PM_gt(cfg, testset_name, idx) 73 | top1id = out_dict['scores'].argmax() 74 | box_pred = out_dict['boxes'][top1id] 75 | boxes[idx, :] = box_pred.detach() 76 | 77 | pts2d_out_coord1 = out_dict['keypoints1'][top1id,:,:2].unsqueeze(0) 78 | pts2d_out_coord2 = out_dict['keypoints2'][top1id,:,:2].unsqueeze(0) 79 | pts2d_out_coord3 = out_dict['keypoints3'][top1id,:,:2].unsqueeze(0) 80 | 81 | P1 = pnp(pts2d_out_coord1, pts3d_fps, K) 82 | pose_record1[idx, :] = P1.detach() 83 | pts2d_record1[idx, :, :] = pts2d_out_coord1.detach() 84 | n_correct1, _, ticks1 = utils.ADD_accuracy_withID(P1, pts_model_h, PM_gt, diameter) 85 | n_correct_sum1 += n_correct1 86 | if cfg.obj == 'eggbox' or cfg.obj == 'glue': 87 | n_correct1s, _, ticks1s = utils.ADDS_accuracy_withID(P1, pts_model_h, PM_gt, diameter) 88 | n_correct_sum1s += n_correct1s 89 | corrects1[idx, :] = ticks1.detach() 90 | 91 | P2 = pnp(pts2d_out_coord2, pts3d_fps, K) 92 | pose_record2[idx, :] = P2.detach() 93 | pts2d_record2[idx, :, :] = pts2d_out_coord2.detach() 94 | n_correct2, _, ticks2 = utils.ADD_accuracy_withID(P2, pts_model_h, PM_gt, diameter) 95 | n_correct_sum2 += n_correct2 96 | if cfg.obj == 'eggbox' or cfg.obj == 'glue': 97 | n_correct2s, _, ticks2s = utils.ADDS_accuracy_withID(P2, pts_model_h, PM_gt, diameter) 98 | n_correct_sum2s += n_correct2s 99 | corrects2[idx, :] = ticks2.detach() 100 | 101 | P3 = pnp(pts2d_out_coord3, pts3d_fps, K) 102 | pose_record3[idx, :] = P3.detach() 103 | pts2d_record3[idx, :, :] = pts2d_out_coord3.detach() 104 | n_correct3, _, ticks3 = utils.ADD_accuracy_withID(P3, pts_model_h, PM_gt, diameter) 105 | n_correct_sum3 += n_correct3 106 | if cfg.obj == 'eggbox' or cfg.obj == 'glue': 107 | n_correct3s, _, ticks3s = utils.ADDS_accuracy_withID(P3, pts_model_h, PM_gt, diameter) 108 | n_correct_sum3s += n_correct3s 109 | corrects3[idx, :] = ticks3.detach() 110 | 111 | pts2d_out_coordc, consist_ids = utils.get_kp_consensus_aslist(pts2d_out_coord1, pts2d_out_coord2, pts2d_out_coord3, thres, n_min) 112 | Pc = [pnp(pts2d_out_coordc[i].unsqueeze(0), pts3d_fps[consist_ids[i]], K) for i in range(len(pts2d_out_coordc))] 113 | Pc = torch.cat(tuple(Pc), dim=0) 114 | 115 | pose_recordc[idx, :] = Pc.detach() 116 | n_correctc, _, ticksc = utils.ADD_accuracy_withID(Pc, pts_model_h, PM_gt, diameter) 117 | n_correct_sumc += n_correctc 118 | if cfg.obj == 'eggbox' or cfg.obj == 'glue': 119 | n_correctcs, _, tickscs = utils.ADDS_accuracy_withID(Pc, pts_model_h, PM_gt, diameter) 120 | n_correct_sumcs += n_correctcs 121 | correctsc[idx, :] = ticksc.detach() 122 | 123 | if logger is not None: 124 | if cfg.obj == 'eggbox' or cfg.obj == 'glue': 125 | logger.info('Epoch {:3d} {:s} {} test, {}, s1:{}, s2:{}, s3:{}, n_min: {:d}, thres: {}, ADD1: {:1.4f}, ADD2: {:1.4f}, ADD3: {:1.4f}, \ 126 | ADDc: {:1.4f}, {:s}, ADD1s: {:1.4f}, ADD2s: {:1.4f}, ADD3s: {:1.4f}, ADDcs: {:1.4f}'.format(epoch, testset_name, cfg.obj, \ 127 | cfg.log_name, cfg.sigma1, cfg.sigma2, cfg.sigma3, n_min, thres, n_correct_sum1/n_test, n_correct_sum2/n_test, n_correct_sum3/n_test, \ 128 | n_correct_sumc/n_test, cfg.obj, n_correct_sum1s/n_test, n_correct_sum2s/n_test, n_correct_sum3s/n_test, n_correct_sumcs/n_test)) 129 | else: 130 | logger.info('Epoch {:3d} {:s} {} test, {}, s1:{}, s2:{}, s3:{}, n_min: {:d}, thres: {}, ADD1: {:1.4f}, ADD2: {:1.4f}, ADD3: {:1.4f} \ 131 | ADDc: {:1.4f}'.format(epoch, testset_name, cfg.obj, cfg.log_name, cfg.sigma1, cfg.sigma2, cfg.sigma3, n_min, thres, \ 132 | n_correct_sum1/n_test, n_correct_sum2/n_test, n_correct_sum3/n_test, n_correct_sumc/n_test)) 133 | return boxes, pose_record1, pose_record2, pose_record3, pose_recordc, pts2d_record1, pts2d_record2, pts2d_record3, corrects1, corrects2, \ 134 | corrects3, correctsc 135 | 136 | 137 | def get_ycbv_accuracy(self, cfg, epoch, n_test, testset_name, n_min=4, thres=1, logger=None): 138 | assert self.has_gathered 139 | n_correct_sum1 = 0 140 | n_correct_sum2 = 0 141 | n_correct_sum3 = 0 142 | n_correct_sumc = 0 143 | n_correct_sum1s = 0 144 | n_correct_sum2s = 0 145 | n_correct_sum3s = 0 146 | n_correct_sumcs = 0 147 | size_vec = torch.tensor([cfg.MODEL.IMAGE_SIZE],dtype=torch.float) 148 | 149 | assert testset_name=='ycbv' 150 | pts3d_fps = utils.get_ycbv_fps(cfg) 151 | pts_model_h, is_sym = utils.get_ycbv_3dmodel(cfg,homo=True) 152 | 153 | annos = loadmat(cfg.YCBV_DIR+'/test_annos/obj{:02d}.mat'.format(utils.get_ycbv_objid(cfg.obj))) 154 | seq_ids = annos['seq_ids'].squeeze() 155 | img_ids = annos['img_ids'].squeeze() 156 | Ks = annos['Ks'] 157 | PMs = annos['PMs'] 158 | n_dataset = len(img_ids) 159 | 160 | pts3d_h = torch.cat((pts3d_fps, torch.ones(cfg.N_PTS, 1)), dim=-1) 161 | diameter = utils.get_ycbv_distance(cfg) 162 | 163 | pts2d_record1 = torch.zeros(n_dataset, cfg.N_PTS, 2) 164 | pts2d_record2 = torch.zeros(n_dataset, cfg.N_PTS, 2) 165 | pts2d_record3 = torch.zeros(n_dataset, cfg.N_PTS, 2) 166 | # distance_record = torch.zeros(n_dataset, 1) 167 | boxes = torch.zeros(n_dataset, 4) 168 | pose_record1 = torch.zeros(n_dataset, 6) 169 | pose_record2 = torch.zeros(n_dataset, 6) 170 | pose_record3 = torch.zeros(n_dataset, 6) 171 | pose_recordc = torch.zeros(n_dataset, 6) 172 | corrects1 = torch.zeros(n_dataset, 1) 173 | corrects2 = torch.zeros(n_dataset, 1) 174 | corrects3 = torch.zeros(n_dataset, 1) 175 | correctsc = torch.zeros(n_dataset, 1) 176 | for idx, out_dict in tqdm(zip(self.img_ids_all, self.outputs_all)): 177 | if len(out_dict['scores'])==0: 178 | continue 179 | PM_gt = torch.tensor(PMs[idx]).float().unsqueeze(0) 180 | K = torch.tensor(Ks[idx]).float() 181 | top1id = out_dict['scores'].argmax() 182 | box_pred = out_dict['boxes'][top1id] 183 | boxes[idx, :] = box_pred.detach() 184 | 185 | pts2d_out_coord1 = out_dict['keypoints1'][top1id,:,:2].unsqueeze(0) 186 | pts2d_out_coord2 = out_dict['keypoints2'][top1id,:,:2].unsqueeze(0) 187 | pts2d_out_coord3 = out_dict['keypoints3'][top1id,:,:2].unsqueeze(0) 188 | 189 | P1 = pnp(pts2d_out_coord1, pts3d_fps, K) 190 | pose_record1[idx, :] = P1.detach() 191 | pts2d_record1[idx, :, :] = pts2d_out_coord1.detach() 192 | n_correct1, _, ticks1 = utils.ADD_accuracy_withID(P1, pts_model_h, PM_gt, diameter) 193 | n_correct_sum1 += n_correct1 194 | if is_sym: 195 | n_correct1s, _, ticks1s = utils.ADDS_accuracy_withID(P1, pts_model_h, PM_gt, diameter) 196 | n_correct_sum1s += n_correct1s 197 | corrects1[idx, :] = ticks1.detach() 198 | 199 | P2 = pnp(pts2d_out_coord2, pts3d_fps, K) 200 | pose_record2[idx, :] = P2.detach() 201 | pts2d_record2[idx, :, :] = pts2d_out_coord2.detach() 202 | n_correct2, _, ticks2 = utils.ADD_accuracy_withID(P2, pts_model_h, PM_gt, diameter) 203 | n_correct_sum2 += n_correct2 204 | if is_sym: 205 | n_correct2s, _, ticks2s = utils.ADDS_accuracy_withID(P2, pts_model_h, PM_gt, diameter) 206 | n_correct_sum2s += n_correct2s 207 | corrects2[idx, :] = ticks2.detach() 208 | 209 | P3 = pnp(pts2d_out_coord3, pts3d_fps, K) 210 | pose_record3[idx, :] = P3.detach() 211 | pts2d_record3[idx, :, :] = pts2d_out_coord3.detach() 212 | n_correct3, _, ticks3 = utils.ADD_accuracy_withID(P3, pts_model_h, PM_gt, diameter) 213 | n_correct_sum3 += n_correct3 214 | if is_sym: 215 | n_correct3s, _, ticks3s = utils.ADDS_accuracy_withID(P3, pts_model_h, PM_gt, diameter) 216 | n_correct_sum3s += n_correct3s 217 | corrects3[idx, :] = ticks3.detach() 218 | 219 | pts2d_out_coordc, consist_ids = utils.get_kp_consensus_aslist(pts2d_out_coord1, pts2d_out_coord2, pts2d_out_coord3, thres, n_min) 220 | Pc = [pnp(pts2d_out_coordc[i].unsqueeze(0), pts3d_fps[consist_ids[i]], K) for i in range(len(pts2d_out_coordc))] 221 | Pc = torch.cat(tuple(Pc), dim=0) 222 | 223 | pose_recordc[idx, :] = Pc.detach() 224 | n_correctc, _, ticksc = utils.ADD_accuracy_withID(Pc, pts_model_h, PM_gt, diameter) 225 | n_correct_sumc += n_correctc 226 | if is_sym: 227 | n_correctcs, _, tickscs = utils.ADDS_accuracy_withID(Pc, pts_model_h, PM_gt, diameter) 228 | n_correct_sumcs += n_correctcs 229 | correctsc[idx, :] = ticksc.detach() 230 | 231 | if logger is not None: 232 | if is_sym: 233 | logger.info('Epoch {:3d} {:s} obj{} test, {}, s1:{}, s2:{}, s3:{}, n_min: {:d}, thres: {}, ADD1: {:1.4f}, ADD2: {:1.4f}, ADD3: {:1.4f}, \ 234 | ADDc: {:1.4f}, {:s}, ADD1s: {:1.4f}, ADD2s: {:1.4f}, ADD3s: {:1.4f}, ADDcs: {:1.4f}'.format(epoch, testset_name, cfg.obj, \ 235 | cfg.log_name, cfg.sigma1, cfg.sigma2, cfg.sigma3, n_min, thres, n_correct_sum1/n_test, n_correct_sum2/n_test, n_correct_sum3/n_test, \ 236 | n_correct_sumc/n_test, cfg.obj, n_correct_sum1s/n_test, n_correct_sum2s/n_test, n_correct_sum3s/n_test, n_correct_sumcs/n_test)) 237 | else: 238 | logger.info('Epoch {:3d} {:s} obj{} test, {}, s1:{}, s2:{}, s3:{}, n_min: {:d}, thres: {}, ADD1: {:1.4f}, ADD2: {:1.4f}, ADD3: {:1.4f} \ 239 | ADDc: {:1.4f}'.format(epoch, testset_name, cfg.obj, cfg.log_name, cfg.sigma1, cfg.sigma2, cfg.sigma3, n_min, thres, \ 240 | n_correct_sum1/n_test, n_correct_sum2/n_test, n_correct_sum3/n_test, n_correct_sumc/n_test)) 241 | return boxes, pose_record1, pose_record2, pose_record3, pose_recordc, pts2d_record1, pts2d_record2, pts2d_record3, corrects1, corrects2, \ 242 | corrects3, correctsc, seq_ids, img_ids 243 | 244 | 245 | 246 | def pnp(pts2d, pts3d, K): 247 | bs = pts2d.size(0) 248 | n = pts2d.size(1) 249 | device = pts2d.device 250 | pts3d_np = np.array(pts3d.detach().cpu()) 251 | K_np = np.array(K.cpu()) 252 | P_6d = torch.zeros(bs,6,device=device) 253 | R_inv = torch.tensor([[-1,0,0],[0,-1,0],[0,0,-1]],device=device,dtype=torch.float) 254 | 255 | for i in range(bs): 256 | pts2d_i_np = np.ascontiguousarray(pts2d[i].detach().cpu()).reshape((n,1,2)) 257 | _, rvec, T, _ = cv.solvePnPRansac(objectPoints=pts3d_np, imagePoints=pts2d_i_np, cameraMatrix=K_np, distCoeffs=None, flags=cv.SOLVEPNP_ITERATIVE, useExtrinsicGuess=True) 258 | angle_axis = torch.tensor(rvec,device=device,dtype=torch.float).view(1, 3) 259 | T = torch.tensor(T,device=device,dtype=torch.float).view(1, 3) 260 | if T[0,2] < 0: 261 | RR = kn.angle_axis_to_rotation_matrix(angle_axis) 262 | RR = R_inv.matmul(RR) 263 | RR = rowan.from_matrix(RR.cpu(), require_orthogonal=False) 264 | ax = rowan.to_axis_angle(RR) 265 | angle_axis = torch.tensor(ax[0]*ax[1],device=device,dtype=torch.float).view(1,3) 266 | T = R_inv.matmul(T.t()).t() 267 | P_6d[i,:] = torch.cat((angle_axis,T),dim=-1) 268 | return P_6d 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | -------------------------------------------------------------------------------- /detection/_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.jit.annotations import List, Tuple 5 | from torch import Tensor 6 | import torchvision 7 | 8 | 9 | class BalancedPositiveNegativeSampler(object): 10 | """ 11 | This class samples batches, ensuring that they contain a fixed proportion of positives 12 | """ 13 | 14 | def __init__(self, batch_size_per_image, positive_fraction): 15 | # type: (int, float) -> None 16 | """ 17 | Arguments: 18 | batch_size_per_image (int): number of elements to be selected per image 19 | positive_fraction (float): percentace of positive elements per batch 20 | """ 21 | self.batch_size_per_image = batch_size_per_image 22 | self.positive_fraction = positive_fraction 23 | 24 | def __call__(self, matched_idxs): 25 | # type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]] 26 | """ 27 | Arguments: 28 | matched idxs: list of tensors containing -1, 0 or positive values. 29 | Each tensor corresponds to a specific image. 30 | -1 values are ignored, 0 are considered as negatives and > 0 as 31 | positives. 32 | 33 | Returns: 34 | pos_idx (list[tensor]) 35 | neg_idx (list[tensor]) 36 | 37 | Returns two lists of binary masks for each image. 38 | The first list contains the positive elements that were selected, 39 | and the second list the negative example. 40 | """ 41 | pos_idx = [] 42 | neg_idx = [] 43 | for matched_idxs_per_image in matched_idxs: 44 | positive = torch.nonzero(matched_idxs_per_image >= 1).squeeze(1) 45 | negative = torch.nonzero(matched_idxs_per_image == 0).squeeze(1) 46 | 47 | num_pos = int(self.batch_size_per_image * self.positive_fraction) 48 | # protect against not enough positive examples 49 | num_pos = min(positive.numel(), num_pos) 50 | num_neg = self.batch_size_per_image - num_pos 51 | # protect against not enough negative examples 52 | num_neg = min(negative.numel(), num_neg) 53 | 54 | # randomly select positive and negative examples 55 | perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos] 56 | perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg] 57 | 58 | pos_idx_per_image = positive[perm1] 59 | neg_idx_per_image = negative[perm2] 60 | 61 | # create binary mask from indices 62 | pos_idx_per_image_mask = torch.zeros_like( 63 | matched_idxs_per_image, dtype=torch.uint8 64 | ) 65 | neg_idx_per_image_mask = torch.zeros_like( 66 | matched_idxs_per_image, dtype=torch.uint8 67 | ) 68 | 69 | pos_idx_per_image_mask[pos_idx_per_image] = 1 70 | neg_idx_per_image_mask[neg_idx_per_image] = 1 71 | 72 | pos_idx.append(pos_idx_per_image_mask) 73 | neg_idx.append(neg_idx_per_image_mask) 74 | 75 | return pos_idx, neg_idx 76 | 77 | 78 | @torch.jit._script_if_tracing 79 | def encode_boxes(reference_boxes, proposals, weights): 80 | # type: (torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor 81 | """ 82 | Encode a set of proposals with respect to some 83 | reference boxes 84 | 85 | Arguments: 86 | reference_boxes (Tensor): reference boxes 87 | proposals (Tensor): boxes to be encoded 88 | """ 89 | 90 | # perform some unpacking to make it JIT-fusion friendly 91 | wx = weights[0] 92 | wy = weights[1] 93 | ww = weights[2] 94 | wh = weights[3] 95 | 96 | proposals_x1 = proposals[:, 0].unsqueeze(1) 97 | proposals_y1 = proposals[:, 1].unsqueeze(1) 98 | proposals_x2 = proposals[:, 2].unsqueeze(1) 99 | proposals_y2 = proposals[:, 3].unsqueeze(1) 100 | 101 | reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1) 102 | reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1) 103 | reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1) 104 | reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1) 105 | 106 | # implementation starts here 107 | ex_widths = proposals_x2 - proposals_x1 108 | ex_heights = proposals_y2 - proposals_y1 109 | ex_ctr_x = proposals_x1 + 0.5 * ex_widths 110 | ex_ctr_y = proposals_y1 + 0.5 * ex_heights 111 | 112 | gt_widths = reference_boxes_x2 - reference_boxes_x1 113 | gt_heights = reference_boxes_y2 - reference_boxes_y1 114 | gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths 115 | gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights 116 | 117 | targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths 118 | targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights 119 | targets_dw = ww * torch.log(gt_widths / ex_widths) 120 | targets_dh = wh * torch.log(gt_heights / ex_heights) 121 | 122 | targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh), dim=1) 123 | return targets 124 | 125 | 126 | class BoxCoder(object): 127 | """ 128 | This class encodes and decodes a set of bounding boxes into 129 | the representation used for training the regressors. 130 | """ 131 | 132 | def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)): 133 | # type: (Tuple[float, float, float, float], float) -> None 134 | """ 135 | Arguments: 136 | weights (4-element tuple) 137 | bbox_xform_clip (float) 138 | """ 139 | self.weights = weights 140 | self.bbox_xform_clip = bbox_xform_clip 141 | 142 | def encode(self, reference_boxes, proposals): 143 | # type: (List[Tensor], List[Tensor]) -> List[Tensor] 144 | boxes_per_image = [len(b) for b in reference_boxes] 145 | reference_boxes = torch.cat(reference_boxes, dim=0) 146 | proposals = torch.cat(proposals, dim=0) 147 | targets = self.encode_single(reference_boxes, proposals) 148 | return targets.split(boxes_per_image, 0) 149 | 150 | def encode_single(self, reference_boxes, proposals): 151 | """ 152 | Encode a set of proposals with respect to some 153 | reference boxes 154 | 155 | Arguments: 156 | reference_boxes (Tensor): reference boxes 157 | proposals (Tensor): boxes to be encoded 158 | """ 159 | dtype = reference_boxes.dtype 160 | device = reference_boxes.device 161 | weights = torch.as_tensor(self.weights, dtype=dtype, device=device) 162 | targets = encode_boxes(reference_boxes, proposals, weights) 163 | 164 | return targets 165 | 166 | def decode(self, rel_codes, boxes): 167 | # type: (Tensor, List[Tensor]) -> Tensor 168 | assert isinstance(boxes, (list, tuple)) 169 | assert isinstance(rel_codes, torch.Tensor) 170 | boxes_per_image = [b.size(0) for b in boxes] 171 | concat_boxes = torch.cat(boxes, dim=0) 172 | box_sum = 0 173 | for val in boxes_per_image: 174 | box_sum += val 175 | pred_boxes = self.decode_single( 176 | rel_codes.reshape(box_sum, -1), concat_boxes 177 | ) 178 | return pred_boxes.reshape(box_sum, -1, 4) 179 | 180 | def decode_single(self, rel_codes, boxes): 181 | """ 182 | From a set of original boxes and encoded relative box offsets, 183 | get the decoded boxes. 184 | 185 | Arguments: 186 | rel_codes (Tensor): encoded boxes 187 | boxes (Tensor): reference boxes. 188 | """ 189 | 190 | boxes = boxes.to(rel_codes.dtype) 191 | 192 | widths = boxes[:, 2] - boxes[:, 0] 193 | heights = boxes[:, 3] - boxes[:, 1] 194 | ctr_x = boxes[:, 0] + 0.5 * widths 195 | ctr_y = boxes[:, 1] + 0.5 * heights 196 | 197 | wx, wy, ww, wh = self.weights 198 | dx = rel_codes[:, 0::4] / wx 199 | dy = rel_codes[:, 1::4] / wy 200 | dw = rel_codes[:, 2::4] / ww 201 | dh = rel_codes[:, 3::4] / wh 202 | 203 | # Prevent sending too large values into torch.exp() 204 | dw = torch.clamp(dw, max=self.bbox_xform_clip) 205 | dh = torch.clamp(dh, max=self.bbox_xform_clip) 206 | 207 | pred_ctr_x = dx * widths[:, None] + ctr_x[:, None] 208 | pred_ctr_y = dy * heights[:, None] + ctr_y[:, None] 209 | pred_w = torch.exp(dw) * widths[:, None] 210 | pred_h = torch.exp(dh) * heights[:, None] 211 | 212 | pred_boxes1 = pred_ctr_x - torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w 213 | pred_boxes2 = pred_ctr_y - torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h 214 | pred_boxes3 = pred_ctr_x + torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w 215 | pred_boxes4 = pred_ctr_y + torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h 216 | pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1) 217 | return pred_boxes 218 | 219 | 220 | class Matcher(object): 221 | """ 222 | This class assigns to each predicted "element" (e.g., a box) a ground-truth 223 | element. Each predicted element will have exactly zero or one matches; each 224 | ground-truth element may be assigned to zero or more predicted elements. 225 | 226 | Matching is based on the MxN match_quality_matrix, that characterizes how well 227 | each (ground-truth, predicted)-pair match. For example, if the elements are 228 | boxes, the matrix may contain box IoU overlap values. 229 | 230 | The matcher returns a tensor of size N containing the index of the ground-truth 231 | element m that matches to prediction n. If there is no match, a negative value 232 | is returned. 233 | """ 234 | 235 | BELOW_LOW_THRESHOLD = -1 236 | BETWEEN_THRESHOLDS = -2 237 | 238 | __annotations__ = { 239 | 'BELOW_LOW_THRESHOLD': int, 240 | 'BETWEEN_THRESHOLDS': int, 241 | } 242 | 243 | def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False): 244 | # type: (float, float, bool) -> None 245 | """ 246 | Args: 247 | high_threshold (float): quality values greater than or equal to 248 | this value are candidate matches. 249 | low_threshold (float): a lower quality threshold used to stratify 250 | matches into three levels: 251 | 1) matches >= high_threshold 252 | 2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold) 253 | 3) BELOW_LOW_THRESHOLD matches in [0, low_threshold) 254 | allow_low_quality_matches (bool): if True, produce additional matches 255 | for predictions that have only low-quality match candidates. See 256 | set_low_quality_matches_ for more details. 257 | """ 258 | self.BELOW_LOW_THRESHOLD = -1 259 | self.BETWEEN_THRESHOLDS = -2 260 | assert low_threshold <= high_threshold 261 | self.high_threshold = high_threshold 262 | self.low_threshold = low_threshold 263 | self.allow_low_quality_matches = allow_low_quality_matches 264 | 265 | def __call__(self, match_quality_matrix): 266 | """ 267 | Args: 268 | match_quality_matrix (Tensor[float]): an MxN tensor, containing the 269 | pairwise quality between M ground-truth elements and N predicted elements. 270 | 271 | Returns: 272 | matches (Tensor[int64]): an N tensor where N[i] is a matched gt in 273 | [0, M - 1] or a negative value indicating that prediction i could not 274 | be matched. 275 | """ 276 | if match_quality_matrix.numel() == 0: 277 | # empty targets or proposals not supported during training 278 | if match_quality_matrix.shape[0] == 0: 279 | raise ValueError( 280 | "No ground-truth boxes available for one of the images " 281 | "during training") 282 | else: 283 | raise ValueError( 284 | "No proposal boxes available for one of the images " 285 | "during training") 286 | 287 | # match_quality_matrix is M (gt) x N (predicted) 288 | # Max over gt elements (dim 0) to find best gt candidate for each prediction 289 | matched_vals, matches = match_quality_matrix.max(dim=0) 290 | if self.allow_low_quality_matches: 291 | all_matches = matches.clone() 292 | else: 293 | all_matches = None 294 | 295 | # Assign candidate matches with low quality to negative (unassigned) values 296 | below_low_threshold = matched_vals < self.low_threshold 297 | between_thresholds = (matched_vals >= self.low_threshold) & ( 298 | matched_vals < self.high_threshold 299 | ) 300 | matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD 301 | matches[between_thresholds] = self.BETWEEN_THRESHOLDS 302 | 303 | if self.allow_low_quality_matches: 304 | assert all_matches is not None 305 | self.set_low_quality_matches_(matches, all_matches, match_quality_matrix) 306 | 307 | return matches 308 | 309 | def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix): 310 | """ 311 | Produce additional matches for predictions that have only low-quality matches. 312 | Specifically, for each ground-truth find the set of predictions that have 313 | maximum overlap with it (including ties); for each prediction in that set, if 314 | it is unmatched, then match it to the ground-truth with which it has the highest 315 | quality value. 316 | """ 317 | # For each gt, find the prediction with which it has highest quality 318 | highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1) 319 | # Find highest quality match available, even if it is low, including ties 320 | gt_pred_pairs_of_highest_quality = torch.nonzero( 321 | match_quality_matrix == highest_quality_foreach_gt[:, None] 322 | ) 323 | # Example gt_pred_pairs_of_highest_quality: 324 | # tensor([[ 0, 39796], 325 | # [ 1, 32055], 326 | # [ 1, 32070], 327 | # [ 2, 39190], 328 | # [ 2, 40255], 329 | # [ 3, 40390], 330 | # [ 3, 41455], 331 | # [ 4, 45470], 332 | # [ 5, 45325], 333 | # [ 5, 46390]]) 334 | # Each row is a (gt index, prediction index) 335 | # Note how gt items 1, 2, 3, and 5 each have two ties 336 | 337 | pred_inds_to_update = gt_pred_pairs_of_highest_quality[:, 1] 338 | matches[pred_inds_to_update] = all_matches[pred_inds_to_update] 339 | 340 | 341 | def smooth_l1_loss(input, target, beta: float = 1. / 9, size_average: bool = True): 342 | """ 343 | very similar to the smooth_l1_loss from pytorch, but with 344 | the extra beta parameter 345 | """ 346 | n = torch.abs(input - target) 347 | cond = n < beta 348 | loss = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta) 349 | if size_average: 350 | return loss.mean() 351 | return loss.sum() 352 | -------------------------------------------------------------------------------- /detection/keypoint_rcnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from torchvision.ops import MultiScaleRoIAlign, RoIAlign 5 | 6 | from torchvision.models.utils import load_state_dict_from_url 7 | 8 | from .faster_rcnn import FasterRCNN 9 | from model.hrnet_backbone import get_hrnet 10 | 11 | 12 | class KeypointRCNN(FasterRCNN): 13 | """ 14 | Implements Keypoint R-CNN. 15 | 16 | The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each 17 | image, and should be in 0-1 range. Different images can have different sizes. 18 | 19 | The behavior of the model changes depending if it is in training or evaluation mode. 20 | 21 | During training, the model expects both the input tensors, as well as a targets (list of dictionary), 22 | containing: 23 | - boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2] format, with values of x 24 | between 0 and W and values of y between 0 and H 25 | - labels (Int64Tensor[N]): the class label for each ground-truth box 26 | - keypoints (FloatTensor[N, K, 3]): the K keypoints location for each of the N instances, in the 27 | format [x, y, visibility], where visibility=0 means that the keypoint is not visible. 28 | 29 | The model returns a Dict[Tensor] during training, containing the classification and regression 30 | losses for both the RPN and the R-CNN, and the keypoint loss. 31 | 32 | During inference, the model requires only the input tensors, and returns the post-processed 33 | predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as 34 | follows: 35 | - boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values of x 36 | between 0 and W and values of y between 0 and H 37 | - labels (Int64Tensor[N]): the predicted labels for each image 38 | - scores (Tensor[N]): the scores or each prediction 39 | - keypoints (FloatTensor[N, K, 3]): the locations of the predicted keypoints, in [x, y, v] format. 40 | 41 | Arguments: 42 | backbone (nn.Module): the network used to compute the features for the model. 43 | It should contain a out_channels attribute, which indicates the number of output 44 | channels that each feature map has (and it should be the same for all feature maps). 45 | The backbone should return a single Tensor or and OrderedDict[Tensor]. 46 | num_classes (int): number of output classes of the model (including the background). 47 | If box_predictor is specified, num_classes should be None. 48 | min_size (int): minimum size of the image to be rescaled before feeding it to the backbone 49 | max_size (int): maximum size of the image to be rescaled before feeding it to the backbone 50 | image_mean (Tuple[float, float, float]): mean values used for input normalization. 51 | They are generally the mean values of the dataset on which the backbone has been trained 52 | on 53 | image_std (Tuple[float, float, float]): std values used for input normalization. 54 | They are generally the std values of the dataset on which the backbone has been trained on 55 | rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature 56 | maps. 57 | rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN 58 | rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training 59 | rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing 60 | rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training 61 | rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing 62 | rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals 63 | rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be 64 | considered as positive during training of the RPN. 65 | rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be 66 | considered as negative during training of the RPN. 67 | rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN 68 | for computing the loss 69 | rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training 70 | of the RPN 71 | box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in 72 | the locations indicated by the bounding boxes 73 | box_head (nn.Module): module that takes the cropped feature maps as input 74 | box_predictor (nn.Module): module that takes the output of box_head and returns the 75 | classification logits and box regression deltas. 76 | box_score_thresh (float): during inference, only return proposals with a classification score 77 | greater than box_score_thresh 78 | box_nms_thresh (float): NMS threshold for the prediction head. Used during inference 79 | box_detections_per_img (int): maximum number of detections per image, for all classes. 80 | box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be 81 | considered as positive during training of the classification head 82 | box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be 83 | considered as negative during training of the classification head 84 | box_batch_size_per_image (int): number of proposals that are sampled during training of the 85 | classification head 86 | box_positive_fraction (float): proportion of positive proposals in a mini-batch during training 87 | of the classification head 88 | bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the 89 | bounding boxes 90 | keypoint_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in 91 | the locations indicated by the bounding boxes, which will be used for the keypoint head. 92 | keypoint_head (nn.Module): module that takes the cropped feature maps as input 93 | keypoint_predictor (nn.Module): module that takes the output of the keypoint_head and returns the 94 | heatmap logits 95 | 96 | Example:: 97 | 98 | >>> import torch 99 | >>> import torchvision 100 | >>> from torchvision.models.detection import KeypointRCNN 101 | >>> from torchvision.models.detection.rpn import AnchorGenerator 102 | >>> 103 | >>> # load a pre-trained model for classification and return 104 | >>> # only the features 105 | >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features 106 | >>> # KeypointRCNN needs to know the number of 107 | >>> # output channels in a backbone. For mobilenet_v2, it's 1280 108 | >>> # so we need to add it here 109 | >>> backbone.out_channels = 1280 110 | >>> 111 | >>> # let's make the RPN generate 5 x 3 anchors per spatial 112 | >>> # location, with 5 different sizes and 3 different aspect 113 | >>> # ratios. We have a Tuple[Tuple[int]] because each feature 114 | >>> # map could potentially have different sizes and 115 | >>> # aspect ratios 116 | >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),), 117 | >>> aspect_ratios=((0.5, 1.0, 2.0),)) 118 | >>> 119 | >>> # let's define what are the feature maps that we will 120 | >>> # use to perform the region of interest cropping, as well as 121 | >>> # the size of the crop after rescaling. 122 | >>> # if your backbone returns a Tensor, featmap_names is expected to 123 | >>> # be ['0']. More generally, the backbone should return an 124 | >>> # OrderedDict[Tensor], and in featmap_names you can choose which 125 | >>> # feature maps to use. 126 | >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'], 127 | >>> output_size=7, 128 | >>> sampling_ratio=2) 129 | >>> 130 | >>> keypoint_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'], 131 | >>> output_size=14, 132 | >>> sampling_ratio=2) 133 | >>> # put the pieces together inside a KeypointRCNN model 134 | >>> model = KeypointRCNN(backbone, 135 | >>> num_classes=2, 136 | >>> rpn_anchor_generator=anchor_generator, 137 | >>> box_roi_pool=roi_pooler, 138 | >>> keypoint_roi_pool=keypoint_roi_pooler) 139 | >>> model.eval() 140 | >>> model.eval() 141 | >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] 142 | >>> predictions = model(x) 143 | """ 144 | def __init__(self, backbone, num_classes=None, 145 | # transform parameters 146 | min_size=None, max_size=1333, 147 | image_mean=None, image_std=None, 148 | # RPN parameters 149 | rpn_anchor_generator=None, rpn_head=None, 150 | rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000, 151 | rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000, 152 | rpn_nms_thresh=0.7, 153 | rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, 154 | rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, 155 | # Box parameters 156 | box_roi_pool=None, box_head=None, box_predictor=None, 157 | box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100, 158 | box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5, 159 | box_batch_size_per_image=512, box_positive_fraction=0.25, 160 | bbox_reg_weights=None, 161 | # keypoint parameters 162 | keypoint_roi_pool=None, keypoint_head1=None, keypoint_predictor1=None, 163 | keypoint_head2=None, keypoint_predictor2=None, 164 | keypoint_head3=None, keypoint_predictor3=None, 165 | dist_head=None, sigma1=1.5, sigma2=3, sigma3=8, num_keypoints=11): 166 | 167 | assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))) 168 | if min_size is None: 169 | min_size = (640, 672, 704, 736, 768, 800) 170 | 171 | if num_classes is not None: 172 | if keypoint_predictor1 is not None: 173 | raise ValueError("num_classes should be None when keypoint_predictor is specified") 174 | 175 | out_channels = backbone.out_channels 176 | 177 | if keypoint_roi_pool is None: 178 | keypoint_roi_pool = RoIAlign( 179 | output_size=28, 180 | spatial_scale=0.25, 181 | sampling_ratio=2, 182 | aligned=True) 183 | 184 | 185 | if keypoint_head1 is None: 186 | keypoint_layers = tuple(64 for _ in range(8)) 187 | keypoint_head1 = KeypointRCNNHeads(out_channels, keypoint_layers) 188 | 189 | if keypoint_predictor1 is None: 190 | keypoint_dim_reduced = 64 191 | keypoint_predictor1 = KeypointRCNNPredictor(keypoint_dim_reduced, num_keypoints) 192 | 193 | if keypoint_head2 is None: 194 | keypoint_layers = tuple(64 for _ in range(8)) 195 | keypoint_head2 = KeypointRCNNHeads(out_channels, keypoint_layers) 196 | 197 | if keypoint_predictor2 is None: 198 | keypoint_dim_reduced = 64 199 | keypoint_predictor2 = KeypointRCNNPredictor(keypoint_dim_reduced, num_keypoints) 200 | 201 | if keypoint_head3 is None: 202 | keypoint_layers = tuple(64 for _ in range(8)) 203 | keypoint_head3 = KeypointRCNNHeads(out_channels, keypoint_layers) 204 | 205 | if keypoint_predictor3 is None: 206 | keypoint_dim_reduced = 64 207 | keypoint_predictor3 = KeypointRCNNPredictor(keypoint_dim_reduced, num_keypoints) 208 | 209 | 210 | super(KeypointRCNN, self).__init__( 211 | backbone, num_classes, 212 | sigma1, sigma2, sigma3, 213 | # transform parameters 214 | min_size, max_size, 215 | image_mean, image_std, 216 | # RPN-specific parameters 217 | rpn_anchor_generator, rpn_head, 218 | rpn_pre_nms_top_n_train, rpn_pre_nms_top_n_test, 219 | rpn_post_nms_top_n_train, rpn_post_nms_top_n_test, 220 | rpn_nms_thresh, 221 | rpn_fg_iou_thresh, rpn_bg_iou_thresh, 222 | rpn_batch_size_per_image, rpn_positive_fraction, 223 | # Box parameters 224 | box_roi_pool, box_head, box_predictor, 225 | box_score_thresh, box_nms_thresh, box_detections_per_img, 226 | box_fg_iou_thresh, box_bg_iou_thresh, 227 | box_batch_size_per_image, box_positive_fraction, 228 | bbox_reg_weights) 229 | 230 | self.roi_heads.keypoint_roi_pool = keypoint_roi_pool 231 | self.roi_heads.keypoint_head1 = keypoint_head1 232 | self.roi_heads.keypoint_predictor1 = keypoint_predictor1 233 | self.roi_heads.keypoint_head2 = keypoint_head2 234 | self.roi_heads.keypoint_predictor2 = keypoint_predictor2 235 | self.roi_heads.keypoint_head3 = keypoint_head3 236 | self.roi_heads.keypoint_predictor3 = keypoint_predictor3 237 | self.roi_heads.multi_precision = True 238 | 239 | 240 | class KeypointRCNNHeads(nn.Sequential): 241 | def __init__(self, in_channels, layers): 242 | d = [] 243 | next_feature = in_channels 244 | for out_channels in layers: 245 | d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1)) 246 | d.append(nn.ReLU(inplace=True)) 247 | next_feature = out_channels 248 | super(KeypointRCNNHeads, self).__init__(*d) 249 | for m in self.children(): 250 | if isinstance(m, nn.Conv2d): 251 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 252 | nn.init.constant_(m.bias, 0) 253 | 254 | class KeypointRCNNPredictor(nn.Module): 255 | def __init__(self, in_channels, num_keypoints): 256 | super(KeypointRCNNPredictor, self).__init__() 257 | input_features = in_channels 258 | deconv_kernel = 4 259 | self.kps_score_lowres = nn.ConvTranspose2d( 260 | input_features, 261 | num_keypoints, 262 | deconv_kernel, 263 | stride=2, 264 | padding=deconv_kernel // 2 - 1, 265 | ) 266 | nn.init.kaiming_normal_( 267 | self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu" 268 | ) 269 | nn.init.constant_(self.kps_score_lowres.bias, 0) 270 | self.up_scale = 2 271 | self.out_channels = num_keypoints 272 | 273 | def forward(self, x): 274 | x = self.kps_score_lowres(x) 275 | return torch.nn.functional.interpolate( 276 | x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False 277 | ) 278 | 279 | 280 | 281 | def keypointrcnn_hrnet(cfg, resume=False, **kwargs): 282 | backbone = get_hrnet(cfg, resume) 283 | model = KeypointRCNN(backbone, num_classes=2, num_keypoints=cfg.N_PTS, sigma1=cfg.sigma1, 284 | sigma2=cfg.sigma2, sigma3=cfg.sigma3, **kwargs) 285 | 286 | return model 287 | 288 | 289 | 290 | 291 | 292 | 293 | -------------------------------------------------------------------------------- /detection/faster_rcnn.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | from torchvision.ops import misc as misc_nn_ops 8 | from torchvision.ops import MultiScaleRoIAlign 9 | 10 | from torchvision.models.utils import load_state_dict_from_url 11 | 12 | from .generalized_rcnn import GeneralizedRCNN 13 | from .rpn import AnchorGenerator, RPNHead, RegionProposalNetwork 14 | from .roi_heads import RoIHeads 15 | from .transform import GeneralizedRCNNTransform 16 | 17 | 18 | __all__ = [ 19 | "FasterRCNN", "fasterrcnn_resnet50_fpn", 20 | ] 21 | 22 | 23 | class FasterRCNN(GeneralizedRCNN): 24 | """ 25 | Implements Faster R-CNN. 26 | 27 | The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each 28 | image, and should be in 0-1 range. Different images can have different sizes. 29 | 30 | The behavior of the model changes depending if it is in training or evaluation mode. 31 | 32 | During training, the model expects both the input tensors, as well as a targets (list of dictionary), 33 | containing: 34 | - boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2] format, with values of x 35 | between 0 and W and values of y between 0 and H 36 | - labels (Int64Tensor[N]): the class label for each ground-truth box 37 | 38 | The model returns a Dict[Tensor] during training, containing the classification and regression 39 | losses for both the RPN and the R-CNN. 40 | 41 | During inference, the model requires only the input tensors, and returns the post-processed 42 | predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as 43 | follows: 44 | - boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values of x 45 | between 0 and W and values of y between 0 and H 46 | - labels (Int64Tensor[N]): the predicted labels for each image 47 | - scores (Tensor[N]): the scores or each prediction 48 | 49 | Arguments: 50 | backbone (nn.Module): the network used to compute the features for the model. 51 | It should contain a out_channels attribute, which indicates the number of output 52 | channels that each feature map has (and it should be the same for all feature maps). 53 | The backbone should return a single Tensor or and OrderedDict[Tensor]. 54 | num_classes (int): number of output classes of the model (including the background). 55 | If box_predictor is specified, num_classes should be None. 56 | min_size (int): minimum size of the image to be rescaled before feeding it to the backbone 57 | max_size (int): maximum size of the image to be rescaled before feeding it to the backbone 58 | image_mean (Tuple[float, float, float]): mean values used for input normalization. 59 | They are generally the mean values of the dataset on which the backbone has been trained 60 | on 61 | image_std (Tuple[float, float, float]): std values used for input normalization. 62 | They are generally the std values of the dataset on which the backbone has been trained on 63 | rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature 64 | maps. 65 | rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN 66 | rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training 67 | rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing 68 | rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training 69 | rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing 70 | rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals 71 | rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be 72 | considered as positive during training of the RPN. 73 | rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be 74 | considered as negative during training of the RPN. 75 | rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN 76 | for computing the loss 77 | rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training 78 | of the RPN 79 | box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in 80 | the locations indicated by the bounding boxes 81 | box_head (nn.Module): module that takes the cropped feature maps as input 82 | box_predictor (nn.Module): module that takes the output of box_head and returns the 83 | classification logits and box regression deltas. 84 | box_score_thresh (float): during inference, only return proposals with a classification score 85 | greater than box_score_thresh 86 | box_nms_thresh (float): NMS threshold for the prediction head. Used during inference 87 | box_detections_per_img (int): maximum number of detections per image, for all classes. 88 | box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be 89 | considered as positive during training of the classification head 90 | box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be 91 | considered as negative during training of the classification head 92 | box_batch_size_per_image (int): number of proposals that are sampled during training of the 93 | classification head 94 | box_positive_fraction (float): proportion of positive proposals in a mini-batch during training 95 | of the classification head 96 | bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the 97 | bounding boxes 98 | 99 | Example:: 100 | 101 | >>> import torch 102 | >>> import torchvision 103 | >>> from torchvision.models.detection import FasterRCNN 104 | >>> from torchvision.models.detection.rpn import AnchorGenerator 105 | >>> # load a pre-trained model for classification and return 106 | >>> # only the features 107 | >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features 108 | >>> # FasterRCNN needs to know the number of 109 | >>> # output channels in a backbone. For mobilenet_v2, it's 1280 110 | >>> # so we need to add it here 111 | >>> backbone.out_channels = 1280 112 | >>> 113 | >>> # let's make the RPN generate 5 x 3 anchors per spatial 114 | >>> # location, with 5 different sizes and 3 different aspect 115 | >>> # ratios. We have a Tuple[Tuple[int]] because each feature 116 | >>> # map could potentially have different sizes and 117 | >>> # aspect ratios 118 | >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),), 119 | >>> aspect_ratios=((0.5, 1.0, 2.0),)) 120 | >>> 121 | >>> # let's define what are the feature maps that we will 122 | >>> # use to perform the region of interest cropping, as well as 123 | >>> # the size of the crop after rescaling. 124 | >>> # if your backbone returns a Tensor, featmap_names is expected to 125 | >>> # be ['0']. More generally, the backbone should return an 126 | >>> # OrderedDict[Tensor], and in featmap_names you can choose which 127 | >>> # feature maps to use. 128 | >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'], 129 | >>> output_size=7, 130 | >>> sampling_ratio=2) 131 | >>> 132 | >>> # put the pieces together inside a FasterRCNN model 133 | >>> model = FasterRCNN(backbone, 134 | >>> num_classes=2, 135 | >>> rpn_anchor_generator=anchor_generator, 136 | >>> box_roi_pool=roi_pooler) 137 | >>> model.eval() 138 | >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] 139 | >>> predictions = model(x) 140 | """ 141 | 142 | def __init__(self, backbone, num_classes=None, 143 | sigma1=1.5, sigma2=3, sigma3=8, 144 | # transform parameters 145 | min_size=800, max_size=1333, 146 | image_mean=None, image_std=None, 147 | # RPN parameters 148 | rpn_anchor_generator=None, rpn_head=None, 149 | rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000, 150 | rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000, 151 | rpn_nms_thresh=0.7, 152 | rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, 153 | rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, 154 | # Box parameters 155 | box_roi_pool=None, box_head=None, box_predictor=None, 156 | box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100, 157 | box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5, 158 | box_batch_size_per_image=512, box_positive_fraction=0.25, 159 | bbox_reg_weights=None): 160 | 161 | if not hasattr(backbone, "out_channels"): 162 | raise ValueError( 163 | "backbone should contain an attribute out_channels " 164 | "specifying the number of output channels (assumed to be the " 165 | "same for all the levels)") 166 | 167 | assert isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))) 168 | assert isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))) 169 | 170 | if num_classes is not None: 171 | if box_predictor is not None: 172 | raise ValueError("num_classes should be None when box_predictor is specified") 173 | else: 174 | if box_predictor is None: 175 | raise ValueError("num_classes should not be None when box_predictor " 176 | "is not specified") 177 | 178 | out_channels = backbone.out_channels 179 | 180 | if rpn_anchor_generator is None: 181 | anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) 182 | aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) 183 | rpn_anchor_generator = AnchorGenerator( 184 | anchor_sizes, aspect_ratios 185 | ) 186 | if rpn_head is None: 187 | rpn_head = RPNHead( 188 | out_channels, rpn_anchor_generator.num_anchors_per_location()[0] 189 | ) 190 | 191 | rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test) 192 | rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test) 193 | 194 | rpn = RegionProposalNetwork( 195 | rpn_anchor_generator, rpn_head, 196 | rpn_fg_iou_thresh, rpn_bg_iou_thresh, 197 | rpn_batch_size_per_image, rpn_positive_fraction, 198 | rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh) 199 | 200 | if box_roi_pool is None: 201 | box_roi_pool = MultiScaleRoIAlign( 202 | featmap_names=['0', '1', '2', '3'], 203 | output_size=7, 204 | sampling_ratio=2) 205 | 206 | if box_head is None: 207 | resolution = box_roi_pool.output_size[0] 208 | representation_size = 1024 209 | box_head = TwoMLPHead( 210 | out_channels * resolution ** 2, 211 | representation_size) 212 | 213 | if box_predictor is None: 214 | representation_size = 1024 215 | box_predictor = FastRCNNPredictor( 216 | representation_size, 217 | num_classes) 218 | 219 | roi_heads = RoIHeads( 220 | # Box 221 | box_roi_pool, box_head, box_predictor, 222 | box_fg_iou_thresh, box_bg_iou_thresh, 223 | box_batch_size_per_image, box_positive_fraction, 224 | bbox_reg_weights, 225 | box_score_thresh, box_nms_thresh, box_detections_per_img, 226 | sigma1=sigma1, sigma2=sigma2, sigma3=sigma3) 227 | 228 | if image_mean is None: 229 | image_mean = [0.485, 0.456, 0.406] 230 | if image_std is None: 231 | image_std = [0.229, 0.224, 0.225] 232 | transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std) 233 | 234 | super(FasterRCNN, self).__init__(backbone, rpn, roi_heads, transform) 235 | 236 | 237 | class TwoMLPHead(nn.Module): 238 | """ 239 | Standard heads for FPN-based models 240 | 241 | Arguments: 242 | in_channels (int): number of input channels 243 | representation_size (int): size of the intermediate representation 244 | """ 245 | 246 | def __init__(self, in_channels, representation_size): 247 | super(TwoMLPHead, self).__init__() 248 | 249 | self.fc6 = nn.Linear(in_channels, representation_size) 250 | self.fc7 = nn.Linear(representation_size, representation_size) 251 | 252 | def forward(self, x): 253 | x = x.flatten(start_dim=1) 254 | 255 | x = F.relu(self.fc6(x)) 256 | x = F.relu(self.fc7(x)) 257 | 258 | return x 259 | 260 | 261 | class FastRCNNPredictor(nn.Module): 262 | """ 263 | Standard classification + bounding box regression layers 264 | for Fast R-CNN. 265 | 266 | Arguments: 267 | in_channels (int): number of input channels 268 | num_classes (int): number of output classes (including background) 269 | """ 270 | 271 | def __init__(self, in_channels, num_classes): 272 | super(FastRCNNPredictor, self).__init__() 273 | self.cls_score = nn.Linear(in_channels, num_classes) 274 | self.bbox_pred = nn.Linear(in_channels, num_classes * 4) 275 | 276 | def forward(self, x): 277 | if x.dim() == 4: 278 | assert list(x.shape[2:]) == [1, 1] 279 | x = x.flatten(start_dim=1) 280 | scores = self.cls_score(x) 281 | bbox_deltas = self.bbox_pred(x) 282 | 283 | return scores, bbox_deltas 284 | 285 | 286 | model_urls = { 287 | 'fasterrcnn_resnet50_fpn_coco': 288 | 'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth', 289 | } 290 | 291 | 292 | def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, 293 | num_classes=91, pretrained_backbone=True, trainable_backbone_layers=3, **kwargs): 294 | """ 295 | Constructs a Faster R-CNN model with a ResNet-50-FPN backbone. 296 | 297 | The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each 298 | image, and should be in ``0-1`` range. Different images can have different sizes. 299 | 300 | The behavior of the model changes depending if it is in training or evaluation mode. 301 | 302 | During training, the model expects both the input tensors, as well as a targets (list of dictionary), 303 | containing: 304 | - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with values of ``x`` 305 | between ``0`` and ``W`` and values of ``y`` between ``0`` and ``H`` 306 | - labels (``Int64Tensor[N]``): the class label for each ground-truth box 307 | 308 | The model returns a ``Dict[Tensor]`` during training, containing the classification and regression 309 | losses for both the RPN and the R-CNN. 310 | 311 | During inference, the model requires only the input tensors, and returns the post-processed 312 | predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as 313 | follows: 314 | - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with values of ``x`` 315 | between ``0`` and ``W`` and values of ``y`` between ``0`` and ``H`` 316 | - labels (``Int64Tensor[N]``): the predicted labels for each image 317 | - scores (``Tensor[N]``): the scores or each prediction 318 | 319 | Faster R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size. 320 | 321 | Example:: 322 | 323 | >>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) 324 | >>> # For training 325 | >>> images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4) 326 | >>> labels = torch.randint(1, 91, (4, 11)) 327 | >>> images = list(image for image in images) 328 | >>> targets = [] 329 | >>> for i in range(len(images)): 330 | >>> d = {} 331 | >>> d['boxes'] = boxes[i] 332 | >>> d['labels'] = labels[i] 333 | >>> targets.append(d) 334 | >>> output = model(images, targets) 335 | >>> # For inference 336 | >>> model.eval() 337 | >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] 338 | >>> predictions = model(x) 339 | >>> 340 | >>> # optionally, if you want to export the model to ONNX: 341 | >>> torch.onnx.export(model, x, "faster_rcnn.onnx", opset_version = 11) 342 | 343 | Arguments: 344 | pretrained (bool): If True, returns a model pre-trained on COCO train2017 345 | progress (bool): If True, displays a progress bar of the download to stderr 346 | pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet 347 | num_classes (int): number of output classes of the model (including the background) 348 | trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. 349 | Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. 350 | """ 351 | assert trainable_backbone_layers <= 5 and trainable_backbone_layers >= 0 352 | # dont freeze any layers if pretrained model or backbone is not used 353 | if not (pretrained or pretrained_backbone): 354 | trainable_backbone_layers = 5 355 | if pretrained: 356 | # no need to download the backbone if pretrained is set 357 | pretrained_backbone = False 358 | backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers) 359 | model = FasterRCNN(backbone, num_classes, **kwargs) 360 | if pretrained: 361 | state_dict = load_state_dict_from_url(model_urls['fasterrcnn_resnet50_fpn_coco'], 362 | progress=progress) 363 | model.load_state_dict(state_dict) 364 | return model 365 | -------------------------------------------------------------------------------- /model/hrnet_backbone.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os 12 | import logging 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | 18 | BN_MOMENTUM = 0.1 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def conv3x3(in_planes, out_planes, stride=1): 23 | """3x3 convolution with padding""" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 25 | padding=1, bias=False) 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | expansion = 1 30 | 31 | def __init__(self, inplanes, planes, stride=1, downsample=None): 32 | super(BasicBlock, self).__init__() 33 | self.conv1 = conv3x3(inplanes, planes, stride) 34 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.conv2 = conv3x3(planes, planes) 37 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 38 | self.downsample = downsample 39 | self.stride = stride 40 | 41 | def forward(self, x): 42 | residual = x 43 | 44 | out = self.conv1(x) 45 | out = self.bn1(out) 46 | out = self.relu(out) 47 | 48 | out = self.conv2(out) 49 | out = self.bn2(out) 50 | 51 | if self.downsample is not None: 52 | residual = self.downsample(x) 53 | 54 | out += residual 55 | out = self.relu(out) 56 | 57 | return out 58 | 59 | 60 | class Bottleneck(nn.Module): 61 | expansion = 4 62 | 63 | def __init__(self, inplanes, planes, stride=1, downsample=None): 64 | super(Bottleneck, self).__init__() 65 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 66 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 67 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 68 | padding=1, bias=False) 69 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 70 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, 71 | bias=False) 72 | self.bn3 = nn.BatchNorm2d(planes * self.expansion, 73 | momentum=BN_MOMENTUM) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.downsample = downsample 76 | self.stride = stride 77 | 78 | def forward(self, x): 79 | residual = x 80 | 81 | out = self.conv1(x) 82 | out = self.bn1(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv2(out) 86 | out = self.bn2(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv3(out) 90 | out = self.bn3(out) 91 | 92 | if self.downsample is not None: 93 | residual = self.downsample(x) 94 | 95 | out += residual 96 | out = self.relu(out) 97 | 98 | return out 99 | 100 | 101 | class HighResolutionModule(nn.Module): 102 | def __init__(self, num_branches, blocks, num_blocks, num_inchannels, 103 | num_channels, fuse_method, multi_scale_output=True): 104 | super(HighResolutionModule, self).__init__() 105 | self._check_branches( 106 | num_branches, blocks, num_blocks, num_inchannels, num_channels) 107 | 108 | self.num_inchannels = num_inchannels 109 | self.fuse_method = fuse_method 110 | self.num_branches = num_branches 111 | 112 | self.multi_scale_output = multi_scale_output 113 | 114 | self.branches = self._make_branches( 115 | num_branches, blocks, num_blocks, num_channels) 116 | self.fuse_layers = self._make_fuse_layers() 117 | self.relu = nn.ReLU(True) 118 | 119 | def _check_branches(self, num_branches, blocks, num_blocks, 120 | num_inchannels, num_channels): 121 | if num_branches != len(num_blocks): 122 | error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( 123 | num_branches, len(num_blocks)) 124 | logger.error(error_msg) 125 | raise ValueError(error_msg) 126 | 127 | if num_branches != len(num_channels): 128 | error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( 129 | num_branches, len(num_channels)) 130 | logger.error(error_msg) 131 | raise ValueError(error_msg) 132 | 133 | if num_branches != len(num_inchannels): 134 | error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( 135 | num_branches, len(num_inchannels)) 136 | logger.error(error_msg) 137 | raise ValueError(error_msg) 138 | 139 | def _make_one_branch(self, branch_index, block, num_blocks, num_channels, 140 | stride=1): 141 | downsample = None 142 | if stride != 1 or \ 143 | self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: 144 | downsample = nn.Sequential( 145 | nn.Conv2d( 146 | self.num_inchannels[branch_index], 147 | num_channels[branch_index] * block.expansion, 148 | kernel_size=1, stride=stride, bias=False 149 | ), 150 | nn.BatchNorm2d( 151 | num_channels[branch_index] * block.expansion, 152 | momentum=BN_MOMENTUM 153 | ), 154 | ) 155 | 156 | layers = [] 157 | layers.append( 158 | block( 159 | self.num_inchannels[branch_index], 160 | num_channels[branch_index], 161 | stride, 162 | downsample 163 | ) 164 | ) 165 | self.num_inchannels[branch_index] = \ 166 | num_channels[branch_index] * block.expansion 167 | for i in range(1, num_blocks[branch_index]): 168 | layers.append( 169 | block( 170 | self.num_inchannels[branch_index], 171 | num_channels[branch_index] 172 | ) 173 | ) 174 | 175 | return nn.Sequential(*layers) 176 | 177 | def _make_branches(self, num_branches, block, num_blocks, num_channels): 178 | branches = [] 179 | 180 | for i in range(num_branches): 181 | branches.append( 182 | self._make_one_branch(i, block, num_blocks, num_channels) 183 | ) 184 | 185 | return nn.ModuleList(branches) 186 | 187 | def _make_fuse_layers(self): 188 | if self.num_branches == 1: 189 | return None 190 | 191 | num_branches = self.num_branches 192 | num_inchannels = self.num_inchannels 193 | fuse_layers = [] 194 | for i in range(num_branches if self.multi_scale_output else 1): 195 | fuse_layer = [] 196 | for j in range(num_branches): 197 | if j > i: 198 | fuse_layer.append( 199 | nn.Sequential( 200 | nn.Conv2d( 201 | num_inchannels[j], 202 | num_inchannels[i], 203 | 1, 1, 0, bias=False 204 | ), 205 | nn.BatchNorm2d(num_inchannels[i]), 206 | nn.Upsample(scale_factor=2**(j-i), mode='nearest') 207 | ) 208 | ) 209 | elif j == i: 210 | fuse_layer.append(None) 211 | else: 212 | conv3x3s = [] 213 | for k in range(i-j): 214 | if k == i - j - 1: 215 | num_outchannels_conv3x3 = num_inchannels[i] 216 | conv3x3s.append( 217 | nn.Sequential( 218 | nn.Conv2d( 219 | num_inchannels[j], 220 | num_outchannels_conv3x3, 221 | 3, 2, 1, bias=False 222 | ), 223 | nn.BatchNorm2d(num_outchannels_conv3x3) 224 | ) 225 | ) 226 | else: 227 | num_outchannels_conv3x3 = num_inchannels[j] 228 | conv3x3s.append( 229 | nn.Sequential( 230 | nn.Conv2d( 231 | num_inchannels[j], 232 | num_outchannels_conv3x3, 233 | 3, 2, 1, bias=False 234 | ), 235 | nn.BatchNorm2d(num_outchannels_conv3x3), 236 | nn.ReLU(True) 237 | ) 238 | ) 239 | fuse_layer.append(nn.Sequential(*conv3x3s)) 240 | fuse_layers.append(nn.ModuleList(fuse_layer)) 241 | 242 | return nn.ModuleList(fuse_layers) 243 | 244 | def get_num_inchannels(self): 245 | return self.num_inchannels 246 | 247 | def forward(self, x): 248 | if self.num_branches == 1: 249 | return [self.branches[0](x[0])] 250 | 251 | for i in range(self.num_branches): 252 | x[i] = self.branches[i](x[i]) 253 | 254 | x_fuse = [] 255 | 256 | for i in range(len(self.fuse_layers)): 257 | y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) 258 | for j in range(1, self.num_branches): 259 | if i == j: 260 | y = y + x[j] 261 | else: 262 | y = y + self.fuse_layers[i][j](x[j]) 263 | x_fuse.append(self.relu(y)) 264 | 265 | return x_fuse 266 | 267 | 268 | blocks_dict = { 269 | 'BASIC': BasicBlock, 270 | 'BOTTLENECK': Bottleneck 271 | } 272 | 273 | 274 | class PoseHighResolutionNet(nn.Module): 275 | 276 | def __init__(self, cfg, **kwargs): 277 | self.inplanes = 64 278 | extra = cfg.MODEL.EXTRA 279 | super(PoseHighResolutionNet, self).__init__() 280 | 281 | # stem net 282 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, 283 | bias=False) 284 | self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) 285 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, 286 | bias=False) 287 | self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) 288 | self.relu = nn.ReLU(inplace=True) 289 | self.layer1 = self._make_layer(Bottleneck, 64, 4) 290 | 291 | self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2'] 292 | num_channels = self.stage2_cfg['NUM_CHANNELS'] 293 | block = blocks_dict[self.stage2_cfg['BLOCK']] 294 | num_channels = [ 295 | num_channels[i] * block.expansion for i in range(len(num_channels)) 296 | ] 297 | self.transition1 = self._make_transition_layer([256], num_channels) 298 | self.stage2, pre_stage_channels = self._make_stage( 299 | self.stage2_cfg, num_channels) 300 | 301 | self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3'] 302 | num_channels = self.stage3_cfg['NUM_CHANNELS'] 303 | block = blocks_dict[self.stage3_cfg['BLOCK']] 304 | num_channels = [ 305 | num_channels[i] * block.expansion for i in range(len(num_channels)) 306 | ] 307 | self.transition2 = self._make_transition_layer( 308 | pre_stage_channels, num_channels) 309 | self.stage3, pre_stage_channels = self._make_stage( 310 | self.stage3_cfg, num_channels) 311 | 312 | self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4'] 313 | num_channels = self.stage4_cfg['NUM_CHANNELS'] 314 | block = blocks_dict[self.stage4_cfg['BLOCK']] 315 | num_channels = [ 316 | num_channels[i] * block.expansion for i in range(len(num_channels)) 317 | ] 318 | self.transition3 = self._make_transition_layer( 319 | pre_stage_channels, num_channels) 320 | self.stage4, pre_stage_channels = self._make_stage( 321 | self.stage4_cfg, num_channels, multi_scale_output=False) 322 | 323 | self.out_channels = pre_stage_channels[0] 324 | 325 | self.pretrained_layers = cfg['MODEL']['EXTRA']['PRETRAINED_LAYERS'] 326 | 327 | def _make_transition_layer( 328 | self, num_channels_pre_layer, num_channels_cur_layer): 329 | num_branches_cur = len(num_channels_cur_layer) 330 | num_branches_pre = len(num_channels_pre_layer) 331 | 332 | transition_layers = [] 333 | for i in range(num_branches_cur): 334 | if i < num_branches_pre: 335 | if num_channels_cur_layer[i] != num_channels_pre_layer[i]: 336 | transition_layers.append( 337 | nn.Sequential( 338 | nn.Conv2d( 339 | num_channels_pre_layer[i], 340 | num_channels_cur_layer[i], 341 | 3, 1, 1, bias=False 342 | ), 343 | nn.BatchNorm2d(num_channels_cur_layer[i]), 344 | nn.ReLU(inplace=True) 345 | ) 346 | ) 347 | else: 348 | transition_layers.append(None) 349 | else: 350 | conv3x3s = [] 351 | for j in range(i+1-num_branches_pre): 352 | inchannels = num_channels_pre_layer[-1] 353 | outchannels = num_channels_cur_layer[i] \ 354 | if j == i-num_branches_pre else inchannels 355 | conv3x3s.append( 356 | nn.Sequential( 357 | nn.Conv2d( 358 | inchannels, outchannels, 3, 2, 1, bias=False 359 | ), 360 | nn.BatchNorm2d(outchannels), 361 | nn.ReLU(inplace=True) 362 | ) 363 | ) 364 | transition_layers.append(nn.Sequential(*conv3x3s)) 365 | 366 | return nn.ModuleList(transition_layers) 367 | 368 | def _make_layer(self, block, planes, blocks, stride=1): 369 | downsample = None 370 | if stride != 1 or self.inplanes != planes * block.expansion: 371 | downsample = nn.Sequential( 372 | nn.Conv2d( 373 | self.inplanes, planes * block.expansion, 374 | kernel_size=1, stride=stride, bias=False 375 | ), 376 | nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), 377 | ) 378 | 379 | layers = [] 380 | layers.append(block(self.inplanes, planes, stride, downsample)) 381 | self.inplanes = planes * block.expansion 382 | for i in range(1, blocks): 383 | layers.append(block(self.inplanes, planes)) 384 | 385 | return nn.Sequential(*layers) 386 | 387 | def _make_stage(self, layer_config, num_inchannels, 388 | multi_scale_output=True): 389 | num_modules = layer_config['NUM_MODULES'] 390 | num_branches = layer_config['NUM_BRANCHES'] 391 | num_blocks = layer_config['NUM_BLOCKS'] 392 | num_channels = layer_config['NUM_CHANNELS'] 393 | block = blocks_dict[layer_config['BLOCK']] 394 | fuse_method = layer_config['FUSE_METHOD'] 395 | 396 | modules = [] 397 | for i in range(num_modules): 398 | # multi_scale_output is only used last module 399 | if not multi_scale_output and i == num_modules - 1: 400 | reset_multi_scale_output = False 401 | else: 402 | reset_multi_scale_output = True 403 | 404 | modules.append( 405 | HighResolutionModule( 406 | num_branches, 407 | block, 408 | num_blocks, 409 | num_inchannels, 410 | num_channels, 411 | fuse_method, 412 | reset_multi_scale_output 413 | ) 414 | ) 415 | num_inchannels = modules[-1].get_num_inchannels() 416 | 417 | return nn.Sequential(*modules), num_inchannels 418 | 419 | def forward(self, x): 420 | 421 | x = self.conv1(x) 422 | x = self.bn1(x) 423 | x = self.relu(x) 424 | x = self.conv2(x) 425 | x = self.bn2(x) 426 | x = self.relu(x) 427 | x = self.layer1(x) 428 | 429 | x_list = [] 430 | for i in range(self.stage2_cfg['NUM_BRANCHES']): 431 | if self.transition1[i] is not None: 432 | x_list.append(self.transition1[i](x)) 433 | else: 434 | x_list.append(x) 435 | y_list = self.stage2(x_list) 436 | 437 | x_list = [] 438 | for i in range(self.stage3_cfg['NUM_BRANCHES']): 439 | if self.transition2[i] is not None: 440 | x_list.append(self.transition2[i](y_list[-1])) 441 | else: 442 | x_list.append(y_list[i]) 443 | y_list = self.stage3(x_list) 444 | 445 | x_list = [] 446 | for i in range(self.stage4_cfg['NUM_BRANCHES']): 447 | if self.transition3[i] is not None: 448 | x_list.append(self.transition3[i](y_list[-1])) 449 | else: 450 | x_list.append(y_list[i]) 451 | y_list = self.stage4(x_list) 452 | 453 | return y_list[0] 454 | 455 | def init_weights(self, pretrained=''): 456 | logger.info('=> init weights from normal distribution') 457 | for m in self.modules(): 458 | if isinstance(m, nn.Conv2d): 459 | # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 460 | nn.init.normal_(m.weight, std=0.001) 461 | for name, _ in m.named_parameters(): 462 | if name in ['bias']: 463 | nn.init.constant_(m.bias, 0) 464 | elif isinstance(m, nn.BatchNorm2d): 465 | nn.init.constant_(m.weight, 1) 466 | nn.init.constant_(m.bias, 0) 467 | elif isinstance(m, nn.ConvTranspose2d): 468 | nn.init.normal_(m.weight, std=0.001) 469 | for name, _ in m.named_parameters(): 470 | if name in ['bias']: 471 | nn.init.constant_(m.bias, 0) 472 | 473 | if os.path.isfile(pretrained): 474 | pretrained_state_dict = torch.load(pretrained, map_location='cpu') 475 | logger.info('=> loading pretrained model {}'.format(pretrained)) 476 | 477 | need_init_state_dict = {} 478 | for name, m in pretrained_state_dict.items(): 479 | if name.split('.')[0] in self.pretrained_layers \ 480 | or self.pretrained_layers[0] is '*': 481 | need_init_state_dict[name] = m 482 | self.load_state_dict(need_init_state_dict, strict=False) 483 | elif pretrained: 484 | logger.error('=> please download pre-trained models first!') 485 | raise ValueError('{} is not exist!'.format(pretrained)) 486 | 487 | 488 | def get_hrnet(cfg, resume, **kwargs): 489 | model = PoseHighResolutionNet(cfg, **kwargs) 490 | 491 | if not resume and cfg.MODEL.INIT_WEIGHTS: 492 | model.init_weights(cfg.MODEL.PRETRAINED) 493 | 494 | return model 495 | --------------------------------------------------------------------------------