├── .gitignore
├── .gitmodules
├── CMakeLists.txt
├── README.md
├── convert_weights.ipynb
├── examples
├── CMakeLists.txt
├── opencv_realtime_webcam_human_segmentation.cpp
├── opencv_realtime_webcam_imagenet_classification.cpp
├── pytorch_results_deviation.cpp
├── read_allocated_gpu_memory.cpp
├── resnet_18_16s_benchmark.cpp
├── resnet_18_8s_benchmark.cpp
├── resnet_9_8s_benchmark.cpp
└── segmentation_demo_preview.gif
└── src
├── imagenet_classes.cpp
└── pytorch.cpp
/.gitignore:
--------------------------------------------------------------------------------
1 | build/
2 | ATen/build/
3 | .ipynb_checkpoints/
4 | CMake-hdf5-1.8.20/
5 | *.h5
6 | *.ipynb
7 | *.swp
8 |
9 | !convert_weights.ipynb
10 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "ATen"]
2 | path = ATen
3 | url = git@github.com:warmspringwinds/ATen.git
4 |
--------------------------------------------------------------------------------
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.0)
2 | project(boo)
3 |
4 | set(ATen_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/ATen)
5 | set(ATen_BINARY_DIR ${CMAKE_CURRENT_SOURCE_DIR}/ATen/build)
6 |
7 | # C++11
8 | if(CMAKE_VERSION VERSION_LESS "3.1")
9 | set(CMAKE_CXX_FLAGS "--std=c++11 ${CMAKE_CXX_FLAGS}")
10 | else()
11 | set(CMAKE_CXX_STANDARD 11)
12 | endif()
13 |
14 | find_package(CUDA 5.5)
15 | find_package( OpenCV REQUIRED )
16 |
17 | #find_package(HDF5 COMPONENTS C HL NO_MODULE REQUIRED static)
18 |
19 | include_directories(
20 | # dense
21 | ${ATen_SOURCE_DIR}/lib/TH
22 | ${ATen_SOURCE_DIR}/lib/THC
23 | ${ATen_BINARY_DIR}/lib/TH
24 | ${ATen_BINARY_DIR}/lib/THC
25 | # sparse
26 | ${ATen_SOURCE_DIR}/lib/THS
27 | ${ATen_SOURCE_DIR}/lib/THCS
28 | ${ATen_BINARY_DIR}/lib/THS
29 | ${ATen_BINARY_DIR}/lib/THCS
30 |
31 | ${ATen_SOURCE_DIR}/lib
32 | ${ATen_BINARY_DIR}/lib)
33 |
34 | include_directories(
35 | ${ATen_SOURCE_DIR}/lib/THNN
36 | ${ATen_SOURCE_DIR}/lib/THCUNN)
37 |
38 | include_directories(
39 | ${ATen_SOURCE_DIR}/src
40 | ${ATen_BINARY_DIR}/src/ATen)
41 |
42 | INCLUDE_DIRECTORIES(${CUDA_INCLUDE_DIRS})
43 |
44 | include_directories(
45 | ${CMAKE_CURRENT_SOURCE_DIR}/CMake-hdf5-1.8.20/hdf5-1.8.20/src
46 | ${CMAKE_CURRENT_SOURCE_DIR}/CMake-hdf5-1.8.20/hdf5-1.8.20/c++/src
47 | ${CMAKE_CURRENT_SOURCE_DIR}/CMake-hdf5-1.8.20/build
48 | )
49 |
50 | set(_hdf5_libs
51 | ${CMAKE_CURRENT_SOURCE_DIR}/CMake-hdf5-1.8.20/build/bin/libhdf5_cpp.a
52 | ${CMAKE_CURRENT_SOURCE_DIR}/CMake-hdf5-1.8.20/build/bin/libhdf5.a
53 | ${CMAKE_CURRENT_SOURCE_DIR}/CMake-hdf5-1.8.20/build/bin/libz.a
54 | ${CMAKE_CURRENT_SOURCE_DIR}/CMake-hdf5-1.8.20/build/bin/libszip.a
55 | -ldl
56 | )
57 |
58 | # TODO: structure project in a better way
59 | # Temporary solution -- change to a normal cpp project structure later
60 | include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src)
61 |
62 |
63 | ADD_EXECUTABLE(read_allocated_gpu_memory examples/read_allocated_gpu_memory.cpp)
64 | TARGET_LINK_LIBRARIES(read_allocated_gpu_memory ${ATen_BINARY_DIR}/src/ATen/libATen.so.1 ${CUDA_LIBRARIES})
65 |
66 | ADD_EXECUTABLE(opencv_realtime_webcam_human_segmentation examples/opencv_realtime_webcam_human_segmentation.cpp)
67 | TARGET_LINK_LIBRARIES(opencv_realtime_webcam_human_segmentation ${ATen_BINARY_DIR}/src/ATen/libATen.so.1 ${CUDA_LIBRARIES} ${_hdf5_libs} ${OpenCV_LIBS})
68 |
69 | ADD_EXECUTABLE(pytorch_results_deviation examples/pytorch_results_deviation.cpp)
70 | TARGET_LINK_LIBRARIES(pytorch_results_deviation ${ATen_BINARY_DIR}/src/ATen/libATen.so.1 ${CUDA_LIBRARIES} ${_hdf5_libs})
71 |
72 | ADD_EXECUTABLE(resnet_18_8s_benchmark examples/resnet_18_8s_benchmark.cpp)
73 | TARGET_LINK_LIBRARIES(resnet_18_8s_benchmark ${ATen_BINARY_DIR}/src/ATen/libATen.so.1 ${CUDA_LIBRARIES} ${_hdf5_libs})
74 |
75 | ADD_EXECUTABLE(resnet_18_16s_benchmark examples/resnet_18_16s_benchmark.cpp)
76 | TARGET_LINK_LIBRARIES(resnet_18_16s_benchmark ${ATen_BINARY_DIR}/src/ATen/libATen.so.1 ${CUDA_LIBRARIES} ${_hdf5_libs})
77 |
78 | ADD_EXECUTABLE(resnet_9_8s_benchmark examples/resnet_9_8s_benchmark.cpp)
79 | TARGET_LINK_LIBRARIES(resnet_9_8s_benchmark ${ATen_BINARY_DIR}/src/ATen/libATen.so.1 ${CUDA_LIBRARIES} ${_hdf5_libs})
80 |
81 | ADD_EXECUTABLE(opencv_realtime_webcam_imagenet_classification examples/opencv_realtime_webcam_imagenet_classification.cpp)
82 | TARGET_LINK_LIBRARIES(opencv_realtime_webcam_imagenet_classification ${ATen_BINARY_DIR}/src/ATen/libATen.so.1 ${CUDA_LIBRARIES} ${_hdf5_libs} ${OpenCV_LIBS})
83 |
84 |
85 | #add_subdirectory(examples)
86 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Pytorch-C++
2 |
3 | ```Pytorch-C++``` is a simple C++ 11 library which provides a [Pytorch](http://pytorch.org/)-like
4 | interface for building neural networks and inference (so far only forward pass is supported). The library
5 | respects the semantics of ```torch.nn``` module of PyTorch. Models from [pytorch/vision](https://github.com/pytorch/vision)
6 | are supported and can be [easily converted](convert_weights.ipynb). We also support all the models from [our image segmentation repository](https://github.com/warmspringwinds/pytorch-segmentation-detection) (scroll down for the gif with example output of one of our segmentation models).
7 |
8 | The library heavily relies on an amazing [ATen](https://github.com/zdevito/ATen) library and was inspired by
9 | [cunnproduction](https://github.com/szagoruyko/cunnproduction).
10 |
11 | The structure of the project and CMake will be changed in a future, as it is not optimal now.
12 |
13 | ## Table of contents
14 |
15 | Use-cases
16 | Examples
17 | Implemented layers
18 | Implemented models
19 | Demos
20 | Installation
21 | About
22 | Contributors
23 |
24 |
25 | ## Use-cases
26 |
27 | The library can be used in cases where you want to integrate your trained ```Pytorch```
28 | networks into an existing C++ stack and you don't want to convert your weights to other libraries
29 | like ```Caffe/Caffe2/Tensorflow```. The library respects the semantics of the ```Pytorch``` and uses
30 | the same underlying C library to perform all the operations.
31 |
32 | You can achieve more low-level control over your memory. For example,
33 | you can use a memory that was already allocated on GPU. This way you can accept memory from other
34 | application on GPU and avoid expensive transfer to CPU. See [this example](examples/read_allocated_gpu_memory.cpp).
35 |
36 | Conversion from other image types like OpenCV's ```mat``` to ```Tensor``` can be easily performed and all the post-processing
37 | can be done using numpy-like optimized operations, thanks to [ATen](https://github.com/zdevito/ATen) library.
38 | See examples [here](examples/opencv_realtime_webcam_human_segmentation.cpp).
39 |
40 |
41 | ## Some examples
42 |
43 | ### Inference
44 |
45 | ```c++
46 | auto net = torch::resnet50_imagenet();
47 |
48 | net->load_weights("../resnet50_imagenet.h5");
49 |
50 | # Transfer network to GPU
51 | net->cuda();
52 |
53 | # Generate a dummy tensor on GPU of type float
54 | Tensor dummy_input = CUDA(kFloat).ones({1, 3, 224, 224});
55 |
56 | # Perform inference
57 | auto result = net->forward(dummy_input);
58 |
59 | map dict;
60 |
61 | # Get the result of the inference back to CPU
62 | dict["main"] = result.toBackend(Backend::CPU);
63 |
64 | # Save the result of the inference in the HDF5 file
65 | torch::save("resnet50_output.h5", dict);
66 | ```
67 |
68 | ### Display network's architecture
69 |
70 | ```c++
71 |
72 | auto net = torch::resnet50_imagenet();
73 |
74 | net->load_weights("../resnet50_imagenet.h5");
75 |
76 | cout << net->tostring() << endl;
77 |
78 | ```
79 |
80 | Output:
81 |
82 | ```
83 | ResNet (
84 | (conv1) Conv2d( in_channels=3 out_channels=64 kernel_size=(7, 7) stride=(2, 2) padding=(3, 3) dilation=(1, 1) groups=1 bias=0 )
85 | (bn1) BatchNorm2d( num_features=64 eps=0.000010 momentum=0.100000 )
86 | (relu) ReLU
87 | (maxpool) MaxPool2d( kernel_size=(3, 3) stride=(2, 2) padding=(1, 1) )
88 | (layer1) Sequential (
89 | (0) Bottleneck (
90 | (conv1) Conv2d( in_channels=64 out_channels=64 kernel_size=(1, 1) stride=(1, 1) padding=(0, 0) dilation=(1, 1) groups=1 bias=0 )
91 | (bn1) BatchNorm2d( num_features=64 eps=0.000010 momentum=0.100000 )
92 | (conv2) Conv2d( in_channels=64 out_channels=64 kernel_size=(3, 3) stride=(1, 1) padding=(1, 1) dilation=(1, 1) groups=1 bias=0 )
93 | (bn2) BatchNorm2d( num_features=64 eps=0.000010 momentum=0.100000 )
94 | (conv3) Conv2d( in_channels=64 out_channels=256 kernel_size=(1, 1) stride=(1, 1) padding=(0, 0) dilation=(1, 1) groups=1 bias=0 )
95 | (bn3) BatchNorm2d( num_features=256 eps=0.000010 momentum=0.100000 )
96 | (downsample) Sequential (
97 | (0) Conv2d( in_channels=64 out_channels=256 kernel_size=(1, 1) stride=(1, 1) padding=(0, 0) dilation=(1, 1) groups=1 bias=0 )
98 | (1) BatchNorm2d( num_features=256 eps=0.000010 momentum=0.100000 )
99 | )
100 |
101 | )
102 |
103 | (1) Bottleneck (
104 | (conv1) Conv2d( in_channels=256 out_channels=64 kernel_size=(1, 1) stride=(1, 1) padding=(0, 0) dilation=(1, 1) groups=1 bias=0 )
105 | (bn1) BatchNorm2d( num_features=64 eps=0.000010 momentum=0.100000 )
106 | (conv2) Conv2d( in_channels=64 out_channels=64 kernel_size=(3, 3) stride=(1, 1) padding=(1, 1) dilation=(1, 1) groups=1 bias=0 )
107 | (bn2) BatchNorm2d( num_features=64 eps=0.000010 momentum=0.100000 )
108 | (conv3) Conv2d( in_channels=256 out_channels=256 kernel_size=(1, 1) stride=(1, 1) padding=(0, 0) dilation=(1, 1) groups=1 bias=0 )
109 | (bn3) BatchNorm2d( num_features=256 eps=0.000010 momentum=0.100000 )
110 | )
111 |
112 | (2) Bottleneck (
113 | (conv1) Conv2d( in_channels=256 out_channels=64 kernel_size=(1, 1) stride=(1, 1) padding=(0, 0) dilation=(1, 1) groups=1 bias=0 )
114 | (bn1) BatchNorm2d( num_features=64 eps=0.000010 momentum=0.100000 )
115 | (conv2) Conv2d( in_channels=64 out_channels=64 kernel_size=(3, 3) stride=(1, 1) padding=(1, 1) dilation=(1, 1) groups=1 bias=0 )
116 | (bn2) BatchNorm2d( num_features=64 eps=0.000010 momentum=0.100000 )
117 | (conv3) Conv2d( in_channels=256 out_channels=256 kernel_size=(1, 1) stride=(1, 1) padding=(0, 0) dilation=(1, 1) groups=1 bias=0 )
118 | (bn3) BatchNorm2d( num_features=256 eps=0.000010 momentum=0.100000 )
119 | )
120 |
121 | )
122 |
123 | /* .... */
124 |
125 | (avgpool) AvgPool2d( kernel_size=(7, 7) stride=(1, 1) padding=(0, 0) )
126 | (fc) nn.Linear( in_features=2048 out_features=1000 bias=1 )
127 | )
128 | ```
129 |
130 | ### Inspect a Tensor
131 |
132 |
133 | ```c++
134 | auto net = torch::resnet50_imagenet();
135 |
136 | net->load_weights("../resnet50_imagenet.h5");
137 | net->cuda();
138 |
139 | Tensor dummy_input = CUDA(kFloat).ones({1, 3, 224, 224});
140 |
141 | auto result = net->forward(dummy_input);
142 |
143 | cout << result << endl;
144 | ```
145 |
146 |
147 | ```
148 | Columns 1 to 10-0.3081 0.0798 -1.1900 -1.4837 -0.5136 0.3683 -2.1639 -0.8705 -1.8812 -0.1608
149 |
150 | Columns 11 to 20 0.2168 -0.9283 -1.2954 -1.0791 -1.4445 -0.8946 -0.0959 -1.3099 -1.2062 -1.2327
151 |
152 | Columns 21 to 30-1.0658 0.9427 0.5739 -0.2746 -1.0189 -0.3583 -0.1826 0.2785 0.2209 -0.3340
153 |
154 | Columns 31 to 40-1.9800 -0.5552 -1.0804 -0.8056 -0.0005 -1.8402 -0.7979 -1.4823 1.3657 -0.8970
155 |
156 | /* .... */
157 |
158 | Columns 961 to 970-0.0557 -0.7405 -0.5501 -1.7207 -0.7043 -1.0925 1.5812 -0.1215 0.8915 0.9794
159 |
160 | Columns 971 to 980-1.1422 -0.1235 -0.5999 -2.1338 -0.0775 -0.8374 -0.2350 -0.0104 -0.0416 -1.0296
161 |
162 | Columns 981 to 990-0.2914 -0.2242 -0.8063 -0.7818 -0.2714 0.0002 -1.2355 0.1238 0.0183 -0.6904
163 |
164 | Columns 991 to 1000 0.5216 -1.8008 -1.7826 -1.2970 -1.6565 -1.3306 -0.6564 -1.6531 0.1178 0.2436
165 | [ CUDAFloatTensor{1,1000} ]
166 | ```
167 |
168 | ### Create a network
169 |
170 |
171 | ```c++
172 | auto new_net = std::make_shared();
173 | new_net->add(std::make_shared(3, 10, 3, 3));
174 | new_net->add(std::make_shared(10));
175 | new_net->add(std::make_shared());
176 | new_net->add(std::make_shared(10, 3));
177 | ```
178 | ## Implemented layers
179 |
180 | So far, these layers are available which respect the Pytorch's layers semantics which
181 | can be found [here](http://pytorch.org/docs/0.1.12/nn.html#convolution-layers).
182 |
183 |
184 | - [x] nn.Sequential
185 | - [x] nn.Conv2d
186 | - [x] nn.MaxPool2d
187 | - [x] nn.AvgPool2d
188 | - [x] nn.ReLU
189 | - [x] nn.Linear
190 | - [x] nn.SoftMax
191 | - [x] nn.BatchNorm2d
192 | - [ ] nn.Dropout2d
193 | - [ ] nn.DataParallel
194 | - [ ] nn.AdaptiveMaxPool2d
195 | - [ ] nn.Sigmoid
196 | and others.
197 |
198 | ## Implemented models
199 |
200 | Some convered models are provided for ease of access. Other models can be [easily converted](convert_weights.ipynb).
201 |
202 | ### Imagenet models
203 |
204 | All models were converted from [pytorch/vision](https://github.com/pytorch/vision) and checked for
205 | correctness.
206 |
207 | - [x] Resnet-18
208 | - [x] Resnet-34
209 | - [x] [Resnet-50](https://www.dropbox.com/s/bukezzx17dr8qdd/resnet50_imagenet.h5?dl=0)
210 | - [x] Resnet-101
211 | - [x] Resnet-150
212 | - [x] Resnet-152
213 | - [ ] All VGG models
214 | - [ ] All Densenet models
215 | - [ ] All Inception models
216 | - [ ] All squeezenet models
217 | - [ ] Alexnet
218 |
219 | ### Segmentation PASCAL VOC
220 |
221 | All models were converted from [this repository](https://github.com/warmspringwinds/dense-ai) and checked for
222 | correctness.
223 |
224 | - [x] Resnet-18-8S
225 | - [x] [Resnet-34-8S](https://www.dropbox.com/s/104my8hr5zm6l7d/resnet34_fcn_pascal.h5?dl=0)
226 | - [ ] Resnet-50-8S
227 | - [ ] Resnet-101-8S
228 | - [ ] Resnet-152-8S
229 | - [x] FCN-32s
230 | - [ ] FCN-16s
231 | - [ ] FCN-8s
232 |
233 | ## Demos
234 |
235 | We created a couple of [demos](examples) where we grab frames using opencv and classify
236 | or segment them.
237 |
238 | Here you can see and example of real-time segmentation:
239 |
240 |
241 | 
242 |
243 | ## Installation
244 |
245 | ### ATen
246 |
247 | [ATen](https://github.com/zdevito/ATen) is a C++ 11 library that wraps a powerfull C Tensor library with
248 | implementation of numpy-like operations (CPU/CUDA/SPARSE/CUDA-SPARSE backends).
249 | Follow these steps to install it:
250 |
251 | 0. Make sure you have [dependencies](https://github.com/zdevito/ATen#installation) of ```ATen``` installed.
252 | 1. ```git clone --recursive https://github.com/warmspringwinds/pytorch-cpp```
253 | 2. ```cd pytorch-cpp/ATen;mkdir build;cd build;cmake-gui .. ``` and specify ```CUDA_TOOLKIT_ROOT_DIR```.
254 | 3. ```make``` or better ```make -j7``` (replace ```7``` with a number of cores that you have).
255 | 4. ```cd ../../``` -- returns you back to the root directory (necessary for the next step).
256 |
257 | ### HDF5
258 |
259 | We use ```HDF5``` to be able to [easily convert](convert_weights.ipynb) weigths between ```Pytorch``` and ```Pytorch-C++```.
260 |
261 | 0. ```wget https://support.hdfgroup.org/ftp/HDF5/current18/src/CMake-hdf5-1.8.20.tar.gz; tar xvzf CMake-hdf5-1.8.19.tar.gz```
262 | 1. ```cd CMake-hdf5-1.8.19; ./build-unix.sh```
263 | 2. ```cd ../``` -- return back.
264 |
265 | Additional information: ```HDF5``` gets updated from time to time and there is a good chance that my link might be outdated.
266 | If it's the case, grab the latest version from [the official website](https://support.hdfgroup.org/ftp/HDF5/current18/src/).
267 |
268 | Also, after you do this don't forget to update the ```CMakelists.txt``` file with the new hdf5 folder name.
269 |
270 | ### Opencv
271 |
272 | We need ```OpenCV``` for a couple of examples which grab frames from a web camera.
273 | It is not a dependency and can be removed if necessary.
274 | This was tested on ```Ubuntu-16``` and might need some changes on a different system.
275 |
276 | 0. ```sudo apt-get install libopencv-dev python-opencv```
277 |
278 |
279 | ### Pytorch-C++
280 |
281 | ```Pytorch-C++``` is a library on top of ```ATen``` that provides a [Pytorch](http://pytorch.org/)-like
282 | interface for building neural networks and inference (so far only forward pass is supported)
283 | inspired by [cunnproduction](https://github.com/szagoruyko/cunnproduction) library. To install it, follow
284 | these steps:
285 |
286 | 0. ```mkdir build; cd build; cmake-gui ..``` and specify ```CUDA_TOOLKIT_ROOT_DIR```.
287 | 1. ```make```
288 | 2. ```cd ../``` -- return back
289 |
290 | ### Problems with the build
291 |
292 | It was noticed that if you have anaconda installed and your ```PATH``` variable is modified to include
293 | its folder, it can lead to failed buid (caused by the fact that anaconda uses different version of ```gcc```).
294 | To solve this problem, remove the path to anaconda from ```PATH``` for the time of the build.
295 |
296 | If you face any problems or some steps are not clear, please open an issue. Note: every time you enter the ```cmake-gui```
297 | press ```configure``` first, then specify your ```CUDA``` path and then press ```generate```, after that you can build.
298 |
299 |
300 | ## About
301 |
302 | If you used the code for your research, please, cite the paper:
303 |
304 | @article{pakhomov2017deep,
305 | title={Deep Residual Learning for Instrument Segmentation in Robotic Surgery},
306 | author={Pakhomov, Daniil and Premachandran, Vittal and Allan, Max and Azizian, Mahdi and Navab, Nassir},
307 | journal={arXiv preprint arXiv:1703.08580},
308 | year={2017}
309 | }
310 |
311 | During implementation, some preliminary experiments and notes were reported:
312 | - [Converting Image Classification network into FCN](http://warmspringwinds.github.io/tensorflow/tf-slim/2016/10/30/image-classification-and-segmentation-using-tensorflow-and-tf-slim/)
313 | - [Performing upsampling using transposed convolution](http://warmspringwinds.github.io/tensorflow/tf-slim/2016/11/22/upsampling-and-image-segmentation-with-tensorflow-and-tf-slim/)
314 | - [Conditional Random Fields for Refining of Segmentation and Coarseness of FCN-32s model segmentations](http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/18/image-segmentation-with-tensorflow-using-cnns-and-conditional-random-fields/)
315 | - [TF-records usage](http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/)
316 |
317 | ## Contributors
318 |
319 | - Daniil Pakhomov
320 |
--------------------------------------------------------------------------------
/convert_weights.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## Converting weights between Pytorch and Pytorch-C++"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 2,
13 | "metadata": {
14 | "collapsed": true
15 | },
16 | "outputs": [],
17 | "source": [
18 | "import h5py\n",
19 | "import torch\n",
20 | "import torchvision.models as models\n",
21 | "\n",
22 | "test = models.resnet18(pretrained=True)\n",
23 | "\n",
24 | "state_dict = test.state_dict()\n",
25 | "\n",
26 | "h5f = h5py.File('resnet18.h5', 'w')\n",
27 | "\n",
28 | "for key in state_dict:\n",
29 | " \n",
30 | " h5f.create_dataset(key, data=state_dict[key].numpy())\n",
31 | "\n",
32 | "h5f.close()"
33 | ]
34 | },
35 | {
36 | "cell_type": "markdown",
37 | "metadata": {},
38 | "source": [
39 | "## Comparing the outputs of two classifiction networks"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 1,
45 | "metadata": {
46 | "collapsed": true
47 | },
48 | "outputs": [],
49 | "source": [
50 | "import torch\n",
51 | "import torchvision.models as models\n",
52 | "\n",
53 | "net = models.resnet50(pretrained=True)\n",
54 | "net = net.eval()\n",
55 | "\n",
56 | "ones_input = torch.autograd.Variable( torch.ones(1, 3, 224, 224) )\n",
57 | "\n",
58 | "pytorch_inference_result = net(ones_input).data.numpy()"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": 2,
64 | "metadata": {
65 | "collapsed": true
66 | },
67 | "outputs": [],
68 | "source": [
69 | "import h5py\n",
70 | "import torch\n",
71 | "import torchvision.models as models\n",
72 | "\n",
73 | "\n",
74 | "h5f = h5py.File('build/resnet50_output.h5', 'r')\n",
75 | "pytorch_cpp_inference_result = h5f['main'][:]\n",
76 | "h5f.close()"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": 3,
82 | "metadata": {},
83 | "outputs": [
84 | {
85 | "data": {
86 | "text/plain": [
87 | "True"
88 | ]
89 | },
90 | "execution_count": 3,
91 | "metadata": {},
92 | "output_type": "execute_result"
93 | }
94 | ],
95 | "source": [
96 | "import numpy as np\n",
97 | "\n",
98 | "# Equal up to 1e-4 by absolute value\n",
99 | "np.allclose(pytorch_cpp_inference_result, pytorch_inference_result, atol=1e-4, rtol=0)"
100 | ]
101 | },
102 | {
103 | "cell_type": "markdown",
104 | "metadata": {},
105 | "source": [
106 | "## Comparing the outputs of two segmentation networks"
107 | ]
108 | }
109 | ],
110 | "metadata": {
111 | "kernelspec": {
112 | "display_name": "Python 2",
113 | "language": "python",
114 | "name": "python2"
115 | },
116 | "language_info": {
117 | "codemirror_mode": {
118 | "name": "ipython",
119 | "version": 2
120 | },
121 | "file_extension": ".py",
122 | "mimetype": "text/x-python",
123 | "name": "python",
124 | "nbconvert_exporter": "python",
125 | "pygments_lexer": "ipython2",
126 | "version": "2.7.13"
127 | }
128 | },
129 | "nbformat": 4,
130 | "nbformat_minor": 2
131 | }
132 |
--------------------------------------------------------------------------------
/examples/CMakeLists.txt:
--------------------------------------------------------------------------------
1 |
2 | #ADD_EXECUTABLE(read_allocated_gpu_memory read_allocated_gpu_memory.cpp)
3 | #TARGET_LINK_LIBRARIES(read_allocated_gpu_memory ${ATen_BINARY_DIR}/src/ATen/libATen.so.1 ${CUDA_LIBRARIES})
4 |
5 | #ADD_EXECUTABLE(sequential_test sequential_test.cpp)
6 | #TARGET_LINK_LIBRARIES(sequential_test ${ATen_BINARY_DIR}/src/ATen/libATen.so.1 ${CUDA_LIBRARIES})
7 |
8 |
--------------------------------------------------------------------------------
/examples/opencv_realtime_webcam_human_segmentation.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | Example shows a real-time segmentation of human class from PASCAL VOC.
3 | The network ouputs probabilities of each pixels belonging to the human class.
4 | These probabilities are later on are used as a transparancy mask for the input image.
5 | The final fused image is displayed in the window of the application.
6 | */
7 |
8 | #include "ATen/ATen.h"
9 | #include "ATen/Type.h"
10 | #include