├── .gitignore ├── LICENSE ├── README.md ├── conda_packagelist.txt ├── config.py ├── config ├── CornerNet-multi_scale.json └── CornerNet.json ├── db ├── __init__.py ├── base.py ├── coco.py ├── datasets.py └── detection.py ├── external ├── .gitignore ├── Makefile ├── __init__.py ├── nms.pyx └── setup.py ├── models ├── CornerNet.py ├── __init__.py └── py_utils │ ├── __init__.py │ ├── _cpools │ ├── .gitignore │ ├── __init__.py │ ├── setup.py │ └── src │ │ ├── bottom_pool.cpp │ │ ├── left_pool.cpp │ │ ├── right_pool.cpp │ │ └── top_pool.cpp │ ├── data_parallel.py │ ├── kp.py │ ├── kp_utils.py │ ├── scatter_gather.py │ └── utils.py ├── nnet ├── __init__.py └── py_factory.py ├── sample ├── __init__.py ├── coco.py └── utils.py ├── test.py ├── test ├── __init__.py └── coco.py ├── train.py └── utils ├── __init__.py ├── image.py └── tqdm.py /.gitignore: -------------------------------------------------------------------------------- 1 | cache/ 2 | results/ 3 | 4 | *.swp 5 | 6 | *.pyc 7 | *.o* 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, University of Michigan 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CornerNet: Training and Evaluation Code 2 | Update (4/18/2019): please check out [CornerNet-Lite](https://github.com/princeton-vl/CornerNet-Lite), more efficient variants of CornerNet 3 | 4 | Code for reproducing the results in the following paper: 5 | 6 | [**CornerNet: Detecting Objects as Paired Keypoints**](https://arxiv.org/abs/1808.01244) 7 | Hei Law, Jia Deng 8 | *European Conference on Computer Vision (ECCV), 2018* 9 | 10 | ## Getting Started 11 | Please first install [Anaconda](https://anaconda.org) and create an Anaconda environment using the provided package list. 12 | ``` 13 | conda create --name CornerNet --file conda_packagelist.txt 14 | ``` 15 | 16 | After you create the environment, activate it. 17 | ``` 18 | source activate CornerNet 19 | ``` 20 | 21 | Our current implementation only supports GPU so you need a GPU and need to have CUDA installed on your machine. 22 | 23 | ### Compiling Corner Pooling Layers 24 | You need to compile the C++ implementation of corner pooling layers. 25 | ``` 26 | cd /models/py_utils/_cpools/ 27 | python setup.py install --user 28 | ``` 29 | 30 | ### Compiling NMS 31 | You also need to compile the NMS code (originally from [Faster R-CNN](https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/nms/cpu_nms.pyx) and [Soft-NMS](https://github.com/bharatsingh430/soft-nms/blob/master/lib/nms/cpu_nms.pyx)). 32 | ``` 33 | cd /external 34 | make 35 | ``` 36 | 37 | ### Installing MS COCO APIs 38 | You also need to install the MS COCO APIs. 39 | ``` 40 | cd /data 41 | git clone git@github.com:cocodataset/cocoapi.git coco 42 | cd /data/coco/PythonAPI 43 | make 44 | ``` 45 | 46 | ### Downloading MS COCO Data 47 | - Download the training/validation split we use in our paper from [here](https://drive.google.com/file/d/1dop4188xo5lXDkGtOZUzy2SHOD_COXz4/view?usp=sharing) (originally from [Faster R-CNN](https://github.com/rbgirshick/py-faster-rcnn/tree/master/data)) 48 | - Unzip the file and place `annotations` under `/data/coco` 49 | - Download the images (2014 Train, 2014 Val, 2017 Test) from [here](http://cocodataset.org/#download) 50 | - Create 3 directories, `trainval2014`, `minival2014` and `testdev2017`, under `/data/coco/images/` 51 | - Copy the training/validation/testing images to the corresponding directories according to the annotation files 52 | 53 | ## Training and Evaluation 54 | To train and evaluate a network, you will need to create a configuration file, which defines the hyperparameters, and a model file, which defines the network architecture. The configuration file should be in JSON format and placed in `config/`. Each configuration file should have a corresponding model file in `models/`. i.e. If there is a `.json` in `config/`, there should be a `.py` in `models/`. There is only one exception which we will mention later. 55 | 56 | To train a model: 57 | ``` 58 | python train.py 59 | ``` 60 | 61 | We provide the configuration file (`CornerNet.json`) and the model file (`CornerNet.py`) for CornerNet in this repo. 62 | 63 | To train CornerNet: 64 | ``` 65 | python train.py CornerNet 66 | ``` 67 | We also provide a trained model for `CornerNet`, which is trained for 500k iterations using 10 Titan X (PASCAL) GPUs. You can download it from [here](https://drive.google.com/open?id=16bbMAyykdZr2_7afiMZrvvn4xkYa-LYk) and put it under `/cache/nnet/CornerNet` (You may need to create this directory by yourself if it does not exist). If you want to train you own CornerNet, please adjust the batch size in `CornerNet.json` to accommodate the number of GPUs that are available to you. 68 | 69 | To use the trained model: 70 | ``` 71 | python test.py CornerNet --testiter 500000 --split 72 | ``` 73 | 74 | If you want to test different hyperparameters in testing and do not want to overwrite the original configuration file, you can do so by creating a configuration file with a suffix (`-.json`). You **DO NOT** need to create `-.py` in `models/`. 75 | 76 | To use the new configuration file: 77 | ``` 78 | python test.py --testiter --split --suffix 79 | ``` 80 | 81 | We also include a configuration file for multi-scale evaluation, which is `CornerNet-multi_scale.json`, in this repo. 82 | 83 | To use the multi-scale configuration file: 84 | ``` 85 | python test.py CornerNet --testiter --split --suffix multi_scale 86 | ``` 87 | -------------------------------------------------------------------------------- /conda_packagelist.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | @EXPLICIT 5 | https://repo.continuum.io/pkgs/main/linux-64/blas-1.0-mkl.tar.bz2 6 | https://repo.continuum.io/pkgs/main/linux-64/bzip2-1.0.6-h9a117a8_4.tar.bz2 7 | https://conda.anaconda.org/conda-forge/linux-64/ca-certificates-2018.4.16-0.tar.bz2 8 | https://conda.anaconda.org/caffe2/linux-64/caffe2-cuda8.0-cudnn7-0.8.dev-py36_2018.05.14.tar.bz2 9 | https://repo.continuum.io/pkgs/main/linux-64/cairo-1.14.12-h7636065_2.tar.bz2 10 | https://repo.continuum.io/pkgs/main/linux-64/certifi-2018.4.16-py36_0.tar.bz2 11 | https://repo.continuum.io/pkgs/main/linux-64/cffi-1.11.5-py36h9745a5d_0.tar.bz2 12 | https://repo.continuum.io/pkgs/free/linux-64/cudatoolkit-8.0-3.tar.bz2 13 | https://repo.continuum.io/pkgs/main/linux-64/cycler-0.10.0-py36h93f1223_0.tar.bz2 14 | https://repo.continuum.io/pkgs/main/linux-64/dbus-1.13.2-h714fa37_1.tar.bz2 15 | https://repo.continuum.io/pkgs/main/linux-64/expat-2.2.5-he0dffb1_0.tar.bz2 16 | https://repo.continuum.io/pkgs/main/linux-64/ffmpeg-3.4-h7264315_0.tar.bz2 17 | https://repo.continuum.io/pkgs/main/linux-64/fontconfig-2.12.6-h49f89f6_0.tar.bz2 18 | https://repo.continuum.io/pkgs/free/linux-64/freeglut-2.8.1-0.tar.bz2 19 | https://repo.continuum.io/pkgs/main/linux-64/freetype-2.8-hab7d2ae_1.tar.bz2 20 | https://repo.continuum.io/pkgs/free/linux-64/future-0.16.0-py36_1.tar.bz2 21 | https://repo.continuum.io/pkgs/main/linux-64/gflags-2.2.1-hf484d3e_0.tar.bz2 22 | https://repo.continuum.io/pkgs/main/linux-64/glib-2.56.1-h000015b_0.tar.bz2 23 | https://repo.continuum.io/pkgs/main/linux-64/glog-0.3.5-hf484d3e_1.tar.bz2 24 | https://repo.continuum.io/pkgs/main/linux-64/graphite2-1.3.11-hf63cedd_1.tar.bz2 25 | https://repo.continuum.io/pkgs/main/linux-64/gst-plugins-base-1.14.0-hbbd80ab_1.tar.bz2 26 | https://repo.continuum.io/pkgs/main/linux-64/gstreamer-1.14.0-hb453b48_1.tar.bz2 27 | https://repo.continuum.io/pkgs/main/linux-64/h5py-2.8.0-py36hca9c191_0.tar.bz2 28 | https://repo.continuum.io/pkgs/main/linux-64/harfbuzz-1.7.6-h5f0a787_1.tar.bz2 29 | https://repo.continuum.io/pkgs/main/linux-64/hdf5-1.8.18-h6792536_1.tar.bz2 30 | https://repo.continuum.io/pkgs/main/linux-64/icu-58.2-h9c2bf20_1.tar.bz2 31 | https://repo.continuum.io/pkgs/main/linux-64/intel-openmp-2018.0.0-8.tar.bz2 32 | https://repo.continuum.io/pkgs/main/linux-64/jasper-2.0.14-h07fcdf6_0.tar.bz2 33 | https://repo.continuum.io/pkgs/main/linux-64/jpeg-9b-h024ee3a_2.tar.bz2 34 | https://repo.continuum.io/pkgs/main/linux-64/kiwisolver-1.0.1-py36h764f252_0.tar.bz2 35 | https://repo.continuum.io/pkgs/main/linux-64/libedit-3.1-heed3624_0.tar.bz2 36 | https://repo.continuum.io/pkgs/main/linux-64/libffi-3.2.1-hd88cf55_4.tar.bz2 37 | https://repo.continuum.io/pkgs/main/linux-64/libgcc-ng-7.2.0-hdf63c60_3.tar.bz2 38 | https://repo.continuum.io/pkgs/main/linux-64/libgfortran-ng-7.2.0-hdf63c60_3.tar.bz2 39 | https://repo.continuum.io/pkgs/main/linux-64/libglu-9.0.0-h0c0bdc1_1.tar.bz2 40 | https://repo.continuum.io/pkgs/main/linux-64/libopus-1.2.1-hb9ed12e_0.tar.bz2 41 | https://repo.continuum.io/pkgs/main/linux-64/libpng-1.6.34-hb9fc6fc_0.tar.bz2 42 | https://repo.continuum.io/pkgs/main/linux-64/libprotobuf-3.5.2-h6f1eeef_0.tar.bz2 43 | https://repo.continuum.io/pkgs/main/linux-64/libstdcxx-ng-7.2.0-hdf63c60_3.tar.bz2 44 | https://repo.continuum.io/pkgs/main/linux-64/libtiff-4.0.9-h28f6b97_0.tar.bz2 45 | https://repo.continuum.io/pkgs/main/linux-64/libvpx-1.6.1-h888fd40_0.tar.bz2 46 | https://repo.continuum.io/pkgs/main/linux-64/libxcb-1.13-h1bed415_1.tar.bz2 47 | https://repo.continuum.io/pkgs/main/linux-64/libxml2-2.9.8-hf84eae3_0.tar.bz2 48 | https://repo.continuum.io/pkgs/main/linux-64/matplotlib-2.2.2-py36h0e671d2_1.tar.bz2 49 | https://repo.continuum.io/pkgs/main/linux-64/mkl-2018.0.2-1.tar.bz2 50 | https://repo.continuum.io/pkgs/main/linux-64/mkl_fft-1.0.1-py36h3010b51_0.tar.bz2 51 | https://repo.continuum.io/pkgs/main/linux-64/mkl_random-1.0.1-py36h629b387_0.tar.bz2 52 | https://repo.continuum.io/pkgs/main/linux-64/ncurses-6.0-h9df7e31_2.tar.bz2 53 | https://repo.continuum.io/pkgs/main/linux-64/ninja-1.8.2-py36h6bb024c_1.tar.bz2 54 | https://repo.continuum.io/pkgs/main/linux-64/numpy-1.14.3-py36hcd700cb_1.tar.bz2 55 | https://repo.continuum.io/pkgs/main/linux-64/numpy-base-1.14.3-py36h9be14a7_1.tar.bz2 56 | https://repo.continuum.io/pkgs/main/linux-64/olefile-0.45.1-py36_0.tar.bz2 57 | https://repo.continuum.io/pkgs/main/linux-64/opencv-3.3.1-py36h9248ab4_2.tar.bz2 58 | https://repo.continuum.io/pkgs/main/linux-64/openssl-1.0.2o-h20670df_0.tar.bz2 59 | https://repo.continuum.io/pkgs/main/linux-64/pcre-8.42-h439df22_0.tar.bz2 60 | https://repo.continuum.io/pkgs/main/linux-64/pillow-5.1.0-py36h3deb7b8_0.tar.bz2 61 | https://repo.continuum.io/pkgs/main/linux-64/pip-10.0.1-py36_0.tar.bz2 62 | https://repo.continuum.io/pkgs/main/linux-64/pixman-0.34.0-hceecf20_3.tar.bz2 63 | https://conda.anaconda.org/conda-forge/linux-64/protobuf-3.5.2-py36_0.tar.bz2 64 | https://repo.continuum.io/pkgs/main/linux-64/pycparser-2.18-py36hf9f622e_1.tar.bz2 65 | https://repo.continuum.io/pkgs/main/linux-64/pyparsing-2.2.0-py36hee85983_1.tar.bz2 66 | https://repo.continuum.io/pkgs/main/linux-64/pyqt-5.9.2-py36h751905a_0.tar.bz2 67 | https://repo.continuum.io/pkgs/main/linux-64/python-3.6.5-hc3d631a_2.tar.bz2 68 | https://repo.continuum.io/pkgs/main/linux-64/python-dateutil-2.7.2-py36_0.tar.bz2 69 | https://conda.anaconda.org/pytorch/linux-64/pytorch-0.4.0-py36_cuda8.0.61_cudnn7.1.2_1.tar.bz2 70 | https://repo.continuum.io/pkgs/main/linux-64/pytz-2018.4-py36_0.tar.bz2 71 | https://repo.continuum.io/pkgs/main/linux-64/pyyaml-3.12-py36hafb9ca4_1.tar.bz2 72 | https://repo.continuum.io/pkgs/main/linux-64/qt-5.9.5-h7e424d6_0.tar.bz2 73 | https://repo.continuum.io/pkgs/main/linux-64/readline-7.0-ha6073c6_4.tar.bz2 74 | https://repo.continuum.io/pkgs/main/linux-64/scikit-learn-0.19.1-py36h7aa7ec6_0.tar.bz2 75 | https://repo.continuum.io/pkgs/main/linux-64/scipy-1.1.0-py36hfc37229_0.tar.bz2 76 | https://repo.continuum.io/pkgs/main/linux-64/setuptools-39.1.0-py36_0.tar.bz2 77 | https://repo.continuum.io/pkgs/main/linux-64/sip-4.19.8-py36hf484d3e_0.tar.bz2 78 | https://repo.continuum.io/pkgs/main/linux-64/six-1.11.0-py36h372c433_1.tar.bz2 79 | https://repo.continuum.io/pkgs/main/linux-64/sqlite-3.23.1-he433501_0.tar.bz2 80 | https://repo.continuum.io/pkgs/main/linux-64/tk-8.6.7-hc745277_3.tar.bz2 81 | https://conda.anaconda.org/pytorch/linux-64/torchvision-0.2.1-py36_1.tar.bz2 82 | https://repo.continuum.io/pkgs/main/linux-64/tornado-5.0.2-py36_0.tar.bz2 83 | https://repo.continuum.io/pkgs/main/linux-64/tqdm-4.23.0-py36_0.tar.bz2 84 | https://repo.continuum.io/pkgs/main/linux-64/wheel-0.31.0-py36_0.tar.bz2 85 | https://repo.continuum.io/pkgs/main/linux-64/xz-5.2.3-h5e939de_4.tar.bz2 86 | https://repo.continuum.io/pkgs/main/linux-64/yaml-0.1.7-had09818_2.tar.bz2 87 | https://repo.continuum.io/pkgs/main/linux-64/zlib-1.2.11-ha838bed_2.tar.bz2 88 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | class Config: 5 | def __init__(self): 6 | self._configs = {} 7 | self._configs["dataset"] = None 8 | self._configs["sampling_function"] = "kp_detection" 9 | 10 | # Training Config 11 | self._configs["display"] = 5 12 | self._configs["snapshot"] = 5000 13 | self._configs["stepsize"] = 450000 14 | self._configs["learning_rate"] = 0.00025 15 | self._configs["decay_rate"] = 10 16 | self._configs["max_iter"] = 500000 17 | self._configs["val_iter"] = 100 18 | self._configs["batch_size"] = 1 19 | self._configs["snapshot_name"] = None 20 | self._configs["prefetch_size"] = 100 21 | self._configs["weight_decay"] = False 22 | self._configs["weight_decay_rate"] = 1e-5 23 | self._configs["weight_decay_type"] = "l2" 24 | self._configs["pretrain"] = None 25 | self._configs["opt_algo"] = "adam" 26 | self._configs["chunk_sizes"] = None 27 | 28 | # Directories 29 | self._configs["data_dir"] = "./data" 30 | self._configs["cache_dir"] = "./cache" 31 | self._configs["config_dir"] = "./config" 32 | self._configs["result_dir"] = "./results" 33 | 34 | # Split 35 | self._configs["train_split"] = "trainval" 36 | self._configs["val_split"] = "minival" 37 | self._configs["test_split"] = "testdev" 38 | 39 | # Rng 40 | self._configs["data_rng"] = np.random.RandomState(123) 41 | self._configs["nnet_rng"] = np.random.RandomState(317) 42 | 43 | @property 44 | def chunk_sizes(self): 45 | return self._configs["chunk_sizes"] 46 | 47 | @property 48 | def train_split(self): 49 | return self._configs["train_split"] 50 | 51 | @property 52 | def val_split(self): 53 | return self._configs["val_split"] 54 | 55 | @property 56 | def test_split(self): 57 | return self._configs["test_split"] 58 | 59 | @property 60 | def full(self): 61 | return self._configs 62 | 63 | @property 64 | def sampling_function(self): 65 | return self._configs["sampling_function"] 66 | 67 | @property 68 | def data_rng(self): 69 | return self._configs["data_rng"] 70 | 71 | @property 72 | def nnet_rng(self): 73 | return self._configs["nnet_rng"] 74 | 75 | @property 76 | def opt_algo(self): 77 | return self._configs["opt_algo"] 78 | 79 | @property 80 | def weight_decay_type(self): 81 | return self._configs["weight_decay_type"] 82 | 83 | @property 84 | def prefetch_size(self): 85 | return self._configs["prefetch_size"] 86 | 87 | @property 88 | def pretrain(self): 89 | return self._configs["pretrain"] 90 | 91 | @property 92 | def weight_decay_rate(self): 93 | return self._configs["weight_decay_rate"] 94 | 95 | @property 96 | def weight_decay(self): 97 | return self._configs["weight_decay"] 98 | 99 | @property 100 | def result_dir(self): 101 | result_dir = os.path.join(self._configs["result_dir"], self.snapshot_name) 102 | if not os.path.exists(result_dir): 103 | os.makedirs(result_dir) 104 | return result_dir 105 | 106 | @property 107 | def dataset(self): 108 | return self._configs["dataset"] 109 | 110 | @property 111 | def snapshot_name(self): 112 | return self._configs["snapshot_name"] 113 | 114 | @property 115 | def snapshot_dir(self): 116 | snapshot_dir = os.path.join(self.cache_dir, "nnet", self.snapshot_name) 117 | 118 | if not os.path.exists(snapshot_dir): 119 | os.makedirs(snapshot_dir) 120 | 121 | return snapshot_dir 122 | 123 | @property 124 | def snapshot_file(self): 125 | snapshot_file = os.path.join(self.snapshot_dir, self.snapshot_name + "_{}.pkl") 126 | return snapshot_file 127 | 128 | @property 129 | def config_dir(self): 130 | return self._configs["config_dir"] 131 | 132 | @property 133 | def batch_size(self): 134 | return self._configs["batch_size"] 135 | 136 | @property 137 | def max_iter(self): 138 | return self._configs["max_iter"] 139 | 140 | @property 141 | def learning_rate(self): 142 | return self._configs["learning_rate"] 143 | 144 | @property 145 | def decay_rate(self): 146 | return self._configs["decay_rate"] 147 | 148 | @property 149 | def stepsize(self): 150 | return self._configs["stepsize"] 151 | 152 | @property 153 | def snapshot(self): 154 | return self._configs["snapshot"] 155 | 156 | @property 157 | def display(self): 158 | return self._configs["display"] 159 | 160 | @property 161 | def val_iter(self): 162 | return self._configs["val_iter"] 163 | 164 | @property 165 | def data_dir(self): 166 | return self._configs["data_dir"] 167 | 168 | @property 169 | def cache_dir(self): 170 | if not os.path.exists(self._configs["cache_dir"]): 171 | os.makedirs(self._configs["cache_dir"]) 172 | return self._configs["cache_dir"] 173 | 174 | def update_config(self, new): 175 | for key in new: 176 | if key in self._configs: 177 | self._configs[key] = new[key] 178 | 179 | system_configs = Config() 180 | -------------------------------------------------------------------------------- /config/CornerNet-multi_scale.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": { 3 | "dataset": "MSCOCO", 4 | "batch_size": 49, 5 | "sampling_function": "kp_detection", 6 | 7 | "train_split": "trainval", 8 | "val_split": "minival", 9 | 10 | "learning_rate": 0.00025, 11 | "decay_rate": 10, 12 | 13 | "val_iter": 100, 14 | 15 | "opt_algo": "adam", 16 | "prefetch_size": 5, 17 | 18 | "max_iter": 500000, 19 | "stepsize": 450000, 20 | "snapshot": 5000, 21 | 22 | "chunk_sizes": [4, 5, 5, 5, 5, 5, 5, 5, 5, 5], 23 | 24 | "data_dir": "./data" 25 | }, 26 | 27 | "db": { 28 | "rand_scale_min": 0.6, 29 | "rand_scale_max": 1.4, 30 | "rand_scale_step": 0.1, 31 | "rand_scales": null, 32 | 33 | "rand_crop": true, 34 | "rand_color": true, 35 | 36 | "border": 128, 37 | "gaussian_bump": true, 38 | 39 | "input_size": [511, 511], 40 | "output_sizes": [[128, 128]], 41 | 42 | "test_scales": [0.5, 0.75, 1, 1.25, 1.5], 43 | 44 | "top_k": 100, 45 | "categories": 80, 46 | "ae_threshold": 0.5, 47 | "nms_threshold": 0.5, 48 | 49 | "merge_bbox": true, 50 | "weight_exp": 10, 51 | 52 | "max_per_image": 100 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /config/CornerNet.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": { 3 | "dataset": "MSCOCO", 4 | "batch_size": 49, 5 | "sampling_function": "kp_detection", 6 | 7 | "train_split": "trainval", 8 | "val_split": "minival", 9 | 10 | "learning_rate": 0.00025, 11 | "decay_rate": 10, 12 | 13 | "val_iter": 100, 14 | 15 | "opt_algo": "adam", 16 | "prefetch_size": 5, 17 | 18 | "max_iter": 500000, 19 | "stepsize": 450000, 20 | "snapshot": 5000, 21 | 22 | "chunk_sizes": [4, 5, 5, 5, 5, 5, 5, 5, 5, 5], 23 | 24 | "data_dir": "./data" 25 | }, 26 | 27 | "db": { 28 | "rand_scale_min": 0.6, 29 | "rand_scale_max": 1.4, 30 | "rand_scale_step": 0.1, 31 | "rand_scales": null, 32 | 33 | "rand_crop": true, 34 | "rand_color": true, 35 | 36 | "border": 128, 37 | "gaussian_bump": true, 38 | "gaussian_iou": 0.3, 39 | 40 | "input_size": [511, 511], 41 | "output_sizes": [[128, 128]], 42 | 43 | "test_scales": [1], 44 | 45 | "top_k": 100, 46 | "categories": 80, 47 | "ae_threshold": 0.5, 48 | "nms_threshold": 0.5, 49 | 50 | "max_per_image": 100 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /db/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-vl/CornerNet/e5c39a31a8abef5841976c8eab18da86d6ee5f9a/db/__init__.py -------------------------------------------------------------------------------- /db/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import numpy as np 4 | 5 | from config import system_configs 6 | 7 | class BASE(object): 8 | def __init__(self): 9 | self._split = None 10 | self._db_inds = [] 11 | self._image_ids = [] 12 | 13 | self._data = None 14 | self._image_hdf5 = None 15 | self._image_file = None 16 | self._image_hdf5_file = None 17 | 18 | self._mean = np.zeros((3, ), dtype=np.float32) 19 | self._std = np.ones((3, ), dtype=np.float32) 20 | self._eig_val = np.ones((3, ), dtype=np.float32) 21 | self._eig_vec = np.zeros((3, 3), dtype=np.float32) 22 | 23 | self._configs = {} 24 | self._configs["data_aug"] = True 25 | 26 | self._data_rng = None 27 | 28 | @property 29 | def data(self): 30 | if self._data is None: 31 | raise ValueError("data is not set") 32 | return self._data 33 | 34 | @property 35 | def configs(self): 36 | return self._configs 37 | 38 | @property 39 | def mean(self): 40 | return self._mean 41 | 42 | @property 43 | def std(self): 44 | return self._std 45 | 46 | @property 47 | def eig_val(self): 48 | return self._eig_val 49 | 50 | @property 51 | def eig_vec(self): 52 | return self._eig_vec 53 | 54 | @property 55 | def db_inds(self): 56 | return self._db_inds 57 | 58 | @property 59 | def split(self): 60 | return self._split 61 | 62 | def update_config(self, new): 63 | for key in new: 64 | if key in self._configs: 65 | self._configs[key] = new[key] 66 | 67 | def image_ids(self, ind): 68 | return self._image_ids[ind] 69 | 70 | def image_file(self, ind): 71 | if self._image_file is None: 72 | raise ValueError("Image path is not initialized") 73 | 74 | image_id = self._image_ids[ind] 75 | return self._image_file.format(image_id) 76 | 77 | def write_result(self, ind, all_bboxes, all_scores): 78 | pass 79 | 80 | def evaluate(self, name): 81 | pass 82 | 83 | def shuffle_inds(self, quiet=False): 84 | if self._data_rng is None: 85 | self._data_rng = np.random.RandomState(os.getpid()) 86 | 87 | if not quiet: 88 | print("shuffling indices...") 89 | rand_perm = self._data_rng.permutation(len(self._db_inds)) 90 | self._db_inds = self._db_inds[rand_perm] 91 | -------------------------------------------------------------------------------- /db/coco.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, "data/coco/PythonAPI/") 3 | 4 | import os 5 | import json 6 | import numpy as np 7 | import pickle 8 | 9 | from tqdm import tqdm 10 | from db.detection import DETECTION 11 | from config import system_configs 12 | from pycocotools.coco import COCO 13 | from pycocotools.cocoeval import COCOeval 14 | 15 | class MSCOCO(DETECTION): 16 | def __init__(self, db_config, split): 17 | super(MSCOCO, self).__init__(db_config) 18 | data_dir = system_configs.data_dir 19 | result_dir = system_configs.result_dir 20 | cache_dir = system_configs.cache_dir 21 | 22 | self._split = split 23 | self._dataset = { 24 | "trainval": "trainval2014", 25 | "minival": "minival2014", 26 | "testdev": "testdev2017" 27 | }[self._split] 28 | 29 | self._coco_dir = os.path.join(data_dir, "coco") 30 | 31 | self._label_dir = os.path.join(self._coco_dir, "annotations") 32 | self._label_file = os.path.join(self._label_dir, "instances_{}.json") 33 | self._label_file = self._label_file.format(self._dataset) 34 | 35 | self._image_dir = os.path.join(self._coco_dir, "images", self._dataset) 36 | self._image_file = os.path.join(self._image_dir, "{}") 37 | 38 | self._data = "coco" 39 | self._mean = np.array([0.40789654, 0.44719302, 0.47026115], dtype=np.float32) 40 | self._std = np.array([0.28863828, 0.27408164, 0.27809835], dtype=np.float32) 41 | self._eig_val = np.array([0.2141788, 0.01817699, 0.00341571], dtype=np.float32) 42 | self._eig_vec = np.array([ 43 | [-0.58752847, -0.69563484, 0.41340352], 44 | [-0.5832747, 0.00994535, -0.81221408], 45 | [-0.56089297, 0.71832671, 0.41158938] 46 | ], dtype=np.float32) 47 | 48 | self._cat_ids = [ 49 | 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 50 | 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 51 | 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 52 | 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 53 | 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 54 | 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 55 | 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 56 | 82, 84, 85, 86, 87, 88, 89, 90 57 | ] 58 | self._classes = { 59 | ind + 1: cat_id for ind, cat_id in enumerate(self._cat_ids) 60 | } 61 | self._coco_to_class_map = { 62 | value: key for key, value in self._classes.items() 63 | } 64 | 65 | self._cache_file = os.path.join(cache_dir, "coco_{}.pkl".format(self._dataset)) 66 | self._load_data() 67 | self._db_inds = np.arange(len(self._image_ids)) 68 | 69 | self._load_coco_data() 70 | 71 | def _load_data(self): 72 | print("loading from cache file: {}".format(self._cache_file)) 73 | if not os.path.exists(self._cache_file): 74 | print("No cache file found...") 75 | self._extract_data() 76 | with open(self._cache_file, "wb") as f: 77 | pickle.dump([self._detections, self._image_ids], f) 78 | else: 79 | with open(self._cache_file, "rb") as f: 80 | self._detections, self._image_ids = pickle.load(f) 81 | 82 | def _load_coco_data(self): 83 | self._coco = COCO(self._label_file) 84 | with open(self._label_file, "r") as f: 85 | data = json.load(f) 86 | 87 | coco_ids = self._coco.getImgIds() 88 | eval_ids = { 89 | self._coco.loadImgs(coco_id)[0]["file_name"]: coco_id 90 | for coco_id in coco_ids 91 | } 92 | 93 | self._coco_categories = data["categories"] 94 | self._coco_eval_ids = eval_ids 95 | 96 | def class_name(self, cid): 97 | cat_id = self._classes[cid] 98 | cat = self._coco.loadCats([cat_id])[0] 99 | return cat["name"] 100 | 101 | def _extract_data(self): 102 | self._coco = COCO(self._label_file) 103 | self._cat_ids = self._coco.getCatIds() 104 | 105 | coco_image_ids = self._coco.getImgIds() 106 | 107 | self._image_ids = [ 108 | self._coco.loadImgs(img_id)[0]["file_name"] 109 | for img_id in coco_image_ids 110 | ] 111 | self._detections = {} 112 | for ind, (coco_image_id, image_id) in enumerate(tqdm(zip(coco_image_ids, self._image_ids))): 113 | image = self._coco.loadImgs(coco_image_id)[0] 114 | bboxes = [] 115 | categories = [] 116 | 117 | for cat_id in self._cat_ids: 118 | annotation_ids = self._coco.getAnnIds(imgIds=image["id"], catIds=cat_id) 119 | annotations = self._coco.loadAnns(annotation_ids) 120 | category = self._coco_to_class_map[cat_id] 121 | for annotation in annotations: 122 | bbox = np.array(annotation["bbox"]) 123 | bbox[[2, 3]] += bbox[[0, 1]] 124 | bboxes.append(bbox) 125 | 126 | categories.append(category) 127 | 128 | bboxes = np.array(bboxes, dtype=float) 129 | categories = np.array(categories, dtype=float) 130 | if bboxes.size == 0 or categories.size == 0: 131 | self._detections[image_id] = np.zeros((0, 5), dtype=np.float32) 132 | else: 133 | self._detections[image_id] = np.hstack((bboxes, categories[:, None])) 134 | 135 | def detections(self, ind): 136 | image_id = self._image_ids[ind] 137 | detections = self._detections[image_id] 138 | 139 | return detections.astype(float).copy() 140 | 141 | def _to_float(self, x): 142 | return float("{:.2f}".format(x)) 143 | 144 | def convert_to_coco(self, all_bboxes): 145 | detections = [] 146 | for image_id in all_bboxes: 147 | coco_id = self._coco_eval_ids[image_id] 148 | for cls_ind in all_bboxes[image_id]: 149 | category_id = self._classes[cls_ind] 150 | for bbox in all_bboxes[image_id][cls_ind]: 151 | bbox[2] -= bbox[0] 152 | bbox[3] -= bbox[1] 153 | 154 | score = bbox[4] 155 | bbox = list(map(self._to_float, bbox[0:4])) 156 | 157 | detection = { 158 | "image_id": coco_id, 159 | "category_id": category_id, 160 | "bbox": bbox, 161 | "score": float("{:.2f}".format(score)) 162 | } 163 | 164 | detections.append(detection) 165 | return detections 166 | 167 | def evaluate(self, result_json, cls_ids, image_ids, gt_json=None): 168 | if self._split == "testdev": 169 | return None 170 | 171 | coco = self._coco if gt_json is None else COCO(gt_json) 172 | 173 | eval_ids = [self._coco_eval_ids[image_id] for image_id in image_ids] 174 | cat_ids = [self._classes[cls_id] for cls_id in cls_ids] 175 | 176 | coco_dets = coco.loadRes(result_json) 177 | coco_eval = COCOeval(coco, coco_dets, "bbox") 178 | coco_eval.params.imgIds = eval_ids 179 | coco_eval.params.catIds = cat_ids 180 | coco_eval.evaluate() 181 | coco_eval.accumulate() 182 | coco_eval.summarize() 183 | return coco_eval.stats[0], coco_eval.stats[12:] 184 | -------------------------------------------------------------------------------- /db/datasets.py: -------------------------------------------------------------------------------- 1 | from db.coco import MSCOCO 2 | 3 | datasets = { 4 | "MSCOCO": MSCOCO 5 | } 6 | -------------------------------------------------------------------------------- /db/detection.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from db.base import BASE 3 | 4 | class DETECTION(BASE): 5 | def __init__(self, db_config): 6 | super(DETECTION, self).__init__() 7 | 8 | self._configs["categories"] = 80 9 | self._configs["rand_scales"] = [1] 10 | self._configs["rand_scale_min"] = 0.8 11 | self._configs["rand_scale_max"] = 1.4 12 | self._configs["rand_scale_step"] = 0.2 13 | 14 | self._configs["input_size"] = [511] 15 | self._configs["output_sizes"] = [[128, 128]] 16 | 17 | self._configs["nms_threshold"] = 0.5 18 | self._configs["max_per_image"] = 100 19 | self._configs["top_k"] = 100 20 | self._configs["ae_threshold"] = 0.5 21 | self._configs["nms_kernel"] = 3 22 | 23 | self._configs["nms_algorithm"] = "exp_soft_nms" 24 | self._configs["weight_exp"] = 8 25 | self._configs["merge_bbox"] = False 26 | 27 | self._configs["data_aug"] = True 28 | self._configs["lighting"] = True 29 | 30 | self._configs["border"] = 128 31 | self._configs["gaussian_bump"] = True 32 | self._configs["gaussian_iou"] = 0.7 33 | self._configs["gaussian_radius"] = -1 34 | self._configs["rand_crop"] = False 35 | self._configs["rand_color"] = False 36 | self._configs["rand_pushes"] = False 37 | self._configs["rand_samples"] = False 38 | self._configs["special_crop"] = False 39 | 40 | self._configs["test_scales"] = [1] 41 | 42 | self.update_config(db_config) 43 | 44 | if self._configs["rand_scales"] is None: 45 | self._configs["rand_scales"] = np.arange( 46 | self._configs["rand_scale_min"], 47 | self._configs["rand_scale_max"], 48 | self._configs["rand_scale_step"] 49 | ) 50 | -------------------------------------------------------------------------------- /external/.gitignore: -------------------------------------------------------------------------------- 1 | bbox.c 2 | bbox.cpython-35m-x86_64-linux-gnu.so 3 | bbox.cpython-36m-x86_64-linux-gnu.so 4 | 5 | nms.c 6 | nms.cpython-35m-x86_64-linux-gnu.so 7 | nms.cpython-36m-x86_64-linux-gnu.so 8 | -------------------------------------------------------------------------------- /external/Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | python setup.py build_ext --inplace 3 | rm -rf build 4 | -------------------------------------------------------------------------------- /external/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-vl/CornerNet/e5c39a31a8abef5841976c8eab18da86d6ee5f9a/external/__init__.py -------------------------------------------------------------------------------- /external/nms.pyx: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | 8 | # ---------------------------------------------------------- 9 | # Soft-NMS: Improving Object Detection With One Line of Code 10 | # Copyright (c) University of Maryland, College Park 11 | # Licensed under The MIT License [see LICENSE for details] 12 | # Written by Navaneeth Bodla and Bharat Singh 13 | # ---------------------------------------------------------- 14 | 15 | import numpy as np 16 | cimport numpy as np 17 | 18 | cdef inline np.float32_t max(np.float32_t a, np.float32_t b): 19 | return a if a >= b else b 20 | 21 | cdef inline np.float32_t min(np.float32_t a, np.float32_t b): 22 | return a if a <= b else b 23 | 24 | def nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh): 25 | cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0] 26 | cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1] 27 | cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2] 28 | cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3] 29 | cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4] 30 | 31 | cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1 + 1) * (y2 - y1 + 1) 32 | cdef np.ndarray[np.int_t, ndim=1] order = scores.argsort()[::-1] 33 | 34 | cdef int ndets = dets.shape[0] 35 | cdef np.ndarray[np.int_t, ndim=1] suppressed = \ 36 | np.zeros((ndets), dtype=np.int) 37 | 38 | # nominal indices 39 | cdef int _i, _j 40 | # sorted indices 41 | cdef int i, j 42 | # temp variables for box i's (the box currently under consideration) 43 | cdef np.float32_t ix1, iy1, ix2, iy2, iarea 44 | # variables for computing overlap with box j (lower scoring box) 45 | cdef np.float32_t xx1, yy1, xx2, yy2 46 | cdef np.float32_t w, h 47 | cdef np.float32_t inter, ovr 48 | 49 | keep = [] 50 | for _i in range(ndets): 51 | i = order[_i] 52 | if suppressed[i] == 1: 53 | continue 54 | keep.append(i) 55 | ix1 = x1[i] 56 | iy1 = y1[i] 57 | ix2 = x2[i] 58 | iy2 = y2[i] 59 | iarea = areas[i] 60 | for _j in range(_i + 1, ndets): 61 | j = order[_j] 62 | if suppressed[j] == 1: 63 | continue 64 | xx1 = max(ix1, x1[j]) 65 | yy1 = max(iy1, y1[j]) 66 | xx2 = min(ix2, x2[j]) 67 | yy2 = min(iy2, y2[j]) 68 | w = max(0.0, xx2 - xx1 + 1) 69 | h = max(0.0, yy2 - yy1 + 1) 70 | inter = w * h 71 | ovr = inter / (iarea + areas[j] - inter) 72 | if ovr >= thresh: 73 | suppressed[j] = 1 74 | 75 | return keep 76 | 77 | def soft_nms(np.ndarray[float, ndim=2] boxes, float sigma=0.5, float Nt=0.3, float threshold=0.001, unsigned int method=0): 78 | cdef unsigned int N = boxes.shape[0] 79 | cdef float iw, ih, box_area 80 | cdef float ua 81 | cdef int pos = 0 82 | cdef float maxscore = 0 83 | cdef int maxpos = 0 84 | cdef float x1,x2,y1,y2,tx1,tx2,ty1,ty2,ts,area,weight,ov 85 | 86 | for i in range(N): 87 | maxscore = boxes[i, 4] 88 | maxpos = i 89 | 90 | tx1 = boxes[i,0] 91 | ty1 = boxes[i,1] 92 | tx2 = boxes[i,2] 93 | ty2 = boxes[i,3] 94 | ts = boxes[i,4] 95 | 96 | pos = i + 1 97 | # get max box 98 | while pos < N: 99 | if maxscore < boxes[pos, 4]: 100 | maxscore = boxes[pos, 4] 101 | maxpos = pos 102 | pos = pos + 1 103 | 104 | # add max box as a detection 105 | boxes[i,0] = boxes[maxpos,0] 106 | boxes[i,1] = boxes[maxpos,1] 107 | boxes[i,2] = boxes[maxpos,2] 108 | boxes[i,3] = boxes[maxpos,3] 109 | boxes[i,4] = boxes[maxpos,4] 110 | 111 | # swap ith box with position of max box 112 | boxes[maxpos,0] = tx1 113 | boxes[maxpos,1] = ty1 114 | boxes[maxpos,2] = tx2 115 | boxes[maxpos,3] = ty2 116 | boxes[maxpos,4] = ts 117 | 118 | tx1 = boxes[i,0] 119 | ty1 = boxes[i,1] 120 | tx2 = boxes[i,2] 121 | ty2 = boxes[i,3] 122 | ts = boxes[i,4] 123 | 124 | pos = i + 1 125 | # NMS iterations, note that N changes if detection boxes fall below threshold 126 | while pos < N: 127 | x1 = boxes[pos, 0] 128 | y1 = boxes[pos, 1] 129 | x2 = boxes[pos, 2] 130 | y2 = boxes[pos, 3] 131 | s = boxes[pos, 4] 132 | 133 | area = (x2 - x1 + 1) * (y2 - y1 + 1) 134 | iw = (min(tx2, x2) - max(tx1, x1) + 1) 135 | if iw > 0: 136 | ih = (min(ty2, y2) - max(ty1, y1) + 1) 137 | if ih > 0: 138 | ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih) 139 | ov = iw * ih / ua #iou between max box and detection box 140 | 141 | if method == 1: # linear 142 | if ov > Nt: 143 | weight = 1 - ov 144 | else: 145 | weight = 1 146 | elif method == 2: # gaussian 147 | weight = np.exp(-(ov * ov)/sigma) 148 | else: # original NMS 149 | if ov > Nt: 150 | weight = 0 151 | else: 152 | weight = 1 153 | 154 | boxes[pos, 4] = weight*boxes[pos, 4] 155 | 156 | # if box score falls below threshold, discard the box by swapping with last box 157 | # update N 158 | if boxes[pos, 4] < threshold: 159 | boxes[pos,0] = boxes[N-1, 0] 160 | boxes[pos,1] = boxes[N-1, 1] 161 | boxes[pos,2] = boxes[N-1, 2] 162 | boxes[pos,3] = boxes[N-1, 3] 163 | boxes[pos,4] = boxes[N-1, 4] 164 | N = N - 1 165 | pos = pos - 1 166 | 167 | pos = pos + 1 168 | 169 | keep = [i for i in range(N)] 170 | return keep 171 | 172 | def soft_nms_merge(np.ndarray[float, ndim=2] boxes, float sigma=0.5, float Nt=0.3, float threshold=0.001, unsigned int method=0, float weight_exp=6): 173 | cdef unsigned int N = boxes.shape[0] 174 | cdef float iw, ih, box_area 175 | cdef float ua 176 | cdef int pos = 0 177 | cdef float maxscore = 0 178 | cdef int maxpos = 0 179 | cdef float x1,x2,y1,y2,tx1,tx2,ty1,ty2,ts,area,weight,ov 180 | cdef float mx1,mx2,my1,my2,mts,mbs,mw 181 | 182 | for i in range(N): 183 | maxscore = boxes[i, 4] 184 | maxpos = i 185 | 186 | tx1 = boxes[i,0] 187 | ty1 = boxes[i,1] 188 | tx2 = boxes[i,2] 189 | ty2 = boxes[i,3] 190 | ts = boxes[i,4] 191 | 192 | pos = i + 1 193 | # get max box 194 | while pos < N: 195 | if maxscore < boxes[pos, 4]: 196 | maxscore = boxes[pos, 4] 197 | maxpos = pos 198 | pos = pos + 1 199 | 200 | # add max box as a detection 201 | boxes[i,0] = boxes[maxpos,0] 202 | boxes[i,1] = boxes[maxpos,1] 203 | boxes[i,2] = boxes[maxpos,2] 204 | boxes[i,3] = boxes[maxpos,3] 205 | boxes[i,4] = boxes[maxpos,4] 206 | 207 | mx1 = boxes[i, 0] * boxes[i, 5] 208 | my1 = boxes[i, 1] * boxes[i, 5] 209 | mx2 = boxes[i, 2] * boxes[i, 6] 210 | my2 = boxes[i, 3] * boxes[i, 6] 211 | mts = boxes[i, 5] 212 | mbs = boxes[i, 6] 213 | 214 | # swap ith box with position of max box 215 | boxes[maxpos,0] = tx1 216 | boxes[maxpos,1] = ty1 217 | boxes[maxpos,2] = tx2 218 | boxes[maxpos,3] = ty2 219 | boxes[maxpos,4] = ts 220 | 221 | tx1 = boxes[i,0] 222 | ty1 = boxes[i,1] 223 | tx2 = boxes[i,2] 224 | ty2 = boxes[i,3] 225 | ts = boxes[i,4] 226 | 227 | pos = i + 1 228 | # NMS iterations, note that N changes if detection boxes fall below threshold 229 | while pos < N: 230 | x1 = boxes[pos, 0] 231 | y1 = boxes[pos, 1] 232 | x2 = boxes[pos, 2] 233 | y2 = boxes[pos, 3] 234 | s = boxes[pos, 4] 235 | 236 | area = (x2 - x1 + 1) * (y2 - y1 + 1) 237 | iw = (min(tx2, x2) - max(tx1, x1) + 1) 238 | if iw > 0: 239 | ih = (min(ty2, y2) - max(ty1, y1) + 1) 240 | if ih > 0: 241 | ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih) 242 | ov = iw * ih / ua #iou between max box and detection box 243 | 244 | if method == 1: # linear 245 | if ov > Nt: 246 | weight = 1 - ov 247 | else: 248 | weight = 1 249 | elif method == 2: # gaussian 250 | weight = np.exp(-(ov * ov)/sigma) 251 | else: # original NMS 252 | if ov > Nt: 253 | weight = 0 254 | else: 255 | weight = 1 256 | 257 | mw = (1 - weight) ** weight_exp 258 | mx1 = mx1 + boxes[pos, 0] * boxes[pos, 5] * mw 259 | my1 = my1 + boxes[pos, 1] * boxes[pos, 5] * mw 260 | mx2 = mx2 + boxes[pos, 2] * boxes[pos, 6] * mw 261 | my2 = my2 + boxes[pos, 3] * boxes[pos, 6] * mw 262 | mts = mts + boxes[pos, 5] * mw 263 | mbs = mbs + boxes[pos, 6] * mw 264 | 265 | boxes[pos, 4] = weight*boxes[pos, 4] 266 | 267 | # if box score falls below threshold, discard the box by swapping with last box 268 | # update N 269 | if boxes[pos, 4] < threshold: 270 | boxes[pos,0] = boxes[N-1, 0] 271 | boxes[pos,1] = boxes[N-1, 1] 272 | boxes[pos,2] = boxes[N-1, 2] 273 | boxes[pos,3] = boxes[N-1, 3] 274 | boxes[pos,4] = boxes[N-1, 4] 275 | N = N - 1 276 | pos = pos - 1 277 | 278 | pos = pos + 1 279 | 280 | boxes[i, 0] = mx1 / mts 281 | boxes[i, 1] = my1 / mts 282 | boxes[i, 2] = mx2 / mbs 283 | boxes[i, 3] = my2 / mbs 284 | 285 | keep = [i for i in range(N)] 286 | return keep 287 | -------------------------------------------------------------------------------- /external/setup.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | from distutils.core import setup 3 | from distutils.extension import Extension 4 | from Cython.Build import cythonize 5 | 6 | extensions = [ 7 | Extension( 8 | "nms", 9 | ["nms.pyx"], 10 | extra_compile_args=["-Wno-cpp", "-Wno-unused-function"] 11 | ) 12 | ] 13 | 14 | setup( 15 | name="coco", 16 | ext_modules=cythonize(extensions), 17 | include_dirs=[numpy.get_include()] 18 | ) 19 | -------------------------------------------------------------------------------- /models/CornerNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .py_utils import kp, AELoss, _neg_loss, convolution, residual 5 | from .py_utils import TopPool, BottomPool, LeftPool, RightPool 6 | 7 | class pool(nn.Module): 8 | def __init__(self, dim, pool1, pool2): 9 | super(pool, self).__init__() 10 | self.p1_conv1 = convolution(3, dim, 128) 11 | self.p2_conv1 = convolution(3, dim, 128) 12 | 13 | self.p_conv1 = nn.Conv2d(128, dim, (3, 3), padding=(1, 1), bias=False) 14 | self.p_bn1 = nn.BatchNorm2d(dim) 15 | 16 | self.conv1 = nn.Conv2d(dim, dim, (1, 1), bias=False) 17 | self.bn1 = nn.BatchNorm2d(dim) 18 | self.relu1 = nn.ReLU(inplace=True) 19 | 20 | self.conv2 = convolution(3, dim, dim) 21 | 22 | self.pool1 = pool1() 23 | self.pool2 = pool2() 24 | 25 | def forward(self, x): 26 | # pool 1 27 | p1_conv1 = self.p1_conv1(x) 28 | pool1 = self.pool1(p1_conv1) 29 | 30 | # pool 2 31 | p2_conv1 = self.p2_conv1(x) 32 | pool2 = self.pool2(p2_conv1) 33 | 34 | # pool 1 + pool 2 35 | p_conv1 = self.p_conv1(pool1 + pool2) 36 | p_bn1 = self.p_bn1(p_conv1) 37 | 38 | conv1 = self.conv1(x) 39 | bn1 = self.bn1(conv1) 40 | relu1 = self.relu1(p_bn1 + bn1) 41 | 42 | conv2 = self.conv2(relu1) 43 | return conv2 44 | 45 | class tl_pool(pool): 46 | def __init__(self, dim): 47 | super(tl_pool, self).__init__(dim, TopPool, LeftPool) 48 | 49 | class br_pool(pool): 50 | def __init__(self, dim): 51 | super(br_pool, self).__init__(dim, BottomPool, RightPool) 52 | 53 | def make_tl_layer(dim): 54 | return tl_pool(dim) 55 | 56 | def make_br_layer(dim): 57 | return br_pool(dim) 58 | 59 | def make_pool_layer(dim): 60 | return nn.Sequential() 61 | 62 | def make_hg_layer(kernel, dim0, dim1, mod, layer=convolution, **kwargs): 63 | layers = [layer(kernel, dim0, dim1, stride=2)] 64 | layers += [layer(kernel, dim1, dim1) for _ in range(mod - 1)] 65 | return nn.Sequential(*layers) 66 | 67 | class model(kp): 68 | def __init__(self, db): 69 | n = 5 70 | dims = [256, 256, 384, 384, 384, 512] 71 | modules = [2, 2, 2, 2, 2, 4] 72 | out_dim = 80 73 | 74 | super(model, self).__init__( 75 | n, 2, dims, modules, out_dim, 76 | make_tl_layer=make_tl_layer, 77 | make_br_layer=make_br_layer, 78 | make_pool_layer=make_pool_layer, 79 | make_hg_layer=make_hg_layer, 80 | kp_layer=residual, cnv_dim=256 81 | ) 82 | 83 | loss = AELoss(pull_weight=1e-1, push_weight=1e-1, focal_loss=_neg_loss) 84 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-vl/CornerNet/e5c39a31a8abef5841976c8eab18da86d6ee5f9a/models/__init__.py -------------------------------------------------------------------------------- /models/py_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .kp import kp, AELoss 2 | from .kp_utils import _neg_loss 3 | 4 | from .utils import convolution, fully_connected, residual 5 | 6 | from ._cpools import TopPool, BottomPool, LeftPool, RightPool 7 | -------------------------------------------------------------------------------- /models/py_utils/_cpools/.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | cpools.egg-info/ 3 | dist/ 4 | -------------------------------------------------------------------------------- /models/py_utils/_cpools/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import nn 4 | from torch.autograd import Function 5 | 6 | import top_pool, bottom_pool, left_pool, right_pool 7 | 8 | class TopPoolFunction(Function): 9 | @staticmethod 10 | def forward(ctx, input): 11 | output = top_pool.forward(input)[0] 12 | ctx.save_for_backward(input) 13 | return output 14 | 15 | @staticmethod 16 | def backward(ctx, grad_output): 17 | input = ctx.saved_variables[0] 18 | output = top_pool.backward(input, grad_output)[0] 19 | return output 20 | 21 | class BottomPoolFunction(Function): 22 | @staticmethod 23 | def forward(ctx, input): 24 | output = bottom_pool.forward(input)[0] 25 | ctx.save_for_backward(input) 26 | return output 27 | 28 | @staticmethod 29 | def backward(ctx, grad_output): 30 | input = ctx.saved_variables[0] 31 | output = bottom_pool.backward(input, grad_output)[0] 32 | return output 33 | 34 | class LeftPoolFunction(Function): 35 | @staticmethod 36 | def forward(ctx, input): 37 | output = left_pool.forward(input)[0] 38 | ctx.save_for_backward(input) 39 | return output 40 | 41 | @staticmethod 42 | def backward(ctx, grad_output): 43 | input = ctx.saved_variables[0] 44 | output = left_pool.backward(input, grad_output)[0] 45 | return output 46 | 47 | class RightPoolFunction(Function): 48 | @staticmethod 49 | def forward(ctx, input): 50 | output = right_pool.forward(input)[0] 51 | ctx.save_for_backward(input) 52 | return output 53 | 54 | @staticmethod 55 | def backward(ctx, grad_output): 56 | input = ctx.saved_variables[0] 57 | output = right_pool.backward(input, grad_output)[0] 58 | return output 59 | 60 | class TopPool(nn.Module): 61 | def forward(self, x): 62 | return TopPoolFunction.apply(x) 63 | 64 | class BottomPool(nn.Module): 65 | def forward(self, x): 66 | return BottomPoolFunction.apply(x) 67 | 68 | class LeftPool(nn.Module): 69 | def forward(self, x): 70 | return LeftPoolFunction.apply(x) 71 | 72 | class RightPool(nn.Module): 73 | def forward(self, x): 74 | return RightPoolFunction.apply(x) 75 | -------------------------------------------------------------------------------- /models/py_utils/_cpools/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CppExtension 3 | 4 | setup( 5 | name="cpools", 6 | ext_modules=[ 7 | CppExtension("top_pool", ["src/top_pool.cpp"]), 8 | CppExtension("bottom_pool", ["src/bottom_pool.cpp"]), 9 | CppExtension("left_pool", ["src/left_pool.cpp"]), 10 | CppExtension("right_pool", ["src/right_pool.cpp"]) 11 | ], 12 | cmdclass={ 13 | "build_ext": BuildExtension 14 | } 15 | ) 16 | -------------------------------------------------------------------------------- /models/py_utils/_cpools/src/bottom_pool.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | std::vector pool_forward( 6 | at::Tensor input 7 | ) { 8 | // Initialize output 9 | at::Tensor output = at::zeros_like(input); 10 | 11 | // Get height 12 | int64_t height = input.size(2); 13 | 14 | // Copy the last column 15 | at::Tensor input_temp = input.select(2, 0); 16 | at::Tensor output_temp = output.select(2, 0); 17 | output_temp.copy_(input_temp); 18 | 19 | at::Tensor max_temp; 20 | for (int64_t ind = 0; ind < height - 1; ++ind) { 21 | input_temp = input.select(2, ind + 1); 22 | output_temp = output.select(2, ind); 23 | max_temp = output.select(2, ind + 1); 24 | 25 | at::max_out(max_temp, input_temp, output_temp); 26 | } 27 | 28 | return { 29 | output 30 | }; 31 | } 32 | 33 | std::vector pool_backward( 34 | at::Tensor input, 35 | at::Tensor grad_output 36 | ) { 37 | auto output = at::zeros_like(input); 38 | 39 | int32_t batch = input.size(0); 40 | int32_t channel = input.size(1); 41 | int32_t height = input.size(2); 42 | int32_t width = input.size(3); 43 | 44 | auto max_val = at::zeros(torch::CUDA(at::kFloat), {batch, channel, width}); 45 | auto max_ind = at::zeros(torch::CUDA(at::kLong), {batch, channel, width}); 46 | 47 | auto input_temp = input.select(2, 0); 48 | max_val.copy_(input_temp); 49 | 50 | max_ind.fill_(0); 51 | 52 | auto output_temp = output.select(2, 0); 53 | auto grad_output_temp = grad_output.select(2, 0); 54 | output_temp.copy_(grad_output_temp); 55 | 56 | auto un_max_ind = max_ind.unsqueeze(2); 57 | auto gt_mask = at::zeros(torch::CUDA(at::kByte), {batch, channel, width}); 58 | auto max_temp = at::zeros(torch::CUDA(at::kFloat), {batch, channel, width}); 59 | for (int32_t ind = 0; ind < height - 1; ++ind) { 60 | input_temp = input.select(2, ind + 1); 61 | at::gt_out(gt_mask, input_temp, max_val); 62 | 63 | at::masked_select_out(max_temp, input_temp, gt_mask); 64 | max_val.masked_scatter_(gt_mask, max_temp); 65 | max_ind.masked_fill_(gt_mask, ind + 1); 66 | 67 | grad_output_temp = grad_output.select(2, ind + 1).unsqueeze(2); 68 | output.scatter_add_(2, un_max_ind, grad_output_temp); 69 | } 70 | 71 | return { 72 | output 73 | }; 74 | } 75 | 76 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 77 | m.def( 78 | "forward", &pool_forward, "Bottom Pool Forward", 79 | py::call_guard() 80 | ); 81 | m.def( 82 | "backward", &pool_backward, "Bottom Pool Backward", 83 | py::call_guard() 84 | ); 85 | } 86 | -------------------------------------------------------------------------------- /models/py_utils/_cpools/src/left_pool.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | std::vector pool_forward( 6 | at::Tensor input 7 | ) { 8 | // Initialize output 9 | at::Tensor output = at::zeros_like(input); 10 | 11 | // Get width 12 | int64_t width = input.size(3); 13 | 14 | // Copy the last column 15 | at::Tensor input_temp = input.select(3, width - 1); 16 | at::Tensor output_temp = output.select(3, width - 1); 17 | output_temp.copy_(input_temp); 18 | 19 | at::Tensor max_temp; 20 | for (int64_t ind = 1; ind < width; ++ind) { 21 | input_temp = input.select(3, width - ind - 1); 22 | output_temp = output.select(3, width - ind); 23 | max_temp = output.select(3, width - ind - 1); 24 | 25 | at::max_out(max_temp, input_temp, output_temp); 26 | } 27 | 28 | return { 29 | output 30 | }; 31 | } 32 | 33 | std::vector pool_backward( 34 | at::Tensor input, 35 | at::Tensor grad_output 36 | ) { 37 | auto output = at::zeros_like(input); 38 | 39 | int32_t batch = input.size(0); 40 | int32_t channel = input.size(1); 41 | int32_t height = input.size(2); 42 | int32_t width = input.size(3); 43 | 44 | auto max_val = at::zeros(torch::CUDA(at::kFloat), {batch, channel, height}); 45 | auto max_ind = at::zeros(torch::CUDA(at::kLong), {batch, channel, height}); 46 | 47 | auto input_temp = input.select(3, width - 1); 48 | max_val.copy_(input_temp); 49 | 50 | max_ind.fill_(width - 1); 51 | 52 | auto output_temp = output.select(3, width - 1); 53 | auto grad_output_temp = grad_output.select(3, width - 1); 54 | output_temp.copy_(grad_output_temp); 55 | 56 | auto un_max_ind = max_ind.unsqueeze(3); 57 | auto gt_mask = at::zeros(torch::CUDA(at::kByte), {batch, channel, height}); 58 | auto max_temp = at::zeros(torch::CUDA(at::kFloat), {batch, channel, height}); 59 | for (int32_t ind = 1; ind < width; ++ind) { 60 | input_temp = input.select(3, width - ind - 1); 61 | at::gt_out(gt_mask, input_temp, max_val); 62 | 63 | at::masked_select_out(max_temp, input_temp, gt_mask); 64 | max_val.masked_scatter_(gt_mask, max_temp); 65 | max_ind.masked_fill_(gt_mask, width - ind - 1); 66 | 67 | grad_output_temp = grad_output.select(3, width - ind - 1).unsqueeze(3); 68 | output.scatter_add_(3, un_max_ind, grad_output_temp); 69 | } 70 | 71 | return { 72 | output 73 | }; 74 | } 75 | 76 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 77 | m.def( 78 | "forward", &pool_forward, "Left Pool Forward", 79 | py::call_guard() 80 | ); 81 | m.def( 82 | "backward", &pool_backward, "Left Pool Backward", 83 | py::call_guard() 84 | ); 85 | } 86 | -------------------------------------------------------------------------------- /models/py_utils/_cpools/src/right_pool.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | std::vector pool_forward( 6 | at::Tensor input 7 | ) { 8 | // Initialize output 9 | at::Tensor output = at::zeros_like(input); 10 | 11 | // Get width 12 | int64_t width = input.size(3); 13 | 14 | // Copy the last column 15 | at::Tensor input_temp = input.select(3, 0); 16 | at::Tensor output_temp = output.select(3, 0); 17 | output_temp.copy_(input_temp); 18 | 19 | at::Tensor max_temp; 20 | for (int64_t ind = 0; ind < width - 1; ++ind) { 21 | input_temp = input.select(3, ind + 1); 22 | output_temp = output.select(3, ind); 23 | max_temp = output.select(3, ind + 1); 24 | 25 | at::max_out(max_temp, input_temp, output_temp); 26 | } 27 | 28 | return { 29 | output 30 | }; 31 | } 32 | 33 | std::vector pool_backward( 34 | at::Tensor input, 35 | at::Tensor grad_output 36 | ) { 37 | at::Tensor output = at::zeros_like(input); 38 | 39 | int32_t batch = input.size(0); 40 | int32_t channel = input.size(1); 41 | int32_t height = input.size(2); 42 | int32_t width = input.size(3); 43 | 44 | auto max_val = at::zeros(torch::CUDA(at::kFloat), {batch, channel, height}); 45 | auto max_ind = at::zeros(torch::CUDA(at::kLong), {batch, channel, height}); 46 | 47 | auto input_temp = input.select(3, 0); 48 | max_val.copy_(input_temp); 49 | 50 | max_ind.fill_(0); 51 | 52 | auto output_temp = output.select(3, 0); 53 | auto grad_output_temp = grad_output.select(3, 0); 54 | output_temp.copy_(grad_output_temp); 55 | 56 | auto un_max_ind = max_ind.unsqueeze(3); 57 | auto gt_mask = at::zeros(torch::CUDA(at::kByte), {batch, channel, height}); 58 | auto max_temp = at::zeros(torch::CUDA(at::kFloat), {batch, channel, height}); 59 | for (int32_t ind = 0; ind < width - 1; ++ind) { 60 | input_temp = input.select(3, ind + 1); 61 | at::gt_out(gt_mask, input_temp, max_val); 62 | 63 | at::masked_select_out(max_temp, input_temp, gt_mask); 64 | max_val.masked_scatter_(gt_mask, max_temp); 65 | max_ind.masked_fill_(gt_mask, ind + 1); 66 | 67 | grad_output_temp = grad_output.select(3, ind + 1).unsqueeze(3); 68 | output.scatter_add_(3, un_max_ind, grad_output_temp); 69 | } 70 | 71 | return { 72 | output 73 | }; 74 | } 75 | 76 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 77 | m.def( 78 | "forward", &pool_forward, "Right Pool Forward", 79 | py::call_guard() 80 | ); 81 | m.def( 82 | "backward", &pool_backward, "Right Pool Backward", 83 | py::call_guard() 84 | ); 85 | } 86 | -------------------------------------------------------------------------------- /models/py_utils/_cpools/src/top_pool.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | std::vector top_pool_forward( 6 | at::Tensor input 7 | ) { 8 | // Initialize output 9 | at::Tensor output = at::zeros_like(input); 10 | 11 | // Get height 12 | int64_t height = input.size(2); 13 | 14 | // Copy the last column 15 | at::Tensor input_temp = input.select(2, height - 1); 16 | at::Tensor output_temp = output.select(2, height - 1); 17 | output_temp.copy_(input_temp); 18 | 19 | at::Tensor max_temp; 20 | for (int64_t ind = 1; ind < height; ++ind) { 21 | input_temp = input.select(2, height - ind - 1); 22 | output_temp = output.select(2, height - ind); 23 | max_temp = output.select(2, height - ind - 1); 24 | 25 | at::max_out(max_temp, input_temp, output_temp); 26 | } 27 | 28 | return { 29 | output 30 | }; 31 | } 32 | 33 | std::vector top_pool_backward( 34 | at::Tensor input, 35 | at::Tensor grad_output 36 | ) { 37 | auto output = at::zeros_like(input); 38 | 39 | int32_t batch = input.size(0); 40 | int32_t channel = input.size(1); 41 | int32_t height = input.size(2); 42 | int32_t width = input.size(3); 43 | 44 | auto max_val = at::zeros(torch::CUDA(at::kFloat), {batch, channel, width}); 45 | auto max_ind = at::zeros(torch::CUDA(at::kLong), {batch, channel, width}); 46 | 47 | auto input_temp = input.select(2, height - 1); 48 | max_val.copy_(input_temp); 49 | 50 | max_ind.fill_(height - 1); 51 | 52 | auto output_temp = output.select(2, height - 1); 53 | auto grad_output_temp = grad_output.select(2, height - 1); 54 | output_temp.copy_(grad_output_temp); 55 | 56 | auto un_max_ind = max_ind.unsqueeze(2); 57 | auto gt_mask = at::zeros(torch::CUDA(at::kByte), {batch, channel, width}); 58 | auto max_temp = at::zeros(torch::CUDA(at::kFloat), {batch, channel, width}); 59 | for (int32_t ind = 1; ind < height; ++ind) { 60 | input_temp = input.select(2, height - ind - 1); 61 | at::gt_out(gt_mask, input_temp, max_val); 62 | 63 | at::masked_select_out(max_temp, input_temp, gt_mask); 64 | max_val.masked_scatter_(gt_mask, max_temp); 65 | max_ind.masked_fill_(gt_mask, height - ind - 1); 66 | 67 | grad_output_temp = grad_output.select(2, height - ind - 1).unsqueeze(2); 68 | output.scatter_add_(2, un_max_ind, grad_output_temp); 69 | } 70 | 71 | return { 72 | output 73 | }; 74 | } 75 | 76 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 77 | m.def( 78 | "forward", &top_pool_forward, "Top Pool Forward", 79 | py::call_guard() 80 | ); 81 | m.def( 82 | "backward", &top_pool_backward, "Top Pool Backward", 83 | py::call_guard() 84 | ); 85 | } 86 | -------------------------------------------------------------------------------- /models/py_utils/data_parallel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules import Module 3 | from torch.nn.parallel.scatter_gather import gather 4 | from torch.nn.parallel.replicate import replicate 5 | from torch.nn.parallel.parallel_apply import parallel_apply 6 | 7 | from .scatter_gather import scatter_kwargs 8 | 9 | class DataParallel(Module): 10 | r"""Implements data parallelism at the module level. 11 | 12 | This container parallelizes the application of the given module by 13 | splitting the input across the specified devices by chunking in the batch 14 | dimension. In the forward pass, the module is replicated on each device, 15 | and each replica handles a portion of the input. During the backwards 16 | pass, gradients from each replica are summed into the original module. 17 | 18 | The batch size should be larger than the number of GPUs used. It should 19 | also be an integer multiple of the number of GPUs so that each chunk is the 20 | same size (so that each GPU processes the same number of samples). 21 | 22 | See also: :ref:`cuda-nn-dataparallel-instead` 23 | 24 | Arbitrary positional and keyword inputs are allowed to be passed into 25 | DataParallel EXCEPT Tensors. All variables will be scattered on dim 26 | specified (default 0). Primitive types will be broadcasted, but all 27 | other types will be a shallow copy and can be corrupted if written to in 28 | the model's forward pass. 29 | 30 | Args: 31 | module: module to be parallelized 32 | device_ids: CUDA devices (default: all devices) 33 | output_device: device location of output (default: device_ids[0]) 34 | 35 | Example:: 36 | 37 | >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) 38 | >>> output = net(input_var) 39 | """ 40 | 41 | # TODO: update notes/cuda.rst when this class handles 8+ GPUs well 42 | 43 | def __init__(self, module, device_ids=None, output_device=None, dim=0, chunk_sizes=None): 44 | super(DataParallel, self).__init__() 45 | 46 | if not torch.cuda.is_available(): 47 | self.module = module 48 | self.device_ids = [] 49 | return 50 | 51 | if device_ids is None: 52 | device_ids = list(range(torch.cuda.device_count())) 53 | if output_device is None: 54 | output_device = device_ids[0] 55 | self.dim = dim 56 | self.module = module 57 | self.device_ids = device_ids 58 | self.chunk_sizes = chunk_sizes 59 | self.output_device = output_device 60 | if len(self.device_ids) == 1: 61 | self.module.cuda(device_ids[0]) 62 | 63 | def forward(self, *inputs, **kwargs): 64 | if not self.device_ids: 65 | return self.module(*inputs, **kwargs) 66 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids, self.chunk_sizes) 67 | if len(self.device_ids) == 1: 68 | return self.module(*inputs[0], **kwargs[0]) 69 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 70 | outputs = self.parallel_apply(replicas, inputs, kwargs) 71 | return self.gather(outputs, self.output_device) 72 | 73 | def replicate(self, module, device_ids): 74 | return replicate(module, device_ids) 75 | 76 | def scatter(self, inputs, kwargs, device_ids, chunk_sizes): 77 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim, chunk_sizes=self.chunk_sizes) 78 | 79 | def parallel_apply(self, replicas, inputs, kwargs): 80 | return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) 81 | 82 | def gather(self, outputs, output_device): 83 | return gather(outputs, output_device, dim=self.dim) 84 | 85 | 86 | def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None): 87 | r"""Evaluates module(input) in parallel across the GPUs given in device_ids. 88 | 89 | This is the functional version of the DataParallel module. 90 | 91 | Args: 92 | module: the module to evaluate in parallel 93 | inputs: inputs to the module 94 | device_ids: GPU ids on which to replicate module 95 | output_device: GPU location of the output Use -1 to indicate the CPU. 96 | (default: device_ids[0]) 97 | Returns: 98 | a Variable containing the result of module(input) located on 99 | output_device 100 | """ 101 | if not isinstance(inputs, tuple): 102 | inputs = (inputs,) 103 | 104 | if device_ids is None: 105 | device_ids = list(range(torch.cuda.device_count())) 106 | 107 | if output_device is None: 108 | output_device = device_ids[0] 109 | 110 | inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim) 111 | if len(device_ids) == 1: 112 | return module(*inputs[0], **module_kwargs[0]) 113 | used_device_ids = device_ids[:len(inputs)] 114 | replicas = replicate(module, used_device_ids) 115 | outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids) 116 | return gather(outputs, output_device, dim) 117 | -------------------------------------------------------------------------------- /models/py_utils/kp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from .utils import convolution, residual 6 | from .utils import make_layer, make_layer_revr 7 | 8 | from .kp_utils import _tranpose_and_gather_feat, _decode 9 | from .kp_utils import _sigmoid, _ae_loss, _regr_loss, _neg_loss 10 | from .kp_utils import make_tl_layer, make_br_layer, make_kp_layer 11 | from .kp_utils import make_pool_layer, make_unpool_layer 12 | from .kp_utils import make_merge_layer, make_inter_layer, make_cnv_layer 13 | 14 | class kp_module(nn.Module): 15 | def __init__( 16 | self, n, dims, modules, layer=residual, 17 | make_up_layer=make_layer, make_low_layer=make_layer, 18 | make_hg_layer=make_layer, make_hg_layer_revr=make_layer_revr, 19 | make_pool_layer=make_pool_layer, make_unpool_layer=make_unpool_layer, 20 | make_merge_layer=make_merge_layer, **kwargs 21 | ): 22 | super(kp_module, self).__init__() 23 | 24 | self.n = n 25 | 26 | curr_mod = modules[0] 27 | next_mod = modules[1] 28 | 29 | curr_dim = dims[0] 30 | next_dim = dims[1] 31 | 32 | self.up1 = make_up_layer( 33 | 3, curr_dim, curr_dim, curr_mod, 34 | layer=layer, **kwargs 35 | ) 36 | self.max1 = make_pool_layer(curr_dim) 37 | self.low1 = make_hg_layer( 38 | 3, curr_dim, next_dim, curr_mod, 39 | layer=layer, **kwargs 40 | ) 41 | self.low2 = kp_module( 42 | n - 1, dims[1:], modules[1:], layer=layer, 43 | make_up_layer=make_up_layer, 44 | make_low_layer=make_low_layer, 45 | make_hg_layer=make_hg_layer, 46 | make_hg_layer_revr=make_hg_layer_revr, 47 | make_pool_layer=make_pool_layer, 48 | make_unpool_layer=make_unpool_layer, 49 | make_merge_layer=make_merge_layer, 50 | **kwargs 51 | ) if self.n > 1 else \ 52 | make_low_layer( 53 | 3, next_dim, next_dim, next_mod, 54 | layer=layer, **kwargs 55 | ) 56 | self.low3 = make_hg_layer_revr( 57 | 3, next_dim, curr_dim, curr_mod, 58 | layer=layer, **kwargs 59 | ) 60 | self.up2 = make_unpool_layer(curr_dim) 61 | 62 | self.merge = make_merge_layer(curr_dim) 63 | 64 | def forward(self, x): 65 | up1 = self.up1(x) 66 | max1 = self.max1(x) 67 | low1 = self.low1(max1) 68 | low2 = self.low2(low1) 69 | low3 = self.low3(low2) 70 | up2 = self.up2(low3) 71 | return self.merge(up1, up2) 72 | 73 | class kp(nn.Module): 74 | def __init__( 75 | self, n, nstack, dims, modules, out_dim, pre=None, cnv_dim=256, 76 | make_tl_layer=make_tl_layer, make_br_layer=make_br_layer, 77 | make_cnv_layer=make_cnv_layer, make_heat_layer=make_kp_layer, 78 | make_tag_layer=make_kp_layer, make_regr_layer=make_kp_layer, 79 | make_up_layer=make_layer, make_low_layer=make_layer, 80 | make_hg_layer=make_layer, make_hg_layer_revr=make_layer_revr, 81 | make_pool_layer=make_pool_layer, make_unpool_layer=make_unpool_layer, 82 | make_merge_layer=make_merge_layer, make_inter_layer=make_inter_layer, 83 | kp_layer=residual 84 | ): 85 | super(kp, self).__init__() 86 | 87 | self.nstack = nstack 88 | self._decode = _decode 89 | 90 | curr_dim = dims[0] 91 | 92 | self.pre = nn.Sequential( 93 | convolution(7, 3, 128, stride=2), 94 | residual(3, 128, 256, stride=2) 95 | ) if pre is None else pre 96 | 97 | self.kps = nn.ModuleList([ 98 | kp_module( 99 | n, dims, modules, layer=kp_layer, 100 | make_up_layer=make_up_layer, 101 | make_low_layer=make_low_layer, 102 | make_hg_layer=make_hg_layer, 103 | make_hg_layer_revr=make_hg_layer_revr, 104 | make_pool_layer=make_pool_layer, 105 | make_unpool_layer=make_unpool_layer, 106 | make_merge_layer=make_merge_layer 107 | ) for _ in range(nstack) 108 | ]) 109 | self.cnvs = nn.ModuleList([ 110 | make_cnv_layer(curr_dim, cnv_dim) for _ in range(nstack) 111 | ]) 112 | 113 | self.tl_cnvs = nn.ModuleList([ 114 | make_tl_layer(cnv_dim) for _ in range(nstack) 115 | ]) 116 | self.br_cnvs = nn.ModuleList([ 117 | make_br_layer(cnv_dim) for _ in range(nstack) 118 | ]) 119 | 120 | ## keypoint heatmaps 121 | self.tl_heats = nn.ModuleList([ 122 | make_heat_layer(cnv_dim, curr_dim, out_dim) for _ in range(nstack) 123 | ]) 124 | self.br_heats = nn.ModuleList([ 125 | make_heat_layer(cnv_dim, curr_dim, out_dim) for _ in range(nstack) 126 | ]) 127 | 128 | ## tags 129 | self.tl_tags = nn.ModuleList([ 130 | make_tag_layer(cnv_dim, curr_dim, 1) for _ in range(nstack) 131 | ]) 132 | self.br_tags = nn.ModuleList([ 133 | make_tag_layer(cnv_dim, curr_dim, 1) for _ in range(nstack) 134 | ]) 135 | 136 | for tl_heat, br_heat in zip(self.tl_heats, self.br_heats): 137 | tl_heat[-1].bias.data.fill_(-2.19) 138 | br_heat[-1].bias.data.fill_(-2.19) 139 | 140 | self.inters = nn.ModuleList([ 141 | make_inter_layer(curr_dim) for _ in range(nstack - 1) 142 | ]) 143 | 144 | self.inters_ = nn.ModuleList([ 145 | nn.Sequential( 146 | nn.Conv2d(curr_dim, curr_dim, (1, 1), bias=False), 147 | nn.BatchNorm2d(curr_dim) 148 | ) for _ in range(nstack - 1) 149 | ]) 150 | self.cnvs_ = nn.ModuleList([ 151 | nn.Sequential( 152 | nn.Conv2d(cnv_dim, curr_dim, (1, 1), bias=False), 153 | nn.BatchNorm2d(curr_dim) 154 | ) for _ in range(nstack - 1) 155 | ]) 156 | 157 | self.tl_regrs = nn.ModuleList([ 158 | make_regr_layer(cnv_dim, curr_dim, 2) for _ in range(nstack) 159 | ]) 160 | self.br_regrs = nn.ModuleList([ 161 | make_regr_layer(cnv_dim, curr_dim, 2) for _ in range(nstack) 162 | ]) 163 | 164 | self.relu = nn.ReLU(inplace=True) 165 | 166 | def _train(self, *xs): 167 | image = xs[0] 168 | tl_inds = xs[1] 169 | br_inds = xs[2] 170 | 171 | inter = self.pre(image) 172 | outs = [] 173 | 174 | layers = zip( 175 | self.kps, self.cnvs, 176 | self.tl_cnvs, self.br_cnvs, 177 | self.tl_heats, self.br_heats, 178 | self.tl_tags, self.br_tags, 179 | self.tl_regrs, self.br_regrs 180 | ) 181 | for ind, layer in enumerate(layers): 182 | kp_, cnv_ = layer[0:2] 183 | tl_cnv_, br_cnv_ = layer[2:4] 184 | tl_heat_, br_heat_ = layer[4:6] 185 | tl_tag_, br_tag_ = layer[6:8] 186 | tl_regr_, br_regr_ = layer[8:10] 187 | 188 | kp = kp_(inter) 189 | cnv = cnv_(kp) 190 | 191 | tl_cnv = tl_cnv_(cnv) 192 | br_cnv = br_cnv_(cnv) 193 | 194 | tl_heat, br_heat = tl_heat_(tl_cnv), br_heat_(br_cnv) 195 | tl_tag, br_tag = tl_tag_(tl_cnv), br_tag_(br_cnv) 196 | tl_regr, br_regr = tl_regr_(tl_cnv), br_regr_(br_cnv) 197 | 198 | tl_tag = _tranpose_and_gather_feat(tl_tag, tl_inds) 199 | br_tag = _tranpose_and_gather_feat(br_tag, br_inds) 200 | tl_regr = _tranpose_and_gather_feat(tl_regr, tl_inds) 201 | br_regr = _tranpose_and_gather_feat(br_regr, br_inds) 202 | 203 | outs += [tl_heat, br_heat, tl_tag, br_tag, tl_regr, br_regr] 204 | 205 | if ind < self.nstack - 1: 206 | inter = self.inters_[ind](inter) + self.cnvs_[ind](cnv) 207 | inter = self.relu(inter) 208 | inter = self.inters[ind](inter) 209 | return outs 210 | 211 | def _test(self, *xs, **kwargs): 212 | image = xs[0] 213 | 214 | inter = self.pre(image) 215 | outs = [] 216 | 217 | layers = zip( 218 | self.kps, self.cnvs, 219 | self.tl_cnvs, self.br_cnvs, 220 | self.tl_heats, self.br_heats, 221 | self.tl_tags, self.br_tags, 222 | self.tl_regrs, self.br_regrs 223 | ) 224 | for ind, layer in enumerate(layers): 225 | kp_, cnv_ = layer[0:2] 226 | tl_cnv_, br_cnv_ = layer[2:4] 227 | tl_heat_, br_heat_ = layer[4:6] 228 | tl_tag_, br_tag_ = layer[6:8] 229 | tl_regr_, br_regr_ = layer[8:10] 230 | 231 | kp = kp_(inter) 232 | cnv = cnv_(kp) 233 | 234 | if ind == self.nstack - 1: 235 | tl_cnv = tl_cnv_(cnv) 236 | br_cnv = br_cnv_(cnv) 237 | 238 | tl_heat, br_heat = tl_heat_(tl_cnv), br_heat_(br_cnv) 239 | tl_tag, br_tag = tl_tag_(tl_cnv), br_tag_(br_cnv) 240 | tl_regr, br_regr = tl_regr_(tl_cnv), br_regr_(br_cnv) 241 | 242 | outs += [tl_heat, br_heat, tl_tag, br_tag, tl_regr, br_regr] 243 | 244 | if ind < self.nstack - 1: 245 | inter = self.inters_[ind](inter) + self.cnvs_[ind](cnv) 246 | inter = self.relu(inter) 247 | inter = self.inters[ind](inter) 248 | 249 | return self._decode(*outs[-6:], **kwargs) 250 | 251 | def forward(self, *xs, **kwargs): 252 | if len(xs) > 1: 253 | return self._train(*xs, **kwargs) 254 | return self._test(*xs, **kwargs) 255 | 256 | class AELoss(nn.Module): 257 | def __init__(self, pull_weight=1, push_weight=1, regr_weight=1, focal_loss=_neg_loss): 258 | super(AELoss, self).__init__() 259 | 260 | self.pull_weight = pull_weight 261 | self.push_weight = push_weight 262 | self.regr_weight = regr_weight 263 | self.focal_loss = focal_loss 264 | self.ae_loss = _ae_loss 265 | self.regr_loss = _regr_loss 266 | 267 | def forward(self, outs, targets): 268 | stride = 6 269 | 270 | tl_heats = outs[0::stride] 271 | br_heats = outs[1::stride] 272 | tl_tags = outs[2::stride] 273 | br_tags = outs[3::stride] 274 | tl_regrs = outs[4::stride] 275 | br_regrs = outs[5::stride] 276 | 277 | gt_tl_heat = targets[0] 278 | gt_br_heat = targets[1] 279 | gt_mask = targets[2] 280 | gt_tl_regr = targets[3] 281 | gt_br_regr = targets[4] 282 | 283 | # focal loss 284 | focal_loss = 0 285 | 286 | tl_heats = [_sigmoid(t) for t in tl_heats] 287 | br_heats = [_sigmoid(b) for b in br_heats] 288 | 289 | focal_loss += self.focal_loss(tl_heats, gt_tl_heat) 290 | focal_loss += self.focal_loss(br_heats, gt_br_heat) 291 | 292 | # tag loss 293 | pull_loss = 0 294 | push_loss = 0 295 | 296 | for tl_tag, br_tag in zip(tl_tags, br_tags): 297 | pull, push = self.ae_loss(tl_tag, br_tag, gt_mask) 298 | pull_loss += pull 299 | push_loss += push 300 | pull_loss = self.pull_weight * pull_loss 301 | push_loss = self.push_weight * push_loss 302 | 303 | regr_loss = 0 304 | for tl_regr, br_regr in zip(tl_regrs, br_regrs): 305 | regr_loss += self.regr_loss(tl_regr, gt_tl_regr, gt_mask) 306 | regr_loss += self.regr_loss(br_regr, gt_br_regr, gt_mask) 307 | regr_loss = self.regr_weight * regr_loss 308 | 309 | loss = (focal_loss + pull_loss + push_loss + regr_loss) / len(tl_heats) 310 | return loss.unsqueeze(0) 311 | -------------------------------------------------------------------------------- /models/py_utils/kp_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .utils import convolution, residual 5 | 6 | class MergeUp(nn.Module): 7 | def forward(self, up1, up2): 8 | return up1 + up2 9 | 10 | def make_merge_layer(dim): 11 | return MergeUp() 12 | 13 | def make_tl_layer(dim): 14 | return None 15 | 16 | def make_br_layer(dim): 17 | return None 18 | 19 | def make_pool_layer(dim): 20 | return nn.MaxPool2d(kernel_size=2, stride=2) 21 | 22 | def make_unpool_layer(dim): 23 | return nn.Upsample(scale_factor=2) 24 | 25 | def make_kp_layer(cnv_dim, curr_dim, out_dim): 26 | return nn.Sequential( 27 | convolution(3, cnv_dim, curr_dim, with_bn=False), 28 | nn.Conv2d(curr_dim, out_dim, (1, 1)) 29 | ) 30 | 31 | def make_inter_layer(dim): 32 | return residual(3, dim, dim) 33 | 34 | def make_cnv_layer(inp_dim, out_dim): 35 | return convolution(3, inp_dim, out_dim) 36 | 37 | def _gather_feat(feat, ind, mask=None): 38 | dim = feat.size(2) 39 | ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) 40 | feat = feat.gather(1, ind) 41 | if mask is not None: 42 | mask = mask.unsqueeze(2).expand_as(feat) 43 | feat = feat[mask] 44 | feat = feat.view(-1, dim) 45 | return feat 46 | 47 | def _nms(heat, kernel=1): 48 | pad = (kernel - 1) // 2 49 | 50 | hmax = nn.functional.max_pool2d(heat, (kernel, kernel), stride=1, padding=pad) 51 | keep = (hmax == heat).float() 52 | return heat * keep 53 | 54 | def _tranpose_and_gather_feat(feat, ind): 55 | feat = feat.permute(0, 2, 3, 1).contiguous() 56 | feat = feat.view(feat.size(0), -1, feat.size(3)) 57 | feat = _gather_feat(feat, ind) 58 | return feat 59 | 60 | def _topk(scores, K=20): 61 | batch, cat, height, width = scores.size() 62 | 63 | topk_scores, topk_inds = torch.topk(scores.view(batch, -1), K) 64 | 65 | topk_clses = (topk_inds / (height * width)).int() 66 | 67 | topk_inds = topk_inds % (height * width) 68 | topk_ys = (topk_inds / width).int().float() 69 | topk_xs = (topk_inds % width).int().float() 70 | return topk_scores, topk_inds, topk_clses, topk_ys, topk_xs 71 | 72 | def _decode( 73 | tl_heat, br_heat, tl_tag, br_tag, tl_regr, br_regr, 74 | K=100, kernel=1, ae_threshold=1, num_dets=1000 75 | ): 76 | batch, cat, height, width = tl_heat.size() 77 | 78 | tl_heat = torch.sigmoid(tl_heat) 79 | br_heat = torch.sigmoid(br_heat) 80 | 81 | # perform nms on heatmaps 82 | tl_heat = _nms(tl_heat, kernel=kernel) 83 | br_heat = _nms(br_heat, kernel=kernel) 84 | 85 | tl_scores, tl_inds, tl_clses, tl_ys, tl_xs = _topk(tl_heat, K=K) 86 | br_scores, br_inds, br_clses, br_ys, br_xs = _topk(br_heat, K=K) 87 | 88 | tl_ys = tl_ys.view(batch, K, 1).expand(batch, K, K) 89 | tl_xs = tl_xs.view(batch, K, 1).expand(batch, K, K) 90 | br_ys = br_ys.view(batch, 1, K).expand(batch, K, K) 91 | br_xs = br_xs.view(batch, 1, K).expand(batch, K, K) 92 | 93 | if tl_regr is not None and br_regr is not None: 94 | tl_regr = _tranpose_and_gather_feat(tl_regr, tl_inds) 95 | tl_regr = tl_regr.view(batch, K, 1, 2) 96 | br_regr = _tranpose_and_gather_feat(br_regr, br_inds) 97 | br_regr = br_regr.view(batch, 1, K, 2) 98 | 99 | tl_xs = tl_xs + tl_regr[..., 0] 100 | tl_ys = tl_ys + tl_regr[..., 1] 101 | br_xs = br_xs + br_regr[..., 0] 102 | br_ys = br_ys + br_regr[..., 1] 103 | 104 | # all possible boxes based on top k corners (ignoring class) 105 | bboxes = torch.stack((tl_xs, tl_ys, br_xs, br_ys), dim=3) 106 | 107 | tl_tag = _tranpose_and_gather_feat(tl_tag, tl_inds) 108 | tl_tag = tl_tag.view(batch, K, 1) 109 | br_tag = _tranpose_and_gather_feat(br_tag, br_inds) 110 | br_tag = br_tag.view(batch, 1, K) 111 | dists = torch.abs(tl_tag - br_tag) 112 | 113 | tl_scores = tl_scores.view(batch, K, 1).expand(batch, K, K) 114 | br_scores = br_scores.view(batch, 1, K).expand(batch, K, K) 115 | scores = (tl_scores + br_scores) / 2 116 | 117 | # reject boxes based on classes 118 | tl_clses = tl_clses.view(batch, K, 1).expand(batch, K, K) 119 | br_clses = br_clses.view(batch, 1, K).expand(batch, K, K) 120 | cls_inds = (tl_clses != br_clses) 121 | 122 | # reject boxes based on distances 123 | dist_inds = (dists > ae_threshold) 124 | 125 | # reject boxes based on widths and heights 126 | width_inds = (br_xs < tl_xs) 127 | height_inds = (br_ys < tl_ys) 128 | 129 | scores[cls_inds] = -1 130 | scores[dist_inds] = -1 131 | scores[width_inds] = -1 132 | scores[height_inds] = -1 133 | 134 | scores = scores.view(batch, -1) 135 | scores, inds = torch.topk(scores, num_dets) 136 | scores = scores.unsqueeze(2) 137 | 138 | bboxes = bboxes.view(batch, -1, 4) 139 | bboxes = _gather_feat(bboxes, inds) 140 | 141 | clses = tl_clses.contiguous().view(batch, -1, 1) 142 | clses = _gather_feat(clses, inds).float() 143 | 144 | tl_scores = tl_scores.contiguous().view(batch, -1, 1) 145 | tl_scores = _gather_feat(tl_scores, inds).float() 146 | br_scores = br_scores.contiguous().view(batch, -1, 1) 147 | br_scores = _gather_feat(br_scores, inds).float() 148 | 149 | detections = torch.cat([bboxes, scores, tl_scores, br_scores, clses], dim=2) 150 | return detections 151 | 152 | def _neg_loss(preds, gt): 153 | pos_inds = gt.eq(1) 154 | neg_inds = gt.lt(1) 155 | 156 | neg_weights = torch.pow(1 - gt[neg_inds], 4) 157 | 158 | loss = 0 159 | for pred in preds: 160 | pos_pred = pred[pos_inds] 161 | neg_pred = pred[neg_inds] 162 | 163 | pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, 2) 164 | neg_loss = torch.log(1 - neg_pred) * torch.pow(neg_pred, 2) * neg_weights 165 | 166 | num_pos = pos_inds.float().sum() 167 | pos_loss = pos_loss.sum() 168 | neg_loss = neg_loss.sum() 169 | 170 | if pos_pred.nelement() == 0: 171 | loss = loss - neg_loss 172 | else: 173 | loss = loss - (pos_loss + neg_loss) / num_pos 174 | return loss 175 | 176 | def _sigmoid(x): 177 | x = torch.clamp(x.sigmoid_(), min=1e-4, max=1-1e-4) 178 | return x 179 | 180 | def _ae_loss(tag0, tag1, mask): 181 | num = mask.sum(dim=1, keepdim=True).float() 182 | tag0 = tag0.squeeze() 183 | tag1 = tag1.squeeze() 184 | 185 | tag_mean = (tag0 + tag1) / 2 186 | 187 | tag0 = torch.pow(tag0 - tag_mean, 2) / (num + 1e-4) 188 | tag0 = tag0[mask].sum() 189 | tag1 = torch.pow(tag1 - tag_mean, 2) / (num + 1e-4) 190 | tag1 = tag1[mask].sum() 191 | pull = tag0 + tag1 192 | 193 | mask = mask.unsqueeze(1) + mask.unsqueeze(2) 194 | mask = mask.eq(2) 195 | num = num.unsqueeze(2) 196 | num2 = (num - 1) * num 197 | dist = tag_mean.unsqueeze(1) - tag_mean.unsqueeze(2) 198 | dist = 1 - torch.abs(dist) 199 | dist = nn.functional.relu(dist, inplace=True) 200 | dist = dist - 1 / (num + 1e-4) 201 | dist = dist / (num2 + 1e-4) 202 | dist = dist[mask] 203 | push = dist.sum() 204 | return pull, push 205 | 206 | def _regr_loss(regr, gt_regr, mask): 207 | num = mask.float().sum() 208 | mask = mask.unsqueeze(2).expand_as(gt_regr) 209 | 210 | regr = regr[mask] 211 | gt_regr = gt_regr[mask] 212 | 213 | regr_loss = nn.functional.smooth_l1_loss(regr, gt_regr, size_average=False) 214 | regr_loss = regr_loss / (num + 1e-4) 215 | return regr_loss 216 | -------------------------------------------------------------------------------- /models/py_utils/scatter_gather.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torch.nn.parallel._functions import Scatter, Gather 4 | 5 | 6 | def scatter(inputs, target_gpus, dim=0, chunk_sizes=None): 7 | r""" 8 | Slices variables into approximately equal chunks and 9 | distributes them across given GPUs. Duplicates 10 | references to objects that are not variables. Does not 11 | support Tensors. 12 | """ 13 | def scatter_map(obj): 14 | if isinstance(obj, Variable): 15 | return Scatter.apply(target_gpus, chunk_sizes, dim, obj) 16 | assert not torch.is_tensor(obj), "Tensors not supported in scatter." 17 | if isinstance(obj, tuple): 18 | return list(zip(*map(scatter_map, obj))) 19 | if isinstance(obj, list): 20 | return list(map(list, zip(*map(scatter_map, obj)))) 21 | if isinstance(obj, dict): 22 | return list(map(type(obj), zip(*map(scatter_map, obj.items())))) 23 | return [obj for targets in target_gpus] 24 | 25 | return scatter_map(inputs) 26 | 27 | 28 | def scatter_kwargs(inputs, kwargs, target_gpus, dim=0, chunk_sizes=None): 29 | r"""Scatter with support for kwargs dictionary""" 30 | inputs = scatter(inputs, target_gpus, dim, chunk_sizes) if inputs else [] 31 | kwargs = scatter(kwargs, target_gpus, dim, chunk_sizes) if kwargs else [] 32 | if len(inputs) < len(kwargs): 33 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 34 | elif len(kwargs) < len(inputs): 35 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 36 | inputs = tuple(inputs) 37 | kwargs = tuple(kwargs) 38 | return inputs, kwargs 39 | -------------------------------------------------------------------------------- /models/py_utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class convolution(nn.Module): 5 | def __init__(self, k, inp_dim, out_dim, stride=1, with_bn=True): 6 | super(convolution, self).__init__() 7 | 8 | pad = (k - 1) // 2 9 | self.conv = nn.Conv2d(inp_dim, out_dim, (k, k), padding=(pad, pad), stride=(stride, stride), bias=not with_bn) 10 | self.bn = nn.BatchNorm2d(out_dim) if with_bn else nn.Sequential() 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | conv = self.conv(x) 15 | bn = self.bn(conv) 16 | relu = self.relu(bn) 17 | return relu 18 | 19 | class fully_connected(nn.Module): 20 | def __init__(self, inp_dim, out_dim, with_bn=True): 21 | super(fully_connected, self).__init__() 22 | self.with_bn = with_bn 23 | 24 | self.linear = nn.Linear(inp_dim, out_dim) 25 | if self.with_bn: 26 | self.bn = nn.BatchNorm1d(out_dim) 27 | self.relu = nn.ReLU(inplace=True) 28 | 29 | def forward(self, x): 30 | linear = self.linear(x) 31 | bn = self.bn(linear) if self.with_bn else linear 32 | relu = self.relu(bn) 33 | return relu 34 | 35 | class residual(nn.Module): 36 | def __init__(self, k, inp_dim, out_dim, stride=1, with_bn=True): 37 | super(residual, self).__init__() 38 | 39 | self.conv1 = nn.Conv2d(inp_dim, out_dim, (3, 3), padding=(1, 1), stride=(stride, stride), bias=False) 40 | self.bn1 = nn.BatchNorm2d(out_dim) 41 | self.relu1 = nn.ReLU(inplace=True) 42 | 43 | self.conv2 = nn.Conv2d(out_dim, out_dim, (3, 3), padding=(1, 1), bias=False) 44 | self.bn2 = nn.BatchNorm2d(out_dim) 45 | 46 | self.skip = nn.Sequential( 47 | nn.Conv2d(inp_dim, out_dim, (1, 1), stride=(stride, stride), bias=False), 48 | nn.BatchNorm2d(out_dim) 49 | ) if stride != 1 or inp_dim != out_dim else nn.Sequential() 50 | self.relu = nn.ReLU(inplace=True) 51 | 52 | def forward(self, x): 53 | conv1 = self.conv1(x) 54 | bn1 = self.bn1(conv1) 55 | relu1 = self.relu1(bn1) 56 | 57 | conv2 = self.conv2(relu1) 58 | bn2 = self.bn2(conv2) 59 | 60 | skip = self.skip(x) 61 | return self.relu(bn2 + skip) 62 | 63 | def make_layer(k, inp_dim, out_dim, modules, layer=convolution, **kwargs): 64 | layers = [layer(k, inp_dim, out_dim, **kwargs)] 65 | for _ in range(1, modules): 66 | layers.append(layer(k, out_dim, out_dim, **kwargs)) 67 | return nn.Sequential(*layers) 68 | 69 | def make_layer_revr(k, inp_dim, out_dim, modules, layer=convolution, **kwargs): 70 | layers = [] 71 | for _ in range(modules - 1): 72 | layers.append(layer(k, inp_dim, inp_dim, **kwargs)) 73 | layers.append(layer(k, inp_dim, out_dim, **kwargs)) 74 | return nn.Sequential(*layers) 75 | -------------------------------------------------------------------------------- /nnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-vl/CornerNet/e5c39a31a8abef5841976c8eab18da86d6ee5f9a/nnet/__init__.py -------------------------------------------------------------------------------- /nnet/py_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import importlib 4 | import torch.nn as nn 5 | 6 | from config import system_configs 7 | from models.py_utils.data_parallel import DataParallel 8 | 9 | torch.manual_seed(317) 10 | 11 | class Network(nn.Module): 12 | def __init__(self, model, loss): 13 | super(Network, self).__init__() 14 | 15 | self.model = model 16 | self.loss = loss 17 | 18 | def forward(self, xs, ys, **kwargs): 19 | preds = self.model(*xs, **kwargs) 20 | loss = self.loss(preds, ys, **kwargs) 21 | return loss 22 | 23 | # for model backward compatibility 24 | # previously model was wrapped by DataParallel module 25 | class DummyModule(nn.Module): 26 | def __init__(self, model): 27 | super(DummyModule, self).__init__() 28 | self.module = model 29 | 30 | def forward(self, *xs, **kwargs): 31 | return self.module(*xs, **kwargs) 32 | 33 | class NetworkFactory(object): 34 | def __init__(self, db): 35 | super(NetworkFactory, self).__init__() 36 | 37 | module_file = "models.{}".format(system_configs.snapshot_name) 38 | print("module_file: {}".format(module_file)) 39 | nnet_module = importlib.import_module(module_file) 40 | 41 | self.model = DummyModule(nnet_module.model(db)) 42 | self.loss = nnet_module.loss 43 | self.network = Network(self.model, self.loss) 44 | self.network = DataParallel(self.network, chunk_sizes=system_configs.chunk_sizes) 45 | 46 | total_params = 0 47 | for params in self.model.parameters(): 48 | num_params = 1 49 | for x in params.size(): 50 | num_params *= x 51 | total_params += num_params 52 | print("total parameters: {}".format(total_params)) 53 | 54 | if system_configs.opt_algo == "adam": 55 | self.optimizer = torch.optim.Adam( 56 | filter(lambda p: p.requires_grad, self.model.parameters()) 57 | ) 58 | elif system_configs.opt_algo == "sgd": 59 | self.optimizer = torch.optim.SGD( 60 | filter(lambda p: p.requires_grad, self.model.parameters()), 61 | lr=system_configs.learning_rate, 62 | momentum=0.9, weight_decay=0.0001 63 | ) 64 | else: 65 | raise ValueError("unknown optimizer") 66 | 67 | def cuda(self): 68 | self.model.cuda() 69 | 70 | def train_mode(self): 71 | self.network.train() 72 | 73 | def eval_mode(self): 74 | self.network.eval() 75 | 76 | def train(self, xs, ys, **kwargs): 77 | xs = [x.cuda(non_blocking=True) for x in xs] 78 | ys = [y.cuda(non_blocking=True) for y in ys] 79 | 80 | self.optimizer.zero_grad() 81 | loss = self.network(xs, ys) 82 | loss = loss.mean() 83 | loss.backward() 84 | self.optimizer.step() 85 | return loss 86 | 87 | def validate(self, xs, ys, **kwargs): 88 | with torch.no_grad(): 89 | xs = [x.cuda(non_blocking=True) for x in xs] 90 | ys = [y.cuda(non_blocking=True) for y in ys] 91 | 92 | loss = self.network(xs, ys) 93 | loss = loss.mean() 94 | return loss 95 | 96 | def test(self, xs, **kwargs): 97 | with torch.no_grad(): 98 | xs = [x.cuda(non_blocking=True) for x in xs] 99 | return self.model(*xs, **kwargs) 100 | 101 | def set_lr(self, lr): 102 | print("setting learning rate to: {}".format(lr)) 103 | for param_group in self.optimizer.param_groups: 104 | param_group["lr"] = lr 105 | 106 | def load_pretrained_params(self, pretrained_model): 107 | print("loading from {}".format(pretrained_model)) 108 | with open(pretrained_model, "rb") as f: 109 | params = torch.load(f) 110 | self.model.load_state_dict(params) 111 | 112 | def load_params(self, iteration): 113 | cache_file = system_configs.snapshot_file.format(iteration) 114 | print("loading model from {}".format(cache_file)) 115 | with open(cache_file, "rb") as f: 116 | params = torch.load(f) 117 | self.model.load_state_dict(params) 118 | 119 | def save_params(self, iteration): 120 | cache_file = system_configs.snapshot_file.format(iteration) 121 | print("saving model to {}".format(cache_file)) 122 | with open(cache_file, "wb") as f: 123 | params = self.model.state_dict() 124 | torch.save(params, f) 125 | -------------------------------------------------------------------------------- /sample/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-vl/CornerNet/e5c39a31a8abef5841976c8eab18da86d6ee5f9a/sample/__init__.py -------------------------------------------------------------------------------- /sample/coco.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import torch 5 | import random 6 | import string 7 | 8 | from config import system_configs 9 | from utils import crop_image, normalize_, color_jittering_, lighting_ 10 | from .utils import random_crop, draw_gaussian, gaussian_radius 11 | 12 | def _full_image_crop(image, detections): 13 | detections = detections.copy() 14 | height, width = image.shape[0:2] 15 | 16 | max_hw = max(height, width) 17 | center = [height // 2, width // 2] 18 | size = [max_hw, max_hw] 19 | 20 | image, border, offset = crop_image(image, center, size) 21 | detections[:, 0:4:2] += border[2] 22 | detections[:, 1:4:2] += border[0] 23 | return image, detections 24 | 25 | def _resize_image(image, detections, size): 26 | detections = detections.copy() 27 | height, width = image.shape[0:2] 28 | new_height, new_width = size 29 | 30 | image = cv2.resize(image, (new_width, new_height)) 31 | 32 | height_ratio = new_height / height 33 | width_ratio = new_width / width 34 | detections[:, 0:4:2] *= width_ratio 35 | detections[:, 1:4:2] *= height_ratio 36 | return image, detections 37 | 38 | def _clip_detections(image, detections): 39 | detections = detections.copy() 40 | height, width = image.shape[0:2] 41 | 42 | detections[:, 0:4:2] = np.clip(detections[:, 0:4:2], 0, width - 1) 43 | detections[:, 1:4:2] = np.clip(detections[:, 1:4:2], 0, height - 1) 44 | keep_inds = ((detections[:, 2] - detections[:, 0]) > 0) & \ 45 | ((detections[:, 3] - detections[:, 1]) > 0) 46 | detections = detections[keep_inds] 47 | return detections 48 | 49 | def kp_detection(db, k_ind, data_aug, debug): 50 | data_rng = system_configs.data_rng 51 | batch_size = system_configs.batch_size 52 | 53 | categories = db.configs["categories"] 54 | input_size = db.configs["input_size"] 55 | output_size = db.configs["output_sizes"][0] 56 | 57 | border = db.configs["border"] 58 | lighting = db.configs["lighting"] 59 | rand_crop = db.configs["rand_crop"] 60 | rand_color = db.configs["rand_color"] 61 | rand_scales = db.configs["rand_scales"] 62 | gaussian_bump = db.configs["gaussian_bump"] 63 | gaussian_iou = db.configs["gaussian_iou"] 64 | gaussian_rad = db.configs["gaussian_radius"] 65 | 66 | max_tag_len = 128 67 | 68 | # allocating memory 69 | images = np.zeros((batch_size, 3, input_size[0], input_size[1]), dtype=np.float32) 70 | tl_heatmaps = np.zeros((batch_size, categories, output_size[0], output_size[1]), dtype=np.float32) 71 | br_heatmaps = np.zeros((batch_size, categories, output_size[0], output_size[1]), dtype=np.float32) 72 | tl_regrs = np.zeros((batch_size, max_tag_len, 2), dtype=np.float32) 73 | br_regrs = np.zeros((batch_size, max_tag_len, 2), dtype=np.float32) 74 | tl_tags = np.zeros((batch_size, max_tag_len), dtype=np.int64) 75 | br_tags = np.zeros((batch_size, max_tag_len), dtype=np.int64) 76 | tag_masks = np.zeros((batch_size, max_tag_len), dtype=np.uint8) 77 | tag_lens = np.zeros((batch_size, ), dtype=np.int32) 78 | 79 | db_size = db.db_inds.size 80 | for b_ind in range(batch_size): 81 | if not debug and k_ind == 0: 82 | db.shuffle_inds() 83 | 84 | db_ind = db.db_inds[k_ind] 85 | k_ind = (k_ind + 1) % db_size 86 | 87 | # reading image 88 | image_file = db.image_file(db_ind) 89 | image = cv2.imread(image_file) 90 | 91 | # reading detections 92 | detections = db.detections(db_ind) 93 | 94 | # cropping an image randomly 95 | if not debug and rand_crop: 96 | image, detections = random_crop(image, detections, rand_scales, input_size, border=border) 97 | else: 98 | image, detections = _full_image_crop(image, detections) 99 | 100 | image, detections = _resize_image(image, detections, input_size) 101 | detections = _clip_detections(image, detections) 102 | 103 | width_ratio = output_size[1] / input_size[1] 104 | height_ratio = output_size[0] / input_size[0] 105 | 106 | # flipping an image randomly 107 | if not debug and np.random.uniform() > 0.5: 108 | image[:] = image[:, ::-1, :] 109 | width = image.shape[1] 110 | detections[:, [0, 2]] = width - detections[:, [2, 0]] - 1 111 | 112 | if not debug: 113 | image = image.astype(np.float32) / 255. 114 | if rand_color: 115 | color_jittering_(data_rng, image) 116 | if lighting: 117 | lighting_(data_rng, image, 0.1, db.eig_val, db.eig_vec) 118 | normalize_(image, db.mean, db.std) 119 | images[b_ind] = image.transpose((2, 0, 1)) 120 | 121 | for ind, detection in enumerate(detections): 122 | category = int(detection[-1]) - 1 123 | 124 | xtl, ytl = detection[0], detection[1] 125 | xbr, ybr = detection[2], detection[3] 126 | 127 | fxtl = (xtl * width_ratio) 128 | fytl = (ytl * height_ratio) 129 | fxbr = (xbr * width_ratio) 130 | fybr = (ybr * height_ratio) 131 | 132 | xtl = int(fxtl) 133 | ytl = int(fytl) 134 | xbr = int(fxbr) 135 | ybr = int(fybr) 136 | 137 | if gaussian_bump: 138 | width = detection[2] - detection[0] 139 | height = detection[3] - detection[1] 140 | 141 | width = math.ceil(width * width_ratio) 142 | height = math.ceil(height * height_ratio) 143 | 144 | if gaussian_rad == -1: 145 | radius = gaussian_radius((height, width), gaussian_iou) 146 | radius = max(0, int(radius)) 147 | else: 148 | radius = gaussian_rad 149 | 150 | draw_gaussian(tl_heatmaps[b_ind, category], [xtl, ytl], radius) 151 | draw_gaussian(br_heatmaps[b_ind, category], [xbr, ybr], radius) 152 | else: 153 | tl_heatmaps[b_ind, category, ytl, xtl] = 1 154 | br_heatmaps[b_ind, category, ybr, xbr] = 1 155 | 156 | tag_ind = tag_lens[b_ind] 157 | tl_regrs[b_ind, tag_ind, :] = [fxtl - xtl, fytl - ytl] 158 | br_regrs[b_ind, tag_ind, :] = [fxbr - xbr, fybr - ybr] 159 | tl_tags[b_ind, tag_ind] = ytl * output_size[1] + xtl 160 | br_tags[b_ind, tag_ind] = ybr * output_size[1] + xbr 161 | tag_lens[b_ind] += 1 162 | 163 | for b_ind in range(batch_size): 164 | tag_len = tag_lens[b_ind] 165 | tag_masks[b_ind, :tag_len] = 1 166 | 167 | images = torch.from_numpy(images) 168 | tl_heatmaps = torch.from_numpy(tl_heatmaps) 169 | br_heatmaps = torch.from_numpy(br_heatmaps) 170 | tl_regrs = torch.from_numpy(tl_regrs) 171 | br_regrs = torch.from_numpy(br_regrs) 172 | tl_tags = torch.from_numpy(tl_tags) 173 | br_tags = torch.from_numpy(br_tags) 174 | tag_masks = torch.from_numpy(tag_masks) 175 | 176 | return { 177 | "xs": [images, tl_tags, br_tags], 178 | "ys": [tl_heatmaps, br_heatmaps, tag_masks, tl_regrs, br_regrs] 179 | }, k_ind 180 | 181 | def sample_data(db, k_ind, data_aug=True, debug=False): 182 | return globals()[system_configs.sampling_function](db, k_ind, data_aug, debug) 183 | -------------------------------------------------------------------------------- /sample/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | def gaussian2D(shape, sigma=1): 5 | m, n = [(ss - 1.) / 2. for ss in shape] 6 | y, x = np.ogrid[-m:m+1,-n:n+1] 7 | 8 | h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) 9 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 10 | return h 11 | 12 | def draw_gaussian(heatmap, center, radius, k=1): 13 | diameter = 2 * radius + 1 14 | gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6) 15 | 16 | x, y = center 17 | 18 | height, width = heatmap.shape[0:2] 19 | 20 | left, right = min(x, radius), min(width - x, radius + 1) 21 | top, bottom = min(y, radius), min(height - y, radius + 1) 22 | 23 | masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right] 24 | masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right] 25 | np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap) 26 | 27 | def gaussian_radius(det_size, min_overlap): 28 | height, width = det_size 29 | 30 | a1 = 1 31 | b1 = (height + width) 32 | c1 = width * height * (1 - min_overlap) / (1 + min_overlap) 33 | sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1) 34 | r1 = (b1 - sq1) / (2 * a1) 35 | 36 | a2 = 4 37 | b2 = 2 * (height + width) 38 | c2 = (1 - min_overlap) * width * height 39 | sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2) 40 | r2 = (b2 - sq2) / (2 * a2) 41 | 42 | a3 = 4 * min_overlap 43 | b3 = -2 * min_overlap * (height + width) 44 | c3 = (min_overlap - 1) * width * height 45 | sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3) 46 | r3 = (b3 + sq3) / (2 * a3) 47 | return min(r1, r2, r3) 48 | 49 | def _get_border(border, size): 50 | i = 1 51 | while size - border // i <= border // i: 52 | i *= 2 53 | return border // i 54 | 55 | def random_crop(image, detections, random_scales, view_size, border=64): 56 | view_height, view_width = view_size 57 | image_height, image_width = image.shape[0:2] 58 | 59 | scale = np.random.choice(random_scales) 60 | height = int(view_height * scale) 61 | width = int(view_width * scale) 62 | 63 | cropped_image = np.zeros((height, width, 3), dtype=image.dtype) 64 | 65 | w_border = _get_border(border, image_width) 66 | h_border = _get_border(border, image_height) 67 | 68 | ctx = np.random.randint(low=w_border, high=image_width - w_border) 69 | cty = np.random.randint(low=h_border, high=image_height - h_border) 70 | 71 | x0, x1 = max(ctx - width // 2, 0), min(ctx + width // 2, image_width) 72 | y0, y1 = max(cty - height // 2, 0), min(cty + height // 2, image_height) 73 | 74 | left_w, right_w = ctx - x0, x1 - ctx 75 | top_h, bottom_h = cty - y0, y1 - cty 76 | 77 | # crop image 78 | cropped_ctx, cropped_cty = width // 2, height // 2 79 | x_slice = slice(cropped_ctx - left_w, cropped_ctx + right_w) 80 | y_slice = slice(cropped_cty - top_h, cropped_cty + bottom_h) 81 | cropped_image[y_slice, x_slice, :] = image[y0:y1, x0:x1, :] 82 | 83 | # crop detections 84 | cropped_detections = detections.copy() 85 | cropped_detections[:, 0:4:2] -= x0 86 | cropped_detections[:, 1:4:2] -= y0 87 | cropped_detections[:, 0:4:2] += cropped_ctx - left_w 88 | cropped_detections[:, 1:4:2] += cropped_cty - top_h 89 | 90 | return cropped_image, cropped_detections 91 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import json 4 | import torch 5 | import pprint 6 | import argparse 7 | import importlib 8 | import numpy as np 9 | 10 | import matplotlib 11 | matplotlib.use("Agg") 12 | 13 | from config import system_configs 14 | from nnet.py_factory import NetworkFactory 15 | from db.datasets import datasets 16 | 17 | torch.backends.cudnn.benchmark = False 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser(description="Test CornerNet") 21 | parser.add_argument("cfg_file", help="config file", type=str) 22 | parser.add_argument("--testiter", dest="testiter", 23 | help="test at iteration i", 24 | default=None, type=int) 25 | parser.add_argument("--split", dest="split", 26 | help="which split to use", 27 | default="validation", type=str) 28 | parser.add_argument("--suffix", dest="suffix", default=None, type=str) 29 | parser.add_argument("--debug", action="store_true") 30 | 31 | args = parser.parse_args() 32 | return args 33 | 34 | def make_dirs(directories): 35 | for directory in directories: 36 | if not os.path.exists(directory): 37 | os.makedirs(directory) 38 | 39 | def test(db, split, testiter, debug=False, suffix=None): 40 | result_dir = system_configs.result_dir 41 | result_dir = os.path.join(result_dir, str(testiter), split) 42 | 43 | if suffix is not None: 44 | result_dir = os.path.join(result_dir, suffix) 45 | 46 | make_dirs([result_dir]) 47 | 48 | test_iter = system_configs.max_iter if testiter is None else testiter 49 | print("loading parameters at iteration: {}".format(test_iter)) 50 | 51 | print("building neural network...") 52 | nnet = NetworkFactory(db) 53 | print("loading parameters...") 54 | nnet.load_params(test_iter) 55 | 56 | test_file = "test.{}".format(db.data) 57 | testing = importlib.import_module(test_file).testing 58 | 59 | nnet.cuda() 60 | nnet.eval_mode() 61 | testing(db, nnet, result_dir, debug=debug) 62 | 63 | if __name__ == "__main__": 64 | args = parse_args() 65 | 66 | if args.suffix is None: 67 | cfg_file = os.path.join(system_configs.config_dir, args.cfg_file + ".json") 68 | else: 69 | cfg_file = os.path.join(system_configs.config_dir, args.cfg_file + "-{}.json".format(args.suffix)) 70 | print("cfg_file: {}".format(cfg_file)) 71 | 72 | with open(cfg_file, "r") as f: 73 | configs = json.load(f) 74 | 75 | configs["system"]["snapshot_name"] = args.cfg_file 76 | system_configs.update_config(configs["system"]) 77 | 78 | train_split = system_configs.train_split 79 | val_split = system_configs.val_split 80 | test_split = system_configs.test_split 81 | 82 | split = { 83 | "training": train_split, 84 | "validation": val_split, 85 | "testing": test_split 86 | }[args.split] 87 | 88 | print("loading all datasets...") 89 | dataset = system_configs.dataset 90 | print("split: {}".format(split)) 91 | testing_db = datasets[dataset](configs["db"], split) 92 | 93 | print("system config...") 94 | pprint.pprint(system_configs.full) 95 | 96 | print("db config...") 97 | pprint.pprint(testing_db.configs) 98 | 99 | test(testing_db, args.split, args.testiter, args.debug, args.suffix) 100 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-vl/CornerNet/e5c39a31a8abef5841976c8eab18da86d6ee5f9a/test/__init__.py -------------------------------------------------------------------------------- /test/coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import json 4 | import numpy as np 5 | import torch 6 | import matplotlib.pyplot as plt 7 | 8 | from tqdm import tqdm 9 | from config import system_configs 10 | from utils import crop_image, normalize_ 11 | from external.nms import soft_nms, soft_nms_merge 12 | 13 | def _rescale_dets(detections, ratios, borders, sizes): 14 | xs, ys = detections[..., 0:4:2], detections[..., 1:4:2] 15 | xs /= ratios[:, 1][:, None, None] 16 | ys /= ratios[:, 0][:, None, None] 17 | xs -= borders[:, 2][:, None, None] 18 | ys -= borders[:, 0][:, None, None] 19 | np.clip(xs, 0, sizes[:, 1][:, None, None], out=xs) 20 | np.clip(ys, 0, sizes[:, 0][:, None, None], out=ys) 21 | 22 | def save_image(data, fn): 23 | sizes = np.shape(data) 24 | height = float(sizes[0]) 25 | width = float(sizes[1]) 26 | 27 | fig = plt.figure() 28 | fig.set_size_inches(width/height, 1, forward=False) 29 | ax = plt.Axes(fig, [0., 0., 1., 1.]) 30 | ax.set_axis_off() 31 | fig.add_axes(ax) 32 | 33 | ax.imshow(data) 34 | plt.savefig(fn, dpi = height) 35 | plt.close() 36 | 37 | def kp_decode(nnet, images, K, ae_threshold=0.5, kernel=3): 38 | detections = nnet.test([images], ae_threshold=ae_threshold, K=K, kernel=kernel) 39 | detections = detections.data.cpu().numpy() 40 | return detections 41 | 42 | def kp_detection(db, nnet, result_dir, debug=False, decode_func=kp_decode): 43 | debug_dir = os.path.join(result_dir, "debug") 44 | if not os.path.exists(debug_dir): 45 | os.makedirs(debug_dir) 46 | 47 | if db.split != "trainval": 48 | db_inds = db.db_inds[:100] if debug else db.db_inds 49 | else: 50 | db_inds = db.db_inds[:100] if debug else db.db_inds[:5000] 51 | num_images = db_inds.size 52 | 53 | K = db.configs["top_k"] 54 | ae_threshold = db.configs["ae_threshold"] 55 | nms_kernel = db.configs["nms_kernel"] 56 | 57 | scales = db.configs["test_scales"] 58 | weight_exp = db.configs["weight_exp"] 59 | merge_bbox = db.configs["merge_bbox"] 60 | categories = db.configs["categories"] 61 | nms_threshold = db.configs["nms_threshold"] 62 | max_per_image = db.configs["max_per_image"] 63 | nms_algorithm = { 64 | "nms": 0, 65 | "linear_soft_nms": 1, 66 | "exp_soft_nms": 2 67 | }[db.configs["nms_algorithm"]] 68 | 69 | top_bboxes = {} 70 | for ind in tqdm(range(0, num_images), ncols=80, desc="locating kps"): 71 | db_ind = db_inds[ind] 72 | 73 | image_id = db.image_ids(db_ind) 74 | image_file = db.image_file(db_ind) 75 | image = cv2.imread(image_file) 76 | 77 | height, width = image.shape[0:2] 78 | 79 | detections = [] 80 | 81 | for scale in scales: 82 | new_height = int(height * scale) 83 | new_width = int(width * scale) 84 | new_center = np.array([new_height // 2, new_width // 2]) 85 | 86 | inp_height = new_height | 127 87 | inp_width = new_width | 127 88 | 89 | images = np.zeros((1, 3, inp_height, inp_width), dtype=np.float32) 90 | ratios = np.zeros((1, 2), dtype=np.float32) 91 | borders = np.zeros((1, 4), dtype=np.float32) 92 | sizes = np.zeros((1, 2), dtype=np.float32) 93 | 94 | out_height, out_width = (inp_height + 1) // 4, (inp_width + 1) // 4 95 | height_ratio = out_height / inp_height 96 | width_ratio = out_width / inp_width 97 | 98 | resized_image = cv2.resize(image, (new_width, new_height)) 99 | resized_image, border, offset = crop_image(resized_image, new_center, [inp_height, inp_width]) 100 | 101 | resized_image = resized_image / 255. 102 | normalize_(resized_image, db.mean, db.std) 103 | 104 | images[0] = resized_image.transpose((2, 0, 1)) 105 | borders[0] = border 106 | sizes[0] = [int(height * scale), int(width * scale)] 107 | ratios[0] = [height_ratio, width_ratio] 108 | 109 | images = np.concatenate((images, images[:, :, :, ::-1]), axis=0) 110 | images = torch.from_numpy(images) 111 | dets = decode_func(nnet, images, K, ae_threshold=ae_threshold, kernel=nms_kernel) 112 | dets = dets.reshape(2, -1, 8) 113 | dets[1, :, [0, 2]] = out_width - dets[1, :, [2, 0]] 114 | dets = dets.reshape(1, -1, 8) 115 | 116 | _rescale_dets(dets, ratios, borders, sizes) 117 | dets[:, :, 0:4] /= scale 118 | detections.append(dets) 119 | 120 | detections = np.concatenate(detections, axis=1) 121 | 122 | classes = detections[..., -1] 123 | classes = classes[0] 124 | detections = detections[0] 125 | 126 | # reject detections with negative scores 127 | keep_inds = (detections[:, 4] > -1) 128 | detections = detections[keep_inds] 129 | classes = classes[keep_inds] 130 | 131 | top_bboxes[image_id] = {} 132 | for j in range(categories): 133 | keep_inds = (classes == j) 134 | top_bboxes[image_id][j + 1] = detections[keep_inds][:, 0:7].astype(np.float32) 135 | if merge_bbox: 136 | soft_nms_merge(top_bboxes[image_id][j + 1], Nt=nms_threshold, method=nms_algorithm, weight_exp=weight_exp) 137 | else: 138 | soft_nms(top_bboxes[image_id][j + 1], Nt=nms_threshold, method=nms_algorithm) 139 | top_bboxes[image_id][j + 1] = top_bboxes[image_id][j + 1][:, 0:5] 140 | 141 | scores = np.hstack([ 142 | top_bboxes[image_id][j][:, -1] 143 | for j in range(1, categories + 1) 144 | ]) 145 | if len(scores) > max_per_image: 146 | kth = len(scores) - max_per_image 147 | thresh = np.partition(scores, kth)[kth] 148 | for j in range(1, categories + 1): 149 | keep_inds = (top_bboxes[image_id][j][:, -1] >= thresh) 150 | top_bboxes[image_id][j] = top_bboxes[image_id][j][keep_inds] 151 | 152 | if debug: 153 | image_file = db.image_file(db_ind) 154 | image = cv2.imread(image_file) 155 | 156 | bboxes = {} 157 | for j in range(1, categories + 1): 158 | keep_inds = (top_bboxes[image_id][j][:, -1] > 0.5) 159 | cat_name = db.class_name(j) 160 | cat_size = cv2.getTextSize(cat_name, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0] 161 | color = np.random.random((3, )) * 0.6 + 0.4 162 | color = color * 255 163 | color = color.astype(np.int32).tolist() 164 | for bbox in top_bboxes[image_id][j][keep_inds]: 165 | bbox = bbox[0:4].astype(np.int32) 166 | if bbox[1] - cat_size[1] - 2 < 0: 167 | cv2.rectangle(image, 168 | (bbox[0], bbox[1] + 2), 169 | (bbox[0] + cat_size[0], bbox[1] + cat_size[1] + 2), 170 | color, -1 171 | ) 172 | cv2.putText(image, cat_name, 173 | (bbox[0], bbox[1] + cat_size[1] + 2), 174 | cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), thickness=1 175 | ) 176 | else: 177 | cv2.rectangle(image, 178 | (bbox[0], bbox[1] - cat_size[1] - 2), 179 | (bbox[0] + cat_size[0], bbox[1] - 2), 180 | color, -1 181 | ) 182 | cv2.putText(image, cat_name, 183 | (bbox[0], bbox[1] - 2), 184 | cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), thickness=1 185 | ) 186 | cv2.rectangle(image, 187 | (bbox[0], bbox[1]), 188 | (bbox[2], bbox[3]), 189 | color, 2 190 | ) 191 | debug_file = os.path.join(debug_dir, "{}.jpg".format(db_ind)) 192 | 193 | result_json = os.path.join(result_dir, "results.json") 194 | detections = db.convert_to_coco(top_bboxes) 195 | with open(result_json, "w") as f: 196 | json.dump(detections, f) 197 | 198 | cls_ids = list(range(1, categories + 1)) 199 | image_ids = [db.image_ids(ind) for ind in db_inds] 200 | db.evaluate(result_json, cls_ids, image_ids) 201 | return 0 202 | 203 | def testing(db, nnet, result_dir, debug=False): 204 | return globals()[system_configs.sampling_function](db, nnet, result_dir, debug=debug) 205 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | 4 | import json 5 | import torch 6 | import numpy as np 7 | import queue 8 | import pprint 9 | import random 10 | import argparse 11 | import importlib 12 | import threading 13 | import traceback 14 | 15 | from tqdm import tqdm 16 | from utils import stdout_to_tqdm 17 | from config import system_configs 18 | from nnet.py_factory import NetworkFactory 19 | from torch.multiprocessing import Process, Queue, Pool 20 | from db.datasets import datasets 21 | 22 | torch.backends.cudnn.enabled = True 23 | torch.backends.cudnn.benchmark = True 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser(description="Train CornerNet") 27 | parser.add_argument("cfg_file", help="config file", type=str) 28 | parser.add_argument("--iter", dest="start_iter", 29 | help="train at iteration i", 30 | default=0, type=int) 31 | parser.add_argument("--threads", dest="threads", default=4, type=int) 32 | 33 | args = parser.parse_args() 34 | return args 35 | 36 | def prefetch_data(db, queue, sample_data, data_aug): 37 | ind = 0 38 | print("start prefetching data...") 39 | np.random.seed(os.getpid()) 40 | while True: 41 | try: 42 | data, ind = sample_data(db, ind, data_aug=data_aug) 43 | queue.put(data) 44 | except Exception as e: 45 | traceback.print_exc() 46 | raise e 47 | 48 | def pin_memory(data_queue, pinned_data_queue, sema): 49 | while True: 50 | data = data_queue.get() 51 | 52 | data["xs"] = [x.pin_memory() for x in data["xs"]] 53 | data["ys"] = [y.pin_memory() for y in data["ys"]] 54 | 55 | pinned_data_queue.put(data) 56 | 57 | if sema.acquire(blocking=False): 58 | return 59 | 60 | def init_parallel_jobs(dbs, queue, fn, data_aug): 61 | tasks = [Process(target=prefetch_data, args=(db, queue, fn, data_aug)) for db in dbs] 62 | for task in tasks: 63 | task.daemon = True 64 | task.start() 65 | return tasks 66 | 67 | def train(training_dbs, validation_db, start_iter=0): 68 | learning_rate = system_configs.learning_rate 69 | max_iteration = system_configs.max_iter 70 | pretrained_model = system_configs.pretrain 71 | snapshot = system_configs.snapshot 72 | val_iter = system_configs.val_iter 73 | display = system_configs.display 74 | decay_rate = system_configs.decay_rate 75 | stepsize = system_configs.stepsize 76 | 77 | # getting the size of each database 78 | training_size = len(training_dbs[0].db_inds) 79 | validation_size = len(validation_db.db_inds) 80 | 81 | # queues storing data for training 82 | training_queue = Queue(system_configs.prefetch_size) 83 | validation_queue = Queue(5) 84 | 85 | # queues storing pinned data for training 86 | pinned_training_queue = queue.Queue(system_configs.prefetch_size) 87 | pinned_validation_queue = queue.Queue(5) 88 | 89 | # load data sampling function 90 | data_file = "sample.{}".format(training_dbs[0].data) 91 | sample_data = importlib.import_module(data_file).sample_data 92 | 93 | # allocating resources for parallel reading 94 | training_tasks = init_parallel_jobs(training_dbs, training_queue, sample_data, True) 95 | if val_iter: 96 | validation_tasks = init_parallel_jobs([validation_db], validation_queue, sample_data, False) 97 | 98 | training_pin_semaphore = threading.Semaphore() 99 | validation_pin_semaphore = threading.Semaphore() 100 | training_pin_semaphore.acquire() 101 | validation_pin_semaphore.acquire() 102 | 103 | training_pin_args = (training_queue, pinned_training_queue, training_pin_semaphore) 104 | training_pin_thread = threading.Thread(target=pin_memory, args=training_pin_args) 105 | training_pin_thread.daemon = True 106 | training_pin_thread.start() 107 | 108 | validation_pin_args = (validation_queue, pinned_validation_queue, validation_pin_semaphore) 109 | validation_pin_thread = threading.Thread(target=pin_memory, args=validation_pin_args) 110 | validation_pin_thread.daemon = True 111 | validation_pin_thread.start() 112 | 113 | print("building model...") 114 | nnet = NetworkFactory(training_dbs[0]) 115 | 116 | if pretrained_model is not None: 117 | if not os.path.exists(pretrained_model): 118 | raise ValueError("pretrained model does not exist") 119 | print("loading from pretrained model") 120 | nnet.load_pretrained_params(pretrained_model) 121 | 122 | if start_iter: 123 | learning_rate /= (decay_rate ** (start_iter // stepsize)) 124 | 125 | nnet.load_params(start_iter) 126 | nnet.set_lr(learning_rate) 127 | print("training starts from iteration {} with learning_rate {}".format(start_iter + 1, learning_rate)) 128 | else: 129 | nnet.set_lr(learning_rate) 130 | 131 | print("training start...") 132 | nnet.cuda() 133 | nnet.train_mode() 134 | with stdout_to_tqdm() as save_stdout: 135 | for iteration in tqdm(range(start_iter + 1, max_iteration + 1), file=save_stdout, ncols=80): 136 | training = pinned_training_queue.get(block=True) 137 | training_loss = nnet.train(**training) 138 | 139 | if display and iteration % display == 0: 140 | print("training loss at iteration {}: {}".format(iteration, training_loss.item())) 141 | del training_loss 142 | 143 | if val_iter and validation_db.db_inds.size and iteration % val_iter == 0: 144 | nnet.eval_mode() 145 | validation = pinned_validation_queue.get(block=True) 146 | validation_loss = nnet.validate(**validation) 147 | print("validation loss at iteration {}: {}".format(iteration, validation_loss.item())) 148 | nnet.train_mode() 149 | 150 | if iteration % snapshot == 0: 151 | nnet.save_params(iteration) 152 | 153 | if iteration % stepsize == 0: 154 | learning_rate /= decay_rate 155 | nnet.set_lr(learning_rate) 156 | 157 | # sending signal to kill the thread 158 | training_pin_semaphore.release() 159 | validation_pin_semaphore.release() 160 | 161 | # terminating data fetching processes 162 | for training_task in training_tasks: 163 | training_task.terminate() 164 | for validation_task in validation_tasks: 165 | validation_task.terminate() 166 | 167 | if __name__ == "__main__": 168 | args = parse_args() 169 | 170 | cfg_file = os.path.join(system_configs.config_dir, args.cfg_file + ".json") 171 | with open(cfg_file, "r") as f: 172 | configs = json.load(f) 173 | 174 | configs["system"]["snapshot_name"] = args.cfg_file 175 | system_configs.update_config(configs["system"]) 176 | 177 | train_split = system_configs.train_split 178 | val_split = system_configs.val_split 179 | 180 | print("loading all datasets...") 181 | dataset = system_configs.dataset 182 | # threads = max(torch.cuda.device_count() * 2, 4) 183 | threads = args.threads 184 | print("using {} threads".format(threads)) 185 | training_dbs = [datasets[dataset](configs["db"], train_split) for _ in range(threads)] 186 | validation_db = datasets[dataset](configs["db"], val_split) 187 | 188 | print("system config...") 189 | pprint.pprint(system_configs.full) 190 | 191 | print("db config...") 192 | pprint.pprint(training_dbs[0].configs) 193 | 194 | print("len of db: {}".format(len(training_dbs[0].db_inds))) 195 | train(training_dbs, validation_db, args.start_iter) 196 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .tqdm import stdout_to_tqdm 2 | 3 | from .image import crop_image 4 | from .image import color_jittering_, lighting_, normalize_ 5 | -------------------------------------------------------------------------------- /utils/image.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import random 4 | 5 | def grayscale(image): 6 | return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 7 | 8 | def normalize_(image, mean, std): 9 | image -= mean 10 | image /= std 11 | 12 | def lighting_(data_rng, image, alphastd, eigval, eigvec): 13 | alpha = data_rng.normal(scale=alphastd, size=(3, )) 14 | image += np.dot(eigvec, eigval * alpha) 15 | 16 | def blend_(alpha, image1, image2): 17 | image1 *= alpha 18 | image2 *= (1 - alpha) 19 | image1 += image2 20 | 21 | def saturation_(data_rng, image, gs, gs_mean, var): 22 | alpha = 1. + data_rng.uniform(low=-var, high=var) 23 | blend_(alpha, image, gs[:, :, None]) 24 | 25 | def brightness_(data_rng, image, gs, gs_mean, var): 26 | alpha = 1. + data_rng.uniform(low=-var, high=var) 27 | image *= alpha 28 | 29 | def contrast_(data_rng, image, gs, gs_mean, var): 30 | alpha = 1. + data_rng.uniform(low=-var, high=var) 31 | blend_(alpha, image, gs_mean) 32 | 33 | def color_jittering_(data_rng, image): 34 | functions = [brightness_, contrast_, saturation_] 35 | random.shuffle(functions) 36 | 37 | gs = grayscale(image) 38 | gs_mean = gs.mean() 39 | for f in functions: 40 | f(data_rng, image, gs, gs_mean, 0.4) 41 | 42 | def crop_image(image, center, size): 43 | cty, ctx = center 44 | height, width = size 45 | im_height, im_width = image.shape[0:2] 46 | cropped_image = np.zeros((height, width, 3), dtype=image.dtype) 47 | 48 | x0, x1 = max(0, ctx - width // 2), min(ctx + width // 2, im_width) 49 | y0, y1 = max(0, cty - height // 2), min(cty + height // 2, im_height) 50 | 51 | left, right = ctx - x0, x1 - ctx 52 | top, bottom = cty - y0, y1 - cty 53 | 54 | cropped_cty, cropped_ctx = height // 2, width // 2 55 | y_slice = slice(cropped_cty - top, cropped_cty + bottom) 56 | x_slice = slice(cropped_ctx - left, cropped_ctx + right) 57 | cropped_image[y_slice, x_slice, :] = image[y0:y1, x0:x1, :] 58 | 59 | border = np.array([ 60 | cropped_cty - top, 61 | cropped_cty + bottom, 62 | cropped_ctx - left, 63 | cropped_ctx + right 64 | ], dtype=np.float32) 65 | 66 | offset = np.array([ 67 | cty - height // 2, 68 | ctx - width // 2 69 | ]) 70 | 71 | return cropped_image, border, offset 72 | -------------------------------------------------------------------------------- /utils/tqdm.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import contextlib 4 | 5 | from tqdm import tqdm 6 | 7 | class TqdmFile(object): 8 | dummy_file = None 9 | def __init__(self, dummy_file): 10 | self.dummy_file = dummy_file 11 | 12 | def write(self, x): 13 | if len(x.rstrip()) > 0: 14 | tqdm.write(x, file=self.dummy_file) 15 | 16 | @contextlib.contextmanager 17 | def stdout_to_tqdm(): 18 | save_stdout = sys.stdout 19 | try: 20 | sys.stdout = TqdmFile(sys.stdout) 21 | yield save_stdout 22 | except Exception as exc: 23 | raise exc 24 | finally: 25 | sys.stdout = save_stdout 26 | --------------------------------------------------------------------------------