├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── README.md ├── convert_weights.ipynb ├── examples ├── CMakeLists.txt ├── opencv_realtime_webcam_human_segmentation.cpp ├── opencv_realtime_webcam_imagenet_classification.cpp ├── pytorch_results_deviation.cpp ├── read_allocated_gpu_memory.cpp ├── resnet_18_16s_benchmark.cpp ├── resnet_18_8s_benchmark.cpp ├── resnet_9_8s_benchmark.cpp └── segmentation_demo_preview.gif └── src ├── imagenet_classes.cpp └── pytorch.cpp /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | ATen/build/ 3 | .ipynb_checkpoints/ 4 | CMake-hdf5-1.8.20/ 5 | *.h5 6 | *.ipynb 7 | *.swp 8 | 9 | !convert_weights.ipynb 10 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "ATen"] 2 | path = ATen 3 | url = git@github.com:warmspringwinds/ATen.git 4 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.0) 2 | project(boo) 3 | 4 | set(ATen_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/ATen) 5 | set(ATen_BINARY_DIR ${CMAKE_CURRENT_SOURCE_DIR}/ATen/build) 6 | 7 | # C++11 8 | if(CMAKE_VERSION VERSION_LESS "3.1") 9 | set(CMAKE_CXX_FLAGS "--std=c++11 ${CMAKE_CXX_FLAGS}") 10 | else() 11 | set(CMAKE_CXX_STANDARD 11) 12 | endif() 13 | 14 | find_package(CUDA 5.5) 15 | find_package( OpenCV REQUIRED ) 16 | 17 | #find_package(HDF5 COMPONENTS C HL NO_MODULE REQUIRED static) 18 | 19 | include_directories( 20 | # dense 21 | ${ATen_SOURCE_DIR}/lib/TH 22 | ${ATen_SOURCE_DIR}/lib/THC 23 | ${ATen_BINARY_DIR}/lib/TH 24 | ${ATen_BINARY_DIR}/lib/THC 25 | # sparse 26 | ${ATen_SOURCE_DIR}/lib/THS 27 | ${ATen_SOURCE_DIR}/lib/THCS 28 | ${ATen_BINARY_DIR}/lib/THS 29 | ${ATen_BINARY_DIR}/lib/THCS 30 | 31 | ${ATen_SOURCE_DIR}/lib 32 | ${ATen_BINARY_DIR}/lib) 33 | 34 | include_directories( 35 | ${ATen_SOURCE_DIR}/lib/THNN 36 | ${ATen_SOURCE_DIR}/lib/THCUNN) 37 | 38 | include_directories( 39 | ${ATen_SOURCE_DIR}/src 40 | ${ATen_BINARY_DIR}/src/ATen) 41 | 42 | INCLUDE_DIRECTORIES(${CUDA_INCLUDE_DIRS}) 43 | 44 | include_directories( 45 | ${CMAKE_CURRENT_SOURCE_DIR}/CMake-hdf5-1.8.20/hdf5-1.8.20/src 46 | ${CMAKE_CURRENT_SOURCE_DIR}/CMake-hdf5-1.8.20/hdf5-1.8.20/c++/src 47 | ${CMAKE_CURRENT_SOURCE_DIR}/CMake-hdf5-1.8.20/build 48 | ) 49 | 50 | set(_hdf5_libs 51 | ${CMAKE_CURRENT_SOURCE_DIR}/CMake-hdf5-1.8.20/build/bin/libhdf5_cpp.a 52 | ${CMAKE_CURRENT_SOURCE_DIR}/CMake-hdf5-1.8.20/build/bin/libhdf5.a 53 | ${CMAKE_CURRENT_SOURCE_DIR}/CMake-hdf5-1.8.20/build/bin/libz.a 54 | ${CMAKE_CURRENT_SOURCE_DIR}/CMake-hdf5-1.8.20/build/bin/libszip.a 55 | -ldl 56 | ) 57 | 58 | # TODO: structure project in a better way 59 | # Temporary solution -- change to a normal cpp project structure later 60 | include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src) 61 | 62 | 63 | ADD_EXECUTABLE(read_allocated_gpu_memory examples/read_allocated_gpu_memory.cpp) 64 | TARGET_LINK_LIBRARIES(read_allocated_gpu_memory ${ATen_BINARY_DIR}/src/ATen/libATen.so.1 ${CUDA_LIBRARIES}) 65 | 66 | ADD_EXECUTABLE(opencv_realtime_webcam_human_segmentation examples/opencv_realtime_webcam_human_segmentation.cpp) 67 | TARGET_LINK_LIBRARIES(opencv_realtime_webcam_human_segmentation ${ATen_BINARY_DIR}/src/ATen/libATen.so.1 ${CUDA_LIBRARIES} ${_hdf5_libs} ${OpenCV_LIBS}) 68 | 69 | ADD_EXECUTABLE(pytorch_results_deviation examples/pytorch_results_deviation.cpp) 70 | TARGET_LINK_LIBRARIES(pytorch_results_deviation ${ATen_BINARY_DIR}/src/ATen/libATen.so.1 ${CUDA_LIBRARIES} ${_hdf5_libs}) 71 | 72 | ADD_EXECUTABLE(resnet_18_8s_benchmark examples/resnet_18_8s_benchmark.cpp) 73 | TARGET_LINK_LIBRARIES(resnet_18_8s_benchmark ${ATen_BINARY_DIR}/src/ATen/libATen.so.1 ${CUDA_LIBRARIES} ${_hdf5_libs}) 74 | 75 | ADD_EXECUTABLE(resnet_18_16s_benchmark examples/resnet_18_16s_benchmark.cpp) 76 | TARGET_LINK_LIBRARIES(resnet_18_16s_benchmark ${ATen_BINARY_DIR}/src/ATen/libATen.so.1 ${CUDA_LIBRARIES} ${_hdf5_libs}) 77 | 78 | ADD_EXECUTABLE(resnet_9_8s_benchmark examples/resnet_9_8s_benchmark.cpp) 79 | TARGET_LINK_LIBRARIES(resnet_9_8s_benchmark ${ATen_BINARY_DIR}/src/ATen/libATen.so.1 ${CUDA_LIBRARIES} ${_hdf5_libs}) 80 | 81 | ADD_EXECUTABLE(opencv_realtime_webcam_imagenet_classification examples/opencv_realtime_webcam_imagenet_classification.cpp) 82 | TARGET_LINK_LIBRARIES(opencv_realtime_webcam_imagenet_classification ${ATen_BINARY_DIR}/src/ATen/libATen.so.1 ${CUDA_LIBRARIES} ${_hdf5_libs} ${OpenCV_LIBS}) 83 | 84 | 85 | #add_subdirectory(examples) 86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch-C++ 2 | 3 | ```Pytorch-C++``` is a simple C++ 11 library which provides a [Pytorch](http://pytorch.org/)-like 4 | interface for building neural networks and inference (so far only forward pass is supported). The library 5 | respects the semantics of ```torch.nn``` module of PyTorch. Models from [pytorch/vision](https://github.com/pytorch/vision) 6 | are supported and can be [easily converted](convert_weights.ipynb). We also support all the models from [our image segmentation repository](https://github.com/warmspringwinds/pytorch-segmentation-detection) (scroll down for the gif with example output of one of our segmentation models). 7 | 8 | The library heavily relies on an amazing [ATen](https://github.com/zdevito/ATen) library and was inspired by 9 | [cunnproduction](https://github.com/szagoruyko/cunnproduction). 10 | 11 | The structure of the project and CMake will be changed in a future, as it is not optimal now. 12 | 13 | ## Table of contents 14 | 15 | Use-cases
16 | Examples
17 | Implemented layers
18 | Implemented models
19 | Demos
20 | Installation
21 | About
22 | Contributors
23 | 24 | 25 | ## Use-cases 26 | 27 | The library can be used in cases where you want to integrate your trained ```Pytorch``` 28 | networks into an existing C++ stack and you don't want to convert your weights to other libraries 29 | like ```Caffe/Caffe2/Tensorflow```. The library respects the semantics of the ```Pytorch``` and uses 30 | the same underlying C library to perform all the operations. 31 | 32 | You can achieve more low-level control over your memory. For example, 33 | you can use a memory that was already allocated on GPU. This way you can accept memory from other 34 | application on GPU and avoid expensive transfer to CPU. See [this example](examples/read_allocated_gpu_memory.cpp). 35 | 36 | Conversion from other image types like OpenCV's ```mat``` to ```Tensor``` can be easily performed and all the post-processing 37 | can be done using numpy-like optimized operations, thanks to [ATen](https://github.com/zdevito/ATen) library. 38 | See examples [here](examples/opencv_realtime_webcam_human_segmentation.cpp). 39 | 40 | 41 | ## Some examples 42 | 43 | ### Inference 44 | 45 | ```c++ 46 | auto net = torch::resnet50_imagenet(); 47 | 48 | net->load_weights("../resnet50_imagenet.h5"); 49 | 50 | # Transfer network to GPU 51 | net->cuda(); 52 | 53 | # Generate a dummy tensor on GPU of type float 54 | Tensor dummy_input = CUDA(kFloat).ones({1, 3, 224, 224}); 55 | 56 | # Perform inference 57 | auto result = net->forward(dummy_input); 58 | 59 | map dict; 60 | 61 | # Get the result of the inference back to CPU 62 | dict["main"] = result.toBackend(Backend::CPU); 63 | 64 | # Save the result of the inference in the HDF5 file 65 | torch::save("resnet50_output.h5", dict); 66 | ``` 67 | 68 | ### Display network's architecture 69 | 70 | ```c++ 71 | 72 | auto net = torch::resnet50_imagenet(); 73 | 74 | net->load_weights("../resnet50_imagenet.h5"); 75 | 76 | cout << net->tostring() << endl; 77 | 78 | ``` 79 | 80 | Output: 81 | 82 | ``` 83 | ResNet ( 84 | (conv1) Conv2d( in_channels=3 out_channels=64 kernel_size=(7, 7) stride=(2, 2) padding=(3, 3) dilation=(1, 1) groups=1 bias=0 ) 85 | (bn1) BatchNorm2d( num_features=64 eps=0.000010 momentum=0.100000 ) 86 | (relu) ReLU 87 | (maxpool) MaxPool2d( kernel_size=(3, 3) stride=(2, 2) padding=(1, 1) ) 88 | (layer1) Sequential ( 89 | (0) Bottleneck ( 90 | (conv1) Conv2d( in_channels=64 out_channels=64 kernel_size=(1, 1) stride=(1, 1) padding=(0, 0) dilation=(1, 1) groups=1 bias=0 ) 91 | (bn1) BatchNorm2d( num_features=64 eps=0.000010 momentum=0.100000 ) 92 | (conv2) Conv2d( in_channels=64 out_channels=64 kernel_size=(3, 3) stride=(1, 1) padding=(1, 1) dilation=(1, 1) groups=1 bias=0 ) 93 | (bn2) BatchNorm2d( num_features=64 eps=0.000010 momentum=0.100000 ) 94 | (conv3) Conv2d( in_channels=64 out_channels=256 kernel_size=(1, 1) stride=(1, 1) padding=(0, 0) dilation=(1, 1) groups=1 bias=0 ) 95 | (bn3) BatchNorm2d( num_features=256 eps=0.000010 momentum=0.100000 ) 96 | (downsample) Sequential ( 97 | (0) Conv2d( in_channels=64 out_channels=256 kernel_size=(1, 1) stride=(1, 1) padding=(0, 0) dilation=(1, 1) groups=1 bias=0 ) 98 | (1) BatchNorm2d( num_features=256 eps=0.000010 momentum=0.100000 ) 99 | ) 100 | 101 | ) 102 | 103 | (1) Bottleneck ( 104 | (conv1) Conv2d( in_channels=256 out_channels=64 kernel_size=(1, 1) stride=(1, 1) padding=(0, 0) dilation=(1, 1) groups=1 bias=0 ) 105 | (bn1) BatchNorm2d( num_features=64 eps=0.000010 momentum=0.100000 ) 106 | (conv2) Conv2d( in_channels=64 out_channels=64 kernel_size=(3, 3) stride=(1, 1) padding=(1, 1) dilation=(1, 1) groups=1 bias=0 ) 107 | (bn2) BatchNorm2d( num_features=64 eps=0.000010 momentum=0.100000 ) 108 | (conv3) Conv2d( in_channels=256 out_channels=256 kernel_size=(1, 1) stride=(1, 1) padding=(0, 0) dilation=(1, 1) groups=1 bias=0 ) 109 | (bn3) BatchNorm2d( num_features=256 eps=0.000010 momentum=0.100000 ) 110 | ) 111 | 112 | (2) Bottleneck ( 113 | (conv1) Conv2d( in_channels=256 out_channels=64 kernel_size=(1, 1) stride=(1, 1) padding=(0, 0) dilation=(1, 1) groups=1 bias=0 ) 114 | (bn1) BatchNorm2d( num_features=64 eps=0.000010 momentum=0.100000 ) 115 | (conv2) Conv2d( in_channels=64 out_channels=64 kernel_size=(3, 3) stride=(1, 1) padding=(1, 1) dilation=(1, 1) groups=1 bias=0 ) 116 | (bn2) BatchNorm2d( num_features=64 eps=0.000010 momentum=0.100000 ) 117 | (conv3) Conv2d( in_channels=256 out_channels=256 kernel_size=(1, 1) stride=(1, 1) padding=(0, 0) dilation=(1, 1) groups=1 bias=0 ) 118 | (bn3) BatchNorm2d( num_features=256 eps=0.000010 momentum=0.100000 ) 119 | ) 120 | 121 | ) 122 | 123 | /* .... */ 124 | 125 | (avgpool) AvgPool2d( kernel_size=(7, 7) stride=(1, 1) padding=(0, 0) ) 126 | (fc) nn.Linear( in_features=2048 out_features=1000 bias=1 ) 127 | ) 128 | ``` 129 | 130 | ### Inspect a Tensor 131 | 132 | 133 | ```c++ 134 | auto net = torch::resnet50_imagenet(); 135 | 136 | net->load_weights("../resnet50_imagenet.h5"); 137 | net->cuda(); 138 | 139 | Tensor dummy_input = CUDA(kFloat).ones({1, 3, 224, 224}); 140 | 141 | auto result = net->forward(dummy_input); 142 | 143 | cout << result << endl; 144 | ``` 145 | 146 | 147 | ``` 148 | Columns 1 to 10-0.3081 0.0798 -1.1900 -1.4837 -0.5136 0.3683 -2.1639 -0.8705 -1.8812 -0.1608 149 | 150 | Columns 11 to 20 0.2168 -0.9283 -1.2954 -1.0791 -1.4445 -0.8946 -0.0959 -1.3099 -1.2062 -1.2327 151 | 152 | Columns 21 to 30-1.0658 0.9427 0.5739 -0.2746 -1.0189 -0.3583 -0.1826 0.2785 0.2209 -0.3340 153 | 154 | Columns 31 to 40-1.9800 -0.5552 -1.0804 -0.8056 -0.0005 -1.8402 -0.7979 -1.4823 1.3657 -0.8970 155 | 156 | /* .... */ 157 | 158 | Columns 961 to 970-0.0557 -0.7405 -0.5501 -1.7207 -0.7043 -1.0925 1.5812 -0.1215 0.8915 0.9794 159 | 160 | Columns 971 to 980-1.1422 -0.1235 -0.5999 -2.1338 -0.0775 -0.8374 -0.2350 -0.0104 -0.0416 -1.0296 161 | 162 | Columns 981 to 990-0.2914 -0.2242 -0.8063 -0.7818 -0.2714 0.0002 -1.2355 0.1238 0.0183 -0.6904 163 | 164 | Columns 991 to 1000 0.5216 -1.8008 -1.7826 -1.2970 -1.6565 -1.3306 -0.6564 -1.6531 0.1178 0.2436 165 | [ CUDAFloatTensor{1,1000} ] 166 | ``` 167 | 168 | ### Create a network 169 | 170 | 171 | ```c++ 172 | auto new_net = std::make_shared(); 173 | new_net->add(std::make_shared(3, 10, 3, 3)); 174 | new_net->add(std::make_shared(10)); 175 | new_net->add(std::make_shared()); 176 | new_net->add(std::make_shared(10, 3)); 177 | ``` 178 | ## Implemented layers 179 | 180 | So far, these layers are available which respect the Pytorch's layers semantics which 181 | can be found [here](http://pytorch.org/docs/0.1.12/nn.html#convolution-layers). 182 | 183 | 184 | - [x] nn.Sequential 185 | - [x] nn.Conv2d 186 | - [x] nn.MaxPool2d 187 | - [x] nn.AvgPool2d 188 | - [x] nn.ReLU 189 | - [x] nn.Linear 190 | - [x] nn.SoftMax 191 | - [x] nn.BatchNorm2d 192 | - [ ] nn.Dropout2d 193 | - [ ] nn.DataParallel 194 | - [ ] nn.AdaptiveMaxPool2d 195 | - [ ] nn.Sigmoid 196 | and others. 197 | 198 | ## Implemented models 199 | 200 | Some convered models are provided for ease of access. Other models can be [easily converted](convert_weights.ipynb). 201 | 202 | ### Imagenet models 203 | 204 | All models were converted from [pytorch/vision](https://github.com/pytorch/vision) and checked for 205 | correctness. 206 | 207 | - [x] Resnet-18 208 | - [x] Resnet-34 209 | - [x] [Resnet-50](https://www.dropbox.com/s/bukezzx17dr8qdd/resnet50_imagenet.h5?dl=0) 210 | - [x] Resnet-101 211 | - [x] Resnet-150 212 | - [x] Resnet-152 213 | - [ ] All VGG models 214 | - [ ] All Densenet models 215 | - [ ] All Inception models 216 | - [ ] All squeezenet models 217 | - [ ] Alexnet 218 | 219 | ### Segmentation PASCAL VOC 220 | 221 | All models were converted from [this repository](https://github.com/warmspringwinds/dense-ai) and checked for 222 | correctness. 223 | 224 | - [x] Resnet-18-8S 225 | - [x] [Resnet-34-8S](https://www.dropbox.com/s/104my8hr5zm6l7d/resnet34_fcn_pascal.h5?dl=0) 226 | - [ ] Resnet-50-8S 227 | - [ ] Resnet-101-8S 228 | - [ ] Resnet-152-8S 229 | - [x] FCN-32s 230 | - [ ] FCN-16s 231 | - [ ] FCN-8s 232 | 233 | ## Demos 234 | 235 | We created a couple of [demos](examples) where we grab frames using opencv and classify 236 | or segment them. 237 | 238 | Here you can see and example of real-time segmentation: 239 | 240 | 241 | ![Alt text](examples/segmentation_demo_preview.gif?raw=true "Title") 242 | 243 | ## Installation 244 | 245 | ### ATen 246 | 247 | [ATen](https://github.com/zdevito/ATen) is a C++ 11 library that wraps a powerfull C Tensor library with 248 | implementation of numpy-like operations (CPU/CUDA/SPARSE/CUDA-SPARSE backends). 249 | Follow these steps to install it: 250 | 251 | 0. Make sure you have [dependencies](https://github.com/zdevito/ATen#installation) of ```ATen``` installed. 252 | 1. ```git clone --recursive https://github.com/warmspringwinds/pytorch-cpp``` 253 | 2. ```cd pytorch-cpp/ATen;mkdir build;cd build;cmake-gui .. ``` and specify ```CUDA_TOOLKIT_ROOT_DIR```. 254 | 3. ```make``` or better ```make -j7``` (replace ```7``` with a number of cores that you have). 255 | 4. ```cd ../../``` -- returns you back to the root directory (necessary for the next step). 256 | 257 | ### HDF5 258 | 259 | We use ```HDF5``` to be able to [easily convert](convert_weights.ipynb) weigths between ```Pytorch``` and ```Pytorch-C++```. 260 | 261 | 0. ```wget https://support.hdfgroup.org/ftp/HDF5/current18/src/CMake-hdf5-1.8.20.tar.gz; tar xvzf CMake-hdf5-1.8.19.tar.gz``` 262 | 1. ```cd CMake-hdf5-1.8.19; ./build-unix.sh``` 263 | 2. ```cd ../``` -- return back. 264 | 265 | Additional information: ```HDF5``` gets updated from time to time and there is a good chance that my link might be outdated. 266 | If it's the case, grab the latest version from [the official website](https://support.hdfgroup.org/ftp/HDF5/current18/src/). 267 | 268 | Also, after you do this don't forget to update the ```CMakelists.txt``` file with the new hdf5 folder name. 269 | 270 | ### Opencv 271 | 272 | We need ```OpenCV``` for a couple of examples which grab frames from a web camera. 273 | It is not a dependency and can be removed if necessary. 274 | This was tested on ```Ubuntu-16``` and might need some changes on a different system. 275 | 276 | 0. ```sudo apt-get install libopencv-dev python-opencv``` 277 | 278 | 279 | ### Pytorch-C++ 280 | 281 | ```Pytorch-C++``` is a library on top of ```ATen``` that provides a [Pytorch](http://pytorch.org/)-like 282 | interface for building neural networks and inference (so far only forward pass is supported) 283 | inspired by [cunnproduction](https://github.com/szagoruyko/cunnproduction) library. To install it, follow 284 | these steps: 285 | 286 | 0. ```mkdir build; cd build; cmake-gui ..``` and specify ```CUDA_TOOLKIT_ROOT_DIR```. 287 | 1. ```make``` 288 | 2. ```cd ../``` -- return back 289 | 290 | ### Problems with the build 291 | 292 | It was noticed that if you have anaconda installed and your ```PATH``` variable is modified to include 293 | its folder, it can lead to failed buid (caused by the fact that anaconda uses different version of ```gcc```). 294 | To solve this problem, remove the path to anaconda from ```PATH``` for the time of the build. 295 | 296 | If you face any problems or some steps are not clear, please open an issue. Note: every time you enter the ```cmake-gui``` 297 | press ```configure``` first, then specify your ```CUDA``` path and then press ```generate```, after that you can build. 298 | 299 | 300 | ## About 301 | 302 | If you used the code for your research, please, cite the paper: 303 | 304 | @article{pakhomov2017deep, 305 | title={Deep Residual Learning for Instrument Segmentation in Robotic Surgery}, 306 | author={Pakhomov, Daniil and Premachandran, Vittal and Allan, Max and Azizian, Mahdi and Navab, Nassir}, 307 | journal={arXiv preprint arXiv:1703.08580}, 308 | year={2017} 309 | } 310 | 311 | During implementation, some preliminary experiments and notes were reported: 312 | - [Converting Image Classification network into FCN](http://warmspringwinds.github.io/tensorflow/tf-slim/2016/10/30/image-classification-and-segmentation-using-tensorflow-and-tf-slim/) 313 | - [Performing upsampling using transposed convolution](http://warmspringwinds.github.io/tensorflow/tf-slim/2016/11/22/upsampling-and-image-segmentation-with-tensorflow-and-tf-slim/) 314 | - [Conditional Random Fields for Refining of Segmentation and Coarseness of FCN-32s model segmentations](http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/18/image-segmentation-with-tensorflow-using-cnns-and-conditional-random-fields/) 315 | - [TF-records usage](http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/) 316 | 317 | ## Contributors 318 | 319 | - Daniil Pakhomov 320 | -------------------------------------------------------------------------------- /convert_weights.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Converting weights between Pytorch and Pytorch-C++" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 2, 13 | "metadata": { 14 | "collapsed": true 15 | }, 16 | "outputs": [], 17 | "source": [ 18 | "import h5py\n", 19 | "import torch\n", 20 | "import torchvision.models as models\n", 21 | "\n", 22 | "test = models.resnet18(pretrained=True)\n", 23 | "\n", 24 | "state_dict = test.state_dict()\n", 25 | "\n", 26 | "h5f = h5py.File('resnet18.h5', 'w')\n", 27 | "\n", 28 | "for key in state_dict:\n", 29 | " \n", 30 | " h5f.create_dataset(key, data=state_dict[key].numpy())\n", 31 | "\n", 32 | "h5f.close()" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "## Comparing the outputs of two classifiction networks" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 1, 45 | "metadata": { 46 | "collapsed": true 47 | }, 48 | "outputs": [], 49 | "source": [ 50 | "import torch\n", 51 | "import torchvision.models as models\n", 52 | "\n", 53 | "net = models.resnet50(pretrained=True)\n", 54 | "net = net.eval()\n", 55 | "\n", 56 | "ones_input = torch.autograd.Variable( torch.ones(1, 3, 224, 224) )\n", 57 | "\n", 58 | "pytorch_inference_result = net(ones_input).data.numpy()" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 2, 64 | "metadata": { 65 | "collapsed": true 66 | }, 67 | "outputs": [], 68 | "source": [ 69 | "import h5py\n", 70 | "import torch\n", 71 | "import torchvision.models as models\n", 72 | "\n", 73 | "\n", 74 | "h5f = h5py.File('build/resnet50_output.h5', 'r')\n", 75 | "pytorch_cpp_inference_result = h5f['main'][:]\n", 76 | "h5f.close()" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 3, 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "data": { 86 | "text/plain": [ 87 | "True" 88 | ] 89 | }, 90 | "execution_count": 3, 91 | "metadata": {}, 92 | "output_type": "execute_result" 93 | } 94 | ], 95 | "source": [ 96 | "import numpy as np\n", 97 | "\n", 98 | "# Equal up to 1e-4 by absolute value\n", 99 | "np.allclose(pytorch_cpp_inference_result, pytorch_inference_result, atol=1e-4, rtol=0)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "## Comparing the outputs of two segmentation networks" 107 | ] 108 | } 109 | ], 110 | "metadata": { 111 | "kernelspec": { 112 | "display_name": "Python 2", 113 | "language": "python", 114 | "name": "python2" 115 | }, 116 | "language_info": { 117 | "codemirror_mode": { 118 | "name": "ipython", 119 | "version": 2 120 | }, 121 | "file_extension": ".py", 122 | "mimetype": "text/x-python", 123 | "name": "python", 124 | "nbconvert_exporter": "python", 125 | "pygments_lexer": "ipython2", 126 | "version": "2.7.13" 127 | } 128 | }, 129 | "nbformat": 4, 130 | "nbformat_minor": 2 131 | } 132 | -------------------------------------------------------------------------------- /examples/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | #ADD_EXECUTABLE(read_allocated_gpu_memory read_allocated_gpu_memory.cpp) 3 | #TARGET_LINK_LIBRARIES(read_allocated_gpu_memory ${ATen_BINARY_DIR}/src/ATen/libATen.so.1 ${CUDA_LIBRARIES}) 4 | 5 | #ADD_EXECUTABLE(sequential_test sequential_test.cpp) 6 | #TARGET_LINK_LIBRARIES(sequential_test ${ATen_BINARY_DIR}/src/ATen/libATen.so.1 ${CUDA_LIBRARIES}) 7 | 8 | -------------------------------------------------------------------------------- /examples/opencv_realtime_webcam_human_segmentation.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Example shows a real-time segmentation of human class from PASCAL VOC. 3 | The network ouputs probabilities of each pixels belonging to the human class. 4 | These probabilities are later on are used as a transparancy mask for the input image. 5 | The final fused image is displayed in the window of the application. 6 | */ 7 | 8 | #include "ATen/ATen.h" 9 | #include "ATen/Type.h" 10 | #include 11 | 12 | #include 13 | 14 | #include 15 | 16 | using namespace at; 17 | 18 | using std::map; 19 | using std::string; 20 | 21 | using namespace cv; 22 | 23 | int main() 24 | { 25 | 26 | 27 | // Structure the project in a better way 28 | 29 | // Add a correct linking to Opencv on the local machine 30 | 31 | // Get the build running on laptop for demo 32 | 33 | // upload all the transferred models 34 | 35 | 36 | // ----- 37 | 38 | // * Should we convert the renset 50 and 101? 39 | // * we don't have any segmentatin models trained using them 40 | // * maybe only to make the framework more complete? (check) 41 | 42 | // * Make the classification demo? 43 | // * need to put a softmax on top -- should be very easy 44 | // * need a dict with number --> class name mapping (check) 45 | 46 | // * Structure the whole project (check) 47 | 48 | // * write docs on how to build it 49 | 50 | // * write missing parts -- good for future contributions 51 | 52 | // * Write the dataloaders for the new surgical datasets 53 | 54 | // * start the training 55 | 56 | 57 | auto net = torch::resnet34_8s_pascal_voc(); 58 | 59 | net->load_weights("../resnet34_fcn_pascal.h5"); 60 | net->cuda(); 61 | 62 | VideoCapture cap(0); // open the default camera 63 | 64 | if(!cap.isOpened()) // check if we succeeded 65 | return -1; 66 | 67 | Mat frame; 68 | 69 | for(;;) 70 | { 71 | 72 | cap >> frame; 73 | 74 | // BGR to RGB which is what our network was trained on 75 | cvtColor(frame, frame, COLOR_BGR2RGB); 76 | 77 | // Resizing while preserving aspect ratio, comment out to run 78 | // it on the whole input image. 79 | resize(frame, frame, Size(0, 0), 0.5, 0.5, INTER_LINEAR); 80 | 81 | // Outputs height x width x 3 tensor converted from Opencv's Mat with 0-255 values 82 | // and convert to 0-1 range 83 | auto image_tensor = torch::convert_opencv_mat_image_to_tensor(frame).toType(CPU(kFloat)) / 255; 84 | 85 | auto output_height = image_tensor.size(0); 86 | auto output_width = image_tensor.size(1); 87 | 88 | // Reshape image into 1 x 3 x height x width 89 | auto image_batch_tensor = torch::convert_image_to_batch(image_tensor); 90 | 91 | // Subtract the mean and divide by standart deivation 92 | auto image_batch_normalized_tensor = torch::preprocess_batch(image_batch_tensor); 93 | 94 | auto input_tensor_gpu = image_batch_normalized_tensor.toBackend(Backend::CUDA); 95 | 96 | auto full_prediction = net->forward(input_tensor_gpu); 97 | 98 | // This is necessary to correctly apply softmax, 99 | // last dimension should represent logits 100 | auto full_prediction_flattned = full_prediction.squeeze(0) 101 | .view({21, -1}) 102 | .transpose(0, 1); 103 | 104 | // Converting logits to probabilities 105 | auto softmaxed = torch::softmax(full_prediction_flattned).transpose(0, 1); 106 | 107 | // 15 is a class for a person 108 | auto layer = softmaxed[15].contiguous().view({output_height, output_width, 1}).toBackend(Backend::CPU); 109 | 110 | // Fuse the prediction probabilities and the actual image to form a masked image. 111 | auto masked_image = ( image_tensor * layer.expand({output_height, output_width, 3}) ) * 255 ; 112 | 113 | // A function to convert Tensor to a Mat 114 | auto layer_cpu = masked_image.toType(CPU(kByte)); 115 | 116 | auto converted = Mat(output_height, output_width, CV_8UC3, layer_cpu.data_ptr()); 117 | 118 | // OpenCV wants BGR not RGB 119 | cvtColor(converted, converted, COLOR_RGB2BGR); 120 | 121 | imshow("Masked image", converted); 122 | 123 | if(waitKey(30) >= 0 ) break; 124 | } 125 | 126 | 127 | return 0; 128 | } -------------------------------------------------------------------------------- /examples/opencv_realtime_webcam_imagenet_classification.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Example shows a real-time classification. The name of the most probable class 3 | is printed over the image. 4 | */ 5 | 6 | #include "ATen/ATen.h" 7 | #include "ATen/Type.h" 8 | #include 9 | 10 | #include 11 | 12 | #include 13 | #include 14 | 15 | using namespace at; 16 | 17 | using std::map; 18 | using std::string; 19 | using std::tie; 20 | 21 | using namespace cv; 22 | 23 | Mat resize_and_center_square_crop(Mat input_image, int square_size=224) 24 | { 25 | 26 | // Resize image so that the smallest side == square_size 27 | // and do a center crop along the biggest side. 28 | // This way we preserve the aspect ratio and prepare the image 29 | // for network. 30 | 31 | int width = input_image.cols, 32 | height = input_image.rows; 33 | 34 | int min_dim = ( width >= height ) ? height : width; 35 | float scale = ( ( float ) square_size ) / min_dim; 36 | 37 | resize(input_image, input_image, Size(0, 0), scale, scale, INTER_LINEAR); 38 | 39 | Rect roi; 40 | 41 | if ( height >= width ) 42 | { 43 | roi.width = square_size; 44 | roi.x = 0; 45 | 46 | roi.height = square_size; 47 | roi.y = ( input_image.rows - roi.height ) / 2; 48 | } 49 | else 50 | { 51 | roi.y = 0; 52 | roi.height = square_size; 53 | 54 | 55 | roi.width = square_size; 56 | roi.x = ( input_image.cols - roi.width ) / 2; 57 | } 58 | 59 | Mat square_crop = input_image(roi); 60 | 61 | return square_crop; 62 | } 63 | 64 | 65 | int main() 66 | { 67 | 68 | auto net = torch::resnet50_imagenet(); 69 | 70 | net->load_weights("../resnet50_imagenet.h5"); 71 | net->cuda(); 72 | 73 | VideoCapture cap(0); // open the default camera 74 | 75 | 76 | if(!cap.isOpened()) // check if we succeeded 77 | return -1; 78 | 79 | Mat frame; 80 | Mat resized_img; 81 | Mat tmp; 82 | 83 | for(;;) 84 | { 85 | 86 | cap.read(frame); 87 | 88 | // BGR to RGB which is what our network was trained on 89 | cvtColor(frame, tmp, COLOR_BGR2RGB); 90 | 91 | // Be carefull: convert_opencv_mat_image_to_tensor() sometimes fails because 92 | // of different management of underlaying image representation, this is why we 93 | // do .clone() 94 | // TODO: investigate it further 95 | resized_img = resize_and_center_square_crop(tmp).clone(); 96 | 97 | // Outputs height x width x 3 tensor converted from Opencv's Mat with 0-255 values 98 | // and convert to 0-1 range 99 | auto image_tensor = torch::convert_opencv_mat_image_to_tensor(resized_img).toType(CPU(kFloat)) / 255; 100 | 101 | // Reshape image into 1 x 3 x height x width 102 | auto image_batch_tensor = torch::convert_image_to_batch(image_tensor); 103 | 104 | auto image_batch_normalized_tensor = torch::preprocess_batch(image_batch_tensor); 105 | 106 | auto input_tensor_gpu = image_batch_normalized_tensor.toBackend(Backend::CUDA); 107 | 108 | auto full_prediction = net->forward(input_tensor_gpu); 109 | 110 | auto softmaxed = torch::softmax(full_prediction); 111 | 112 | Tensor top_probability_indexes; 113 | Tensor top_probabilies; 114 | 115 | tie(top_probabilies, top_probability_indexes) = topk(softmaxed, 5, 1, true); 116 | 117 | top_probability_indexes = top_probability_indexes.toBackend(Backend::CPU).view({-1}); 118 | 119 | auto accessor = top_probability_indexes.accessor(); 120 | 121 | putText(frame, imagenet_classes[ accessor[0] ], cvPoint(30,30), 122 | FONT_HERSHEY_COMPLEX_SMALL, 0.8, cvScalar(200,200,250), 1, CV_AA); 123 | 124 | imshow("Masked image", frame); 125 | 126 | if(waitKey(30) >= 0 ) break; 127 | } 128 | 129 | } -------------------------------------------------------------------------------- /examples/pytorch_results_deviation.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Example shows how to run a resnet 50 imagenet-trained classification 3 | model on a dummy input and save it to an hdf5 file. This output can be 4 | later on compared to the output acquired from pytorch in a provided .ipynb 5 | notebook -- results differ no more than 10^{-5}. 6 | */ 7 | 8 | #include "ATen/ATen.h" 9 | #include "ATen/Type.h" 10 | #include 11 | 12 | #include 13 | 14 | using namespace at; 15 | 16 | using std::map; 17 | using std::string; 18 | 19 | 20 | int main() 21 | { 22 | 23 | auto net = torch::resnet50_imagenet(); 24 | 25 | net->load_weights("../resnet50_imagenet.h5"); 26 | net->cuda(); 27 | 28 | Tensor dummy_input = CUDA(kFloat).ones({1, 3, 224, 224}); 29 | 30 | auto result = net->forward(dummy_input); 31 | 32 | map dict; 33 | 34 | dict["main"] = result.toBackend(Backend::CPU); 35 | 36 | torch::save("resnet50_output.h5", dict); 37 | 38 | return 0; 39 | } -------------------------------------------------------------------------------- /examples/read_allocated_gpu_memory.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Example shows how an already allocated memory can be reused. 3 | It's a common case when the memory has to be used without transferring it to CPU 4 | and back to GPU. 5 | */ 6 | 7 | #include "ATen/ATen.h" 8 | #include 9 | 10 | using namespace at; // assumed in the following 11 | 12 | 13 | int main() 14 | { 15 | 16 | int width = 300; 17 | int height = 300; 18 | 19 | // Dummy CPU image -- RGBA 20 | std::vector image(4 * width * height); 21 | for (size_t y = 0; y < height; ++y) { 22 | for (size_t x = 0; x < width; ++x) { 23 | size_t idx = y * width + x; 24 | unsigned char value = (float) (y + 1) / height * 255; 25 | if (x < 0.03125*width) { 26 | image[ idx * 4 + 0] = 255-value; 27 | image[ idx * 4 + 1] = 255-value; 28 | image[ idx * 4 + 2] = 255-value; 29 | image[ idx * 4 + 3] = 255; 30 | } 31 | else if (x < 0.34375*width) { 32 | image[ idx * 4 + 0] = value; 33 | image[ idx * 4 + 1] = 0; 34 | image[ idx * 4 + 2] = 0; 35 | image[ idx * 4 + 3] = 255; 36 | } 37 | else if (x < 0.65625*width) { 38 | image[ idx * 4 + 0] = 0; 39 | image[ idx * 4 + 1] = 255-value; 40 | image[ idx * 4 + 2] = 0; 41 | image[ idx * 4 + 3] = 255; 42 | } 43 | else if (x < 0.96875*width) { 44 | image[ idx * 4 + 0] = 0; 45 | image[ idx * 4 + 1] = 0; 46 | image[ idx * 4 + 2] = value; 47 | image[ idx * 4 + 3] = 255; 48 | } 49 | else { 50 | image[ idx * 4 + 0] = value; 51 | image[ idx * 4 + 1] = value; 52 | image[ idx * 4 + 2] = value; 53 | image[ idx * 4 + 3] = 255; 54 | } 55 | } 56 | } 57 | 58 | // Load the dummy image to GPU 59 | unsigned char * cuda_pointer; 60 | cudaMalloc(&cuda_pointer, 4 * width * height * sizeof(unsigned char)); 61 | cudaMemcpy(cuda_pointer, image.data(), sizeof(unsigned char) * 4 * width * height, cudaMemcpyHostToDevice); 62 | 63 | // Read the dummy image from GPU and use it as a tensor later on 64 | auto f = CUDA(kByte).tensorFromBlob(cuda_pointer, {4 * width * height}); 65 | auto new_one = f.toType(CPU(kByte)); 66 | 67 | // Nicely print out the contents of the variable 68 | std::cout << f << std::endl; 69 | 70 | cudaFree(cuda_pointer); 71 | 72 | } 73 | -------------------------------------------------------------------------------- /examples/resnet_18_16s_benchmark.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Example shows how to measure the average execution time spent on one image. 3 | Here we test resnet 18 with the output stride of 16 which shows execution time of 10.42 ms 4 | per frame of size 512x512 on average. 5 | */ 6 | 7 | #include "ATen/ATen.h" 8 | #include "ATen/Type.h" 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | using namespace at; 16 | 17 | using std::map; 18 | using std::string; 19 | 20 | using namespace std; 21 | using namespace std::chrono; 22 | 23 | 24 | int main() 25 | { 26 | 27 | // The reason we do a first run before measuring the time is 28 | // because first run is slow and doesn't represent the actual speed. 29 | auto net = torch::resnet18_16s_pascal_voc(); 30 | 31 | net->cuda(); 32 | 33 | Tensor dummy_input = CUDA(kFloat).ones({1, 3, 512, 512}); 34 | 35 | high_resolution_clock::time_point t1; 36 | high_resolution_clock::time_point t2; 37 | 38 | cudaDeviceSynchronize(); 39 | 40 | t1 = high_resolution_clock::now(); 41 | 42 | auto result = net->forward(dummy_input); 43 | 44 | cudaDeviceSynchronize(); 45 | 46 | t2 = high_resolution_clock::now(); 47 | 48 | auto duration = duration_cast( t2 - t1 ).count(); 49 | 50 | // Now running in a loop and getting an average result. 51 | 52 | int number_of_iterations = 100; 53 | int overall_miliseconds_count = 0; 54 | 55 | for (int i = 0; i < number_of_iterations; ++i) 56 | { 57 | 58 | t1 = high_resolution_clock::now(); 59 | 60 | result = net->forward(dummy_input); 61 | 62 | cudaDeviceSynchronize(); 63 | 64 | t2 = high_resolution_clock::now(); 65 | 66 | duration = duration_cast( t2 - t1 ).count(); 67 | 68 | overall_miliseconds_count += duration; 69 | 70 | } 71 | 72 | cout << "Average execution time: " << overall_miliseconds_count / float(number_of_iterations) << " ms" << endl; 73 | 74 | // On our system it outpts: 10.42 ms per frame. 75 | 76 | return 0; 77 | 78 | } -------------------------------------------------------------------------------- /examples/resnet_18_8s_benchmark.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Example shows how to measure the average execution time spent on one image. 3 | Here we test resnet 18 with the output stride of 8 which shows execution time of 25 ms 4 | per frame of size 512x512 on average. 5 | */ 6 | 7 | #include "ATen/ATen.h" 8 | #include "ATen/Type.h" 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | using namespace at; 16 | 17 | using std::map; 18 | using std::string; 19 | 20 | using namespace std; 21 | using namespace std::chrono; 22 | 23 | 24 | int main() 25 | { 26 | 27 | // The reason we do a first run before measuring the time is 28 | // because first run is slow and doesn't represent the actual speed. 29 | auto net = torch::resnet18_8s_pascal_voc(); 30 | 31 | net->cuda(); 32 | 33 | Tensor dummy_input = CUDA(kFloat).ones({1, 3, 512, 512}); 34 | 35 | high_resolution_clock::time_point t1; 36 | high_resolution_clock::time_point t2; 37 | 38 | cudaDeviceSynchronize(); 39 | 40 | t1 = high_resolution_clock::now(); 41 | 42 | auto result = net->forward(dummy_input); 43 | 44 | cudaDeviceSynchronize(); 45 | 46 | t2 = high_resolution_clock::now(); 47 | 48 | auto duration = duration_cast( t2 - t1 ).count(); 49 | 50 | // Now running in a loop and getting an average result. 51 | 52 | int number_of_iterations = 100; 53 | int overall_miliseconds_count = 0; 54 | 55 | for (int i = 0; i < number_of_iterations; ++i) 56 | { 57 | 58 | t1 = high_resolution_clock::now(); 59 | 60 | result = net->forward(dummy_input); 61 | 62 | cudaDeviceSynchronize(); 63 | 64 | t2 = high_resolution_clock::now(); 65 | 66 | duration = duration_cast( t2 - t1 ).count(); 67 | 68 | overall_miliseconds_count += duration; 69 | 70 | } 71 | 72 | cout << "Average execution time: " << overall_miliseconds_count / float(number_of_iterations) << " ms" << endl; 73 | 74 | // On our system it outpts: 25ms per frame. 75 | 76 | return 0; 77 | 78 | } -------------------------------------------------------------------------------- /examples/resnet_9_8s_benchmark.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Example shows how to measure the average execution time spent on one image. 3 | Here we test resnet 9 with the output stride of 8 which shows execution time of 12 ms 4 | per frame of size 512x512 on average. 5 | */ 6 | 7 | #include "ATen/ATen.h" 8 | #include "ATen/Type.h" 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | using namespace at; 16 | 17 | using std::map; 18 | using std::string; 19 | 20 | using namespace std; 21 | using namespace std::chrono; 22 | 23 | 24 | int main() 25 | { 26 | 27 | // The reason we do a first run before measuring the time is 28 | // because first run is slow and doesn't represent the actual speed. 29 | auto net = torch::resnet9_8s_endovis_binary(); 30 | 31 | net->cuda(); 32 | 33 | Tensor dummy_input = CUDA(kFloat).ones({1, 3, 512, 512}); 34 | 35 | high_resolution_clock::time_point t1; 36 | high_resolution_clock::time_point t2; 37 | 38 | cudaDeviceSynchronize(); 39 | 40 | t1 = high_resolution_clock::now(); 41 | 42 | auto result = net->forward(dummy_input); 43 | 44 | cudaDeviceSynchronize(); 45 | 46 | t2 = high_resolution_clock::now(); 47 | 48 | auto duration = duration_cast( t2 - t1 ).count(); 49 | 50 | // Now running in a loop and getting an average result. 51 | 52 | int number_of_iterations = 100; 53 | int overall_miliseconds_count = 0; 54 | 55 | for (int i = 0; i < number_of_iterations; ++i) 56 | { 57 | 58 | t1 = high_resolution_clock::now(); 59 | 60 | result = net->forward(dummy_input); 61 | 62 | cudaDeviceSynchronize(); 63 | 64 | t2 = high_resolution_clock::now(); 65 | 66 | duration = duration_cast( t2 - t1 ).count(); 67 | 68 | overall_miliseconds_count += duration; 69 | 70 | } 71 | 72 | cout << "Average execution time: " << overall_miliseconds_count / float(number_of_iterations) << " ms" << endl; 73 | 74 | // Average execution time: 12.04 ms 75 | 76 | return 0; 77 | 78 | } -------------------------------------------------------------------------------- /examples/segmentation_demo_preview.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmspringwinds/pytorch-cpp/56591806f6b204a2f3217ccf7c1eb00752824ca5/examples/segmentation_demo_preview.gif -------------------------------------------------------------------------------- /src/imagenet_classes.cpp: -------------------------------------------------------------------------------- 1 | 2 | using std::string; 3 | 4 | std::vector imagenet_classes = 5 | {"tench, Tinca tinca", 6 | "goldfish, Carassius auratus", 7 | "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias", 8 | "tiger shark, Galeocerdo cuvieri", 9 | "hammerhead, hammerhead shark", 10 | "electric ray, crampfish, numbfish, torpedo", 11 | "stingray", 12 | "cock", 13 | "hen", 14 | "ostrich, Struthio camelus", 15 | "brambling, Fringilla montifringilla", 16 | "goldfinch, Carduelis carduelis", 17 | "house finch, linnet, Carpodacus mexicanus", 18 | "junco, snowbird", 19 | "indigo bunting, indigo finch, indigo bird, Passerina cyanea", 20 | "robin, American robin, Turdus migratorius", 21 | "bulbul", 22 | "jay", 23 | "magpie", 24 | "chickadee", 25 | "water ouzel, dipper", 26 | "kite", 27 | "bald eagle, American eagle, Haliaeetus leucocephalus", 28 | "vulture", 29 | "great grey owl, great gray owl, Strix nebulosa", 30 | "European fire salamander, Salamandra salamandra", 31 | "common newt, Triturus vulgaris", 32 | "eft", 33 | "spotted salamander, Ambystoma maculatum", 34 | "axolotl, mud puppy, Ambystoma mexicanum", 35 | "bullfrog, Rana catesbeiana", 36 | "tree frog, tree-frog", 37 | "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui", 38 | "loggerhead, loggerhead turtle, Caretta caretta", 39 | "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea", 40 | "mud turtle", 41 | "terrapin", 42 | "box turtle, box tortoise", 43 | "banded gecko", 44 | "common iguana, iguana, Iguana iguana", 45 | "American chameleon, anole, Anolis carolinensis", 46 | "whiptail, whiptail lizard", 47 | "agama", 48 | "frilled lizard, Chlamydosaurus kingi", 49 | "alligator lizard", 50 | "Gila monster, Heloderma suspectum", 51 | "green lizard, Lacerta viridis", 52 | "African chameleon, Chamaeleo chamaeleon", 53 | "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis", 54 | "African crocodile, Nile crocodile, Crocodylus niloticus", 55 | "American alligator, Alligator mississipiensis", 56 | "triceratops", 57 | "thunder snake, worm snake, Carphophis amoenus", 58 | "ringneck snake, ring-necked snake, ring snake", 59 | "hognose snake, puff adder, sand viper", 60 | "green snake, grass snake", 61 | "king snake, kingsnake", 62 | "garter snake, grass snake", 63 | "water snake", 64 | "vine snake", 65 | "night snake, Hypsiglena torquata", 66 | "boa constrictor, Constrictor constrictor", 67 | "rock python, rock snake, Python sebae", 68 | "Indian cobra, Naja naja", 69 | "green mamba", 70 | "sea snake", 71 | "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus", 72 | "diamondback, diamondback rattlesnake, Crotalus adamanteus", 73 | "sidewinder, horned rattlesnake, Crotalus cerastes", 74 | "trilobite", 75 | "harvestman, daddy longlegs, Phalangium opilio", 76 | "scorpion", 77 | "black and gold garden spider, Argiope aurantia", 78 | "barn spider, Araneus cavaticus", 79 | "garden spider, Aranea diademata", 80 | "black widow, Latrodectus mactans", 81 | "tarantula", 82 | "wolf spider, hunting spider", 83 | "tick", 84 | "centipede", 85 | "black grouse", 86 | "ptarmigan", 87 | "ruffed grouse, partridge, Bonasa umbellus", 88 | "prairie chicken, prairie grouse, prairie fowl", 89 | "peacock", 90 | "quail", 91 | "partridge", 92 | "African grey, African gray, Psittacus erithacus", 93 | "macaw", 94 | "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", 95 | "lorikeet", 96 | "coucal", 97 | "bee eater", 98 | "hornbill", 99 | "hummingbird", 100 | "jacamar", 101 | "toucan", 102 | "drake", 103 | "red-breasted merganser, Mergus serrator", 104 | "goose", 105 | "black swan, Cygnus atratus", 106 | "tusker", 107 | "echidna, spiny anteater, anteater", 108 | "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus", 109 | "wallaby, brush kangaroo", 110 | "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus", 111 | "wombat", 112 | "jellyfish", 113 | "sea anemone, anemone", 114 | "brain coral", 115 | "flatworm, platyhelminth", 116 | "nematode, nematode worm, roundworm", 117 | "conch", 118 | "snail", 119 | "slug", 120 | "sea slug, nudibranch", 121 | "chiton, coat-of-mail shell, sea cradle, polyplacophore", 122 | "chambered nautilus, pearly nautilus, nautilus", 123 | "Dungeness crab, Cancer magister", 124 | "rock crab, Cancer irroratus", 125 | "fiddler crab", 126 | "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica", 127 | "American lobster, Northern lobster, Maine lobster, Homarus americanus", 128 | "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish", 129 | "crayfish, crawfish, crawdad, crawdaddy", 130 | "hermit crab", 131 | "isopod", 132 | "white stork, Ciconia ciconia", 133 | "black stork, Ciconia nigra", 134 | "spoonbill", 135 | "flamingo", 136 | "little blue heron, Egretta caerulea", 137 | "American egret, great white heron, Egretta albus", 138 | "bittern", 139 | "crane", 140 | "limpkin, Aramus pictus", 141 | "European gallinule, Porphyrio porphyrio", 142 | "American coot, marsh hen, mud hen, water hen, Fulica americana", 143 | "bustard", 144 | "ruddy turnstone, Arenaria interpres", 145 | "red-backed sandpiper, dunlin, Erolia alpina", 146 | "redshank, Tringa totanus", 147 | "dowitcher", 148 | "oystercatcher, oyster catcher", 149 | "pelican", 150 | "king penguin, Aptenodytes patagonica", 151 | "albatross, mollymawk", 152 | "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus", 153 | "killer whale, killer, orca, grampus, sea wolf, Orcinus orca", 154 | "dugong, Dugong dugon", 155 | "sea lion", 156 | "Chihuahua", 157 | "Japanese spaniel", 158 | "Maltese dog, Maltese terrier, Maltese", 159 | "Pekinese, Pekingese, Peke", 160 | "Shih-Tzu", 161 | "Blenheim spaniel", 162 | "papillon", 163 | "toy terrier", 164 | "Rhodesian ridgeback", 165 | "Afghan hound, Afghan", 166 | "basset, basset hound", 167 | "beagle", 168 | "bloodhound, sleuthhound", 169 | "bluetick", 170 | "black-and-tan coonhound", 171 | "Walker hound, Walker foxhound", 172 | "English foxhound", 173 | "redbone", 174 | "borzoi, Russian wolfhound", 175 | "Irish wolfhound", 176 | "Italian greyhound", 177 | "whippet", 178 | "Ibizan hound, Ibizan Podenco", 179 | "Norwegian elkhound, elkhound", 180 | "otterhound, otter hound", 181 | "Saluki, gazelle hound", 182 | "Scottish deerhound, deerhound", 183 | "Weimaraner", 184 | "Staffordshire bullterrier, Staffordshire bull terrier", 185 | "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier", 186 | "Bedlington terrier", 187 | "Border terrier", 188 | "Kerry blue terrier", 189 | "Irish terrier", 190 | "Norfolk terrier", 191 | "Norwich terrier", 192 | "Yorkshire terrier", 193 | "wire-haired fox terrier", 194 | "Lakeland terrier", 195 | "Sealyham terrier, Sealyham", 196 | "Airedale, Airedale terrier", 197 | "cairn, cairn terrier", 198 | "Australian terrier", 199 | "Dandie Dinmont, Dandie Dinmont terrier", 200 | "Boston bull, Boston terrier", 201 | "miniature schnauzer", 202 | "giant schnauzer", 203 | "standard schnauzer", 204 | "Scotch terrier, Scottish terrier, Scottie", 205 | "Tibetan terrier, chrysanthemum dog", 206 | "silky terrier, Sydney silky", 207 | "soft-coated wheaten terrier", 208 | "West Highland white terrier", 209 | "Lhasa, Lhasa apso", 210 | "flat-coated retriever", 211 | "curly-coated retriever", 212 | "golden retriever", 213 | "Labrador retriever", 214 | "Chesapeake Bay retriever", 215 | "German short-haired pointer", 216 | "vizsla, Hungarian pointer", 217 | "English setter", 218 | "Irish setter, red setter", 219 | "Gordon setter", 220 | "Brittany spaniel", 221 | "clumber, clumber spaniel", 222 | "English springer, English springer spaniel", 223 | "Welsh springer spaniel", 224 | "cocker spaniel, English cocker spaniel, cocker", 225 | "Sussex spaniel", 226 | "Irish water spaniel", 227 | "kuvasz", 228 | "schipperke", 229 | "groenendael", 230 | "malinois", 231 | "briard", 232 | "kelpie", 233 | "komondor", 234 | "Old English sheepdog, bobtail", 235 | "Shetland sheepdog, Shetland sheep dog, Shetland", 236 | "collie", 237 | "Border collie", 238 | "Bouvier des Flandres, Bouviers des Flandres", 239 | "Rottweiler", 240 | "German shepherd, German shepherd dog, German police dog, alsatian", 241 | "Doberman, Doberman pinscher", 242 | "miniature pinscher", 243 | "Greater Swiss Mountain dog", 244 | "Bernese mountain dog", 245 | "Appenzeller", 246 | "EntleBucher", 247 | "boxer", 248 | "bull mastiff", 249 | "Tibetan mastiff", 250 | "French bulldog", 251 | "Great Dane", 252 | "Saint Bernard, St Bernard", 253 | "Eskimo dog, husky", 254 | "malamute, malemute, Alaskan malamute", 255 | "Siberian husky", 256 | "dalmatian, coach dog, carriage dog", 257 | "affenpinscher, monkey pinscher, monkey dog", 258 | "basenji", 259 | "pug, pug-dog", 260 | "Leonberg", 261 | "Newfoundland, Newfoundland dog", 262 | "Great Pyrenees", 263 | "Samoyed, Samoyede", 264 | "Pomeranian", 265 | "chow, chow chow", 266 | "keeshond", 267 | "Brabancon griffon", 268 | "Pembroke, Pembroke Welsh corgi", 269 | "Cardigan, Cardigan Welsh corgi", 270 | "toy poodle", 271 | "miniature poodle", 272 | "standard poodle", 273 | "Mexican hairless", 274 | "timber wolf, grey wolf, gray wolf, Canis lupus", 275 | "white wolf, Arctic wolf, Canis lupus tundrarum", 276 | "red wolf, maned wolf, Canis rufus, Canis niger", 277 | "coyote, prairie wolf, brush wolf, Canis latrans", 278 | "dingo, warrigal, warragal, Canis dingo", 279 | "dhole, Cuon alpinus", 280 | "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus", 281 | "hyena, hyaena", 282 | "red fox, Vulpes vulpes", 283 | "kit fox, Vulpes macrotis", 284 | "Arctic fox, white fox, Alopex lagopus", 285 | "grey fox, gray fox, Urocyon cinereoargenteus", 286 | "tabby, tabby cat", 287 | "tiger cat", 288 | "Persian cat", 289 | "Siamese cat, Siamese", 290 | "Egyptian cat", 291 | "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor", 292 | "lynx, catamount", 293 | "leopard, Panthera pardus", 294 | "snow leopard, ounce, Panthera uncia", 295 | "jaguar, panther, Panthera onca, Felis onca", 296 | "lion, king of beasts, Panthera leo", 297 | "tiger, Panthera tigris", 298 | "cheetah, chetah, Acinonyx jubatus", 299 | "brown bear, bruin, Ursus arctos", 300 | "American black bear, black bear, Ursus americanus, Euarctos americanus", 301 | "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus", 302 | "sloth bear, Melursus ursinus, Ursus ursinus", 303 | "mongoose", 304 | "meerkat, mierkat", 305 | "tiger beetle", 306 | "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle", 307 | "ground beetle, carabid beetle", 308 | "long-horned beetle, longicorn, longicorn beetle", 309 | "leaf beetle, chrysomelid", 310 | "dung beetle", 311 | "rhinoceros beetle", 312 | "weevil", 313 | "fly", 314 | "bee", 315 | "ant, emmet, pismire", 316 | "grasshopper, hopper", 317 | "cricket", 318 | "walking stick, walkingstick, stick insect", 319 | "cockroach, roach", 320 | "mantis, mantid", 321 | "cicada, cicala", 322 | "leafhopper", 323 | "lacewing, lacewing fly", 324 | "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", 325 | "damselfly", 326 | "admiral", 327 | "ringlet, ringlet butterfly", 328 | "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus", 329 | "cabbage butterfly", 330 | "sulphur butterfly, sulfur butterfly", 331 | "lycaenid, lycaenid butterfly", 332 | "starfish, sea star", 333 | "sea urchin", 334 | "sea cucumber, holothurian", 335 | "wood rabbit, cottontail, cottontail rabbit", 336 | "hare", 337 | "Angora, Angora rabbit", 338 | "hamster", 339 | "porcupine, hedgehog", 340 | "fox squirrel, eastern fox squirrel, Sciurus niger", 341 | "marmot", 342 | "beaver", 343 | "guinea pig, Cavia cobaya", 344 | "sorrel", 345 | "zebra", 346 | "hog, pig, grunter, squealer, Sus scrofa", 347 | "wild boar, boar, Sus scrofa", 348 | "warthog", 349 | "hippopotamus, hippo, river horse, Hippopotamus amphibius", 350 | "ox", 351 | "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis", 352 | "bison", 353 | "ram, tup", 354 | "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis", 355 | "ibex, Capra ibex", 356 | "hartebeest", 357 | "impala, Aepyceros melampus", 358 | "gazelle", 359 | "Arabian camel, dromedary, Camelus dromedarius", 360 | "llama", 361 | "weasel", 362 | "mink", 363 | "polecat, fitch, foulmart, foumart, Mustela putorius", 364 | "black-footed ferret, ferret, Mustela nigripes", 365 | "otter", 366 | "skunk, polecat, wood pussy", 367 | "badger", 368 | "armadillo", 369 | "three-toed sloth, ai, Bradypus tridactylus", 370 | "orangutan, orang, orangutang, Pongo pygmaeus", 371 | "gorilla, Gorilla gorilla", 372 | "chimpanzee, chimp, Pan troglodytes", 373 | "gibbon, Hylobates lar", 374 | "siamang, Hylobates syndactylus, Symphalangus syndactylus", 375 | "guenon, guenon monkey", 376 | "patas, hussar monkey, Erythrocebus patas", 377 | "baboon", 378 | "macaque", 379 | "langur", 380 | "colobus, colobus monkey", 381 | "proboscis monkey, Nasalis larvatus", 382 | "marmoset", 383 | "capuchin, ringtail, Cebus capucinus", 384 | "howler monkey, howler", 385 | "titi, titi monkey", 386 | "spider monkey, Ateles geoffroyi", 387 | "squirrel monkey, Saimiri sciureus", 388 | "Madagascar cat, ring-tailed lemur, Lemur catta", 389 | "indri, indris, Indri indri, Indri brevicaudatus", 390 | "Indian elephant, Elephas maximus", 391 | "African elephant, Loxodonta africana", 392 | "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens", 393 | "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca", 394 | "barracouta, snoek", 395 | "eel", 396 | "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch", 397 | "rock beauty, Holocanthus tricolor", 398 | "anemone fish", 399 | "sturgeon", 400 | "gar, garfish, garpike, billfish, Lepisosteus osseus", 401 | "lionfish", 402 | "puffer, pufferfish, blowfish, globefish", 403 | "abacus", 404 | "abaya", 405 | "academic gown, academic robe, judge's robe", 406 | "accordion, piano accordion, squeeze box", 407 | "acoustic guitar", 408 | "aircraft carrier, carrier, flattop, attack aircraft carrier", 409 | "airliner", 410 | "airship, dirigible", 411 | "altar", 412 | "ambulance", 413 | "amphibian, amphibious vehicle", 414 | "analog clock", 415 | "apiary, bee house", 416 | "apron", 417 | "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", 418 | "assault rifle, assault gun", 419 | "backpack, back pack, knapsack, packsack, rucksack, haversack", 420 | "bakery, bakeshop, bakehouse", 421 | "balance beam, beam", 422 | "balloon", 423 | "ballpoint, ballpoint pen, ballpen, Biro", 424 | "Band Aid", 425 | "banjo", 426 | "bannister, banister, balustrade, balusters, handrail", 427 | "barbell", 428 | "barber chair", 429 | "barbershop", 430 | "barn", 431 | "barometer", 432 | "barrel, cask", 433 | "barrow, garden cart, lawn cart, wheelbarrow", 434 | "baseball", 435 | "basketball", 436 | "bassinet", 437 | "bassoon", 438 | "bathing cap, swimming cap", 439 | "bath towel", 440 | "bathtub, bathing tub, bath, tub", 441 | "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon", 442 | "beacon, lighthouse, beacon light, pharos", 443 | "beaker", 444 | "bearskin, busby, shako", 445 | "beer bottle", 446 | "beer glass", 447 | "bell cote, bell cot", 448 | "bib", 449 | "bicycle-built-for-two, tandem bicycle, tandem", 450 | "bikini, two-piece", 451 | "binder, ring-binder", 452 | "binoculars, field glasses, opera glasses", 453 | "birdhouse", 454 | "boathouse", 455 | "bobsled, bobsleigh, bob", 456 | "bolo tie, bolo, bola tie, bola", 457 | "bonnet, poke bonnet", 458 | "bookcase", 459 | "bookshop, bookstore, bookstall", 460 | "bottlecap", 461 | "bow", 462 | "bow tie, bow-tie, bowtie", 463 | "brass, memorial tablet, plaque", 464 | "brassiere, bra, bandeau", 465 | "breakwater, groin, groyne, mole, bulwark, seawall, jetty", 466 | "breastplate, aegis, egis", 467 | "broom", 468 | "bucket, pail", 469 | "buckle", 470 | "bulletproof vest", 471 | "bullet train, bullet", 472 | "butcher shop, meat market", 473 | "cab, hack, taxi, taxicab", 474 | "caldron, cauldron", 475 | "candle, taper, wax light", 476 | "cannon", 477 | "canoe", 478 | "can opener, tin opener", 479 | "cardigan", 480 | "car mirror", 481 | "carousel, carrousel, merry-go-round, roundabout, whirligig", 482 | "carpenter's kit, tool kit", 483 | "carton", 484 | "car wheel", 485 | "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM", 486 | "cassette", 487 | "cassette player", 488 | "castle", 489 | "catamaran", 490 | "CD player", 491 | "cello, violoncello", 492 | "cellular telephone, cellular phone, cellphone, cell, mobile phone", 493 | "chain", 494 | "chainlink fence", 495 | "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour", 496 | "chain saw, chainsaw", 497 | "chest", 498 | "chiffonier, commode", 499 | "chime, bell, gong", 500 | "china cabinet, china closet", 501 | "Christmas stocking", 502 | "church, church building", 503 | "cinema, movie theater, movie theatre, movie house, picture palace", 504 | "cleaver, meat cleaver, chopper", 505 | "cliff dwelling", 506 | "cloak", 507 | "clog, geta, patten, sabot", 508 | "cocktail shaker", 509 | "coffee mug", 510 | "coffeepot", 511 | "coil, spiral, volute, whorl, helix", 512 | "combination lock", 513 | "computer keyboard, keypad", 514 | "confectionery, confectionary, candy store", 515 | "container ship, containership, container vessel", 516 | "convertible", 517 | "corkscrew, bottle screw", 518 | "cornet, horn, trumpet, trump", 519 | "cowboy boot", 520 | "cowboy hat, ten-gallon hat", 521 | "cradle", 522 | "crane", 523 | "crash helmet", 524 | "crate", 525 | "crib, cot", 526 | "Crock Pot", 527 | "croquet ball", 528 | "crutch", 529 | "cuirass", 530 | "dam, dike, dyke", 531 | "desk", 532 | "desktop computer", 533 | "dial telephone, dial phone", 534 | "diaper, nappy, napkin", 535 | "digital clock", 536 | "digital watch", 537 | "dining table, board", 538 | "dishrag, dishcloth", 539 | "dishwasher, dish washer, dishwashing machine", 540 | "disk brake, disc brake", 541 | "dock, dockage, docking facility", 542 | "dogsled, dog sled, dog sleigh", 543 | "dome", 544 | "doormat, welcome mat", 545 | "drilling platform, offshore rig", 546 | "drum, membranophone, tympan", 547 | "drumstick", 548 | "dumbbell", 549 | "Dutch oven", 550 | "electric fan, blower", 551 | "electric guitar", 552 | "electric locomotive", 553 | "entertainment center", 554 | "envelope", 555 | "espresso maker", 556 | "face powder", 557 | "feather boa, boa", 558 | "file, file cabinet, filing cabinet", 559 | "fireboat", 560 | "fire engine, fire truck", 561 | "fire screen, fireguard", 562 | "flagpole, flagstaff", 563 | "flute, transverse flute", 564 | "folding chair", 565 | "football helmet", 566 | "forklift", 567 | "fountain", 568 | "fountain pen", 569 | "four-poster", 570 | "freight car", 571 | "French horn, horn", 572 | "frying pan, frypan, skillet", 573 | "fur coat", 574 | "garbage truck, dustcart", 575 | "gasmask, respirator, gas helmet", 576 | "gas pump, gasoline pump, petrol pump, island dispenser", 577 | "goblet", 578 | "go-kart", 579 | "golf ball", 580 | "golfcart, golf cart", 581 | "gondola", 582 | "gong, tam-tam", 583 | "gown", 584 | "grand piano, grand", 585 | "greenhouse, nursery, glasshouse", 586 | "grille, radiator grille", 587 | "grocery store, grocery, food market, market", 588 | "guillotine", 589 | "hair slide", 590 | "hair spray", 591 | "half track", 592 | "hammer", 593 | "hamper", 594 | "hand blower, blow dryer, blow drier, hair dryer, hair drier", 595 | "hand-held computer, hand-held microcomputer", 596 | "handkerchief, hankie, hanky, hankey", 597 | "hard disc, hard disk, fixed disk", 598 | "harmonica, mouth organ, harp, mouth harp", 599 | "harp", 600 | "harvester, reaper", 601 | "hatchet", 602 | "holster", 603 | "home theater, home theatre", 604 | "honeycomb", 605 | "hook, claw", 606 | "hoopskirt, crinoline", 607 | "horizontal bar, high bar", 608 | "horse cart, horse-cart", 609 | "hourglass", 610 | "iPod", 611 | "iron, smoothing iron", 612 | "jack-o'-lantern", 613 | "jean, blue jean, denim", 614 | "jeep, landrover", 615 | "jersey, T-shirt, tee shirt", 616 | "jigsaw puzzle", 617 | "jinrikisha, ricksha, rickshaw", 618 | "joystick", 619 | "kimono", 620 | "knee pad", 621 | "knot", 622 | "lab coat, laboratory coat", 623 | "ladle", 624 | "lampshade, lamp shade", 625 | "laptop, laptop computer", 626 | "lawn mower, mower", 627 | "lens cap, lens cover", 628 | "letter opener, paper knife, paperknife", 629 | "library", 630 | "lifeboat", 631 | "lighter, light, igniter, ignitor", 632 | "limousine, limo", 633 | "liner, ocean liner", 634 | "lipstick, lip rouge", 635 | "Loafer", 636 | "lotion", 637 | "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", 638 | "loupe, jeweler's loupe", 639 | "lumbermill, sawmill", 640 | "magnetic compass", 641 | "mailbag, postbag", 642 | "mailbox, letter box", 643 | "maillot", 644 | "maillot, tank suit", 645 | "manhole cover", 646 | "maraca", 647 | "marimba, xylophone", 648 | "mask", 649 | "matchstick", 650 | "maypole", 651 | "maze, labyrinth", 652 | "measuring cup", 653 | "medicine chest, medicine cabinet", 654 | "megalith, megalithic structure", 655 | "microphone, mike", 656 | "microwave, microwave oven", 657 | "military uniform", 658 | "milk can", 659 | "minibus", 660 | "miniskirt, mini", 661 | "minivan", 662 | "missile", 663 | "mitten", 664 | "mixing bowl", 665 | "mobile home, manufactured home", 666 | "Model T", 667 | "modem", 668 | "monastery", 669 | "monitor", 670 | "moped", 671 | "mortar", 672 | "mortarboard", 673 | "mosque", 674 | "mosquito net", 675 | "motor scooter, scooter", 676 | "mountain bike, all-terrain bike, off-roader", 677 | "mountain tent", 678 | "mouse, computer mouse", 679 | "mousetrap", 680 | "moving van", 681 | "muzzle", 682 | "nail", 683 | "neck brace", 684 | "necklace", 685 | "nipple", 686 | "notebook, notebook computer", 687 | "obelisk", 688 | "oboe, hautboy, hautbois", 689 | "ocarina, sweet potato", 690 | "odometer, hodometer, mileometer, milometer", 691 | "oil filter", 692 | "organ, pipe organ", 693 | "oscilloscope, scope, cathode-ray oscilloscope, CRO", 694 | "overskirt", 695 | "oxcart", 696 | "oxygen mask", 697 | "packet", 698 | "paddle, boat paddle", 699 | "paddlewheel, paddle wheel", 700 | "padlock", 701 | "paintbrush", 702 | "pajama, pyjama, pj's, jammies", 703 | "palace", 704 | "panpipe, pandean pipe, syrinx", 705 | "paper towel", 706 | "parachute, chute", 707 | "parallel bars, bars", 708 | "park bench", 709 | "parking meter", 710 | "passenger car, coach, carriage", 711 | "patio, terrace", 712 | "pay-phone, pay-station", 713 | "pedestal, plinth, footstall", 714 | "pencil box, pencil case", 715 | "pencil sharpener", 716 | "perfume, essence", 717 | "Petri dish", 718 | "photocopier", 719 | "pick, plectrum, plectron", 720 | "pickelhaube", 721 | "picket fence, paling", 722 | "pickup, pickup truck", 723 | "pier", 724 | "piggy bank, penny bank", 725 | "pill bottle", 726 | "pillow", 727 | "ping-pong ball", 728 | "pinwheel", 729 | "pirate, pirate ship", 730 | "pitcher, ewer", 731 | "plane, carpenter's plane, woodworking plane", 732 | "planetarium", 733 | "plastic bag", 734 | "plate rack", 735 | "plow, plough", 736 | "plunger, plumber's helper", 737 | "Polaroid camera, Polaroid Land camera", 738 | "pole", 739 | "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria", 740 | "poncho", 741 | "pool table, billiard table, snooker table", 742 | "pop bottle, soda bottle", 743 | "pot, flowerpot", 744 | "potter's wheel", 745 | "power drill", 746 | "prayer rug, prayer mat", 747 | "printer", 748 | "prison, prison house", 749 | "projectile, missile", 750 | "projector", 751 | "puck, hockey puck", 752 | "punching bag, punch bag, punching ball, punchball", 753 | "purse", 754 | "quill, quill pen", 755 | "quilt, comforter, comfort, puff", 756 | "racer, race car, racing car", 757 | "racket, racquet", 758 | "radiator", 759 | "radio, wireless", 760 | "radio telescope, radio reflector", 761 | "rain barrel", 762 | "recreational vehicle, RV, R.V.", 763 | "reel", 764 | "reflex camera", 765 | "refrigerator, icebox", 766 | "remote control, remote", 767 | "restaurant, eating house, eating place, eatery", 768 | "revolver, six-gun, six-shooter", 769 | "rifle", 770 | "rocking chair, rocker", 771 | "rotisserie", 772 | "rubber eraser, rubber, pencil eraser", 773 | "rugby ball", 774 | "rule, ruler", 775 | "running shoe", 776 | "safe", 777 | "safety pin", 778 | "saltshaker, salt shaker", 779 | "sandal", 780 | "sarong", 781 | "sax, saxophone", 782 | "scabbard", 783 | "scale, weighing machine", 784 | "school bus", 785 | "schooner", 786 | "scoreboard", 787 | "screen, CRT screen", 788 | "screw", 789 | "screwdriver", 790 | "seat belt, seatbelt", 791 | "sewing machine", 792 | "shield, buckler", 793 | "shoe shop, shoe-shop, shoe store", 794 | "shoji", 795 | "shopping basket", 796 | "shopping cart", 797 | "shovel", 798 | "shower cap", 799 | "shower curtain", 800 | "ski", 801 | "ski mask", 802 | "sleeping bag", 803 | "slide rule, slipstick", 804 | "sliding door", 805 | "slot, one-armed bandit", 806 | "snorkel", 807 | "snowmobile", 808 | "snowplow, snowplough", 809 | "soap dispenser", 810 | "soccer ball", 811 | "sock", 812 | "solar dish, solar collector, solar furnace", 813 | "sombrero", 814 | "soup bowl", 815 | "space bar", 816 | "space heater", 817 | "space shuttle", 818 | "spatula", 819 | "speedboat", 820 | "spider web, spider's web", 821 | "spindle", 822 | "sports car, sport car", 823 | "spotlight, spot", 824 | "stage", 825 | "steam locomotive", 826 | "steel arch bridge", 827 | "steel drum", 828 | "stethoscope", 829 | "stole", 830 | "stone wall", 831 | "stopwatch, stop watch", 832 | "stove", 833 | "strainer", 834 | "streetcar, tram, tramcar, trolley, trolley car", 835 | "stretcher", 836 | "studio couch, day bed", 837 | "stupa, tope", 838 | "submarine, pigboat, sub, U-boat", 839 | "suit, suit of clothes", 840 | "sundial", 841 | "sunglass", 842 | "sunglasses, dark glasses, shades", 843 | "sunscreen, sunblock, sun blocker", 844 | "suspension bridge", 845 | "swab, swob, mop", 846 | "sweatshirt", 847 | "swimming trunks, bathing trunks", 848 | "swing", 849 | "switch, electric switch, electrical switch", 850 | "syringe", 851 | "table lamp", 852 | "tank, army tank, armored combat vehicle, armoured combat vehicle", 853 | "tape player", 854 | "teapot", 855 | "teddy, teddy bear", 856 | "television, television system", 857 | "tennis ball", 858 | "thatch, thatched roof", 859 | "theater curtain, theatre curtain", 860 | "thimble", 861 | "thresher, thrasher, threshing machine", 862 | "throne", 863 | "tile roof", 864 | "toaster", 865 | "tobacco shop, tobacconist shop, tobacconist", 866 | "toilet seat", 867 | "torch", 868 | "totem pole", 869 | "tow truck, tow car, wrecker", 870 | "toyshop", 871 | "tractor", 872 | "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi", 873 | "tray", 874 | "trench coat", 875 | "tricycle, trike, velocipede", 876 | "trimaran", 877 | "tripod", 878 | "triumphal arch", 879 | "trolleybus, trolley coach, trackless trolley", 880 | "trombone", 881 | "tub, vat", 882 | "turnstile", 883 | "typewriter keyboard", 884 | "umbrella", 885 | "unicycle, monocycle", 886 | "upright, upright piano", 887 | "vacuum, vacuum cleaner", 888 | "vase", 889 | "vault", 890 | "velvet", 891 | "vending machine", 892 | "vestment", 893 | "viaduct", 894 | "violin, fiddle", 895 | "volleyball", 896 | "waffle iron", 897 | "wall clock", 898 | "wallet, billfold, notecase, pocketbook", 899 | "wardrobe, closet, press", 900 | "warplane, military plane", 901 | "washbasin, handbasin, washbowl, lavabo, wash-hand basin", 902 | "washer, automatic washer, washing machine", 903 | "water bottle", 904 | "water jug", 905 | "water tower", 906 | "whiskey jug", 907 | "whistle", 908 | "wig", 909 | "window screen", 910 | "window shade", 911 | "Windsor tie", 912 | "wine bottle", 913 | "wing", 914 | "wok", 915 | "wooden spoon", 916 | "wool, woolen, woollen", 917 | "worm fence, snake fence, snake-rail fence, Virginia fence", 918 | "wreck", 919 | "yawl", 920 | "yurt", 921 | "web site, website, internet site, site", 922 | "comic book", 923 | "crossword puzzle, crossword", 924 | "street sign", 925 | "traffic light, traffic signal, stoplight", 926 | "book jacket, dust cover, dust jacket, dust wrapper", 927 | "menu", 928 | "plate", 929 | "guacamole", 930 | "consomme", 931 | "hot pot, hotpot", 932 | "trifle", 933 | "ice cream, icecream", 934 | "ice lolly, lolly, lollipop, popsicle", 935 | "French loaf", 936 | "bagel, beigel", 937 | "pretzel", 938 | "cheeseburger", 939 | "hotdog, hot dog, red hot", 940 | "mashed potato", 941 | "head cabbage", 942 | "broccoli", 943 | "cauliflower", 944 | "zucchini, courgette", 945 | "spaghetti squash", 946 | "acorn squash", 947 | "butternut squash", 948 | "cucumber, cuke", 949 | "artichoke, globe artichoke", 950 | "bell pepper", 951 | "cardoon", 952 | "mushroom", 953 | "Granny Smith", 954 | "strawberry", 955 | "orange", 956 | "lemon", 957 | "fig", 958 | "pineapple, ananas", 959 | "banana", 960 | "jackfruit, jak, jack", 961 | "custard apple", 962 | "pomegranate", 963 | "hay", 964 | "carbonara", 965 | "chocolate sauce, chocolate syrup", 966 | "dough", 967 | "meat loaf, meatloaf", 968 | "pizza, pizza pie", 969 | "potpie", 970 | "burrito", 971 | "red wine", 972 | "espresso", 973 | "cup", 974 | "eggnog", 975 | "alp", 976 | "bubble", 977 | "cliff, drop, drop-off", 978 | "coral reef", 979 | "geyser", 980 | "lakeside, lakeshore", 981 | "promontory, headland, head, foreland", 982 | "sandbar, sand bar", 983 | "seashore, coast, seacoast, sea-coast", 984 | "valley, vale", 985 | "volcano", 986 | "ballplayer, baseball player", 987 | "groom, bridegroom", 988 | "scuba diver", 989 | "rapeseed", 990 | "daisy", 991 | "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", 992 | "corn", 993 | "acorn", 994 | "hip, rose hip, rosehip", 995 | "buckeye, horse chestnut, conker", 996 | "coral fungus", 997 | "agaric", 998 | "gyromitra", 999 | "stinkhorn, carrion fungus", 1000 | "earthstar", 1001 | "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa", 1002 | "bolete", 1003 | "ear, spike, capitulum", 1004 | "toilet tissue, toilet paper, bathroom tissue" 1005 | }; -------------------------------------------------------------------------------- /src/pytorch.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "ATen/ATen.h" 3 | #include "ATen/Type.h" 4 | #include 5 | #include 6 | #include "H5Cpp.h" 7 | 8 | #include 9 | 10 | #define TENSOR_DEFAULT_TYPE CPU(kFloat) 11 | 12 | using namespace at; 13 | 14 | 15 | using std::map; 16 | using std::string; 17 | using std::vector; 18 | using std::pair; 19 | using std::shared_ptr; 20 | using std::make_shared; 21 | using std::cout; 22 | using std::endl; 23 | using std::tie; 24 | 25 | using namespace cv; 26 | 27 | 28 | namespace torch 29 | { 30 | 31 | 32 | map load(string hdf5_filename); 33 | void save( string hdf5_filename, map dict_to_write ); 34 | 35 | class Module 36 | { 37 | 38 | public: 39 | 40 | // Sequential module needs the counter 41 | // as names of submodules are not provided 42 | // sometimes. 43 | int submodule_counter; 44 | 45 | Module() : submodule_counter(0) {}; 46 | 47 | ~Module() {}; 48 | 49 | // We will use pointer to other modules a lot 50 | // This is done to automatically handle deallocation of created 51 | // module objects 52 | typedef shared_ptr Ptr; 53 | 54 | virtual Tensor forward(Tensor input) { return input; }; 55 | 56 | string module_name = "Module"; 57 | 58 | // This function gets overwritten 59 | // for the leafnodes like Conv2d, AvgPool2d and so on 60 | virtual string tostring(int indentation_level=0) 61 | { 62 | 63 | std::stringstream s; 64 | 65 | string indentation = string(indentation_level, ' '); 66 | 67 | s << indentation << module_name << " (" << std::endl; 68 | 69 | for(auto name_module_pair: modules) 70 | { 71 | 72 | s << indentation << " (" << name_module_pair.first << ") " 73 | << name_module_pair.second->tostring(indentation_level + 1) << std::endl; 74 | } 75 | 76 | s << indentation << ")" << std::endl; 77 | 78 | return s.str(); 79 | 80 | } 81 | 82 | // vector> because we want to emulate 83 | // the ordered dict this way, meaning that elements 84 | // are stored in the same order they were added 85 | 86 | // Like in Pytorch each module stores the modules that it uses 87 | vector> modules; 88 | 89 | // And parameters that are explicitly used by the current module 90 | map parameters; 91 | 92 | // Plus buffers which are meant to store running mean and var for batchnorm layers 93 | map buffers; 94 | 95 | // We store parameter related to gradient computation here and other 96 | // tensors so far 97 | // TODO: some members of grads are not related to gradient computation 98 | // and were put there temporary -- put them in a more relevant container. 99 | map grads; 100 | 101 | // A function to add another modules inside current module 102 | // Acts as Pytorch's Module.add_module() function 103 | void add_module(string module_name, Module::Ptr module) 104 | { 105 | 106 | 107 | modules.push_back(pair(module_name, module)); 108 | } 109 | 110 | 111 | 112 | // Sometimes, when modules are being added, not all of them 113 | // have weights, like RELU. In this case the weights can be 114 | // numerated out of order. For example: 115 | // net = nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2)) 116 | // net.state_dict().keys() 117 | // output: ['0.weight', '0.bias', '2.weight', '2.bias'] 118 | 119 | // Equivalent behaviour will be seen with the add() function 120 | // described below: if relu is added, the counter for weights will 121 | // be increased. 122 | 123 | void add(Module::Ptr module) 124 | { 125 | 126 | string module_name = std::to_string(submodule_counter); 127 | 128 | add_module(module_name, module); 129 | 130 | submodule_counter++; 131 | } 132 | 133 | 134 | map state_dict( map & destination, 135 | string prefix="") 136 | { 137 | 138 | // TODO: add another function that will not accept any parameters 139 | // and just return the state_dict() 140 | 141 | for(auto name_parameter_pair: parameters) 142 | { 143 | 144 | // Check if the parameter defined -- for example if we don't use bias 145 | // in the convolution, the bias weight will be undefined. 146 | // We need this in order to match the state_dict() function of Pytorch 147 | if(name_parameter_pair.second.defined()) 148 | { 149 | 150 | 151 | destination[prefix + name_parameter_pair.first] = name_parameter_pair.second; 152 | } 153 | } 154 | 155 | for(auto name_buffer_pair: buffers) 156 | { 157 | 158 | 159 | destination[prefix + name_buffer_pair.first] = name_buffer_pair.second; 160 | } 161 | 162 | for(auto name_module_pair: modules) 163 | { 164 | 165 | name_module_pair.second->state_dict(destination, prefix + name_module_pair.first + '.'); 166 | } 167 | 168 | return destination; 169 | 170 | } 171 | 172 | 173 | template 174 | void apply(Func closure) 175 | { 176 | 177 | 178 | for(auto name_parameter_pair: parameters) 179 | { 180 | 181 | if(name_parameter_pair.second.defined()) 182 | { 183 | // maybe catch if it is undefined here 184 | parameters[name_parameter_pair.first] = closure(name_parameter_pair.second); 185 | } 186 | } 187 | 188 | for(auto name_buffer_pair: buffers) 189 | { 190 | 191 | buffers[name_buffer_pair.first] = closure(name_buffer_pair.second); 192 | } 193 | 194 | for(auto name_grad_pair: grads) 195 | { 196 | 197 | grads[name_grad_pair.first] = closure(name_grad_pair.second); 198 | } 199 | 200 | for(auto name_module_pair: modules) 201 | { 202 | 203 | name_module_pair.second->apply(closure); 204 | } 205 | } 206 | 207 | void cuda() 208 | { 209 | 210 | // Transfer each tensor to GPU 211 | this->apply([](Tensor & tensor) { 212 | 213 | return tensor.toBackend(Backend::CUDA); 214 | 215 | }); 216 | 217 | } 218 | 219 | void cpu() 220 | { 221 | 222 | // Transfer each tensor to CPU 223 | this->apply([](Tensor & tensor) { 224 | 225 | return tensor.toBackend(Backend::CPU); 226 | 227 | }); 228 | 229 | } 230 | 231 | void save_weights(string hdf5_filename) 232 | { 233 | 234 | map model_state_dict; 235 | 236 | this->state_dict(model_state_dict); 237 | 238 | save(hdf5_filename, model_state_dict); 239 | } 240 | 241 | 242 | void load_weights(string hdf5_filename) 243 | { 244 | 245 | 246 | // TODO: 247 | // (1) Add check to make sure that the network is on cpu 248 | // before loading weights 249 | // (2) Add support for not float. So far only works with 250 | // float weights only. 251 | 252 | map model_state_dict; 253 | map checkpoint_dict; 254 | 255 | this->state_dict(model_state_dict); 256 | checkpoint_dict = load(hdf5_filename); 257 | 258 | // Compare model_state_dict -> checkpoint_dict keys consistency 259 | 260 | for(auto name_tensor_pair : model_state_dict) 261 | { 262 | 263 | if(checkpoint_dict.count(name_tensor_pair.first) != 1) 264 | { 265 | 266 | cout << "WARNING: model requires parameter ('" << name_tensor_pair.first << "') " 267 | << "which is not present in the checkpoint file. Using model's default." << endl; 268 | } 269 | } 270 | 271 | // Compare checkpoint_dict -> model_state_dict keys consistency 272 | 273 | for(auto name_tensor_pair : checkpoint_dict) 274 | { 275 | 276 | if(model_state_dict.count(name_tensor_pair.first) != 1) 277 | { 278 | 279 | cout << "WARNING: checkpoint file contains parameter ('" << name_tensor_pair.first << "') " 280 | << "which is not required by the model. The parameter is not used." << endl; 281 | } 282 | } 283 | 284 | for(auto name_tensor_pair : model_state_dict) 285 | { 286 | 287 | if(checkpoint_dict.count(name_tensor_pair.first) == 1) 288 | { 289 | 290 | // Copy in-place 291 | name_tensor_pair.second.copy_(checkpoint_dict[name_tensor_pair.first]); 292 | } 293 | } 294 | 295 | } 296 | 297 | }; 298 | 299 | 300 | class Sequential : public Module 301 | { 302 | public: 303 | 304 | Sequential() 305 | { 306 | 307 | module_name = "Sequential"; 308 | }; 309 | 310 | ~Sequential() { }; 311 | 312 | // Forward for sequential block makes forward pass 313 | // for each submodule and passed it to the next one 314 | Tensor forward(Tensor input) 315 | { 316 | Tensor out = input; 317 | 318 | for(auto name_module_pair: modules) 319 | { 320 | out = name_module_pair.second->forward(out); 321 | } 322 | 323 | return out; 324 | } 325 | 326 | 327 | Module::Ptr get(int i) const { return modules[i].second; } 328 | 329 | }; 330 | 331 | 332 | class ReLU : public Module 333 | { 334 | public: 335 | 336 | ReLU() {}; 337 | ~ReLU() {}; 338 | 339 | Tensor forward(Tensor input) 340 | { 341 | Threshold_updateOutput(input, input, 0, 0, true) ; 342 | return input; 343 | }; 344 | 345 | 346 | string tostring(int indentation_level=0) 347 | { 348 | 349 | string indentation = string(indentation_level, ' '); 350 | 351 | return indentation + std::string("ReLU"); 352 | } 353 | }; 354 | 355 | 356 | class Conv2d : public Module 357 | { 358 | 359 | public: 360 | 361 | int in_channels; 362 | int out_channels; 363 | int kernel_width; 364 | int kernel_height; 365 | int stride_width; 366 | int stride_height; 367 | int dilation_width; 368 | int dilation_height; 369 | int padding_width; 370 | int padding_height; 371 | int groups; 372 | int bias; 373 | bool dilated; 374 | 375 | Conv2d( int in_channels, 376 | int out_channels, 377 | int kernel_width, 378 | int kernel_height, 379 | int stride_width=1, 380 | int stride_height=1, 381 | int padding_width=0, 382 | int padding_height=0, 383 | int dilation_width=1, 384 | int dilation_height=1, 385 | int groups=1, 386 | int bias=true) : 387 | 388 | in_channels(in_channels), 389 | out_channels(out_channels), 390 | kernel_width(kernel_width), 391 | kernel_height(kernel_height), 392 | stride_width(stride_width), 393 | stride_height(stride_height), 394 | padding_width(padding_width), 395 | padding_height(padding_height), 396 | dilation_width(dilation_width), 397 | dilation_height(dilation_height), 398 | groups(groups), 399 | bias(bias) 400 | { 401 | 402 | // Register "wight" as a parameter in order to be able to 403 | // restore it from a file later on 404 | parameters["weight"] = TENSOR_DEFAULT_TYPE.zeros({out_channels, 405 | in_channels, 406 | kernel_width, 407 | kernel_height}); 408 | 409 | 410 | // Check if we need bias for our convolution 411 | if(bias) 412 | { 413 | parameters["bias"] = TENSOR_DEFAULT_TYPE.zeros({out_channels}); 414 | } 415 | else 416 | { 417 | 418 | // Doesn't work with TENSOR_DEFAULT_TYPE.tensor();, 419 | // This is why we use Tensor() 420 | parameters["bias"] = Tensor(); 421 | } 422 | 423 | // These variables are not needed for forward inferece, 424 | // but we need them in order to call an underlying C 425 | // function. Later they will be used for backward pass 426 | 427 | grads["finput"] = TENSOR_DEFAULT_TYPE.tensor(); 428 | grads["fgradInput"] = TENSOR_DEFAULT_TYPE.tensor(); 429 | 430 | // These variables depend on # of groups, so far only 431 | // one group is supported. Needs to be changed to tensor_list 432 | // in order to support multiple groups. 433 | grads["ones"] = TENSOR_DEFAULT_TYPE.tensor(); 434 | grads["columns"] = TENSOR_DEFAULT_TYPE.tensor(); 435 | 436 | // There are separate functions for dilated and non-dilated convolutions 437 | dilated = false; 438 | 439 | if( (dilation_width > 1) || (dilation_height > 1) ) 440 | { 441 | dilated = true; 442 | } 443 | 444 | }; 445 | 446 | ~Conv2d() {}; 447 | 448 | 449 | string tostring(int indentation_level=0) 450 | { 451 | 452 | std::stringstream string_stream; 453 | 454 | string indentation = string(indentation_level, ' '); 455 | 456 | string_stream << indentation << "Conv2d( " 457 | << "in_channels=" << std::to_string(in_channels) << " " 458 | << "out_channels=" << std::to_string(out_channels) << " " 459 | << "kernel_size=(" << std::to_string(kernel_width) << ", " << std::to_string(kernel_height) << ") " 460 | << "stride=(" << std::to_string(stride_width) << ", " << std::to_string(stride_height) << ") " 461 | << "padding=(" << std::to_string(padding_width) << ", " << std::to_string(padding_height) << ") " 462 | << "dilation=(" << std::to_string(dilation_width) << ", " << std::to_string(dilation_height) << ") " 463 | << "groups=" << std::to_string(groups) << " " 464 | << "bias=" << std::to_string(bias) << " )"; 465 | 466 | return string_stream.str(); 467 | 468 | }; 469 | 470 | Tensor forward(Tensor input) 471 | { 472 | 473 | Tensor output = input.type().tensor(); 474 | 475 | if (dilated) 476 | { 477 | 478 | SpatialDilatedConvolution_updateOutput(input, 479 | output, 480 | parameters["weight"], 481 | parameters["bias"], 482 | grads["columns"], 483 | grads["ones"], 484 | kernel_width, 485 | kernel_height, 486 | stride_width, 487 | stride_height, 488 | padding_width, 489 | padding_height, 490 | dilation_width, 491 | dilation_height); 492 | } 493 | else 494 | { 495 | 496 | SpatialConvolutionMM_updateOutput(input, 497 | output, 498 | parameters["weight"], 499 | parameters["bias"], 500 | grads["finput"], 501 | grads["fgradInput"], 502 | kernel_width, 503 | kernel_height, 504 | stride_width, 505 | stride_height, 506 | padding_width, 507 | padding_height); 508 | } 509 | 510 | 511 | return output; 512 | }; 513 | }; 514 | 515 | class BatchNorm2d : public Module 516 | { 517 | public: 518 | 519 | int num_features; 520 | bool affine; 521 | bool training; 522 | double momentum; 523 | double eps; 524 | 525 | 526 | 527 | BatchNorm2d( int num_features, 528 | double eps=1e-5, 529 | double momentum=0.1, 530 | bool affine=true, 531 | bool training=false) : 532 | 533 | num_features(num_features), 534 | eps(eps), 535 | momentum(momentum), 536 | affine(affine), 537 | training(training) 538 | 539 | { 540 | 541 | // Initialize weights here 542 | 543 | // Ones initialization is temporarry -- just to avoid 544 | // division by zero during testing 545 | parameters["weight"] = TENSOR_DEFAULT_TYPE.ones(num_features); 546 | parameters["bias"] = TENSOR_DEFAULT_TYPE.zeros(num_features); 547 | 548 | buffers["running_mean"] = TENSOR_DEFAULT_TYPE.zeros(num_features); 549 | buffers["running_var"] = TENSOR_DEFAULT_TYPE.ones(num_features); 550 | 551 | // We don't recompute the mean and var during inference 552 | // So, some variables are initialized for possible future use case. 553 | grads["save_mean"] = TENSOR_DEFAULT_TYPE.ones(num_features); 554 | grads["save_std"] = TENSOR_DEFAULT_TYPE.ones(num_features); 555 | 556 | }; 557 | 558 | ~BatchNorm2d() {}; 559 | 560 | string tostring(int indentation_level=0) 561 | { 562 | 563 | std::stringstream string_stream; 564 | 565 | string indentation = string(indentation_level, ' '); 566 | 567 | string_stream << indentation 568 | << "BatchNorm2d( " 569 | << "num_features=" << std::to_string(num_features) << " " 570 | << "eps=" << std::to_string(eps) << " " 571 | << "momentum=" << std::to_string(momentum) << " )"; 572 | 573 | return string_stream.str(); 574 | 575 | }; 576 | 577 | 578 | Tensor forward(Tensor input) 579 | { 580 | 581 | Tensor output = input.type().tensor(); 582 | 583 | BatchNormalization_updateOutput(input, 584 | output, 585 | parameters["weight"], 586 | parameters["bias"], 587 | buffers["running_mean"], 588 | buffers["running_var"], 589 | grads["save_mean"], 590 | grads["save_std"], 591 | training, 592 | momentum, 593 | eps); 594 | return output; 595 | }; 596 | 597 | }; 598 | 599 | 600 | // TODO: move this thing out in a separate logical unit: models/resnet 601 | 602 | // Helper functions for a 3 by 3 convolution without bias 603 | // Which is used in every resnet architecture. 604 | Tensor compute_full_padding_for_dilated_conv(Tensor kernel_size, int dilation=1) 605 | { 606 | 607 | // Convert IntList to Tensor to be able to use element-wise operations 608 | Tensor kernel_size_tensor = kernel_size.toType(CPU(kFloat)); 609 | 610 | // Compute the actual kernel size after dilation 611 | auto actual_kernel_size = (kernel_size_tensor - 1) * (dilation - 1) + kernel_size_tensor; 612 | 613 | // Compute the padding size in order to achieve the 'full padding' mode 614 | auto full_padding = (actual_kernel_size / 2).floor_() 615 | .toType(CPU(kInt)); 616 | 617 | return full_padding; 618 | }; 619 | 620 | Module::Ptr conv3x3(int in_planes, int out_planes, int stride=1, int dilation=1) 621 | { 622 | 623 | // {3, 3} tuple in tensor form. 624 | // We need this because next function accepts Tensor 625 | Tensor kernel_size = CPU(kInt).tensor({2}) 626 | .fill_(3); 627 | 628 | Tensor padding = compute_full_padding_for_dilated_conv(kernel_size, dilation); 629 | 630 | auto padding_accessor = padding.accessor(); 631 | 632 | return std::make_shared(in_planes, 633 | out_planes, 634 | 3, 3, 635 | stride, stride, 636 | padding_accessor[0], padding_accessor[1], 637 | dilation, dilation, 638 | 1, false); 639 | }; 640 | 641 | 642 | Module::Ptr resnet_base_conv7x7() 643 | { 644 | 645 | return make_shared(3, /* in_planes */ 646 | 64, /* out_planes */ 647 | 7, /* kernel_w */ 648 | 7, /* kernel_h */ 649 | 2, /* stride_w */ 650 | 2, /* stride_h */ 651 | 3, /* padding_w */ 652 | 3, /* padding_h */ 653 | 1, /* dilation_w */ 654 | 1, /* dilation_h */ 655 | 1, /* groups */ 656 | false); /* bias */ 657 | } 658 | 659 | Module::Ptr renset_conv1x1(int in_planes, int planes) 660 | { 661 | 662 | return std::make_shared(in_planes, planes, 663 | 1, 1, 664 | 1, 1, 665 | 0, 0, 666 | 1, 1, 667 | 1, false); 668 | } 669 | 670 | 671 | class MaxPool2d : public Module 672 | { 673 | public: 674 | 675 | Tensor indices; 676 | 677 | bool ceil_mode; 678 | int kernel_width; 679 | int kernel_height; 680 | int stride_width; 681 | int stride_height; 682 | int padding_width; 683 | int padding_height; 684 | 685 | 686 | MaxPool2d(int kernel_width, 687 | int kernel_height, 688 | int stride_width=1, 689 | int stride_height=1, 690 | int padding_width=0, 691 | int padding_height=0, 692 | bool ceil_mode=false) : 693 | 694 | kernel_width(kernel_width), 695 | kernel_height(kernel_height), 696 | stride_width(stride_width), 697 | stride_height(stride_height), 698 | padding_width(padding_width), 699 | padding_height(padding_height), 700 | ceil_mode(ceil_mode) 701 | { 702 | 703 | // TODO: so far this one is hardcoded. 704 | // Change to make it gpu or cpu depending 705 | // on the network placement 706 | grads["indices"] = CPU(kLong).tensor(); 707 | }; 708 | 709 | 710 | ~MaxPool2d() {}; 711 | 712 | Tensor forward(Tensor input) 713 | { 714 | 715 | Tensor output = input.type().tensor(); 716 | 717 | SpatialMaxPooling_updateOutput(input, 718 | output, 719 | grads["indices"], 720 | kernel_width, 721 | kernel_width, 722 | stride_width, 723 | stride_height, 724 | padding_width, 725 | padding_height, 726 | ceil_mode); 727 | 728 | return output; 729 | }; 730 | 731 | string tostring(int indentation_level=0) 732 | { 733 | 734 | std::stringstream string_stream; 735 | 736 | string indentation = string(indentation_level, ' '); 737 | 738 | string_stream << indentation 739 | << "MaxPool2d( " 740 | << "kernel_size=(" << std::to_string(kernel_width) << ", " << std::to_string(kernel_height) << ") " 741 | << "stride=(" << std::to_string(stride_width) << ", " << std::to_string(stride_height) << ") " 742 | << "padding=(" << std::to_string(padding_width) << ", " << std::to_string(padding_height) << ") )"; 743 | 744 | return string_stream.str(); 745 | 746 | }; 747 | }; 748 | 749 | 750 | class AvgPool2d : public Module 751 | { 752 | public: 753 | 754 | bool ceil_mode; 755 | bool count_include_pad; 756 | int kernel_width; 757 | int kernel_height; 758 | int stride_width; 759 | int stride_height; 760 | int padding_width; 761 | int padding_height; 762 | 763 | 764 | AvgPool2d(int kernel_width, 765 | int kernel_height, 766 | int stride_width=1, 767 | int stride_height=1, 768 | int padding_width=0, 769 | int padding_height=0, 770 | bool ceil_mode=false, 771 | bool count_include_pad=true) : 772 | 773 | kernel_width(kernel_width), 774 | kernel_height(kernel_height), 775 | stride_width(stride_width), 776 | stride_height(stride_height), 777 | padding_width(padding_width), 778 | padding_height(padding_height), 779 | ceil_mode(ceil_mode), 780 | count_include_pad(count_include_pad) 781 | { }; 782 | 783 | 784 | ~AvgPool2d() {}; 785 | 786 | Tensor forward(Tensor input) 787 | { 788 | 789 | Tensor output = input.type().tensor(); 790 | 791 | SpatialAveragePooling_updateOutput(input, 792 | output, 793 | kernel_width, 794 | kernel_height, 795 | stride_width, 796 | stride_height, 797 | padding_width, 798 | padding_height, 799 | ceil_mode, 800 | count_include_pad); 801 | 802 | return output; 803 | }; 804 | 805 | string tostring(int indentation_level=0) 806 | { 807 | 808 | std::stringstream string_stream; 809 | 810 | string indentation = string(indentation_level, ' '); 811 | 812 | string_stream << indentation 813 | << "AvgPool2d( " 814 | << "kernel_size=(" << std::to_string(kernel_width) << ", " << std::to_string(kernel_height) << ") " 815 | << "stride=(" << std::to_string(stride_width) << ", " << std::to_string(stride_height) << ") " 816 | << "padding=(" << std::to_string(padding_width) << ", " << std::to_string(padding_height) << ") )"; 817 | 818 | return string_stream.str(); 819 | 820 | }; 821 | }; 822 | 823 | 824 | class Linear : public Module 825 | { 826 | 827 | public: 828 | 829 | 830 | int in_features; 831 | int out_features; 832 | bool bias; 833 | 834 | Linear( int in_features, 835 | int out_features, 836 | bool bias=true) : 837 | 838 | in_features(in_features), 839 | out_features(out_features), 840 | bias(bias) 841 | { 842 | 843 | // Initialize weights here 844 | 845 | parameters["weight"] = TENSOR_DEFAULT_TYPE.zeros({out_features, in_features}); 846 | 847 | // Check if we need bias for our convolution 848 | if(bias) 849 | { 850 | 851 | parameters["bias"] = TENSOR_DEFAULT_TYPE.ones({out_features}); 852 | } 853 | else 854 | { 855 | 856 | // don't know why this works yet, doesn't work with TENSOR_DEFAULT_TYPE.tensor(); 857 | parameters["bias"] = Tensor(); 858 | } 859 | 860 | }; 861 | 862 | ~Linear() {}; 863 | 864 | string tostring(int indentation_level=0) 865 | { 866 | 867 | std::stringstream string_stream; 868 | 869 | string indentation = string(indentation_level, ' '); 870 | 871 | string_stream << indentation 872 | << "nn.Linear( " 873 | << "in_features=" << std::to_string(in_features) << " " 874 | << "out_features=" << std::to_string(out_features) << " " 875 | << "bias=" << std::to_string(bias) << " )"; 876 | 877 | return string_stream.str(); 878 | 879 | }; 880 | 881 | Tensor forward(Tensor input) 882 | { 883 | 884 | // https://github.com/pytorch/pytorch/blob/49ec984c406e67107aae2891d24c8839b7dc7c33/torch/nn/_functions/linear.py 885 | 886 | Tensor output = input.type().zeros({input.size(0), parameters["weight"].size(0)}); 887 | 888 | output.addmm_(0, 1, input, parameters["weight"].t()); 889 | 890 | if(bias) 891 | { 892 | // TODO: check if in-place resize affects the result 893 | output.add_(parameters["bias"].expand({output.size(0), output.size(1)})); 894 | } 895 | 896 | return output; 897 | }; 898 | }; 899 | 900 | 901 | 902 | class BasicBlock : public Module 903 | { 904 | 905 | public: 906 | 907 | static const int expansion = 1; 908 | 909 | int stride; 910 | Module::Ptr conv1; 911 | Module::Ptr bn1; 912 | Module::Ptr relu; 913 | Module::Ptr conv2; 914 | Module::Ptr bn2; 915 | Module::Ptr downsample; 916 | 917 | // Make a standart value 918 | BasicBlock(int inplanes, int planes, int stride=1, int dilation=1, Module::Ptr downsample=nullptr) 919 | { 920 | 921 | conv1 = conv3x3(inplanes, planes, stride, dilation); 922 | bn1 = std::make_shared(planes); 923 | relu = std::make_shared(); 924 | conv2 = conv3x3(planes, planes, 1, dilation); 925 | bn2 = std::make_shared(planes); 926 | 927 | // This doesn't work 928 | // downsample = downsample because 929 | // the argument gets assigned instead of a class member, 930 | // Should probably change the name of the member and argument 931 | // to be different 932 | this->downsample = downsample; 933 | 934 | stride = stride; 935 | 936 | add_module("conv1", conv1); 937 | add_module("bn1", bn1); 938 | add_module("conv2", conv2); 939 | add_module("bn2", bn2); 940 | 941 | if( downsample != nullptr ) 942 | { 943 | 944 | add_module("downsample", downsample); 945 | } 946 | 947 | module_name = "BasicBlock"; 948 | 949 | }; 950 | 951 | ~BasicBlock() {}; 952 | 953 | Tensor forward(Tensor input) 954 | { 955 | 956 | // This is done in case we don't have the 957 | // downsample module 958 | Tensor residual = input; 959 | Tensor out; 960 | 961 | out = conv1->forward(input); 962 | out = bn1->forward(out); 963 | out = relu->forward(out); 964 | out = conv2->forward(out); 965 | out = bn2->forward(out); 966 | 967 | if(downsample != nullptr) 968 | { 969 | 970 | residual = downsample->forward(input); 971 | } 972 | 973 | out += residual; 974 | out = relu->forward(out); 975 | 976 | return out; 977 | } 978 | 979 | }; 980 | 981 | 982 | class Bottleneck : public Module 983 | { 984 | 985 | public: 986 | 987 | static const int expansion = 4; 988 | 989 | // done 990 | int stride; 991 | Module::Ptr conv1; 992 | Module::Ptr bn1; 993 | Module::Ptr relu; 994 | Module::Ptr conv2; 995 | Module::Ptr bn2; 996 | Module::Ptr conv3; 997 | Module::Ptr bn3; 998 | Module::Ptr downsample; 999 | 1000 | 1001 | // Make a standart value 1002 | Bottleneck(int inplanes, int planes, int stride=1, int dilation=1, Module::Ptr downsample=nullptr) 1003 | { 1004 | 1005 | conv1 = renset_conv1x1(inplanes, planes); 1006 | bn1 = std::make_shared(planes); 1007 | relu = std::make_shared(); 1008 | 1009 | conv2 = conv3x3(planes, planes, stride, dilation); 1010 | bn2 = std::make_shared(planes); 1011 | 1012 | conv3 = renset_conv1x1(inplanes, planes * Bottleneck::expansion); 1013 | bn3 = std::make_shared(planes * Bottleneck::expansion); 1014 | 1015 | // Avoiding ambiguitiy -- this is why we are using 'this' keyword. 1016 | this->downsample = downsample; 1017 | 1018 | stride = stride; 1019 | 1020 | add_module("conv1", conv1); 1021 | add_module("bn1", bn1); 1022 | add_module("conv2", conv2); 1023 | add_module("bn2", bn2); 1024 | add_module("conv3", conv3); 1025 | add_module("bn3", bn3); 1026 | 1027 | 1028 | if( downsample != nullptr ) 1029 | { 1030 | 1031 | add_module("downsample", downsample); 1032 | } 1033 | 1034 | module_name = "Bottleneck"; 1035 | 1036 | }; 1037 | 1038 | ~Bottleneck() {}; 1039 | 1040 | // done 1041 | Tensor forward(Tensor input) 1042 | { 1043 | 1044 | // This is done in case we don't have the 1045 | // downsample module 1046 | Tensor residual = input; 1047 | Tensor out; 1048 | 1049 | out = conv1->forward(input); 1050 | out = bn1->forward(out); 1051 | out = relu->forward(out); 1052 | 1053 | out = conv2->forward(out); 1054 | out = bn2->forward(out); 1055 | out = relu->forward(out); 1056 | 1057 | out = conv3->forward(out); 1058 | out = bn3->forward(out); 1059 | 1060 | 1061 | if(downsample != nullptr) 1062 | { 1063 | 1064 | residual = downsample->forward(input); 1065 | } 1066 | 1067 | out += residual; 1068 | out = relu->forward(out); 1069 | 1070 | return out; 1071 | } 1072 | }; 1073 | 1074 | 1075 | 1076 | template 1077 | class ResNet : public Module 1078 | { 1079 | 1080 | public: 1081 | 1082 | int output_stride; 1083 | int in_planes; 1084 | 1085 | // Helper variables to help track 1086 | // dilation factor and output stride 1087 | int current_stride; 1088 | int current_dilation; 1089 | 1090 | // Variables realted to the type of architecture. 1091 | // Image Segmentation models don't have average pool 1092 | // layer and Linear layers are converted to 1x1 convolution 1093 | bool fully_conv; 1094 | bool remove_avg_pool; 1095 | 1096 | Module::Ptr conv1; 1097 | Module::Ptr bn1; 1098 | Module::Ptr relu; 1099 | Module::Ptr maxpool; 1100 | Module::Ptr layer1; 1101 | Module::Ptr layer2; 1102 | Module::Ptr layer3; 1103 | Module::Ptr layer4; 1104 | Module::Ptr avgpool; 1105 | Module::Ptr fc; 1106 | 1107 | // block, layers, num_classes=1000): 1108 | ResNet(IntList layers, 1109 | int num_classes=1000, 1110 | bool fully_conv=false, 1111 | bool remove_avg_pool=false, 1112 | int output_stride=32) : 1113 | 1114 | // First depth input is the same for all resnet models 1115 | in_planes(64), 1116 | output_stride(output_stride), 1117 | fully_conv(fully_conv), 1118 | remove_avg_pool(remove_avg_pool) 1119 | 1120 | { 1121 | 1122 | // Stride is four after first convolution and maxpool layer. 1123 | // We use this class member to track current output stride in make_layer() 1124 | current_stride = 4; 1125 | 1126 | // Dilation hasn't been applied after convolution and maxpool layer. 1127 | // We use this class member to track dilation factor in make_layer() 1128 | current_dilation = 1; 1129 | 1130 | conv1 = resnet_base_conv7x7(); 1131 | bn1 = std::make_shared(64); 1132 | relu = std::make_shared(); 1133 | // Kernel size: 3, Stride: 2, Padding, 1 -- full padding 1134 | maxpool = std::make_shared(3, 3, 2, 2, 1, 1); 1135 | 1136 | layer1 = make_layer(64, layers[0], 1); 1137 | layer2 = make_layer(128, layers[1], 2); 1138 | layer3 = make_layer(256, layers[2], 2); 1139 | layer4 = make_layer(512, layers[3], 2); 1140 | 1141 | avgpool = std::make_shared(7, 7); 1142 | fc = std::make_shared(512 * BlockType::expansion, num_classes); 1143 | 1144 | if(fully_conv) 1145 | { 1146 | 1147 | // Average pooling with 'full padding' mode 1148 | avgpool = std::make_shared(7, 7, 1149 | 1, 1, 1150 | 3, 3 ); 1151 | 1152 | // 1x1 Convolution -- Convolutionalized Linear Layer 1153 | fc = std::make_shared(512 * BlockType::expansion, 1154 | num_classes, 1155 | 1, 1); 1156 | } 1157 | 1158 | add_module("conv1", conv1); 1159 | add_module("bn1", bn1); 1160 | add_module("relu", relu); 1161 | 1162 | add_module("maxpool", maxpool); 1163 | 1164 | add_module("layer1", layer1); 1165 | add_module("layer2", layer2); 1166 | add_module("layer3", layer3); 1167 | add_module("layer4", layer4); 1168 | 1169 | add_module("avgpool", avgpool); 1170 | 1171 | add_module("fc", fc); 1172 | 1173 | module_name = "ResNet"; 1174 | 1175 | } 1176 | 1177 | Tensor forward(Tensor input) 1178 | { 1179 | 1180 | Tensor output = input.type().tensor(); 1181 | 1182 | output = conv1->forward(input); 1183 | output = bn1->forward(output); 1184 | output = relu->forward(output); 1185 | output = maxpool->forward(output); 1186 | 1187 | output = layer1->forward(output); 1188 | output = layer2->forward(output); 1189 | output = layer3->forward(output); 1190 | output = layer4->forward(output); 1191 | 1192 | if(!remove_avg_pool) 1193 | { 1194 | 1195 | output = avgpool->forward(output); 1196 | } 1197 | 1198 | if(!fully_conv) 1199 | { 1200 | 1201 | // Flatten the output in order to apply linear layer 1202 | output = output.view({output.size(0), -1}); 1203 | } 1204 | 1205 | output = fc->forward(output); 1206 | 1207 | return output; 1208 | 1209 | } 1210 | 1211 | 1212 | Module::Ptr make_layer(int planes, int blocks, int stride) 1213 | { 1214 | 1215 | auto new_layer = std::make_shared(); 1216 | 1217 | Module::Ptr downsample = nullptr; 1218 | 1219 | // Check if we need to downsample 1220 | if(stride != 1 || in_planes != planes * BlockType::expansion) 1221 | { 1222 | 1223 | // See if we already achieved desired output stride 1224 | if(current_stride == output_stride) 1225 | { 1226 | 1227 | // If so, replace subsampling with dilation to preserve 1228 | // current spatial resolution 1229 | current_dilation = current_dilation * stride; 1230 | stride = 1; 1231 | } 1232 | else 1233 | { 1234 | 1235 | // If not, we perform subsampling 1236 | current_stride = current_stride * stride; 1237 | } 1238 | 1239 | 1240 | downsample = std::make_shared(); 1241 | 1242 | downsample->add( std::make_shared(in_planes, 1243 | planes * BlockType::expansion, 1244 | 1, 1, 1245 | stride, stride, 1246 | 0, 0, 1247 | 1, 1, 1248 | 1, 1249 | false) ); 1250 | 1251 | downsample->add(std::make_shared(planes * BlockType::expansion)); 1252 | 1253 | } 1254 | 1255 | auto first_block = std::make_shared(in_planes, 1256 | planes, 1257 | stride, 1258 | current_dilation, 1259 | downsample); 1260 | new_layer->add(first_block); 1261 | 1262 | in_planes = planes * BlockType::expansion; 1263 | 1264 | for (int i = 0; i < blocks - 1; ++i) 1265 | { 1266 | 1267 | new_layer->add(std::make_shared(in_planes, 1268 | planes, 1269 | 1, 1270 | current_dilation)); 1271 | } 1272 | 1273 | return new_layer; 1274 | 1275 | } 1276 | 1277 | }; 1278 | 1279 | 1280 | Tensor preprocess_batch(Tensor input_batch) 1281 | { 1282 | 1283 | // Subtracts mean and divides by std. 1284 | // Important: image should be in a 0-1 range and not in 0-255 1285 | 1286 | // TODO: create a pull request to add broadcastable 1287 | // operations 1288 | 1289 | auto mean_value = CPU(kFloat).ones({1, 3, 1, 1}); 1290 | 1291 | mean_value[0][0][0][0] = 0.485f; 1292 | mean_value[0][1][0][0] = 0.456f; 1293 | mean_value[0][2][0][0] = 0.406f; 1294 | 1295 | // Broadcast the value 1296 | auto mean_value_broadcasted = mean_value.expand(input_batch.sizes()); 1297 | 1298 | auto std_value = CPU(kFloat).ones({1, 3, 1, 1}); 1299 | 1300 | std_value[0][0][0][0] = 0.229f; 1301 | std_value[0][1][0][0] = 0.224f; 1302 | std_value[0][2][0][0] = 0.225f; 1303 | 1304 | auto std_value_broadcasted = std_value.expand(input_batch.sizes()); 1305 | 1306 | return (input_batch - mean_value_broadcasted) / std_value_broadcasted; 1307 | 1308 | } 1309 | 1310 | vector get_hdf5_file_keys(string hdf5_filename) 1311 | { 1312 | 1313 | // We open and close hdf5 file here. It might be an overkill 1314 | // as we can open the file once, read keys and read tensors outright, 1315 | // but this way we also add a simple debugging function to be able to 1316 | // easily get keys without dealing with HDF5 API directly. 1317 | 1318 | // Open the file 1319 | H5::H5File file = H5::H5File(hdf5_filename, H5F_ACC_RDONLY); 1320 | 1321 | vector names; 1322 | 1323 | // Define a closure to populate our names array 1324 | auto closure = [] (hid_t loc_id, const char *name, const H5L_info_t *linfo, void *opdata) 1325 | { 1326 | 1327 | vector * names_array_pointer = reinterpret_cast< vector *>(opdata); 1328 | 1329 | names_array_pointer->push_back(string(name)); 1330 | 1331 | return 0; 1332 | }; 1333 | 1334 | // Run our closure and populate array 1335 | H5Literate(file.getId(), H5_INDEX_NAME, H5_ITER_INC, NULL, closure, &names); 1336 | 1337 | file.close(); 1338 | 1339 | return names; 1340 | 1341 | } 1342 | 1343 | map load(string hdf5_filename) 1344 | { 1345 | 1346 | map tensor_dict; 1347 | 1348 | // use our get_names function 1349 | vector tensor_names = get_hdf5_file_keys(hdf5_filename); 1350 | 1351 | H5::H5File file = H5::H5File(hdf5_filename, H5F_ACC_RDONLY); 1352 | 1353 | // Array to store the shape of the current tensor 1354 | hsize_t * dims_hsize_t; 1355 | 1356 | // We need this because one function can't accept hsize_t 1357 | vector dims_int; 1358 | 1359 | // Float buffer to intermediately store weights 1360 | float * float_buffer; 1361 | 1362 | // 'Rank' of the tensor 1363 | int ndims; 1364 | 1365 | // Number of elements in the current tensor 1366 | hsize_t tensor_flattened_size; 1367 | 1368 | Tensor buffer_tensor; 1369 | 1370 | 1371 | for (auto tensor_name: tensor_names) 1372 | { 1373 | 1374 | dims_int.clear(); 1375 | 1376 | // Open a 'dataset' which stores current tensor 1377 | H5::DataSet current_dataset = file.openDataSet(tensor_name); 1378 | 1379 | // We can infer the sizes of a store tensor from H5::DataSpace 1380 | H5::DataSpace dataspace = current_dataset.getSpace(); 1381 | ndims = dataspace.getSimpleExtentNdims(); 1382 | 1383 | // Get the overall number of elements -- we need this 1384 | // to allocate the temporary buffer 1385 | tensor_flattened_size = dataspace.getSimpleExtentNpoints(); 1386 | 1387 | // Get the shame of the tensor 1388 | dims_hsize_t = new hsize_t[ndims]; 1389 | dataspace.getSimpleExtentDims(dims_hsize_t, NULL); 1390 | 1391 | for (int i = 0; i < ndims; ++i) 1392 | { 1393 | 1394 | // Converting hsize_t to int 1395 | dims_int.push_back(long(dims_hsize_t[i])); 1396 | } 1397 | 1398 | // Allocate temporary float buffer 1399 | // TODO: add support for other types like int 1400 | // and make automatic type inference 1401 | float_buffer = new float[tensor_flattened_size]; 1402 | 1403 | current_dataset.read(float_buffer, H5::PredType::NATIVE_FLOAT, 1404 | dataspace, dataspace); 1405 | 1406 | 1407 | buffer_tensor = CPU(kFloat).tensorFromBlob(float_buffer, dims_int); 1408 | 1409 | tensor_dict[tensor_name] = buffer_tensor.type().copy(buffer_tensor); 1410 | 1411 | delete[] float_buffer; 1412 | delete[] dims_hsize_t; 1413 | 1414 | } 1415 | 1416 | file.close(); 1417 | 1418 | return tensor_dict; 1419 | } 1420 | 1421 | void save( string hdf5_filename, map dict_to_write) 1422 | { 1423 | 1424 | H5::H5File file = H5::H5File(hdf5_filename, H5F_ACC_TRUNC); 1425 | 1426 | for(auto name_tensor_pair : dict_to_write) 1427 | { 1428 | 1429 | auto tensor_to_write = name_tensor_pair.second.contiguous(); 1430 | auto tensor_name = name_tensor_pair.first; 1431 | 1432 | auto dims = tensor_to_write.sizes(); 1433 | 1434 | // The dimensionality of the tensor 1435 | auto ndims = tensor_to_write.ndimension(); 1436 | auto tensor_flattened_size = tensor_to_write.numel(); 1437 | auto tensor_to_write_flatten = tensor_to_write.view({-1}); 1438 | auto tensor_to_write_flatten_accessor = tensor_to_write_flatten.accessor(); 1439 | 1440 | float * float_buffer = new float[tensor_flattened_size]; 1441 | 1442 | // Convert an array of ints into an array of hsize_t 1443 | auto dims_hsize_t = new hsize_t[ndims]; 1444 | 1445 | for (int i = 0; i < ndims; ++i) 1446 | { 1447 | dims_hsize_t[i] = dims[i]; 1448 | } 1449 | 1450 | for (int i = 0; i < tensor_flattened_size; ++i) 1451 | { 1452 | 1453 | float_buffer[i] = tensor_to_write_flatten_accessor[i]; 1454 | } 1455 | 1456 | H5::DataSpace space(ndims, dims_hsize_t); 1457 | 1458 | H5::DataSet dataset = H5::DataSet(file.createDataSet(tensor_name, 1459 | H5::PredType::NATIVE_FLOAT, 1460 | space)); 1461 | 1462 | 1463 | dataset.write(float_buffer, H5::PredType::NATIVE_FLOAT); 1464 | 1465 | delete[] float_buffer; 1466 | 1467 | } 1468 | 1469 | file.close(); 1470 | 1471 | } 1472 | 1473 | void inspect_checkpoint(string hdf5_filename) 1474 | { 1475 | 1476 | auto dict = load(hdf5_filename); 1477 | 1478 | for (auto name_tensor_pair : dict) 1479 | { 1480 | cout << name_tensor_pair.first << ": " << name_tensor_pair.second.sizes() <forward(input); 1631 | 1632 | auto full_prediction = upsample_bilinear(subsampled_prediction, output_height, output_width); 1633 | 1634 | return full_prediction; 1635 | } 1636 | }; 1637 | 1638 | class Resnet18_8s : public Module 1639 | { 1640 | 1641 | public: 1642 | 1643 | int num_classes; 1644 | Module::Ptr resnet18_8s; 1645 | 1646 | Resnet18_8s(int num_classes=21): 1647 | num_classes(num_classes) 1648 | 1649 | { 1650 | 1651 | resnet18_8s = torch::resnet18(num_classes, 1652 | true, /* fully convolutional model */ 1653 | 8, /* we want subsampled by 8 prediction*/ 1654 | true); /* remove average pooling layer */ 1655 | 1656 | // Adding a module with this name to be able to easily load 1657 | // weights from pytorch models 1658 | add_module("resnet18_8s", resnet18_8s); 1659 | 1660 | } 1661 | 1662 | Tensor forward(Tensor input) 1663 | { 1664 | 1665 | // probably we can add some utility functions to add softmax on top 1666 | // resize the ouput in a proper way 1667 | 1668 | // input is a tensor of shape batch_size x #channels x height x width 1669 | int output_height = input.size(2); 1670 | int output_width = input.size(3); 1671 | 1672 | auto subsampled_prediction = resnet18_8s->forward(input); 1673 | 1674 | auto full_prediction = upsample_bilinear(subsampled_prediction, output_height, output_width); 1675 | 1676 | return full_prediction; 1677 | } 1678 | }; 1679 | 1680 | 1681 | class Resnet18_16s : public Module 1682 | { 1683 | 1684 | public: 1685 | 1686 | int num_classes; 1687 | Module::Ptr resnet18_16s; 1688 | 1689 | Resnet18_16s(int num_classes=21): 1690 | num_classes(num_classes) 1691 | 1692 | { 1693 | 1694 | resnet18_16s = torch::resnet18(num_classes, 1695 | true, /* fully convolutional model */ 1696 | 16, /* we want subsampled by 16 prediction*/ 1697 | true); /* remove average pooling layer */ 1698 | 1699 | // Adding a module with this name to be able to easily load 1700 | // weights from pytorch models 1701 | add_module("resnet18_16s", resnet18_16s); 1702 | 1703 | } 1704 | 1705 | Tensor forward(Tensor input) 1706 | { 1707 | 1708 | // probably we can add some utility functions to add softmax on top 1709 | // resize the ouput in a proper way 1710 | 1711 | // input is a tensor of shape batch_size x #channels x height x width 1712 | int output_height = input.size(2); 1713 | int output_width = input.size(3); 1714 | 1715 | auto subsampled_prediction = resnet18_16s->forward(input); 1716 | 1717 | auto full_prediction = upsample_bilinear(subsampled_prediction, output_height, output_width); 1718 | 1719 | return full_prediction; 1720 | } 1721 | }; 1722 | 1723 | 1724 | class Resnet34_8s : public Module 1725 | { 1726 | 1727 | public: 1728 | 1729 | int num_classes; 1730 | Module::Ptr resnet34_8s; 1731 | 1732 | Resnet34_8s(int num_classes=21): 1733 | num_classes(num_classes) 1734 | 1735 | { 1736 | 1737 | resnet34_8s = torch::resnet34(num_classes, 1738 | true, /* fully convolutional model */ 1739 | 8, /* we want subsampled by 8 prediction*/ 1740 | true); /* remove average pooling layer */ 1741 | 1742 | // Adding a module with this name to be able to easily load 1743 | // weights from pytorch models 1744 | add_module("resnet34_8s", resnet34_8s); 1745 | 1746 | } 1747 | 1748 | Tensor forward(Tensor input) 1749 | { 1750 | 1751 | // TODO: 1752 | 1753 | // (1) This part with upsampling is the same for all fully conv models 1754 | // Might make sense to write an abstract class to avoid duplication 1755 | // (2) Probably we can add some utility functions to add softmax on top 1756 | // resize the ouput in a proper way 1757 | 1758 | // input is a tensor of shape batch_size x #channels x height x width 1759 | int output_height = input.size(2); 1760 | int output_width = input.size(3); 1761 | 1762 | auto subsampled_prediction = resnet34_8s->forward(input); 1763 | 1764 | auto full_prediction = upsample_bilinear(subsampled_prediction, output_height, output_width); 1765 | 1766 | return full_prediction; 1767 | } 1768 | }; 1769 | 1770 | 1771 | Module::Ptr resnet18(int num_classes=1000, bool fully_conv=false, int output_stride=32, bool remove_avg_pool=false) 1772 | { 1773 | 1774 | return std::shared_ptr>( 1775 | new torch::ResNet({2, 2, 2, 2}, 1776 | num_classes, 1777 | fully_conv, 1778 | remove_avg_pool, 1779 | output_stride )); 1780 | } 1781 | 1782 | 1783 | Module::Ptr resnet34(int num_classes=1000, bool fully_conv=false, int output_stride=32, bool remove_avg_pool=false) 1784 | { 1785 | 1786 | return std::shared_ptr>( 1787 | new torch::ResNet({3, 4, 6, 3}, 1788 | num_classes, 1789 | fully_conv, 1790 | remove_avg_pool, 1791 | output_stride )); 1792 | } 1793 | 1794 | Module::Ptr resnet50(int num_classes=1000, bool fully_conv=false, int output_stride=32, bool remove_avg_pool=false) 1795 | { 1796 | 1797 | return std::shared_ptr>( 1798 | new torch::ResNet({3, 4, 6, 3}, 1799 | num_classes, 1800 | fully_conv, 1801 | remove_avg_pool, 1802 | output_stride )); 1803 | } 1804 | 1805 | Module::Ptr resnet101(int num_classes=1000, bool fully_conv=false, int output_stride=32, bool remove_avg_pool=false) 1806 | { 1807 | 1808 | return std::shared_ptr>( 1809 | new torch::ResNet({3, 4, 23, 3}, 1810 | num_classes, 1811 | fully_conv, 1812 | remove_avg_pool, 1813 | output_stride )); 1814 | } 1815 | 1816 | Module::Ptr resnet152(int num_classes=1000, bool fully_conv=false, int output_stride=32, bool remove_avg_pool=false) 1817 | { 1818 | 1819 | return std::shared_ptr>( 1820 | new torch::ResNet({3, 8, 36, 3}, 1821 | num_classes, 1822 | fully_conv, 1823 | remove_avg_pool, 1824 | output_stride )); 1825 | } 1826 | 1827 | Module::Ptr resnet9(int num_classes=1000, bool fully_conv=false, int output_stride=32, bool remove_avg_pool=false) 1828 | { 1829 | 1830 | return std::shared_ptr>( 1831 | new torch::ResNet({1, 1, 1, 1}, 1832 | num_classes, 1833 | fully_conv, 1834 | remove_avg_pool, 1835 | output_stride )); 1836 | } 1837 | 1838 | 1839 | // Maybe add new options like add_softmax?, 1840 | Module::Ptr resnet18_imagenet() 1841 | { 1842 | 1843 | return resnet18(1000, false, 32, false); 1844 | } 1845 | 1846 | Module::Ptr resnet34_imagenet() 1847 | { 1848 | 1849 | return resnet34(1000, false, 32, false); 1850 | } 1851 | 1852 | Module::Ptr resnet50_imagenet() 1853 | { 1854 | 1855 | return resnet50(1000, false, 32, false); 1856 | } 1857 | 1858 | Module::Ptr resnet101_imagenet() 1859 | { 1860 | 1861 | return resnet101(1000, false, 32, false); 1862 | } 1863 | 1864 | Module::Ptr resnet152_imagenet() 1865 | { 1866 | 1867 | return resnet152(1000, false, 32, false); 1868 | } 1869 | 1870 | Module::Ptr resnet18_8s_pascal_voc() 1871 | { 1872 | 1873 | return make_shared(21); 1874 | } 1875 | 1876 | Module::Ptr resnet18_16s_pascal_voc() 1877 | { 1878 | 1879 | return make_shared(21); 1880 | } 1881 | 1882 | Module::Ptr resnet34_8s_pascal_voc() 1883 | { 1884 | 1885 | return make_shared(21); 1886 | } 1887 | 1888 | Module::Ptr resnet9_8s_endovis_binary() 1889 | { 1890 | 1891 | return make_shared(2); 1892 | } 1893 | 1894 | } --------------------------------------------------------------------------------