├── .gitignore
├── CMakeLists.txt
├── Config.cmake.in
├── LICENSE
├── README-Chinese.md
├── README.md
├── build
└── cmake_buid_here.txt
├── docs
└── training tricks.md
├── prediction.jpg
├── segmentation.pc.in
├── src
├── SegDataset.cpp
├── SegDataset.h
├── Segmentor.cpp
├── Segmentor.h
├── architectures
│ ├── DeepLab.cpp
│ ├── DeepLab.h
│ ├── DeepLabDecoder.cpp
│ ├── DeepLabDecoder.h
│ ├── FPN.cpp
│ ├── FPN.h
│ ├── FPNDecoder.cpp
│ ├── FPNDecoder.h
│ ├── LinkNet.cpp
│ ├── LinkNet.h
│ ├── LinknetDecoder.cpp
│ ├── LinknetDecoder.h
│ ├── PAN.cpp
│ ├── PAN.h
│ ├── PANDecoder.cpp
│ ├── PANDecoder.h
│ ├── PSPNet.cpp
│ ├── PSPNet.h
│ ├── PSPNetDecoder.cpp
│ ├── PSPNetDecoder.h
│ ├── UNet.cpp
│ ├── UNet.h
│ ├── UNetDecoder.cpp
│ └── UNetDecoder.h
├── backbones
│ ├── ResNet.cpp
│ ├── ResNet.h
│ ├── VGG.cpp
│ └── VGG.h
└── utils
│ ├── Augmentations.cpp
│ ├── Augmentations.h
│ ├── InterFace.h
│ ├── _dirent.h
│ ├── json.hpp
│ ├── loss.h
│ ├── readfile.h
│ ├── util.cpp
│ └── util.h
├── test
├── CMakeLists.txt
├── resnet34.cpp
└── train.cpp
├── trace.py
├── voc_person_seg
├── train
│ ├── 2007_000027.jpg
│ ├── 2007_000027.json
│ ├── 2007_000170.jpg
│ ├── 2007_000170.json
│ ├── 2007_000272.jpg
│ ├── 2007_000272.json
│ ├── 2007_000323.jpg
│ ├── 2007_000346.jpg
│ ├── 2007_000346.json
│ ├── 2007_000423.jpg
│ ├── 2007_000423.json
│ ├── 2007_000664.jpg
│ ├── 2007_000664.json
│ ├── 2007_000733.jpg
│ ├── 2007_000733.json
│ ├── 2007_000762.jpg
│ ├── 2007_000783.jpg
│ ├── 2007_000783.json
│ ├── 2007_000799.jpg
│ ├── 2007_000799.json
│ ├── 2007_000807.jpg
│ ├── 2007_000807.json
│ ├── 2007_000836.jpg
│ ├── 2007_000836.json
│ ├── 2007_000847.jpg
│ ├── 2007_000847.json
│ ├── 2007_000999.jpg
│ ├── 2007_000999.json
│ ├── 2007_001185.jpg
│ ├── 2007_001185.json
│ ├── 2007_001408.jpg
│ ├── 2007_001408.json
│ ├── 2007_001430.jpg
│ ├── 2007_001430.json
│ ├── 2007_001526.jpg
│ ├── 2007_001558.jpg
│ ├── 2007_001583.jpg
│ ├── 2007_001585.jpg
│ ├── 2007_001585.json
│ ├── 2007_001627.jpg
│ ├── 2007_001627.json
│ ├── 2007_001630.jpg
│ ├── 2007_001630.json
│ ├── 2007_001717.jpg
│ ├── 2007_001717.json
│ ├── 2007_002293.jpg
│ ├── 2007_002293.json
│ ├── 2007_002668.jpg
│ ├── 2007_002728.jpg
│ ├── 2007_002728.json
│ ├── 2007_002824.jpg
│ ├── 2007_002824.json
│ ├── 2007_003106.jpg
│ ├── 2007_003106.json
│ ├── 2007_003118.jpg
│ ├── 2007_003188.jpg
│ ├── 2007_003188.json
│ ├── 2007_003189.jpg
│ ├── 2007_003189.json
│ ├── 2007_003191.jpg
│ ├── 2007_003191.json
│ ├── 2007_003205.jpg
│ ├── 2007_003205.json
│ ├── 2007_003329.jpg
│ ├── 2007_003529.jpg
│ ├── 2007_003530.jpg
│ ├── 2007_003530.json
│ ├── 2007_003541.jpg
│ ├── 2007_003541.json
│ ├── 2007_003742.jpg
│ ├── 2007_003742.json
│ ├── 2007_003745.jpg
│ └── 2007_003745.json
└── val
│ ├── 2007_003747.jpg
│ ├── 2007_003747.json
│ ├── 2007_004000.jpg
│ ├── 2007_004000.json
│ ├── 2007_004049.jpg
│ ├── 2007_004133.jpg
│ ├── 2007_004133.json
│ ├── 2007_004197.jpg
│ ├── 2007_004197.json
│ ├── 2007_004291.jpg
│ ├── 2007_004392.jpg
│ ├── 2007_004397.jpg
│ ├── 2007_004476.jpg
│ ├── 2007_004476.json
│ ├── 2007_004712.jpg
│ ├── 2007_004712.json
│ ├── 2007_004831.jpg
│ ├── 2007_004902.jpg
│ ├── 2007_005019.jpg
│ ├── 2007_005086.jpg
│ ├── 2007_005086.json
│ ├── 2007_005124.jpg
│ └── 2007_005124.json
└── weights
└── input weights here.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Prerequisites
2 | *.d
3 |
4 | # Compiled Object files
5 | *.slo
6 | *.lo
7 | *.o
8 | *.obj
9 |
10 | # Precompiled Headers
11 | *.gch
12 | *.pch
13 |
14 | # Compiled Dynamic libraries
15 | *.so
16 | *.dylib
17 | *.dll
18 |
19 | # Fortran module files
20 | *.mod
21 | *.smod
22 |
23 | # Compiled Static libraries
24 | *.lai
25 | *.la
26 | *.a
27 | *.lib
28 |
29 | # Executables
30 | *.exe
31 | *.out
32 | *.app
33 |
34 | # LibtorchSegmentation
35 | build
36 | weights
--------------------------------------------------------------------------------
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | include(GNUInstallDirs)
2 |
3 | cmake_minimum_required(VERSION 3.10)
4 |
5 | set(CMAKE_CXX_STANDARD 14)
6 |
7 | set(PROJECT_VERSION 1.0.0)
8 |
9 |
10 | project(LibTorchSegmentation VERSION ${PROJECT_VERSION}
11 | DESCRIPTION "Image Segmentation library based on LibTorch")
12 |
13 |
14 | # First of all set your libtorch path.
15 | set(Torch_DIR $ENV{Torch_DIR}/share/cmake/Torch)
16 |
17 | find_package(Torch REQUIRED)
18 | if (Torch_FOUND)
19 | message(STATUS "Torch library found!")
20 | else ()
21 | message(FATAL_ERROR "Could not locate Torch" \n)
22 | endif()
23 |
24 | # At this point, OpenCV should be already installed
25 | find_package(OpenCV REQUIRED)
26 |
27 | message(STATUS "OpenCV VERSION " ${OpenCV_VERSION})
28 |
29 | include_directories(
30 | ${OpenCV_INCLUDE_DIRS}
31 | )
32 | find_package(OpenCV REQUIRED)
33 |
34 |
35 | FILE(GLOB ALL_SOURCES
36 | "src/*.cpp"
37 | "src/architectures/*.cpp"
38 | "src/backbones/*.cpp"
39 | "src/utils/*.cpp"
40 | )
41 |
42 | if(BUILD_SHARED)
43 | add_library(segmentation SHARED ${ALL_SOURCES})
44 | message(STATUS "Target shared library")
45 | else()
46 | add_library(segmentation STATIC ${ALL_SOURCES})
47 | message(STATUS "Target static library")
48 | endif()
49 |
50 | configure_file(segmentation.pc.in segmentation.pc @ONLY)
51 |
52 |
53 | install(DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/src/"
54 | DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/segmentation
55 | FILES_MATCHING
56 | PATTERN "*.h*"
57 | )
58 |
59 | install(TARGETS segmentation
60 | LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
61 | )
62 |
63 | install(FILES ${CMAKE_BINARY_DIR}/segmentation.pc
64 | DESTINATION ${CMAKE_INSTALL_DATAROOTDIR}/pkgconfig)
65 |
66 |
67 | include(CMakePackageConfigHelpers)
68 |
69 | configure_package_config_file(
70 | "Config.cmake.in"
71 | "SegmentationConfig.cmake"
72 | INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/Segmentation
73 | PATH_VARS
74 | CMAKE_INSTALL_LIBDIR
75 | )
76 |
77 | write_basic_package_version_file(
78 | ${CMAKE_CURRENT_BINARY_DIR}/SegmentationConfigVersion.cmake
79 | VERSION ${PROJECT_VERSION}
80 | COMPATIBILITY SameMajorVersion
81 | )
82 |
83 | ### Install Config and ConfigVersion files
84 | install(
85 | FILES "${CMAKE_CURRENT_BINARY_DIR}/SegmentationConfig.cmake"
86 | "${CMAKE_CURRENT_BINARY_DIR}/SegmentationConfigVersion.cmake"
87 | DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/Segmentation"
88 | )
89 |
90 |
91 | set_target_properties(segmentation PROPERTIES VERSION ${PROJECT_VERSION})
92 |
93 | set(ALL_LIBS
94 | ${OpenCV_LIBS}
95 | ${TORCH_LIBRARIES}
96 | ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cudart_static_LIBRARY}
97 | )
98 |
99 | target_link_libraries(segmentation ${ALL_LIBS})
100 |
101 |
102 |
--------------------------------------------------------------------------------
/Config.cmake.in:
--------------------------------------------------------------------------------
1 | @PACKAGE_INIT@
2 | set(SEGMENTATION_INCLUDE_DIRS "@CMAKE_INSTALL_PREFIX@/include/segmentation" )
3 | set(Segmentation_FOUND TRUE)
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 AllentDan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README-Chinese.md:
--------------------------------------------------------------------------------
1 | [English](https://github.com/AllentDan/SegmentationCpp) | 中文
2 |
3 |
4 |
5 | 
6 | **基于[LibTorch](https://pytorch.org/)的C++开源图像分割神经网络库.**
7 |
8 |
9 |
10 | **⭐如果有用请给我一个star⭐**
11 |
12 | 这个库具有以下优点:
13 |
14 | - 高级的API (只需一行代码就可创建网络)
15 | - 7 种模型架构可用于单类或者多类的分割任务 (包括Unet)
16 | - 15 种编码器网络
17 | - 所有的编码器都有预训练权重,可以更快更好地收敛
18 | - 相比于python下的GPU前向推理速度具有30%或以上的提速, cpu下保持速度一致. (Unet测试于RTX 2070S).
19 |
20 | ### [📚 Libtorch教程 📚](https://github.com/AllentDan/LibtorchTutorials)
21 |
22 | 如果你想对该开源项目有更多更详细的了解,请前往本人另一个开源项目:[Libtorch教程](https://github.com/AllentDan/LibtorchTutorials) .
23 |
24 | ### 📋 目录
25 | 1. [快速开始](#start)
26 | 2. [例子](#examples)
27 | 3. [训练自己的数据](#trainingOwn)
28 | 4. [模型](#models)
29 | 1. [架构](#architectures)
30 | 2. [编码器](#encoders)
31 | 5. [安装](#installation)
32 | 6. [ToDo](#todo)
33 | 7. [感谢](#thanks)
34 | 8. [引用](#citing)
35 | 9. [证书](#license)
36 | 10. [相关项目](#related_repos)
37 |
38 | ### ⏳ 快速开始
39 |
40 | #### 1. 用 Libtorch Segment 创建你的第一个分割网络
41 |
42 | [这](https://github.com/AllentDan/LibtorchSegmentation/releases/download/weights/segmentor.pt)是一个resnet34的torchscript模型,可以作为骨干网络权重。分割模型是 LibTorch 的 torch::nn::Module的派生类, 可以很容易生成:
43 |
44 | ```cpp
45 | #include "Segmentor.h"
46 | auto model = UNet(1, /*num of classes*/
47 | "resnet34", /*encoder name, could be resnet50 or others*/
48 | "path to resnet34.pt"/*weight path pretrained on ImageNet, it is produced by torchscript*/
49 | );
50 | ```
51 | - 见 [表](#architectures) 查看所有支持的模型架构
52 | - 见 [表](#encoders) 查看所有的编码器网络和相应的预训练权重
53 |
54 | #### 2. 生成自己的预训练权重
55 |
56 | 所有编码器均具有预训练的权重。加载预训练权重,以相同的方式训练数据,可能会获得更好的结果(更高的指标得分和更快的收敛速度)。还可以在冻结主干的同时仅训练解码器和分割头。
57 |
58 | ```python
59 | import torch
60 | from torchvision import models
61 |
62 | # resnet50 for example
63 | model = models.resnet50(pretrained=True)
64 | model.eval()
65 | var=torch.ones((1,3,224,224))
66 | traced_script_module = torch.jit.trace(model, var)
67 | traced_script_module.save("resnet50.pt")
68 | ```
69 |
70 | 恭喜你! 大功告成! 现在,您可以使用自己喜欢的主干和分割框架来训练模型了。
71 |
72 | ### 💡 例子
73 | - 使用来自PASCAL VOC数据集的图像进行人体分割数据训练模型. "voc_person_seg" 目录包含32个json标签及其相应的jpeg图像用于训练,还有8个json标签以及相应的图像用于验证。
74 | ```cpp
75 | Segmentor segmentor;
76 | segmentor.Initialize(0/*gpu id, -1 for cpu*/,
77 | 512/*resize width*/,
78 | 512/*resize height*/,
79 | {"background","person"}/*class name dict, background included*/,
80 | "resnet34"/*backbone name*/,
81 | "your path to resnet34.pt");
82 | segmentor.Train(0.0003/*initial leaning rate*/,
83 | 300/*training epochs*/,
84 | 4/*batch size*/,
85 | "your path to voc_person_seg",
86 | ".jpg"/*image type*/,
87 | "your path to save segmentor.pt");
88 | ```
89 |
90 | - 预测测试。项目中提供了以ResNet34为骨干网络的FPN网络,训练了一些周期得到segmentor.pt文件[在这](https://github.com/AllentDan/LibtorchSegmentation/releases/download/weights/segmentor.pt)。 您可以直接测试分割结果:
91 | ```cpp
92 | cv::Mat image = cv::imread("your path to voc_person_seg\\val\\2007_004000.jpg");
93 | Segmentor segmentor;
94 | segmentor.Initialize(0,512,512,{"background","person"},
95 | "resnet34","your path to resnet34.pt");
96 | segmentor.LoadWeight("segmentor.pt"/*the saved .pt path*/);
97 | segmentor.Predict(image,"person"/*class name for showing*/);
98 | ```
99 | 预测结果显示如下:
100 |
101 | 
102 |
103 | ### 🧑🚀 训练自己的数据
104 | - 创建自己的数据集. 使用"pip install"安装[labelme](https://github.com/wkentaro/labelme)并标注你的图像. 将输出的json文件和图像分成以下文件夹:
105 | ```
106 | Dataset
107 | ├── train
108 | │ ├── xxx.json
109 | │ ├── xxx.jpg
110 | │ └......
111 | ├── val
112 | │ ├── xxxx.json
113 | │ ├── xxxx.jpg
114 | │ └......
115 | ```
116 | - 训练或测试。就像“ voc_person_seg”的示例一样,用自己的数据集路径替换“ voc_person_seg”。
117 | - 记得使用[训练技巧](https://github.com/AllentDan/LibtorchSegmentation/blob/main/docs/training%20tricks.md)以提高模型的训练效果。
118 |
119 |
120 | ### 📦 Models
121 |
122 | #### Architectures
123 | - [x] Unet [[paper](https://arxiv.org/abs/1505.04597)]
124 | - [x] FPN [[paper](http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf)]
125 | - [x] PAN [[paper](https://arxiv.org/abs/1805.10180)]
126 | - [x] PSPNet [[paper](https://arxiv.org/abs/1612.01105)]
127 | - [x] LinkNet [[paper](https://arxiv.org/abs/1707.03718)]
128 | - [x] DeepLabV3 [[paper](https://arxiv.org/abs/1706.05587)]
129 | - [x] DeepLabV3+ [[paper](https://arxiv.org/abs/1802.02611)]
130 |
131 | #### Encoders
132 | - [x] ResNet
133 | - [x] ResNext
134 | - [x] VGG
135 |
136 | 以下是该项目中受支持的编码器的列表。除resnest外,所有编码器权重都可以通过torchvision生成。选择适当的编码器,然后单击以展开表格,然后选择特定的编码器及其预训练的权重。
137 |
138 |
139 | ResNet
140 |
141 |
142 | |Encoder |Weights |Params, M |
143 | |--------------------------------|:------------------------------:|:------------------------------:|
144 | |resnet18 |imagenet |11M |
145 | |resnet34 |imagenet |21M |
146 | |resnet50 |imagenet |23M |
147 | |resnet101 |imagenet |42M |
148 | |resnet152 |imagenet |58M |
149 |
150 |
151 |
152 |
153 |
154 | ResNeXt
155 |
156 |
157 | |Encoder |Weights |Params, M |
158 | |--------------------------------|:------------------------------:|:------------------------------:|
159 | |resnext50_32x4d |imagenet |22M |
160 | |resnext101_32x8d |imagenet |86M |
161 |
162 |
163 |
164 |
165 |
166 | ResNeSt
167 |
168 |
169 | |Encoder |Weights |Params, M |
170 | |--------------------------------|:------------------------------:|:------------------------------:|
171 | |timm-resnest14d |imagenet |8M |
172 | |timm-resnest26d |imagenet |15M |
173 | |timm-resnest50d |imagenet |25M |
174 | |timm-resnest101e |imagenet |46M |
175 | |timm-resnest200e |imagenet |68M |
176 | |timm-resnest269e |imagenet |108M |
177 | |timm-resnest50d_4s2x40d |imagenet |28M |
178 | |timm-resnest50d_1s4x24d |imagenet |23M |
179 |
180 |
181 |
182 |
183 |
184 | SE-Net
185 |
186 |
187 | |Encoder |Weights |Params, M |
188 | |--------------------------------|:------------------------------:|:------------------------------:|
189 | |senet154 |imagenet |113M |
190 | |se_resnet50 |imagenet |26M |
191 | |se_resnet101 |imagenet |47M |
192 | |se_resnet152 |imagenet |64M |
193 | |se_resnext50_32x4d |imagenet |25M |
194 | |se_resnext101_32x4d |imagenet |46M |
195 |
196 |
197 |
198 |
199 |
200 | VGG
201 |
202 |
203 | |Encoder |Weights |Params, M |
204 | |--------------------------------|:------------------------------:|:------------------------------:|
205 | |vgg11 |imagenet |9M |
206 | |vgg11_bn |imagenet |9M |
207 | |vgg13 |imagenet |9M |
208 | |vgg13_bn |imagenet |9M |
209 | |vgg16 |imagenet |14M |
210 | |vgg16_bn |imagenet |14M |
211 | |vgg19 |imagenet |20M |
212 | |vgg19_bn |imagenet |20M |
213 |
214 |
215 |
216 |
217 | ### 🛠 安装
218 | **依赖库:**
219 |
220 | - [Opencv 3+](https://opencv.org/releases/)
221 | - [Libtorch 1.7+](https://pytorch.org/)
222 |
223 | **Windows:**
224 |
225 | 配置libtorch 开发环境. [Visual studio](https://allentdan.github.io/2020/12/16/pytorch%E9%83%A8%E7%BD%B2torchscript%E7%AF%87) 和 [Qt Creator](https://allentdan.github.io/2021/01/21/QT%20Creator%20+%20Opencv4.x%20+%20Libtorch1.7%E9%85%8D%E7%BD%AE/#more)已经通过libtorch1.7x release的验证.
226 |
227 | **Linux && MacOS:**
228 |
229 | 安装libtorch和opencv。
230 | 对于libtorch, 按照官方[教程](https://pytorch.org/tutorials/advanced/cpp_export.html)安装。
231 | 对于opencv, 按照官方安装[步骤](https://github.com/opencv/opencv)。
232 |
233 | 如果你都配置好了他们,恭喜!!! 下载一个resnet34的预训练权重,[点击下载](https://github.com/AllentDan/LibtorchSegmentation/releases/download/weights/resnet34.pt)和一个示例.pt文件,[点击下载](https://github.com/AllentDan/LibtorchSegmentation/releases/download/weights/segmentor.pt),放入weights文件夹。
234 |
235 | 更改src/main.cpp中的图片路径预训练权重和加载的segmentor权重路径。随后,build路径在终端输入:
236 | ```bash
237 | export Torch_DIR='/path/to/libtorch'
238 | cd build
239 | cmake ..
240 | make
241 | ./LibtorchSegmentation
242 | ```
243 |
244 | ### ⏳ ToDo
245 | - [ ] 更多的骨干网络和分割框架
246 | - [ ] UNet++ [[paper](https://arxiv.org/pdf/1807.10165.pdf)]
247 | - [ ] ResNest
248 | - [ ] Se-Net
249 | - [ ] ...
250 | - [x] 数据增强
251 | - [x] 随机水平翻转
252 | - [x] 随机垂直翻转
253 | - [x] 随机缩放和旋转
254 | - [ ] ...
255 | - [x] 训练技巧
256 | - [x] 联合损失:dice和交叉熵
257 | - [x] 冻结骨干网络
258 | - [x] 学习率衰减策略
259 | - [ ] ...
260 |
261 |
262 | ### 🤝 感谢
263 | 以下是目前给予帮助的项目.
264 | - [official pytorch](https://github.com/pytorch/pytorch)
265 | - [qubvel SMP](https://github.com/qubvel/segmentation_models.pytorch)
266 | - [wkentaro labelme](https://github.com/wkentaro/labelme)
267 | - [nlohmann json](https://github.com/nlohmann/json)
268 |
269 | ### 📝 引用
270 | ```
271 | @misc{Chunyu:2021,
272 | Author = {Chunyu Dong},
273 | Title = {Libtorch Segment},
274 | Year = {2021},
275 | Publisher = {GitHub},
276 | Journal = {GitHub repository},
277 | Howpublished = {\url{https://github.com/AllentDan/SegmentationCpp}}
278 | }
279 | ```
280 |
281 | ### 🛡️ 证书
282 | 该项目以 [MIT License](https://github.com/qubvel/segmentation_models.pytorch/blob/master/LICENSE)开源,
283 |
284 | ## 相关项目
285 | 基于libtorch,我释放了如下开源项目:
286 | - [LibtorchTutorials](https://github.com/AllentDan/LibtorchTutorials)
287 | - [LibtorchSegmentation](https://github.com/AllentDan/LibtorchSegmentation)
288 | - [LibtorchDetection](https://github.com/AllentDan/LibtorchDetection)
289 |
290 | 别忘了点赞哟
291 | 
292 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | English | [中文](https://github.com/AllentDan/SegmentationCpp/blob/main/README-Chinese.md)
2 |
3 |
4 |
5 | 
6 | **C++ library with Neural Networks for Image
7 | Segmentation based on [LibTorch](https://pytorch.org/).**
8 |
9 |
10 |
11 | **⭐Please give a star if this project helps you.⭐**
12 |
13 | The main features of this library are:
14 |
15 | - High level API (just a line to create a neural network)
16 | - 7 models architectures for binary and multi class segmentation (including legendary Unet)
17 | - 15 available encoders
18 | - All encoders have pre-trained weights for faster and better convergence
19 | - 35% or more inference speed boost compared with pytorch cuda, same speed for cpu. (Unet tested in rtx 2070s).
20 |
21 | ### [📚 Libtorch Tutorials 📚](https://github.com/AllentDan/LibtorchTutorials)
22 |
23 | Visit [Libtorch Tutorials Project](https://github.com/AllentDan/LibtorchTutorials) if you want to know more about Libtorch Segment library.
24 |
25 | ### 📋 Table of content
26 | 1. [Quick start](#start)
27 | 2. [Examples](#examples)
28 | 3. [Train your own data](#trainingOwn)
29 | 4. [Models](#models)
30 | 1. [Architectures](#architectures)
31 | 2. [Encoders](#encoders)
32 | 5. [Installation](#installation)
33 | 6. [Thanks](#thanks)
34 | 7. [To do list](#todo)
35 | 8. [Citing](#citing)
36 | 9. [License](#license)
37 | 10. [Related repository](#related_repos)
38 |
39 | ### ⏳ Quick start
40 |
41 | #### 1. Create your first Segmentation model with Libtorch Segment
42 |
43 | A resnet34 trochscript file is provided [here](https://github.com/AllentDan/LibtorchSegmentation/releases/download/weights/resnet34.pt). Segmentation model is just a LibTorch torch::nn::Module, which can be created as easy as:
44 |
45 | ```cpp
46 | #include "Segmentor.h"
47 | auto model = UNet(1, /*num of classes*/
48 | "resnet34", /*encoder name, could be resnet50 or others*/
49 | "path to resnet34.pt"/*weight path pretrained on ImageNet, it is produced by torchscript*/
50 | );
51 | ```
52 | - see [table](#architectures) with available model architectures
53 | - see [table](#encoders) with available encoders and their corresponding weights
54 |
55 | #### 2. Generate your own pretrained weights
56 |
57 | All encoders have pretrained weights. Preparing your data the same way as during weights pre-training may give your better results (higher metric score and faster convergence). And you can also train only the decoder and segmentation head while freeze the backbone.
58 |
59 | ```python
60 | import torch
61 | from torchvision import models
62 |
63 | # resnet34 for example
64 | model = models.resnet34(pretrained=True)
65 | model.eval()
66 | var=torch.ones((1,3,224,224))
67 | traced_script_module = torch.jit.trace(model, var)
68 | traced_script_module.save("resnet34.pt")
69 | ```
70 |
71 | Congratulations! You are done! Now you can train your model with your favorite backbone and segmentation framework.
72 |
73 | ### 💡 Examples
74 | - Training model for person segmentation using images from PASCAL VOC Dataset. "voc_person_seg" dir contains 32 json labels and their corresponding jpeg images for training and 8 json labels with corresponding images for validation.
75 | ```cpp
76 | Segmentor segmentor;
77 | segmentor.Initialize(0/*gpu id, -1 for cpu*/,
78 | 512/*resize width*/,
79 | 512/*resize height*/,
80 | {"background","person"}/*class name dict, background included*/,
81 | "resnet34"/*backbone name*/,
82 | "your path to resnet34.pt");
83 | segmentor.Train(0.0003/*initial leaning rate*/,
84 | 300/*training epochs*/,
85 | 4/*batch size*/,
86 | "your path to voc_person_seg",
87 | ".jpg"/*image type*/,
88 | "your path to save segmentor.pt");
89 | ```
90 |
91 | - Predicting test. A segmentor.pt file is provided in the project [here](https://github.com/AllentDan/LibtorchSegmentation/releases/download/weights/segmentor.pt). It is trained through a FPN with ResNet34 backbone for a few epochs. You can directly test the segmentation result through:
92 | ```cpp
93 | cv::Mat image = cv::imread("your path to voc_person_seg\\val\\2007_004000.jpg");
94 | Segmentor segmentor;
95 | segmentor.Initialize(0,512,512,{"background","person"},
96 | "resnet34","your path to resnet34.pt");
97 | segmentor.LoadWeight("segmentor.pt"/*the saved .pt path*/);
98 | segmentor.Predict(image,"person"/*class name for showing*/);
99 | ```
100 | the predicted result shows as follow:
101 |
102 | 
103 |
104 | ### 🧑🚀 Train your own data
105 | - Create your own dataset. Using [labelme](https://github.com/wkentaro/labelme) through "pip install" and label your images. Split the output json files and images into folders just like below:
106 | ```
107 | Dataset
108 | ├── train
109 | │ ├── xxx.json
110 | │ ├── xxx.jpg
111 | │ └......
112 | ├── val
113 | │ ├── xxxx.json
114 | │ ├── xxxx.jpg
115 | │ └......
116 | ```
117 | - Training or testing. Just like the example of "voc_person_seg", replace "voc_person_seg" with your own dataset path.
118 | - Refer to [training tricks](https://github.com/AllentDan/LibtorchSegmentation/blob/main/docs/training%20tricks.md) to improve your final training performance.
119 |
120 |
121 | ### 📦 Models
122 |
123 | #### Architectures
124 | - [x] Unet [[paper](https://arxiv.org/abs/1505.04597)]
125 | - [x] FPN [[paper](http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf)]
126 | - [x] PAN [[paper](https://arxiv.org/abs/1805.10180)]
127 | - [x] PSPNet [[paper](https://arxiv.org/abs/1612.01105)]
128 | - [x] LinkNet [[paper](https://arxiv.org/abs/1707.03718)]
129 | - [x] DeepLabV3 [[paper](https://arxiv.org/abs/1706.05587)]
130 | - [x] DeepLabV3+ [[paper](https://arxiv.org/abs/1802.02611)]
131 |
132 |
133 | #### Encoders
134 | - [x] ResNet
135 | - [x] ResNext
136 | - [x] VGG
137 |
138 | The following is a list of supported encoders in the Libtorch Segment. All the encoders weights can be generated through torchvision except resnest. Select the appropriate family of encoders and click to expand the table and select a specific encoder and its pre-trained weights.
139 |
140 |
141 | ResNet
142 |
143 |
144 | | Encoder | Weights | Params, M |
145 | | --------- | :------: | :-------: |
146 | | resnet18 | imagenet | 11M |
147 | | resnet34 | imagenet | 21M |
148 | | resnet50 | imagenet | 23M |
149 | | resnet101 | imagenet | 42M |
150 | | resnet152 | imagenet | 58M |
151 |
152 |
153 |
154 |
155 |
156 | ResNeXt
157 |
158 |
159 | | Encoder | Weights | Params, M |
160 | | ---------------- | :------: | :-------: |
161 | | resnext50_32x4d | imagenet | 22M |
162 | | resnext101_32x8d | imagenet | 86M |
163 |
164 |
165 |
166 |
167 |
168 | ResNeSt
169 |
170 |
171 | | Encoder | Weights | Params, M |
172 | | ----------------------- | :------: | :-------: |
173 | | timm-resnest14d | imagenet | 8M |
174 | | timm-resnest26d | imagenet | 15M |
175 | | timm-resnest50d | imagenet | 25M |
176 | | timm-resnest101e | imagenet | 46M |
177 | | timm-resnest200e | imagenet | 68M |
178 | | timm-resnest269e | imagenet | 108M |
179 | | timm-resnest50d_4s2x40d | imagenet | 28M |
180 | | timm-resnest50d_1s4x24d | imagenet | 23M |
181 |
182 |
183 |
184 |
185 |
186 | SE-Net
187 |
188 |
189 | | Encoder | Weights | Params, M |
190 | | ------------------- | :------: | :-------: |
191 | | senet154 | imagenet | 113M |
192 | | se_resnet50 | imagenet | 26M |
193 | | se_resnet101 | imagenet | 47M |
194 | | se_resnet152 | imagenet | 64M |
195 | | se_resnext50_32x4d | imagenet | 25M |
196 | | se_resnext101_32x4d | imagenet | 46M |
197 |
198 |
199 |
200 |
201 |
202 | VGG
203 |
204 |
205 | | Encoder | Weights | Params, M |
206 | | -------- | :------: | :-------: |
207 | | vgg11 | imagenet | 9M |
208 | | vgg11_bn | imagenet | 9M |
209 | | vgg13 | imagenet | 9M |
210 | | vgg13_bn | imagenet | 9M |
211 | | vgg16 | imagenet | 14M |
212 | | vgg16_bn | imagenet | 14M |
213 | | vgg19 | imagenet | 20M |
214 | | vgg19_bn | imagenet | 20M |
215 |
216 |
217 |
218 |
219 | ### 🛠 Installation
220 | **Dependency:**
221 |
222 | - [Opencv 3+](https://opencv.org/releases/)
223 | - [Libtorch 1.7+](https://pytorch.org/)
224 |
225 | **Windows:**
226 |
227 | Configure the environment for libtorch development. [Visual studio](https://allentdan.github.io/2020/03/05/windows-libtorch-configuration/) and [Qt Creator](https://allentdan.github.io/2020/03/05/QT-Creator-Opencv-Libtorch-CUDA-English/) are verified for libtorch1.7x release.
228 |
229 | **Linux && MacOS:**
230 |
231 | Install libtorch and opencv.
232 |
233 | For libtorch, follow the official pytorch c++ tutorials [here](https://pytorch.org/tutorials/advanced/cpp_export.html).
234 |
235 | For opencv, follow the official opencv install steps [here](https://github.com/opencv/opencv).
236 |
237 | If you have already configured them both, congratulations!!! Download the pretrained weight [here](https://github.com/AllentDan/LibtorchSegmentation/releases/download/weights/resnet34.pt) and a demo .pt file [here](https://github.com/AllentDan/LibtorchSegmentation/releases/download/weights/segmentor.pt) into weights.
238 |
239 | Building shared or static library -DBUILD_SHARED=:
240 |
241 | ```bash
242 | export Torch_DIR='/path/to/libtorch'
243 | cd build
244 | cmake -DBUILD_SHARED=TRUE ..
245 | make
246 | sudo make install
247 | ```
248 |
249 | Building tests:
250 | ```bash
251 | cd test
252 | mkdir build && cd build
253 | cmake ..
254 | make
255 | ./resnet34 ../../voc_person_seg/val/2007_003747.jpg ../../weights/resnet34.pt ../../weights/segmentor.pt
256 | ```
257 |
258 | ### ⏳ ToDo
259 | - [ ] More segmentation architectures and backbones
260 | - [ ] UNet++ [[paper](https://arxiv.org/pdf/1807.10165.pdf)]
261 | - [ ] ResNest
262 | - [ ] Se-Net
263 | - [ ] ...
264 | - [x] Data augmentations
265 | - [x] Random horizontal flip
266 | - [x] Random vertical flip
267 | - [x] Random scale rotation
268 | - [ ] ...
269 | - [x] Training tricks
270 | - [x] Combined dice and cross entropy loss
271 | - [x] Freeze backbone
272 | - [x] Multi step learning rate schedule
273 | - [ ] ...
274 |
275 |
276 | ### 🤝 Thanks
277 | By now, these projects helps a lot.
278 | - [official pytorch](https://github.com/pytorch/pytorch)
279 | - [qubvel SMP](https://github.com/qubvel/segmentation_models.pytorch)
280 | - [wkentaro labelme](https://github.com/wkentaro/labelme)
281 | - [nlohmann json](https://github.com/nlohmann/json)
282 |
283 | ### 📝 Citing
284 | ```
285 | @misc{Chunyu:2021,
286 | Author = {Chunyu Dong},
287 | Title = {Libtorch Segment},
288 | Year = {2021},
289 | Publisher = {GitHub},
290 | Journal = {GitHub repository},
291 | Howpublished = {\url{https://github.com/AllentDan/SegmentationCpp}}
292 | }
293 | ```
294 |
295 | ### 🛡️ License
296 | Project is distributed under [MIT License](https://github.com/qubvel/segmentation_models.pytorch/blob/master/LICENSE).
297 |
298 | ## Related repository
299 | Based on libtorch, I released following repositories:
300 | - [LibtorchTutorials](https://github.com/AllentDan/LibtorchTutorials)
301 | - [LibtorchSegmentation](https://github.com/AllentDan/LibtorchSegmentation)
302 | - [LibtorchDetection](https://github.com/AllentDan/LibtorchDetection)
303 |
304 | Last but not least, **don't forget** your star...
305 |
306 | Feel free to commit issues or pull requests, contributors wanted.
307 |
308 | 
309 |
--------------------------------------------------------------------------------
/build/cmake_buid_here.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AllentDan/LibtorchSegmentation/3fc7e71378e71fa119d87782bd0c8a67e541f27b/build/cmake_buid_here.txt
--------------------------------------------------------------------------------
/docs/training tricks.md:
--------------------------------------------------------------------------------
1 | ## Training tricks
2 | We provide a struct **trianTricks** for user to get better training performance. Users can apply the setting by:
3 | ```
4 | trianTricks tricks;
5 | segmentor.SetTrainTricks(tricks);
6 | ```
7 |
8 | ### Data augmentations
9 | Through OpenCV, we can also make data augmentations during training procedure. This repository mainly provide the following data augmentations:
10 | - Random horizontal flip
11 | - Random vertical flip
12 | - Random scale rotation
13 | For horizontal and vertical flip, the probabilities to implement them are controlled by:
14 | ```
15 | horizontal_flip_prob (float): probability to do horizontal flip augmentation, default 0;
16 | vertical_flip_prob (float): probability to do vertical flip augmentation, default 0;
17 | ```
18 | For random scale and rotate augmentation, it is set as:
19 | ```
20 | scale_rotate_prob (float): probability to do rotate and scale augmentation, default 0;
21 | ```
22 | Default 0 means this augmentation will be applied with 0 probability during the training procedure, not applied in other words.
23 |
24 | Besides, we also provide scale and rotate limitation parameters, the interpolation method and the padding mode for this augmentation.
25 | ```
26 | scale_limit (float): random enlarge or shrink the image by scale limit. For instance, if scale_limit equal to 0.1, \
27 | it will random resize the image size to [size*(0.9), size*(1.1)];
28 | rotate_limit (float): random rotate the image with the angle of [-rotate_limit, rotate_limit], in degree measure. \
29 | 45 degrees by default.
30 | interpolation (int): It use opencv interpolation setting, cv::INTER_LINEAR by default.
31 | border_mode (int): It use opencv border type setting, cv::BORDER_CONSTANT by default.
32 | ```
33 |
34 | ### Loss
35 | For better training, we can also use different segmentation loss functions. This repository mainly provide the following loss functions:
36 | - Cross entropy loss funcion
37 | - Dice loss function
38 | We can control the final loss by a hyper-parameter:
39 | ```
40 | dice_ce_ratio (float): the weight of dice loss in combind loss, default 0.5;
41 | ```
42 | The final loss used during training will be:
43 | ```
44 | loss = DiceLoss * dice_ce_ratio + CELoss * (1-dice_ce_ratio)
45 | ```
46 |
47 | ### Freeze backbone
48 | Just like pytorch, we can also freeze the backbone of the segmentation network. This repository provide a parameter to contol this:
49 | ```
50 | freeze_epochs (unsigned int): freeze the backbone during the first freeze_epochs, default 0;
51 | ```
52 | By default, the training will not use it.
53 |
54 | ### Decay the learning rate
55 | Well, libtorch does not provide learning rate schedure function for users. But we can also design it by ourselves. In this repository, we a multi step decay schedure.
56 | ```
57 | decay_epochs (std::vector): every decay_epoch, learning rate will decay by 90 percent, default {0};
58 | ```
59 | By default, this schedure will not be used during training.
60 |
61 | ## Some advise or wishes
62 | All the training tricks used in pytorch or python can be implemented in libtorch or cpp. But aparently, this reposity can not fullfill all of them. If you need a specific trick for the training, just implement it.
63 |
64 | **Any pull request to share your tricks or networks will be welcomed.**
--------------------------------------------------------------------------------
/prediction.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AllentDan/LibtorchSegmentation/3fc7e71378e71fa119d87782bd0c8a67e541f27b/prediction.jpg
--------------------------------------------------------------------------------
/segmentation.pc.in:
--------------------------------------------------------------------------------
1 | prefix=@CMAKE_INSTALL_PREFIX@
2 | exec_prefix=@CMAKE_INSTALL_PREFIX@
3 | libdir=${exec_prefix}/@CMAKE_INSTALL_LIBDIR@
4 | includedir=${prefix}/@CMAKE_INSTALL_INCLUDEDIR@/segmentation
5 |
6 | Name: @PROJECT_NAME@
7 | Description: @PROJECT_DESCRIPTION@
8 | Version: @PROJECT_VERSION@
9 |
10 | Requires:
11 | Libs: -L${libdir} -lsegmentation
12 | Cflags: -I${includedir}
13 |
--------------------------------------------------------------------------------
/src/SegDataset.cpp:
--------------------------------------------------------------------------------
1 | #include "SegDataset.h"
2 | #include"utils/Augmentations.h"
3 |
4 | std::vector get_color_list(){
5 | std::vector color_list = {
6 | cv::Scalar(0, 0, 0),
7 | cv::Scalar(128, 0, 0),
8 | cv::Scalar(0, 128, 0),
9 | cv::Scalar(128, 128, 0),
10 | cv::Scalar(0, 0, 128),
11 | cv::Scalar(128, 0, 128),
12 | cv::Scalar(0, 128, 128),
13 | cv::Scalar(128, 128, 128),
14 | cv::Scalar(64, 0, 0),
15 | cv::Scalar(192, 0, 0),
16 | cv::Scalar(64, 128, 0),
17 | cv::Scalar(192, 128, 0),
18 | cv::Scalar(64, 0, 128),
19 | cv::Scalar(192, 0, 128),
20 | cv::Scalar(64, 128, 128),
21 | cv::Scalar(192, 128, 128),
22 | cv::Scalar(0, 64, 0),
23 | cv::Scalar(128, 64, 0),
24 | cv::Scalar(0, 192, 0),
25 | cv::Scalar(128, 192, 0),
26 | cv::Scalar(0, 64, 128),
27 | };
28 | return color_list;
29 | }
30 |
31 |
32 | void show_mask(std::string json_path, std::string image_type) {
33 | using namespace std;
34 | using json = nlohmann::json;
35 | std::string image_path = replace_all_distinct(json_path, ".json", image_type);
36 | cv::Mat image = cv::imread(image_path);
37 | cv::Mat mask;
38 | mask.create(image.size(), CV_8UC3);
39 |
40 | std::ifstream jfile(json_path);
41 | json j;
42 | jfile >> j;
43 | size_t num_blobs = j["shapes"].size();
44 |
45 | for (int i = 0; i < num_blobs; i++)
46 | {
47 | std::string name = j["shapes"][i]["label"];
48 | size_t points_len = j["shapes"][i]["points"].size();
49 | cout << name << endl;
50 | std::vector contour = {};
51 | for (int t = 0; t < points_len; t++)
52 | {
53 | int x = round(double(j["shapes"][i]["points"][t][0]));
54 | int y = round(double(j["shapes"][i]["points"][t][1]));
55 | cout << x << "\t" << y << endl;
56 | contour.push_back(cv::Point(x, y));
57 | }
58 | const cv::Point* ppt[1] = { contour.data() };
59 | int npt[] = { int(contour.size()) };
60 | cv::fillPoly(mask, ppt, npt, 1, cv::Scalar(255, 255, 255));
61 | }
62 | cv::imshow("mask", mask);
63 | cv::imshow("image", image);
64 | cv::waitKey(0);
65 | cv::destroyAllWindows();
66 | }
67 |
68 | void SegDataset::draw_mask(std::string json_path, cv::Mat &mask){
69 | std::ifstream jfile(json_path);
70 | nlohmann::json j;
71 | jfile >> j;
72 | size_t num_blobs = j["shapes"].size();
73 |
74 |
75 | for (int i = 0; i < num_blobs; i++)
76 | {
77 | std::string name = j["shapes"][i]["label"];
78 | size_t points_len = j["shapes"][i]["points"].size();
79 | // std::cout << name << std::endl;
80 | std::vector contour = {};
81 | for (int t = 0; t < points_len; t++)
82 | {
83 | int x = round(double(j["shapes"][i]["points"][t][0]));
84 | int y = round(double(j["shapes"][i]["points"][t][1]));
85 | // std::cout << x << "\t" << y << std::endl;
86 | contour.push_back(cv::Point(x, y));
87 | }
88 | const cv::Point* ppt[1] = { contour.data() };
89 | int npt[] = { int(contour.size()) };
90 | cv::fillPoly(mask, ppt, npt, 1, name2color[name]);
91 | }
92 | }
93 |
94 | SegDataset::SegDataset(int resize_width, int resize_height, std::vector list_images,
95 | std::vector list_labels, std::vector name_list,
96 | trainTricks tricks, bool isTrain)
97 | {
98 | this->tricks = tricks;
99 | this->name_list = name_list;
100 | this->resize_width = resize_width;
101 | this->resize_height = resize_height;
102 | this->list_images = list_images;
103 | this->list_labels = list_labels;
104 | this->isTrain = isTrain;
105 | for(int i=0; i(name_list[i], i));
107 | }
108 | std::vector color_list = get_color_list();
109 | if(name_list.size()>color_list.size()){
110 | std::cout<< "Num of classes exceeds defined color list, please add color to color list in SegDataset.cpp";
111 | }
112 | for(int i = 0; i(name_list[i],color_list[i]));
114 | }
115 | }
116 |
117 | torch::data::Example<> SegDataset::get(size_t index) {
118 | std::string image_path = list_images.at(index);
119 | std::string label_path = list_labels.at(index);
120 | cv::Mat image = cv::imread(image_path);
121 | cv::Mat mask = cv::Mat::zeros(image.rows, image.cols, CV_8UC3);
122 | draw_mask(label_path,mask);
123 |
124 | //Data augmentation like flip or rotate could be implemented here...
125 | auto m_data = Data(image, mask);
126 | if (isTrain) {
127 | m_data = Augmentations::Resize(m_data, resize_width, resize_height, 1);
128 | m_data = Augmentations::HorizontalFlip(m_data, tricks.horizontal_flip_prob);
129 | m_data = Augmentations::VerticalFlip(m_data, tricks.vertical_flip_prob);
130 | m_data = Augmentations::RandomScaleRotate(m_data, tricks.scale_rotate_prob, \
131 | tricks.rotate_limit, tricks.scale_limit, \
132 | tricks.interpolation, tricks.border_mode);
133 | }
134 | else {
135 | m_data = Augmentations::Resize(m_data, resize_width, resize_height, 1);
136 | }
137 | torch::Tensor img_tensor = torch::from_blob(m_data.image.data, { m_data.image.rows, m_data.image.cols, 3 }, torch::kByte).permute({ 2, 0, 1 }); // Channels x Height x Width
138 | torch::Tensor colorful_label_tensor = torch::from_blob(m_data.mask.data, { m_data.mask.rows, m_data.mask.cols, 3 }, torch::kByte);
139 | torch::Tensor label_tensor = torch::zeros({ m_data.image.rows, m_data.image.cols});
140 |
141 | //encode "colorful" tensor to class_index meaning tensor, [w,h,3]->[w,h], pixel value is the index of a class
142 | for(int i = 0; i
6 |
7 | //freeze_epochs (unsigned int): freeze the backbone during the first freeze_epochs, default 0;
8 | //decay_epochs (std::vector): every decay_epoch, learning rate will decay by 90 percent, default {0};
9 | //dice_ce_ratio (float): the weight of dice loss in combind loss, default 0.5;
10 | //horizontal_flip_prob (float): probability to do horizontal flip augmentation, default 0;
11 | //vertical_flip_prob (float): probability to do vertical flip augmentation, default 0;
12 | //scale_rotate_prob (float): probability to do rotate and scale augmentation, default 0;
13 | struct trainTricks {
14 | unsigned int freeze_epochs = 0;
15 | std::vector decay_epochs = { 0 };
16 | float dice_ce_ratio = 0.5;
17 |
18 | float horizontal_flip_prob = 0;
19 | float vertical_flip_prob = 0;
20 | float scale_rotate_prob = 0;
21 |
22 | float scale_limit = 0.1;
23 | float rotate_limit = 45;
24 | int interpolation = cv::INTER_LINEAR;
25 | int border_mode = cv::BORDER_CONSTANT;
26 | };
27 |
28 | void show_mask(std::string json_path, std::string image_type = ".jpg");
29 |
30 | class SegDataset :public torch::data::Dataset
31 | {
32 | public:
33 | SegDataset(int resize_width, int resize_height, std::vector list_images,
34 | std::vector list_labels, std::vector name_list,
35 | trainTricks tricks, bool isTrain = false);
36 | // Override get() function to return tensor at location index
37 | torch::data::Example<> get(size_t index) override;
38 | // Return the length of data
39 | torch::optional size() const override {
40 | return list_labels.size();
41 | };
42 | private:
43 | void draw_mask(std::string json_path, cv::Mat &mask);
44 | int resize_width = 512; int resize_height = 512; bool isTrain = false;
45 | std::vector name_list = {};
46 | std::map name2index = {};
47 | std::map name2color = {};
48 | std::vector list_images;
49 | std::vector list_labels;
50 | trainTricks tricks;
51 | };
52 |
53 | #endif // SEGDATASET_H
54 |
--------------------------------------------------------------------------------
/src/Segmentor.cpp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AllentDan/LibtorchSegmentation/3fc7e71378e71fa119d87782bd0c8a67e541f27b/src/Segmentor.cpp
--------------------------------------------------------------------------------
/src/Segmentor.h:
--------------------------------------------------------------------------------
1 | /*
2 | This libtorch implementation is writen by AllentDan.
3 | Copyright(c) AllentDan 2021,
4 | All rights reserved.
5 | */
6 | #ifndef SEGMENTOR_H
7 | #define SEGMENTOR_H
8 | #include"architectures/FPN.h"
9 | #include"architectures/PAN.h"
10 | #include"architectures/UNet.h"
11 | #include"architectures/LinkNet.h"
12 | #include"architectures/PSPNet.h"
13 | #include"architectures/DeepLab.h"
14 | #include"utils/loss.h"
15 | #include"SegDataset.h"
16 | #include
17 | #if _WIN32
18 | #include
19 | #else
20 | #include
21 | #endif
22 |
23 | template
24 | class Segmentor
25 | {
26 | public:
27 | Segmentor();
28 | ~Segmentor() {};
29 | void Initialize(int gpu_id, int width, int height, std::vector&& name_list,
30 | std::string encoder_name, std::string pretrained_path);
31 | void SetTrainTricks(trainTricks &tricks);
32 | void Train(float learning_rate, int epochs, int batch_size,
33 | std::string train_val_path, std::string image_type, std::string save_path);
34 | void LoadWeight(std::string weight_path);
35 | void Predict(cv::Mat& image, const std::string& which_class);
36 | private:
37 | int width = 512; int height = 512; std::vector name_list;
38 | torch::Device device = torch::Device(torch::kCPU);
39 | trainTricks tricks;
40 | // FPN fpn{nullptr};
41 | // UNet unet{nullptr};
42 | Model model{ nullptr };
43 | };
44 |
45 | template
46 | Segmentor::Segmentor()
47 | {
48 | };
49 |
50 | template
51 | void Segmentor::Initialize(int gpu_id, int _width, int _height, std::vector&& _name_list,
52 | std::string encoder_name, std::string pretrained_path) {
53 | width = _width;
54 | height = _height;
55 | name_list = _name_list;
56 | //std::cout << pretrained_path << std::endl;
57 | //struct stat s {};
58 | //lstat(pretrained_path.c_str(), &s);
59 | #ifdef _WIN32
60 | if ((_access(pretrained_path.data(), 0)) == -1)
61 | {
62 | std::cout<< "Pretrained path is invalid";
63 | }
64 | #else
65 | if (access(pretrained_path.data(), F_OK) != 0)
66 | {
67 | std::cout<< "Pretrained path is invalid";
68 | }
69 | #endif
70 | if (name_list.size() < 2) std::cout<< "Class num is less than 1";
71 | int gpu_num = torch::getNumGPUs();
72 | if (gpu_id >= gpu_num) std::cout<< "GPU id exceeds max number of gpus";
73 | if (gpu_id >= 0) device = torch::Device(torch::kCUDA, gpu_id);
74 |
75 | model = Model(name_list.size(), encoder_name, pretrained_path);
76 | // fpn = FPN(name_list.size(),encoder_name,pretrained_path);
77 | model->to(device);
78 | }
79 |
80 | template
81 | void Segmentor::SetTrainTricks(trainTricks &tricks) {
82 | this->tricks = tricks;
83 | return;
84 | }
85 |
86 | template
87 | void Segmentor::Train(float learning_rate, int epochs, int batch_size,
88 | std::string train_val_path, std::string image_type, std::string save_path) {
89 |
90 | std::string train_dir = train_val_path.append({ file_sepator() }).append("train");
91 | std::string val_dir = replace_all_distinct(train_dir,"train","val");
92 |
93 | std::vector list_images_train = {};
94 | std::vector list_labels_train = {};
95 | std::vector list_images_val = {};
96 | std::vector list_labels_val = {};
97 |
98 | load_seg_data_from_folder(train_dir, image_type, list_images_train, list_labels_train);
99 | load_seg_data_from_folder(val_dir, image_type, list_images_val, list_labels_val);
100 |
101 | auto custom_dataset_train = SegDataset(width, height, list_images_train, list_labels_train, \
102 | name_list, tricks, true).map(torch::data::transforms::Stack<>());
103 | auto custom_dataset_val = SegDataset(width, height, list_images_val, list_labels_val, \
104 | name_list, tricks, false).map(torch::data::transforms::Stack<>());
105 | auto options = torch::data::DataLoaderOptions();
106 | options.drop_last(true);
107 | options.batch_size(batch_size);
108 | auto data_loader_train = torch::data::make_data_loader(std::move(custom_dataset_train), options);
109 | auto data_loader_val = torch::data::make_data_loader(std::move(custom_dataset_val), options);
110 |
111 | float best_loss = 1e10;
112 | for (int epoch = 0; epoch < epochs; epoch++) {
113 | float loss_sum = 0;
114 | int batch_count = 0;
115 | float loss_train = 0;
116 | float dice_coef_sum = 0;
117 |
118 | for (auto decay_epoch : tricks.decay_epochs) {
119 | if(decay_epoch-1 == epoch)
120 | learning_rate /= 10;
121 | }
122 | torch::optim::Adam optimizer(model->parameters(), learning_rate);
123 | if (epoch < tricks.freeze_epochs) {
124 | for (auto mm : model->named_parameters())
125 | {
126 | if (strstr(mm.key().data(), "encoder"))
127 | {
128 | mm.value().set_requires_grad(false);
129 | }
130 | else
131 | {
132 | mm.value().set_requires_grad(true);
133 | }
134 | }
135 | }
136 | else {
137 | for (auto mm : model->named_parameters())
138 | {
139 | mm.value().set_requires_grad(true);
140 | }
141 | }
142 | model->train();
143 | for (auto& batch : *data_loader_train) {
144 | auto data = batch.data;
145 | auto target = batch.target;
146 | data = data.to(torch::kF32).to(device).div(255.0);
147 | target = target.to(torch::kLong).to(device).squeeze(1);//.clamp_max(1);//if you choose clamp, all classes will be set to only one
148 |
149 | optimizer.zero_grad();
150 | // Execute the model
151 | torch::Tensor prediction = model->forward(data);
152 | // Compute loss value
153 | torch::Tensor ce_loss = CELoss(prediction, target);
154 | torch::Tensor dice_loss = DiceLoss(torch::softmax(prediction, 1), target.unsqueeze(1), name_list.size());
155 | auto loss = dice_loss * tricks.dice_ce_ratio + ce_loss * (1 - tricks.dice_ce_ratio);
156 | // Compute gradients
157 | loss.backward();
158 | // Update the parameters
159 | optimizer.step();
160 | loss_sum += loss.item().toFloat();
161 | dice_coef_sum += (1- dice_loss).item().toFloat();
162 | batch_count++;
163 | loss_train = loss_sum / batch_count / batch_size;
164 | auto dice_coef = dice_coef_sum / batch_count;
165 |
166 | std::cout << "Epoch: " << epoch << "," << " Training Loss: " << loss_train << \
167 | "," << " Dice coefficient: " << dice_coef << "\r";
168 | }
169 | std::cout << std::endl;
170 | // validation part
171 | model->eval();
172 | loss_sum = 0; batch_count = 0; dice_coef_sum = 0;
173 | float loss_val = 0;
174 | for (auto& batch : *data_loader_val) {
175 | auto data = batch.data;
176 | auto target = batch.target;
177 | data = data.to(torch::kF32).to(device).div(255.0);
178 | target = target.to(torch::kLong).to(device).squeeze(1);//.clamp_max(1);
179 |
180 | // Execute the model
181 | torch::Tensor prediction = model->forward(data);
182 |
183 | // Compute loss value
184 | torch::Tensor ce_loss = CELoss(prediction, target);
185 | torch::Tensor dice_loss = DiceLoss(torch::softmax(prediction, 1), target.unsqueeze(1), name_list.size());
186 | auto loss = dice_loss * tricks.dice_ce_ratio + ce_loss * (1 - tricks.dice_ce_ratio);
187 | loss_sum += loss.template item();
188 | dice_coef_sum += (1 - dice_loss).item().toFloat();
189 | batch_count++;
190 | loss_val = loss_sum / batch_count / batch_size;
191 | auto dice_coef = dice_coef_sum / batch_count;
192 |
193 | std::cout << "Epoch: " << epoch << "," << " Validation Loss: " << loss_val << \
194 | "," << " Dice coefficient: " << dice_coef << "\r";
195 | }
196 | std::cout << std::endl;
197 | if (loss_val < best_loss) {
198 | torch::save(model, save_path);
199 | best_loss = loss_val;
200 | }
201 | }
202 | return;
203 | }
204 |
205 | template
206 | void Segmentor::LoadWeight(std::string weight_path) {
207 | torch::load(model, weight_path);
208 | model->to(device);
209 | model->eval();
210 | return;
211 | }
212 |
213 | template
214 | void Segmentor::Predict(cv::Mat& image, const std::string& which_class) {
215 | cv::Mat srcImg = image.clone();
216 | int which_class_index = -1;
217 | for (int i = 0; i < name_list.size(); i++) {
218 | if (name_list[i] == which_class) {
219 | which_class_index = i;
220 | break;
221 | }
222 | }
223 | if (which_class_index == -1) std::cout<< which_class + "not in the name list";
224 | int image_width = image.cols;
225 | int image_height = image.rows;
226 | cv::resize(image, image, cv::Size(width, height));
227 | torch::Tensor tensor_image = torch::from_blob(image.data, { 1, height, width,3 }, torch::kByte);
228 | tensor_image = tensor_image.to(device);
229 | tensor_image = tensor_image.permute({ 0,3,1,2 });
230 | tensor_image = tensor_image.to(torch::kFloat);
231 | tensor_image = tensor_image.div(255.0);
232 |
233 | try
234 | {
235 | at::Tensor output = model->forward({ tensor_image });
236 |
237 | }
238 | catch (const std::exception& e)
239 | {
240 | std::cout << e.what();
241 | }
242 | at::Tensor output = model->forward({ tensor_image });
243 | output = torch::softmax(output, 1).mul(255.0).toType(torch::kByte);
244 |
245 | image = cv::Mat::ones(cv::Size(width, height), CV_8UC1);
246 |
247 | at::Tensor re = output[0][which_class_index].to(at::kCPU).detach();
248 | memcpy(image.data, re.data_ptr(), width * height * sizeof(unsigned char));
249 | cv::resize(image, image, cv::Size(image_width, image_height));
250 |
251 | // draw the prediction
252 | cv::imwrite("prediction.jpg", image);
253 | cv::imshow("prediction", image);
254 | cv::imshow("srcImage", srcImg);
255 | cv::waitKey(0);
256 | cv::destroyAllWindows();
257 | return;
258 | }
259 |
260 | #endif // SEGMENTOR_H
261 |
--------------------------------------------------------------------------------
/src/architectures/DeepLab.cpp:
--------------------------------------------------------------------------------
1 | #include "DeepLab.h"
2 |
3 | DeepLabV3Impl::DeepLabV3Impl(int _num_classes, std::string encoder_name, std::string pretrained_path, int encoder_depth,
4 | int decoder_channels, int in_channels, double upsampling) {
5 | num_classes = _num_classes;
6 | auto encoder_param = encoder_params();
7 | std::vector encoder_channels = encoder_param[encoder_name]["out_channels"];
8 | if (!encoder_param.contains(encoder_name))
9 | std::cout<< "encoder name must in {resnet18, resnet34, resnet50, resnet101, resnet150, \
10 | resnext50_32x4d, resnext101_32x8d, vgg11, vgg11_bn, vgg13, vgg13_bn, \
11 | vgg16, vgg16_bn, vgg19, vgg19_bn,}";
12 | if (encoder_param[encoder_name]["class_type"] == "resnet")
13 | encoder = new ResNetImpl(encoder_param[encoder_name]["layers"], 1000, encoder_name);
14 | else if (encoder_param[encoder_name]["class_type"] == "vgg")
15 | encoder = new VGGImpl(encoder_param[encoder_name]["cfg"], 1000, encoder_param[encoder_name]["batch_norm"]);
16 | else std::cout<< "unknown error in backbone initialization";
17 |
18 | encoder->load_pretrained(pretrained_path);
19 | encoder->make_dilated({ 5,4 }, {4,2});
20 |
21 | decoder = DeepLabV3Decoder(encoder_channels[encoder_channels.size()-1], decoder_channels);
22 | segmentation_head = SegmentationHead(decoder_channels, num_classes, 1, upsampling);
23 |
24 | register_module("encoder", std::shared_ptr(encoder));
25 | register_module("decoder", decoder);
26 | register_module("segmentation_head", segmentation_head);
27 | }
28 |
29 | torch::Tensor DeepLabV3Impl::forward(torch::Tensor x) {
30 | std::vector features = encoder->features(x);
31 | x = decoder->forward(features);
32 | x = segmentation_head->forward(x);
33 | return x;
34 | }
35 |
36 | DeepLabV3PlusImpl::DeepLabV3PlusImpl(int _num_classes, std::string encoder_name, std::string pretrained_path, int encoder_depth,
37 | int encoder_output_stride, int decoder_channels, int in_channels, double upsampling) {
38 | num_classes = _num_classes;
39 | auto encoder_param = encoder_params();
40 | std::vector encoder_channels = encoder_param[encoder_name]["out_channels"];
41 | if (!encoder_param.contains(encoder_name))
42 | std::cout<< "encoder name must in {resnet18, resnet34, resnet50, resnet101, resnet150, \
43 | resnext50_32x4d, resnext101_32x8d, vgg11, vgg11_bn, vgg13, vgg13_bn, \
44 | vgg16, vgg16_bn, vgg19, vgg19_bn,}";
45 | if (encoder_param[encoder_name]["class_type"] == "resnet")
46 | encoder = new ResNetImpl(encoder_param[encoder_name]["layers"], 1000, encoder_name);
47 | else if (encoder_param[encoder_name]["class_type"] == "vgg")
48 | encoder = new VGGImpl(encoder_param[encoder_name]["cfg"], 1000, encoder_param[encoder_name]["batch_norm"]);
49 | else std::cout<< "unknown error in backbone initialization";
50 |
51 | encoder->load_pretrained(pretrained_path);
52 | if (encoder_output_stride == 8) {
53 | encoder->make_dilated({ 5,4 }, { 4,2 });
54 | }
55 | else if (encoder_output_stride == 16) {
56 | encoder->make_dilated({ 5 }, { 2 });
57 | }
58 | else {
59 | std::cout<< "Encoder output stride should be 8 or 16";
60 | }
61 |
62 | decoder = DeepLabV3PlusDecoder(encoder_channels, decoder_channels, decoder_atrous_rates, encoder_output_stride);
63 | segmentation_head = SegmentationHead(decoder_channels, num_classes, 1, upsampling);
64 |
65 | register_module("encoder", std::shared_ptr(encoder));
66 | register_module("decoder", decoder);
67 | register_module("segmentation_head", segmentation_head);
68 | }
69 |
70 | torch::Tensor DeepLabV3PlusImpl::forward(torch::Tensor x) {
71 | std::vector features = encoder->features(x);
72 | x = decoder->forward(features);
73 | x = segmentation_head->forward(x);
74 | return x;
75 | }
76 |
--------------------------------------------------------------------------------
/src/architectures/DeepLab.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include"../backbones/ResNet.h"
3 | #include"../backbones/VGG.h"
4 | #include"DeepLabDecoder.h"
5 |
6 | class DeepLabV3Impl : public torch::nn::Module
7 | {
8 | public:
9 | DeepLabV3Impl() {}
10 | ~DeepLabV3Impl() {
11 | //delete encoder;
12 | }
13 | DeepLabV3Impl(int num_classes, std::string encoder_name = "resnet18", std::string pretrained_path = "", int encoder_depth = 5,
14 | int decoder_channels = 256, int in_channels = 3, double upsampling = 8);
15 | torch::Tensor forward(torch::Tensor x);
16 | private:
17 | Backbone *encoder;
18 | DeepLabV3Decoder decoder{ nullptr };
19 | SegmentationHead segmentation_head{ nullptr };
20 | int num_classes = 1;
21 | }; TORCH_MODULE(DeepLabV3);
22 |
23 | class DeepLabV3PlusImpl : public torch::nn::Module
24 | {
25 | public:
26 | DeepLabV3PlusImpl() {};
27 | ~DeepLabV3PlusImpl() {
28 | //delete encoder;
29 | }
30 | DeepLabV3PlusImpl(int num_classes, std::string encoder_name = "resnet18", std::string pretrained_path = "", int encoder_depth = 5,
31 | int encoder_output_stride = 16, int decoder_channels = 256, int in_channels = 3, double upsampling = 4);
32 | torch::Tensor forward(torch::Tensor x);
33 | private:
34 | Backbone* encoder;
35 | DeepLabV3PlusDecoder decoder{ nullptr };
36 | SegmentationHead segmentation_head{ nullptr };
37 | int num_classes = 1;
38 | std::vector decoder_atrous_rates = { 12, 24, 36 };
39 | }; TORCH_MODULE(DeepLabV3Plus);
--------------------------------------------------------------------------------
/src/architectures/DeepLabDecoder.cpp:
--------------------------------------------------------------------------------
1 | #include "DeepLabDecoder.h"
2 |
3 |
4 | torch::nn::Sequential ASPPConv(int in_channels, int out_channels, int dilation) {
5 | return torch::nn::Sequential(
6 | torch::nn::Conv2d(conv_options(in_channels, out_channels, 3, 1, dilation, 1, false, dilation)),
7 | torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(out_channels)),
8 | torch::nn::ReLU()
9 | );
10 | }
11 |
12 | torch::nn::Sequential SeparableConv2d(int in_channels, int out_channels, int kernel_size, int stride,
13 | int padding, int dilation, bool bias) {
14 | torch::nn::Conv2d dephtwise_conv = torch::nn::Conv2d(conv_options(in_channels, in_channels, kernel_size,
15 | stride, padding, 1, false, dilation));
16 | torch::nn::Conv2d pointwise_conv = torch::nn::Conv2d(conv_options(in_channels, out_channels, 1, 1, 0, 1, bias));
17 | return torch::nn::Sequential(dephtwise_conv, pointwise_conv);
18 | };
19 |
20 | torch::nn::Sequential ASPPSeparableConv(int in_channels, int out_channels, int dilation) {
21 | torch::nn::Sequential seq = SeparableConv2d(in_channels, out_channels, 3, 1, dilation, dilation, false);
22 | seq->push_back(torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(out_channels)));
23 | seq->push_back(torch::nn::ReLU());
24 | return seq;
25 | }
26 |
27 | ASPPPoolingImpl::ASPPPoolingImpl(int in_channels, int out_channels) {
28 | seq = torch::nn::Sequential(torch::nn::AdaptiveAvgPool2d(torch::nn::AdaptiveAvgPool2dOptions(1)),
29 | torch::nn::Conv2d(conv_options(in_channels, out_channels, 1, 1, 0, 1, false)),
30 | torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(out_channels)),
31 | torch::nn::ReLU());
32 | register_module("seq", seq);
33 | }
34 |
35 | torch::Tensor ASPPPoolingImpl::forward(torch::Tensor x) {
36 | auto residual(x.clone());
37 | x = seq->forward(x);
38 | x = at::upsample_bilinear2d(x, residual[0][0].sizes(), false);
39 | return x;
40 | }
41 |
42 | ASPPImpl::ASPPImpl(int in_channels, int out_channels, std::vector atrous_rates, bool separable) {
43 | modules->push_back(torch::nn::Sequential(torch::nn::Conv2d(conv_options(in_channels, out_channels, 1, 1, 0, 1, false)),
44 | torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(out_channels)),
45 | torch::nn::ReLU()));
46 | if (atrous_rates.size() != 3) std::cout<< "size of atrous_rates must be 3";
47 | if (separable) {
48 | modules->push_back(ASPPSeparableConv(in_channels, out_channels, atrous_rates[0]));
49 | modules->push_back(ASPPSeparableConv(in_channels, out_channels, atrous_rates[1]));
50 | modules->push_back(ASPPSeparableConv(in_channels, out_channels, atrous_rates[2]));
51 | }
52 | else {
53 | modules->push_back(ASPPConv(in_channels, out_channels, atrous_rates[0]));
54 | modules->push_back(ASPPConv(in_channels, out_channels, atrous_rates[1]));
55 | modules->push_back(ASPPConv(in_channels, out_channels, atrous_rates[2]));
56 | }
57 | aspppooling = ASPPPooling(in_channels, out_channels);
58 |
59 | project = torch::nn::Sequential(
60 | torch::nn::Conv2d(conv_options(5 * out_channels, out_channels, 1, 1, 0, 1, false)),
61 | torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(out_channels)),
62 | torch::nn::ReLU(),
63 | torch::nn::Dropout(torch::nn::DropoutOptions(0.5)));
64 |
65 | register_module("modules", modules);
66 | register_module("aspppooling", aspppooling);
67 | register_module("project", project);
68 | }
69 |
70 | torch::Tensor ASPPImpl::forward(torch::Tensor x) {
71 | std::vector res;
72 | for (int i = 0; i < 4; i++) {
73 | res.push_back(modules[i]->as()->forward(x));
74 | }
75 | res.push_back(aspppooling->forward(x));
76 | x = torch::cat(res, 1);
77 | x = project->forward(x);
78 | return x;
79 | }
80 |
81 | DeepLabV3DecoderImpl::DeepLabV3DecoderImpl(int in_channels, int out_channels, std::vector atrous_rates) {
82 | seq->push_back(ASPP(in_channels, out_channels, atrous_rates));
83 | seq->push_back(torch::nn::Conv2d(conv_options(out_channels, out_channels, 3, 1, 1, 1, false)));
84 | seq->push_back(torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(out_channels)));
85 | seq->push_back(torch::nn::ReLU());
86 |
87 | register_module("seq", seq);
88 | }
89 |
90 | torch::Tensor DeepLabV3DecoderImpl::forward(std::vector< torch::Tensor> x_list) {
91 | auto x = seq->forward(x_list[x_list.size() - 1]);
92 | return x;
93 | }
94 |
95 | DeepLabV3PlusDecoderImpl::DeepLabV3PlusDecoderImpl(std::vector encoder_channels, int out_channels,
96 | std::vector atrous_rates, int output_stride) {
97 | if (output_stride != 8 && output_stride != 16) std::cout<< "Output stride should be 8 or 16";
98 | aspp = ASPP(encoder_channels[encoder_channels.size() - 1], out_channels, atrous_rates, true);
99 | aspp_seq = SeparableConv2d(out_channels, out_channels, 3, 1, 1, 1, false);
100 | aspp_seq->push_back(torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(out_channels)));
101 | aspp_seq->push_back(torch::nn::ReLU());
102 | double scale_factor = double(output_stride / 4);
103 | up = torch::nn::Upsample(torch::nn::UpsampleOptions().align_corners(true).scale_factor(std::vector({ scale_factor,scale_factor })).mode(torch::kBilinear));
104 | int highres_in_channels = encoder_channels[encoder_channels.size() -4];
105 | int highres_out_channels = 48; // proposed by authors of paper
106 |
107 | block1 = torch::nn::Sequential(
108 | torch::nn::Conv2d(conv_options(highres_in_channels, highres_out_channels, 1, 1, 0, 1, false)),
109 | torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(highres_out_channels)),
110 | torch::nn::ReLU());
111 | block2 = SeparableConv2d(highres_out_channels + out_channels, out_channels, 3, 1, 1, 1, false);
112 | block2->push_back(torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(out_channels)));
113 | block2->push_back(torch::nn::ReLU());
114 |
115 | register_module("aspp", aspp);
116 | register_module("aspp_seq", aspp_seq);
117 | register_module("block1", block1);
118 | register_module("block2", block2);
119 | }
120 |
121 | torch::Tensor DeepLabV3PlusDecoderImpl::forward(std::vector x_list) {
122 | auto aspp_features = aspp->forward(x_list[x_list.size() - 1]);
123 | aspp_features = aspp_seq->forward(aspp_features);
124 | aspp_features = up->forward(aspp_features);
125 |
126 | auto high_res_features = block1->forward(x_list[x_list.size() - 4]);
127 | auto concat_features = torch::cat({ aspp_features, high_res_features }, 1);
128 | auto fused_features = block2->forward(concat_features);
129 | return fused_features;
130 | }
--------------------------------------------------------------------------------
/src/architectures/DeepLabDecoder.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | /*
3 | BSD 3 - Clause License
4 |
5 | Copyright(c) Soumith Chintala 2016,
6 | All rights reserved.
7 |
8 | Redistribution and use in source and binary forms, with or without
9 | modification, are permitted provided that the following conditions are met :
10 |
11 | *Redistributions of source code must retain the above copyright notice, this
12 | list of conditions and the following disclaimer.
13 |
14 | * Redistributions in binary form must reproduce the above copyright notice,
15 | this list of conditions and the following disclaimer in the documentation
16 | and/or other materials provided with the distribution.
17 |
18 | * Neither the name of the copyright holder nor the names of its
19 | contributors may be used to endorse or promote products derived from
20 | this software without specific prior written permission.
21 |
22 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
23 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
24 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
25 | DISCLAIMED.IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
26 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
27 | DAMAGES(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
28 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
29 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
30 | OR TORT(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 |
33 | This libtorch implementation is writen by AllentDan.
34 | Copyright(c) AllentDan 2021,
35 | All rights reserved.
36 | */
37 |
38 | #include"../utils/util.h"
39 |
40 | //struct StackSequentailImpl : torch::nn::SequentialImpl {
41 | // using SequentialImpl::SequentialImpl;
42 | //
43 | // torch::Tensor forward(torch::Tensor x) {
44 | // return SequentialImpl::forward(x);
45 | // }
46 | //}; TORCH_MODULE(StackSequentail);
47 |
48 | torch::nn::Sequential ASPPConv(int in_channels, int out_channels, int dilation);
49 |
50 | torch::nn::Sequential SeparableConv2d(int in_channels, int out_channels, int kernel_size, int stride = 1,
51 | int padding = 0, int dilation = 1, bool bias = true);
52 |
53 | torch::nn::Sequential ASPPSeparableConv(int in_channels, int out_channels, int dilation);
54 |
55 | class ASPPPoolingImpl : public torch::nn::Module {
56 | public:
57 | torch::nn::Sequential seq{nullptr};
58 | ASPPPoolingImpl(int in_channels, int out_channels);
59 | torch::Tensor forward(torch::Tensor x);
60 |
61 | }; TORCH_MODULE(ASPPPooling);
62 |
63 | class ASPPImpl : public torch::nn::Module {
64 | public:
65 | ASPPImpl(int in_channels, int out_channels, std::vector atrous_rates, bool separable = false);
66 | torch::Tensor forward(torch::Tensor x);
67 | private:
68 | torch::nn::ModuleList modules{};
69 | ASPPPooling aspppooling{ nullptr };
70 | torch::nn::Sequential project{ nullptr };
71 | }; TORCH_MODULE(ASPP);
72 |
73 | class DeepLabV3DecoderImpl : public torch::nn::Module
74 | {
75 | public:
76 | DeepLabV3DecoderImpl(int in_channels, int out_channels = 256, std::vector atrous_rates = { 12, 24, 36 });
77 | torch::Tensor forward(std::vector< torch::Tensor> x);
78 | int out_channels = 0;
79 | private:
80 | torch::nn::Sequential seq{};
81 | }; TORCH_MODULE(DeepLabV3Decoder);
82 |
83 | class DeepLabV3PlusDecoderImpl :public torch::nn::Module {
84 | public:
85 | DeepLabV3PlusDecoderImpl(std::vector encoder_channels, int out_channels,
86 | std::vector atrous_rates, int output_stride = 16);
87 | torch::Tensor forward(std::vector< torch::Tensor> x);
88 | private:
89 | ASPP aspp{ nullptr };
90 | torch::nn::Sequential aspp_seq{ nullptr };
91 | torch::nn::Upsample up{ nullptr };
92 | torch::nn::Sequential block1{ nullptr };
93 | torch::nn::Sequential block2{ nullptr };
94 | }; TORCH_MODULE(DeepLabV3PlusDecoder);
95 |
--------------------------------------------------------------------------------
/src/architectures/FPN.cpp:
--------------------------------------------------------------------------------
1 | #include "FPN.h"
2 |
3 | FPNImpl::FPNImpl(int _num_classes, std::string encoder_name, std::string pretrained_path, int encoder_depth,
4 | int decoder_pyramid_channel, int decoder_segmentation_channels, std::string decoder_merge_policy,
5 | float decoder_dropout, double upsampling){
6 | num_classes = _num_classes;
7 | auto encoder_param = encoder_params();
8 | std::vector encoder_channels = encoder_param[encoder_name]["out_channels"];
9 | if(!encoder_param.contains(encoder_name))
10 | std::cout<< "encoder name must in {resnet18, resnet34, resnet50, resnet101, resnet150, \
11 | resnext50_32x4d, resnext101_32x8d, vgg11, vgg11_bn, vgg13, vgg13_bn, \
12 | vgg16, vgg16_bn, vgg19, vgg19_bn,}";
13 | if (encoder_param[encoder_name]["class_type"] == "resnet")
14 | encoder = new ResNetImpl(encoder_param[encoder_name]["layers"], 1000, encoder_name);
15 | else if (encoder_param[encoder_name]["class_type"] == "vgg")
16 | encoder = new VGGImpl(encoder_param[encoder_name]["cfg"], 1000, encoder_param[encoder_name]["batch_norm"]);
17 | else std::cout<< "unknown error in backbone initialization";
18 |
19 | encoder->load_pretrained(pretrained_path);
20 | decoder = FPNDecoder(encoder_channels,encoder_depth, decoder_pyramid_channel,
21 | decoder_segmentation_channels,decoder_dropout, decoder_merge_policy);
22 | segmentation_head = SegmentationHead(decoder_segmentation_channels,num_classes,1,upsampling);
23 |
24 | register_module("encoder",std::shared_ptr(encoder));
25 | register_module("decoder",decoder);
26 | register_module("segmentation_head",segmentation_head);
27 | }
28 |
29 | torch::Tensor FPNImpl::forward(torch::Tensor x){
30 | std::vector features = encoder->features(x);
31 | x = decoder->forward(features);
32 | x = segmentation_head->forward(x);
33 | return x;
34 | }
35 |
--------------------------------------------------------------------------------
/src/architectures/FPN.h:
--------------------------------------------------------------------------------
1 | /*
2 | This libtorch implementation is writen by AllentDan.
3 | Copyright(c) AllentDan 2021,
4 | All rights reserved.
5 | */
6 |
7 | #ifndef FPN_H
8 | #define FPN_H
9 | #include"../backbones/ResNet.h"
10 | #include"../backbones/VGG.h"
11 | #include"FPNDecoder.h"
12 |
13 | class FPNImpl : public torch::nn::Module
14 | {
15 | public:
16 | FPNImpl() {}
17 | ~FPNImpl() {
18 | //delete encoder;
19 | }
20 | FPNImpl(int num_classes, std::string encoder_name = "resnet18", std::string pretrained_path = "", int encoder_depth = 5,
21 | int decoder_pyramid_channel=256, int decoder_segmentation_channels = 128, std::string decoder_merge_policy = "add",
22 | float decoder_dropout = 0.2, double upsampling = 4);
23 | torch::Tensor forward(torch::Tensor x);
24 | private:
25 | Backbone *encoder;
26 | FPNDecoder decoder{nullptr};
27 | SegmentationHead segmentation_head{nullptr};
28 | int num_classes = 1;
29 | };TORCH_MODULE(FPN);
30 |
31 | #endif // FPN_H
32 |
--------------------------------------------------------------------------------
/src/architectures/FPNDecoder.cpp:
--------------------------------------------------------------------------------
1 | #include "FPNDecoder.h"
2 |
3 | Conv3x3GNReLUImpl::Conv3x3GNReLUImpl(int _in_channels, int _out_channels, bool _upsample){
4 | upsample = _upsample;
5 | block = torch::nn::Sequential(torch::nn::Conv2d(conv_options(_in_channels, _out_channels, 3, 1, 1, 1, false)),
6 | torch::nn::GroupNorm(torch::nn::GroupNormOptions(32, _out_channels)),
7 | torch::nn::ReLU(torch::nn::ReLUOptions(true)));
8 | register_module("block",block);
9 | }
10 |
11 | torch::Tensor Conv3x3GNReLUImpl::forward(torch::Tensor x){
12 | x = block->forward(x);
13 | if (upsample){
14 | x = torch::nn::Upsample(upsample_options(std::vector{2,2}))->forward(x);
15 | }
16 | return x;
17 | }
18 |
19 | FPNBlockImpl::FPNBlockImpl(int pyramid_channels, int skip_channels)
20 | {
21 | skip_conv = torch::nn::Conv2d(conv_options(skip_channels, pyramid_channels,1));
22 | upsample = torch::nn::Upsample(torch::nn::UpsampleOptions().scale_factor(std::vector({2,2})).mode(torch::kNearest));
23 |
24 | register_module("skip_conv",skip_conv);
25 | }
26 |
27 | torch::Tensor FPNBlockImpl::forward(torch::Tensor x, torch::Tensor skip){
28 | x = upsample->forward(x);
29 | skip = skip_conv(skip);
30 | x = x + skip;
31 | return x;
32 | }
33 |
34 | SegmentationBlockImpl::SegmentationBlockImpl(int in_channels, int out_channels, int n_upsamples)
35 | {
36 | block = torch::nn::Sequential();
37 | block->push_back(Conv3x3GNReLU(in_channels, out_channels, bool(n_upsamples)));
38 | if(n_upsamples>1){
39 | for (int i=1; ipush_back(Conv3x3GNReLU(out_channels, out_channels, true));
41 | }
42 | }
43 | register_module("block",block);
44 | }
45 |
46 | torch::Tensor SegmentationBlockImpl::forward(torch::Tensor x){
47 | x = block->forward(x);
48 | return x;
49 | }
50 |
51 | //vector求和
52 | template
53 | T sumTensor(std::vector x_list){
54 | if(x_list.empty()) std::cout<< "sumTensor only accept non-empty list";
55 | T re = x_list[0];
56 | for(int i = 1; i x){
70 | if(_policy=="add") return sumTensor(x);
71 | else if (_policy == "cat") return torch::cat(x, 1);
72 | else std::cout<< "`merge_policy` must be one of: ['add', 'cat'], got "+_policy;
73 | }
74 |
75 | FPNDecoderImpl::FPNDecoderImpl(std::vector encoder_channels, int encoder_depth, int pyramid_channels, int segmentation_channels,
76 | float dropout_, std::string merge_policy)
77 | {
78 | out_channels = merge_policy == "add"? segmentation_channels :segmentation_channels * 4;
79 | if(encoder_depth<3) std::cout<< "Encoder depth for FPN decoder cannot be less than 3";
80 | std::reverse(std::begin(encoder_channels),std::end(encoder_channels));
81 | encoder_channels = std::vector (encoder_channels.begin(),encoder_channels.begin()+encoder_depth+1);
82 | p5 = torch::nn::Conv2d(conv_options(encoder_channels[0], pyramid_channels, 1));
83 | p4 = FPNBlock(pyramid_channels, encoder_channels[1]);
84 | p3 = FPNBlock(pyramid_channels, encoder_channels[2]);
85 | p2 = FPNBlock(pyramid_channels, encoder_channels[3]);
86 | for(int i = 3; i>=0; i--){
87 | seg_blocks->push_back(SegmentationBlock(pyramid_channels, segmentation_channels, i));
88 | }
89 | merge = MergeBlock(merge_policy);
90 | dropout = torch::nn::Dropout2d(torch::nn::Dropout2dOptions().p(dropout_).inplace(true));
91 |
92 | register_module("p5",p5);
93 | register_module("p4",p4);
94 | register_module("p3",p3);
95 | register_module("p2",p2);
96 | register_module("seg_blocks",seg_blocks);
97 | register_module("merge",merge);
98 | }
99 |
100 | torch::Tensor FPNDecoderImpl::forward(std::vector features){
101 | int features_len = features.size();
102 | auto _p5 = p5->forward(features[features_len-1]);
103 | auto _p4 = p4->forward(_p5, features[features_len - 2]);
104 | auto _p3 = p3->forward(_p4, features[features_len - 3]);
105 | auto _p2 = p2->forward(_p3, features[features_len - 4]);
106 | _p5 = seg_blocks[0]->as()->forward(_p5);
107 | _p4 = seg_blocks[1]->as()->forward(_p4);
108 | _p3 = seg_blocks[2]->as()->forward(_p3);
109 | _p2 = seg_blocks[3]->as()->forward(_p2);
110 |
111 | auto x = merge->forward(std::vector{_p5,_p4,_p3,_p2});
112 | x = dropout->forward(x);
113 | return x;
114 | }
115 |
--------------------------------------------------------------------------------
/src/architectures/FPNDecoder.h:
--------------------------------------------------------------------------------
1 | #ifndef FPNDECODER_H
2 | #define FPNDECODER_H
3 | #include"../utils/util.h"
4 |
5 | class Conv3x3GNReLUImpl : public torch::nn::Module
6 | {
7 | public:
8 | Conv3x3GNReLUImpl(int in_channels, int out_channels, bool upsample=false);
9 | torch::Tensor forward(torch::Tensor x);
10 | private:
11 | bool upsample;
12 | torch::nn::Sequential block{nullptr};
13 | };
14 | TORCH_MODULE(Conv3x3GNReLU);
15 |
16 | class FPNBlockImpl : public torch::nn::Module
17 | {
18 | public:
19 | FPNBlockImpl(int pyramid_channels, int skip_channels);
20 | torch::Tensor forward(torch::Tensor x, torch::Tensor skip);
21 | private:
22 | torch::nn::Conv2d skip_conv{nullptr};
23 | torch::nn::Upsample upsample{nullptr};
24 | };
25 | TORCH_MODULE(FPNBlock);
26 |
27 | class SegmentationBlockImpl: public torch::nn::Module
28 | {
29 | public:
30 | SegmentationBlockImpl(int in_channels, int out_channels, int n_upsamples=0);
31 | torch::Tensor forward(torch::Tensor x);
32 | private:
33 | torch::nn::Sequential block{nullptr};
34 | };TORCH_MODULE(SegmentationBlock);
35 |
36 | class MergeBlockImpl: public torch::nn::Module{
37 | public:
38 | MergeBlockImpl(std::string policy);
39 | torch::Tensor forward(std::vector x);
40 | private:
41 | std::string _policy;
42 | std::string policies[2] = {"add","cat"};
43 | };TORCH_MODULE(MergeBlock);
44 |
45 | class FPNDecoderImpl: public torch::nn::Module
46 | {
47 | public:
48 | FPNDecoderImpl(std::vector encoder_channels = {3, 64, 64, 128, 256, 512}, int encoder_depth=5, int pyramid_channels=256,
49 | int segmentation_channels=128,float dropout=0.2, std::string merge_policy="add");
50 | torch::Tensor forward(std::vector features);
51 | private:
52 | int out_channels;
53 | torch::nn::Conv2d p5{nullptr};
54 | FPNBlock p4{nullptr};
55 | FPNBlock p3{nullptr};
56 | FPNBlock p2{nullptr};
57 | torch::nn::ModuleList seg_blocks{};
58 | MergeBlock merge{nullptr};
59 | torch::nn::Dropout2d dropout{nullptr};
60 |
61 | };TORCH_MODULE(FPNDecoder);
62 |
63 | #endif // FPNDECODER_H
64 |
--------------------------------------------------------------------------------
/src/architectures/LinkNet.cpp:
--------------------------------------------------------------------------------
1 | #include "LinkNet.h"
2 | LinkNetImpl::LinkNetImpl(int _num_classes, std::string encoder_name, std::string pretrained_path, int encoder_depth,
3 | int decoder_use_batchnorm) {
4 | num_classes = _num_classes;
5 | auto encoder_param = encoder_params();
6 | std::vector encoder_channels = encoder_param[encoder_name]["out_channels"];
7 | if (!encoder_param.contains(encoder_name))
8 | std::cout<< "encoder name must in {resnet18, resnet34, resnet50, resnet101, resnet150, \
9 | resnext50_32x4d, resnext101_32x8d, vgg11, vgg11_bn, vgg13, vgg13_bn, \
10 | vgg16, vgg16_bn, vgg19, vgg19_bn,}";
11 | if (encoder_param[encoder_name]["class_type"] == "resnet")
12 | encoder = new ResNetImpl(encoder_param[encoder_name]["layers"], 1000, encoder_name);
13 | else if (encoder_param[encoder_name]["class_type"] == "vgg")
14 | encoder = new VGGImpl(encoder_param[encoder_name]["cfg"], 1000, encoder_param[encoder_name]["batch_norm"]);
15 | else std::cout<< "unknown error in backbone initialization";
16 |
17 | encoder->load_pretrained(pretrained_path);
18 | decoder = LinknetDecoder(encoder_channels, 32, encoder_depth, decoder_use_batchnorm);
19 | segmentation_head = SegmentationHead(32, num_classes, 1);
20 |
21 | register_module("encoder", std::shared_ptr(encoder));
22 | register_module("decoder", decoder);
23 | register_module("segmentation_head", segmentation_head);
24 | }
25 |
26 | torch::Tensor LinkNetImpl::forward(torch::Tensor x) {
27 | std::vector features = encoder->features(x);
28 | x = decoder->forward(features);
29 | x = segmentation_head->forward(x);
30 | return x;
31 | }
32 |
--------------------------------------------------------------------------------
/src/architectures/LinkNet.h:
--------------------------------------------------------------------------------
1 | /*
2 | This libtorch implementation is writen by AllentDan.
3 | Copyright(c) AllentDan 2021,
4 | All rights reserved.
5 | */
6 | #pragma once
7 | #include"../backbones/ResNet.h"
8 | #include"../backbones/VGG.h"
9 | #include"LinknetDecoder.h"
10 |
11 | class LinkNetImpl : public torch::nn::Module
12 | {
13 | public:
14 | LinkNetImpl() {}
15 | ~LinkNetImpl() {
16 | //delete encoder;
17 | }
18 | LinkNetImpl(int num_classes, std::string encoder_name = "resnet18", std::string pretrained_path = "", int encoder_depth = 5,
19 | int decoder_use_batchnorm = true);
20 | torch::Tensor forward(torch::Tensor x);
21 | private:
22 | Backbone* encoder;
23 | LinknetDecoder decoder{ nullptr };
24 | SegmentationHead segmentation_head{ nullptr };
25 | int num_classes = 1;
26 | }; TORCH_MODULE(LinkNet);
27 |
28 |
29 |
--------------------------------------------------------------------------------
/src/architectures/LinknetDecoder.cpp:
--------------------------------------------------------------------------------
1 | #include "LinknetDecoder.h"
2 |
3 | torch::nn::Sequential TransposeX2(int in_channels, int out_channels, bool use_batchnorm) {
4 | torch::nn::Sequential seq = torch::nn::Sequential();
5 | seq->push_back(torch::nn::ConvTranspose2d(torch::nn::ConvTranspose2dOptions(in_channels, out_channels,4).stride(2).padding(1)));
6 | if (use_batchnorm)
7 | seq->push_back(torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(out_channels)));
8 | seq->push_back(torch::nn::ReLU(torch::nn::ReLUOptions(true)));
9 | return seq;
10 | }
11 |
12 | torch::nn::Sequential Conv2dReLU(int in_channels, int out_channels, int kernel_size, int padding,
13 | int stride, bool use_batchnorm) {
14 | torch::nn::Sequential seq = torch::nn::Sequential();
15 | seq->push_back(torch::nn::Conv2d(conv_options(in_channels, out_channels, kernel_size,
16 | stride, padding, 1, !use_batchnorm, 1)));
17 | if (use_batchnorm)
18 | seq->push_back(torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(out_channels)));
19 |
20 | seq->push_back(torch::nn::ReLU(torch::nn::ReLUOptions(true)));
21 | return seq;
22 | }
23 |
24 | DecoderBlockLinkImpl::DecoderBlockLinkImpl(int in_channels, int out_channels, bool use_batchnorm) {
25 | conv2drelu1 = Conv2dReLU(in_channels, in_channels / 4, 1, 0, 1, use_batchnorm);
26 | transpose = TransposeX2(in_channels / 4, in_channels / 4, use_batchnorm);
27 | conv2drelu2 = Conv2dReLU(in_channels / 4, out_channels, 1, 0, 1, use_batchnorm);
28 |
29 | register_module("conv2drelu1", conv2drelu1);
30 | register_module("transpose", transpose);
31 | register_module("conv2drelu2", conv2drelu2);
32 | }
33 |
34 | torch::Tensor DecoderBlockLinkImpl::forward(torch::Tensor x, torch::Tensor skip) {
35 | x = conv2drelu1->forward(x);
36 | x = transpose->forward(x);
37 | x = conv2drelu2->forward(x);
38 | if (skip.sizes()==x.sizes())
39 | x = x + skip;
40 | return x;
41 | }
42 |
43 | LinknetDecoderImpl::LinknetDecoderImpl(std::vector encoder_channels, int prefinal_channels,
44 | int n_blocks, bool use_batchnorm) {
45 | encoder_channels = std::vector(encoder_channels.begin()+1, encoder_channels.end());
46 | std::reverse(std::begin(encoder_channels), std::end(encoder_channels));
47 | std::vector channels = encoder_channels;
48 | channels.push_back(prefinal_channels);
49 | for (int i = 0; i < n_blocks; i++) {
50 | blocks->push_back(DecoderBlockLink(channels[i], channels[i + 1], use_batchnorm));
51 | }
52 |
53 | register_module("blocks", blocks);
54 | }
55 |
56 | torch::Tensor LinknetDecoderImpl::forward(std::vector< torch::Tensor> features) {
57 | features = std::vector(features.begin() + 1, features.end());
58 | std::reverse(std::begin(features), std::end(features));
59 |
60 | auto x = features[0];
61 | auto skips = std::vector(features.begin() + 1, features.end());
62 | for (int i = 0; i < blocks->size(); i++) {
63 | auto skip = i < skips.size() ? skips[i] : torch::zeros({ 1 });
64 | x = blocks[i]->as()->forward(x, skip);
65 | }
66 | return x;
67 | }
--------------------------------------------------------------------------------
/src/architectures/LinknetDecoder.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include"../utils/util.h"
3 |
4 | torch::nn::Sequential TransposeX2(int in_channels, int out_channels, bool use_batchnorm = true);
5 |
6 | torch::nn::Sequential Conv2dReLU(int in_channels, int out_channels, int kernel_size, int padding = 0,
7 | int stride = 1, bool use_batchnorm = true);
8 |
9 | class DecoderBlockLinkImpl : public torch::nn::Module {
10 | public:
11 | DecoderBlockLinkImpl(int in_channels, int out_channels, bool use_batchnorm = true);
12 | torch::Tensor forward(torch::Tensor x, torch::Tensor skip);
13 | private:
14 | torch::nn::Sequential conv2drelu1 { nullptr };
15 | torch::nn::Sequential conv2drelu2{ nullptr };
16 | torch::nn::Sequential transpose{ nullptr };
17 | }; TORCH_MODULE(DecoderBlockLink);
18 |
19 | class LinknetDecoderImpl : public torch::nn::Module
20 | {
21 | public:
22 | LinknetDecoderImpl(std::vector encoder_channels, int prefinal_channels = 32,
23 | int n_blocks = 5, bool use_batchnorm = true);
24 | torch::Tensor forward(std::vector< torch::Tensor> x_list);
25 | private:
26 | torch::nn::ModuleList blocks;
27 | }; TORCH_MODULE(LinknetDecoder);
28 |
29 |
--------------------------------------------------------------------------------
/src/architectures/PAN.cpp:
--------------------------------------------------------------------------------
1 | #include "PAN.h"
2 |
3 | PANImpl::PANImpl(int _num_classes, std::string encoder_name, std::string pretrained_path, int decoder_channels,
4 | double upsampling) {
5 | num_classes = _num_classes;
6 | auto encoder_param = encoder_params();
7 | std::vector encoder_channels = encoder_param[encoder_name]["out_channels"];
8 | if (!encoder_param.contains(encoder_name))
9 | std::cout<< "encoder name must in {resnet18, resnet34, resnet50, resnet101, resnet150, \
10 | resnext50_32x4d, resnext101_32x8d, vgg11, vgg11_bn, vgg13, vgg13_bn, \
11 | vgg16, vgg16_bn, vgg19, vgg19_bn,}";
12 | if (encoder_param[encoder_name]["class_type"] == "resnet")
13 | encoder = new ResNetImpl(encoder_param[encoder_name]["layers"], 1000, encoder_name);
14 | else if (encoder_param[encoder_name]["class_type"] == "vgg")
15 | encoder = new VGGImpl(encoder_param[encoder_name]["cfg"], 1000, encoder_param[encoder_name]["batch_norm"]);
16 | else std::cout<< "unknown error in backbone initialization";
17 |
18 | encoder->load_pretrained(pretrained_path);
19 | encoder->make_dilated({ 5 }, { 2 });
20 |
21 | decoder = PANDecoder(encoder_channels, decoder_channels);
22 | segmentation_head = SegmentationHead(decoder_channels, num_classes, 3, upsampling);
23 |
24 | register_module("encoder", std::shared_ptr(encoder));
25 | register_module("decoder", decoder);
26 | register_module("segmentation_head", segmentation_head);
27 | }
28 |
29 | torch::Tensor PANImpl::forward(torch::Tensor x) {
30 | std::vector features = encoder->features(x);
31 | x = decoder->forward(features);
32 | x = segmentation_head->forward(x);
33 | return x;
34 | }
35 |
--------------------------------------------------------------------------------
/src/architectures/PAN.h:
--------------------------------------------------------------------------------
1 | /*
2 | This libtorch implementation is writen by AllentDan.
3 | Copyright(c) AllentDan 2021,
4 | All rights reserved.
5 | */
6 | #pragma once
7 | #include"../backbones/ResNet.h"
8 | #include"../backbones/VGG.h"
9 | #include "PANDecoder.h"
10 |
11 | class PANImpl : public torch::nn::Module
12 | {
13 | public:
14 | PANImpl() {};
15 | ~PANImpl() {
16 | //delete encoder;
17 | }
18 | PANImpl(int num_classes, std::string encoder_name = "resnet18", std::string pretrained_path = "", int decoder_channels = 32,
19 | double upsampling = 4);
20 | torch::Tensor forward(torch::Tensor x);
21 | private:
22 | Backbone* encoder;
23 | //ResNet encoder{ nullptr };
24 | PANDecoder decoder{ nullptr };
25 | SegmentationHead segmentation_head{ nullptr };
26 | int num_classes = 1;
27 | std::vector BasicChannels = { 3, 64, 64, 128, 256, 512 };
28 | std::vector BottleChannels = { 3, 64, 256, 512, 1024, 2048 };
29 | std::map> name2layers = getParams();
30 | }; TORCH_MODULE(PAN);
31 |
32 |
33 |
--------------------------------------------------------------------------------
/src/architectures/PANDecoder.cpp:
--------------------------------------------------------------------------------
1 | #include "PANDecoder.h"
2 |
3 | ConvBnReluImpl::ConvBnReluImpl(int in_channels, int out_channels, int kernel_size, int stride, int padding,
4 | int dilation, int groups, bool bias, bool _add_relu, bool _interpolate) {
5 | add_relu = _add_relu;
6 | interpolate = _interpolate;
7 | conv = torch::nn::Conv2d(conv_options(in_channels, out_channels, kernel_size, stride, padding, groups, bias, dilation));
8 | bn = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(out_channels));
9 | activation = torch::nn::ReLU(torch::nn::ReLUOptions(true));
10 | up = torch::nn::Upsample(upsample_options(std::vector{2, 2}, true).mode(torch::kBilinear));
11 |
12 | register_module("conv", conv);
13 | register_module("bn", bn);
14 | }
15 |
16 | torch::Tensor ConvBnReluImpl::forward(torch::Tensor x) {
17 | x = conv->forward(x);
18 | x = bn->forward(x);
19 | if (add_relu)
20 | x = activation->forward(x);
21 | if (interpolate)
22 | x = up->forward(x);
23 | return x;
24 | }
25 |
26 | FPABlockImpl::FPABlockImpl(int in_channels, int out_channels, std::string _upscale_mode) {
27 | align_corners = _upscale_mode == "bilinear";
28 | branch1 = torch::nn::Sequential(torch::nn::AdaptiveAvgPool2d(torch::nn::AdaptiveAvgPool2dOptions(1)),
29 | ConvBnRelu(in_channels, out_channels, 1, 1, 0));
30 | mid = torch::nn::Sequential(ConvBnRelu(in_channels, out_channels, 1, 1, 0));
31 | down1 = torch::nn::Sequential(torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2)),
32 | ConvBnRelu(in_channels, 1, 7, 1, 3));
33 | down2 = torch::nn::Sequential(torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2)),
34 | ConvBnRelu(1, 1, 5, 1, 2));
35 | down3 = torch::nn::Sequential(torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2)),
36 | ConvBnRelu(1, 1, 3, 1, 1),
37 | ConvBnRelu(1, 1, 3, 1, 1));
38 | conv2 = ConvBnRelu(1, 1, 5, 1, 2);
39 | conv1 = ConvBnRelu(1, 1, 7, 1, 3);
40 |
41 | register_module("branch1", branch1);
42 | register_module("mid", mid);
43 | register_module("down1", down1);
44 | register_module("down2", down2);
45 | register_module("down3", down3);
46 | register_module("conv2", conv2);
47 | register_module("conv1", conv1);
48 | }
49 |
50 | torch::Tensor FPABlockImpl::forward(torch::Tensor x) {
51 | auto h = x.sizes()[2];
52 | auto w = x.sizes()[3];
53 | auto b1 = branch1->forward(x);
54 | b1 = at::upsample_bilinear2d(b1, { h,w }, align_corners);
55 |
56 | auto mid_tensor = mid->forward(x);
57 | auto x1 = down1->forward(x);
58 | auto x2 = down2->forward(x1);
59 | auto x3 = down3->forward(x2);
60 | x3 = at::upsample_bilinear2d(x3, {h/4,w/4}, align_corners);
61 |
62 | x2 = conv2->forward(x2);
63 | x = x2 + x3;
64 | x = at::upsample_bilinear2d(x, { h / 2,w / 2 }, align_corners);
65 |
66 | x1 = conv1->forward(x1);
67 | x = x + x1;
68 | x = at::upsample_bilinear2d(x, { h ,w }, align_corners);
69 |
70 | x = torch::mul(x, mid_tensor);
71 | x = x + b1;
72 | return x;
73 | }
74 |
75 | GAUBlockImpl::GAUBlockImpl(int in_channels, int out_channels, std::string upscale_mode) {
76 | align_corners = upscale_mode == "bilinear";
77 | conv1 = torch::nn::Sequential(
78 | torch::nn::AdaptiveAvgPool2d(torch::nn::AdaptiveAvgPool2dOptions(1)),
79 | ConvBnRelu(out_channels, out_channels, 1, 1, 0, 1, 1, true, false),
80 | torch::nn::Sigmoid()
81 | );
82 | conv2 = ConvBnRelu(in_channels, out_channels, 3, 1, 1);
83 |
84 | register_module("conv1", conv1);
85 | register_module("conv2", conv2);
86 | }
87 |
88 | torch::Tensor GAUBlockImpl::forward(torch::Tensor x, torch::Tensor y) {
89 | auto h = x.sizes()[2];
90 | auto w = x.sizes()[3];
91 | auto y_up = at::upsample_bilinear2d(y, { h ,w }, align_corners);
92 | x = conv2->forward(x);
93 | y = conv1->forward(y);
94 | auto z = torch::mul(x, y);
95 | return y_up + z;
96 | }
97 |
98 | PANDecoderImpl::PANDecoderImpl(std::vector encoder_channels, int decoder_channels, std::string upscale_mode) {
99 | fpa = FPABlock(encoder_channels[encoder_channels.size() - 1], decoder_channels);
100 | gau3 = GAUBlock(encoder_channels[encoder_channels.size() - 2], decoder_channels, upscale_mode);
101 | gau2 = GAUBlock(encoder_channels[encoder_channels.size() - 3], decoder_channels, upscale_mode);
102 | gau1 = GAUBlock(encoder_channels[encoder_channels.size() - 4], decoder_channels, upscale_mode);
103 |
104 | register_module("fpa", fpa);
105 | register_module("gau3", gau3);
106 | register_module("gau2", gau2);
107 | register_module("gau1", gau1);
108 | }
109 |
110 | torch::Tensor PANDecoderImpl::forward(std::vector features) {
111 | auto bottleneck = features[features.size() - 1];
112 | auto x5 = fpa->forward(bottleneck); // 1 / 32
113 | auto x4 = gau3->forward(features[features.size() -2], x5); //1 / 16
114 | auto x3 = gau2->forward(features[features.size() -3], x4); // 1 / 8
115 | auto x2 = gau1->forward(features[features.size() -4], x3); // 1 / 4
116 |
117 | return x2;
118 | }
--------------------------------------------------------------------------------
/src/architectures/PANDecoder.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include"../utils/util.h"
3 | //Pyramid Attention Network Decoder
4 |
5 | class ConvBnReluImpl : public torch::nn::Module {
6 | public:
7 | ConvBnReluImpl(int in_channels, int out_channels, int kernel_size, int stride = 1, int padding = 0,
8 | int dilation = 1, int groups = 1, bool bias = true, bool add_relu = true, bool interpolate = false);
9 | torch::Tensor forward(torch::Tensor x);
10 | private:
11 | bool add_relu;
12 | bool interpolate;
13 | torch::nn::Conv2d conv{ nullptr };
14 | torch::nn::BatchNorm2d bn{ nullptr };
15 | torch::nn::ReLU activation{ nullptr };
16 | torch::nn::Upsample up{ nullptr };
17 | }; TORCH_MODULE(ConvBnRelu);
18 |
19 | class FPABlockImpl : public torch::nn::Module {
20 | public:
21 | FPABlockImpl(int in_channels, int out_channels, std::string upscale_mode = "bilinear");
22 | torch::Tensor forward(torch::Tensor x);
23 | private:
24 | bool align_corners;
25 | torch::nn::Sequential branch1{ nullptr };
26 | torch::nn::Sequential mid{ nullptr };
27 | torch::nn::Sequential down1{ nullptr };
28 | torch::nn::Sequential down2{ nullptr };
29 | torch::nn::Sequential down3{ nullptr };
30 | ConvBnRelu conv1{ nullptr };
31 | ConvBnRelu conv2{ nullptr };
32 | }; TORCH_MODULE(FPABlock);
33 |
34 | class GAUBlockImpl :public torch::nn::Module {
35 | public:
36 | GAUBlockImpl(int in_channels, int out_channels, std::string upscale_mode = "bilinear");
37 | torch::Tensor forward(torch::Tensor x, torch::Tensor y);
38 | private:
39 | bool align_corners;
40 | torch::nn::Sequential conv1{ nullptr };
41 | ConvBnRelu conv2{ nullptr };
42 | }; TORCH_MODULE(GAUBlock);
43 |
44 | class PANDecoderImpl:public torch::nn::Module
45 | {
46 | public:
47 | PANDecoderImpl(std::vector encoder_channels, int decoder_channels, std::string upscale_mode = "bilinear");
48 | torch::Tensor forward(std::vector x);
49 | private:
50 | FPABlock fpa{ nullptr };
51 | GAUBlock gau3{ nullptr };
52 | GAUBlock gau2{ nullptr };
53 | GAUBlock gau1{ nullptr };
54 |
55 | }; TORCH_MODULE(PANDecoder);
56 |
57 |
--------------------------------------------------------------------------------
/src/architectures/PSPNet.cpp:
--------------------------------------------------------------------------------
1 | #include "PSPNet.h"
2 |
3 | PSPNetImpl::PSPNetImpl(int _num_classes, std::string encoder_name, std::string pretrained_path, int _encoder_depth,
4 | int psp_out_channels, bool psp_use_batchnorm, float psp_dropout, double upsampling) {
5 | num_classes = _num_classes;
6 | encoder_depth = _encoder_depth;
7 |
8 | auto encoder_param = encoder_params();
9 | std::vector encoder_channels = encoder_param[encoder_name]["out_channels"];
10 | if (!encoder_param.contains(encoder_name))
11 | std::cout<< "encoder name must in {resnet18, resnet34, resnet50, resnet101, resnet150, \
12 | resnext50_32x4d, resnext101_32x8d, vgg11, vgg11_bn, vgg13, vgg13_bn, \
13 | vgg16, vgg16_bn, vgg19, vgg19_bn,}";
14 | if (encoder_param[encoder_name]["class_type"] == "resnet")
15 | encoder = new ResNetImpl(encoder_param[encoder_name]["layers"], 1000, encoder_name);
16 | else if (encoder_param[encoder_name]["class_type"] == "vgg")
17 | encoder = new VGGImpl(encoder_param[encoder_name]["cfg"], 1000, encoder_param[encoder_name]["batch_norm"]);
18 | else std::cout<< "unknown error in backbone initialization";
19 |
20 | encoder->load_pretrained(pretrained_path);
21 | decoder = PSPDecoder(encoder_channels, psp_out_channels, psp_dropout, psp_use_batchnorm);
22 | segmentation_head = SegmentationHead(psp_out_channels, num_classes, 3, upsampling);
23 |
24 | register_module("encoder", std::shared_ptr(encoder));
25 | register_module("decoder", decoder);
26 | register_module("segmentation_head", segmentation_head);
27 | }
28 |
29 | torch::Tensor PSPNetImpl::forward(torch::Tensor x) {
30 | std::vector features = encoder->features(x, encoder_depth);
31 | x = decoder->forward(features);
32 | x = segmentation_head->forward(x);
33 | return x;
34 | }
35 |
--------------------------------------------------------------------------------
/src/architectures/PSPNet.h:
--------------------------------------------------------------------------------
1 | /*
2 | This libtorch implementation is writen by AllentDan.
3 | Copyright(c) AllentDan 2021,
4 | All rights reserved.
5 | */
6 | #pragma once
7 | #include "../backbones/ResNet.h"
8 | #include"../backbones//VGG.h"
9 | #include "PSPNetDecoder.h"
10 |
11 | class PSPNetImpl : public torch::nn::Module
12 | {
13 | public:
14 | PSPNetImpl() {}
15 | ~PSPNetImpl() {
16 | //delete encoder;
17 | }
18 | PSPNetImpl(int num_classes, std::string encoder_name = "resnet18", std::string pretrained_path = "", int encoder_depth = 3,
19 | int psp_out_channels = 512, bool psp_use_batchnorm = true, float psp_dropout = 0.2, double upsampling = 8);
20 | torch::Tensor forward(torch::Tensor x);
21 | private:
22 | Backbone* encoder;
23 | PSPDecoder decoder{ nullptr };
24 | SegmentationHead segmentation_head{ nullptr };
25 | int num_classes = 1; int encoder_depth = 3;
26 | std::vector BasicChannels = { 3, 64, 64, 128, 256, 512 };
27 | std::vector BottleChannels = { 3, 64, 256, 512, 1024, 2048 };
28 | std::map> name2layers = getParams();
29 | }; TORCH_MODULE(PSPNet);
--------------------------------------------------------------------------------
/src/architectures/PSPNetDecoder.cpp:
--------------------------------------------------------------------------------
1 | #include "PSPNetDecoder.h"
2 |
3 | PSPBlockImpl::PSPBlockImpl(int in_channels, int out_channels, int pool_size, bool use_bathcnorm) {
4 | if (pool_size == 1)
5 | use_bathcnorm = false;
6 | pool = torch::nn::Sequential(torch::nn::AdaptiveAvgPool2d(torch::nn::AdaptiveAvgPool2dOptions(pool_size)));
7 | conv = Conv2dReLU(in_channels, out_channels, 1, 0, 1, use_bathcnorm);
8 |
9 | register_module("pool", pool);
10 | register_module("conv", conv);
11 | }
12 |
13 | torch::Tensor PSPBlockImpl::forward(torch::Tensor x) {
14 | auto h = x.sizes()[2];
15 | auto w = x.sizes()[3];
16 | x = pool->forward(x);
17 | x = conv->forward(x);
18 | x = at::upsample_bilinear2d(x, { h ,w }, true);
19 | return x;
20 | }
21 |
22 |
23 | PSPModuleImpl::PSPModuleImpl(int in_channels, std::vector _sizes, bool use_bathcnorm) {
24 | for (auto size : _sizes) {
25 | blocks->push_back(PSPBlock(in_channels, in_channels / _sizes.size(), size, use_bathcnorm));
26 | }
27 | register_module("blocks", blocks);
28 | }
29 |
30 | torch::Tensor PSPModuleImpl::forward(torch::Tensor x) {
31 | std::vector xs;
32 | for (int i = 0; i < blocks->size(); i++) {
33 | xs.push_back(blocks[i]->as()->forward(x));
34 | }
35 | xs.push_back(x);
36 | x = torch::cat(xs, 1);
37 | return x;
38 | }
39 |
40 | PSPDecoderImpl::PSPDecoderImpl(std::vector encoder_channels, int out_channels, double _dropout,
41 | bool use_batchnorm, int _encoder_depth) {
42 | encoder_depth = _encoder_depth;
43 | std::vector size = { 1, 2, 3, 6 };
44 | psp = PSPModule(encoder_channels[encoder_depth], size, use_batchnorm);
45 | conv = Conv2dReLU(encoder_channels[encoder_depth] * 2, out_channels, 1, 0, 1, use_batchnorm);
46 | dropout = torch::nn::Dropout2d(torch::nn::Dropout2dOptions(_dropout));
47 |
48 | register_module("psp", psp);
49 | register_module("conv", conv);
50 | }
51 |
52 | torch::Tensor PSPDecoderImpl::forward(std::vector features) {
53 | auto x = features[features.size()-1];
54 | x = psp->forward(x);
55 | x = conv->forward(x);
56 | x = dropout->forward(x);
57 | return x;
58 | }
--------------------------------------------------------------------------------
/src/architectures/PSPNetDecoder.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "LinknetDecoder.h"
3 |
4 | class PSPBlockImpl : public torch::nn::Module {
5 | public:
6 | PSPBlockImpl(int in_channels, int out_channels, int pool_size, bool use_bathcnorm = true);
7 | torch::Tensor forward(torch::Tensor x);
8 | private:
9 | torch::nn::Sequential pool{ nullptr };
10 | torch::nn::Sequential conv{ nullptr };
11 | }; TORCH_MODULE(PSPBlock);
12 |
13 |
14 | class PSPModuleImpl : public torch::nn::Module {
15 | public:
16 | PSPModuleImpl(int in_channels, std::vector sizes, bool use_bathcnorm = true);
17 | torch::Tensor forward(torch::Tensor x);
18 | private:
19 | torch::nn::ModuleList blocks;
20 | }; TORCH_MODULE(PSPModule);
21 |
22 | class PSPDecoderImpl : public torch::nn::Module {
23 | public:
24 | PSPDecoderImpl(std::vector encoder_channels, int out_channels = 512, double dropout = 0.2, bool use_batchnorm = true, int encoder_depth = 3);
25 | torch::Tensor forward(std::vector x);
26 | private:
27 | int encoder_depth = 3;
28 | PSPModule psp{ nullptr };
29 | torch::nn::Sequential conv{ nullptr };
30 | torch::nn::Dropout2d dropout{ nullptr };
31 | }; TORCH_MODULE(PSPDecoder);
32 |
33 |
--------------------------------------------------------------------------------
/src/architectures/UNet.cpp:
--------------------------------------------------------------------------------
1 | #include "UNet.h"
2 |
3 | UNetImpl::UNetImpl(int _num_classes, std::string encoder_name, std::string pretrained_path, int encoder_depth,
4 | std::vector decoder_channels, bool use_attention){
5 | num_classes = _num_classes;
6 | auto encoder_param = encoder_params();
7 | std::vector encoder_channels = encoder_param[encoder_name]["out_channels"];
8 | if (!encoder_param.contains(encoder_name))
9 | std::cout<< "encoder name must in {resnet18, resnet34, resnet50, resnet101, resnet150, \
10 | resnext50_32x4d, resnext101_32x8d, vgg11, vgg11_bn, vgg13, vgg13_bn, \
11 | vgg16, vgg16_bn, vgg19, vgg19_bn,}";
12 | if (encoder_param[encoder_name]["class_type"] == "resnet")
13 | encoder = new ResNetImpl(encoder_param[encoder_name]["layers"], 1000, encoder_name);
14 | else if (encoder_param[encoder_name]["class_type"] == "vgg")
15 | encoder = new VGGImpl(encoder_param[encoder_name]["cfg"], 1000, encoder_param[encoder_name]["batch_norm"]);
16 | else std::cout<< "unknown error in backbone initialization";
17 |
18 | encoder->load_pretrained(pretrained_path);
19 | decoder = UNetDecoder(encoder_channels,decoder_channels, encoder_depth, use_attention, false);
20 | segmentation_head = SegmentationHead(decoder_channels[decoder_channels.size()-1], num_classes, 1, 1);
21 |
22 | register_module("encoder",std::shared_ptr(encoder));
23 | register_module("decoder",decoder);
24 | register_module("segmentation_head",segmentation_head);
25 | }
26 |
27 | torch::Tensor UNetImpl::forward(torch::Tensor x){
28 | std::vector features = encoder->features(x);
29 | x = decoder->forward(features);
30 | x = segmentation_head->forward(x);
31 | return x;
32 | }
33 |
--------------------------------------------------------------------------------
/src/architectures/UNet.h:
--------------------------------------------------------------------------------
1 | /*
2 | This libtorch implementation is writen by AllentDan.
3 | Copyright(c) AllentDan 2021,
4 | All rights reserved.
5 | */
6 | #ifndef UNET_H
7 | #define UNET_H
8 | #include"../backbones/ResNet.h"
9 | #include"../backbones/VGG.h"
10 | #include"UNetDecoder.h"
11 |
12 | class UNetImpl : public torch::nn::Module
13 | {
14 | public:
15 | UNetImpl() {}
16 | ~UNetImpl() {
17 | //delete encoder;
18 | }
19 | UNetImpl(int num_classes, std::string encoder_name = "resnet18", std::string pretrained_path = "", int encoder_depth = 5,
20 | std::vector decoder_channels={256, 128, 64, 32, 16}, bool use_attention = false);
21 | torch::Tensor forward(torch::Tensor x);
22 | private:
23 | Backbone *encoder;
24 | UNetDecoder decoder{nullptr};
25 | SegmentationHead segmentation_head{nullptr};
26 | int num_classes = 1;
27 | std::vector BasicChannels = {3, 64, 64, 128, 256, 512};
28 | std::vector BottleChannels = {3, 64, 256, 512, 1024, 2048};
29 | std::map> name2layers = getParams();
30 | };TORCH_MODULE(UNet);
31 |
32 | #endif // UNET_H
33 |
--------------------------------------------------------------------------------
/src/architectures/UNetDecoder.cpp:
--------------------------------------------------------------------------------
1 | #include "UNetDecoder.h"
2 |
3 | SCSEModuleImpl::SCSEModuleImpl(int in_channels, int reduction, bool _use_attention){
4 | use_attention = _use_attention;
5 | cSE = torch::nn::Sequential(
6 | torch::nn::AdaptiveAvgPool2d(torch::nn::AdaptiveAvgPool2dOptions(1)),
7 | torch::nn::Conv2d(conv_options(in_channels, in_channels / reduction, 1)),
8 | torch::nn::ReLU(torch::nn::ReLUOptions(true)),
9 | torch::nn::Conv2d(conv_options(in_channels / reduction, in_channels, 1)),
10 | torch::nn::Sigmoid());
11 | sSE = torch::nn::Sequential(torch::nn::Conv2d(conv_options(in_channels, 1, 1)), torch::nn::Sigmoid());
12 | register_module("cSE",cSE);
13 | register_module("sSE",sSE);
14 | }
15 |
16 | torch::Tensor SCSEModuleImpl::forward(torch::Tensor x){
17 | if(!use_attention) return x;
18 | return x * cSE->forward(x) + x * sSE->forward(x);
19 | }
20 |
21 | Conv2dReLUImpl::Conv2dReLUImpl(int in_channels, int out_channels, int kernel_size, int padding){
22 | conv2d = torch::nn::Conv2d(conv_options(in_channels,out_channels,kernel_size,1,padding));
23 | bn = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(out_channels));
24 | register_module("conv2d", conv2d);
25 | register_module("bn", bn);
26 | }
27 |
28 | torch::Tensor Conv2dReLUImpl::forward(torch::Tensor x){
29 | x = conv2d->forward(x);
30 | x = bn->forward(x);
31 | return x;
32 | }
33 |
34 | DecoderBlockImpl::DecoderBlockImpl(int in_channels, int skip_channels, int out_channels, bool skip, bool attention){
35 | conv1 = Conv2dReLU(in_channels + skip_channels, out_channels, 3, 1);
36 | conv2 = Conv2dReLU(out_channels, out_channels, 3, 1);
37 | register_module("conv1", conv1);
38 | register_module("conv2", conv2);
39 | upsample = torch::nn::Upsample(torch::nn::UpsampleOptions().scale_factor(std::vector({2,2})).mode(torch::kNearest));
40 |
41 | attention1 = SCSEModule(in_channels + skip_channels, 16, attention);
42 | attention2 = SCSEModule(out_channels, 16, attention);
43 | register_module("attention1", attention1);
44 | register_module("attention2", attention2);
45 | is_skip = skip;
46 | }
47 |
48 | torch::Tensor DecoderBlockImpl::forward(torch::Tensor x, torch::Tensor skip){
49 | x = upsample->forward(x);
50 | if (is_skip){
51 | x = torch::cat({x, skip}, 1);
52 | x = attention1->forward(x);
53 | }
54 | x = conv1->forward(x);
55 | x = conv2->forward(x);
56 | x = attention2->forward(x);
57 | return x;
58 | }
59 |
60 | torch::nn::Sequential CenterBlock(int in_channels, int out_channels){
61 | return torch::nn::Sequential(Conv2dReLU(in_channels, out_channels, 3, 1),
62 | Conv2dReLU(out_channels, out_channels, 3, 1));
63 | }
64 |
65 | UNetDecoderImpl::UNetDecoderImpl(std::vector encoder_channels, std::vector decoder_channels, int n_blocks,
66 | bool use_attention, bool use_center)
67 | {
68 | if (n_blocks != decoder_channels.size()) std::cout<< "Model depth not equal to your provided `decoder_channels`";
69 | std::reverse(std::begin(encoder_channels),std::end(encoder_channels));
70 |
71 | // computing blocks input and output channels
72 | int head_channels = encoder_channels[0];
73 | std::vector out_channels = decoder_channels;
74 | decoder_channels.pop_back();
75 | decoder_channels.insert(decoder_channels.begin(),head_channels);
76 | std::vector in_channels = decoder_channels;
77 | encoder_channels.erase(encoder_channels.begin());
78 | std::vector skip_channels = encoder_channels;
79 | skip_channels[skip_channels.size()-1] = 0;
80 |
81 | if(use_center) center = CenterBlock(head_channels, head_channels);
82 | else center = torch::nn::Sequential(torch::nn::Identity());
83 | //the last DecoderBlock of blocks need no skip tensor
84 | for (int i = 0; i< in_channels.size()-1; i++) {
85 | blocks->push_back(DecoderBlock(in_channels[i], skip_channels[i], out_channels[i], true, use_attention));
86 | }
87 | blocks->push_back(DecoderBlock(in_channels[in_channels.size()-1], skip_channels[in_channels.size()-1],
88 | out_channels[in_channels.size()-1], false, use_attention));
89 |
90 | register_module("center", center);
91 | register_module("blocks", blocks);
92 | }
93 |
94 | torch::Tensor UNetDecoderImpl::forward(std::vector features){
95 | std::reverse(std::begin(features),std::end(features));
96 | torch::Tensor head = features[0];
97 | features.erase(features.begin());
98 | auto x = center->forward(head);
99 | for (int i = 0; isize(); i++) {
100 | x = blocks[i]->as()->forward(x, features[i]);
101 | }
102 | return x;
103 | }
104 |
--------------------------------------------------------------------------------
/src/architectures/UNetDecoder.h:
--------------------------------------------------------------------------------
1 | #ifndef UNETDECODER_H
2 | #define UNETDECODER_H
3 | #include"../utils/util.h"
4 |
5 | //attention and basic
6 | class SCSEModuleImpl: public torch::nn::Module{
7 | public:
8 | SCSEModuleImpl(int in_channels, int reduction=16, bool use_attention = false);
9 | torch::Tensor forward(torch::Tensor x);
10 | private:
11 | bool use_attention = false;
12 | torch::nn::Sequential cSE{nullptr};
13 | torch::nn::Sequential sSE{nullptr};
14 | };TORCH_MODULE(SCSEModule);
15 |
16 | class Conv2dReLUImpl: public torch::nn::Module{
17 | public:
18 | Conv2dReLUImpl(int in_channels, int out_channels, int kernel_size = 3, int padding = 1);
19 | torch::Tensor forward(torch::Tensor x);
20 | private:
21 | torch::nn::Conv2d conv2d{nullptr};
22 | torch::nn::BatchNorm2d bn{nullptr};
23 | };TORCH_MODULE(Conv2dReLU);
24 |
25 | //decoderblock and center block
26 | class DecoderBlockImpl: public torch::nn::Module{
27 | public:
28 | DecoderBlockImpl(int in_channels, int skip_channels, int out_channels, bool skip = true, bool attention = false);
29 | torch::Tensor forward(torch::Tensor x, torch::Tensor skip);
30 | private:
31 | Conv2dReLU conv1{nullptr};
32 | Conv2dReLU conv2{nullptr};
33 | SCSEModule attention1{nullptr};
34 | SCSEModule attention2{nullptr};
35 | torch::nn::Upsample upsample{nullptr};
36 | bool is_skip = true;
37 | };TORCH_MODULE(DecoderBlock);
38 |
39 | torch::nn::Sequential CenterBlock(int in_channels, int out_channels);
40 |
41 | class UNetDecoderImpl:public torch::nn::Module
42 | {
43 | public:
44 | UNetDecoderImpl(std::vector encoder_channels, std::vector decoder_channels, int n_blocks = 5,
45 | bool use_attention = false, bool use_center=false);
46 | torch::Tensor forward(std::vector features);
47 | private:
48 | torch::nn::Sequential center{nullptr};
49 | torch::nn::ModuleList blocks = torch::nn::ModuleList();
50 | };TORCH_MODULE(UNetDecoder);
51 |
52 | #endif // UNETDECODER_H
53 |
--------------------------------------------------------------------------------
/src/backbones/ResNet.cpp:
--------------------------------------------------------------------------------
1 | #include "ResNet.h"
2 |
3 | BlockImpl::BlockImpl(int64_t inplanes, int64_t planes, int64_t stride_,
4 | torch::nn::Sequential downsample_, int groups, int base_width, bool _is_basic)
5 | {
6 | downsample = downsample_;
7 | stride = stride_;
8 | int width = int(planes * (base_width / 64.)) * groups;
9 |
10 | conv1 = torch::nn::Conv2d(conv_options(inplanes, width, 3, stride_, 1, groups, false));
11 | bn1 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(width));
12 | conv2 = torch::nn::Conv2d(conv_options(width, width, 3, 1, 1, groups, false));
13 | bn2 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(width));
14 | is_basic = _is_basic;
15 | if (!is_basic) {
16 | conv1 = torch::nn::Conv2d(conv_options(inplanes, width, 1, 1, 0, 1, false));
17 | conv2 = torch::nn::Conv2d(conv_options(width, width, 3, stride_, 1, groups, false));
18 | conv3 = torch::nn::Conv2d(conv_options(width, planes * 4, 1, 1, 0, 1, false));
19 | bn3 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(planes * 4));
20 | }
21 |
22 | register_module("conv1", conv1);
23 | register_module("bn1", bn1);
24 | register_module("conv2", conv2);
25 | register_module("bn2", bn2);
26 | if (!is_basic) {
27 | register_module("conv3", conv3);
28 | register_module("bn3", bn3);
29 | }
30 |
31 | if (!downsample->is_empty()) {
32 | register_module("downsample", downsample);
33 | }
34 | }
35 |
36 | torch::Tensor BlockImpl::forward(torch::Tensor x) {
37 | torch::Tensor residual = x.clone();
38 |
39 | x = conv1->forward(x);
40 | x = bn1->forward(x);
41 | x = torch::relu(x);
42 |
43 | x = conv2->forward(x);
44 | x = bn2->forward(x);
45 |
46 | if (!is_basic) {
47 | x = torch::relu(x);
48 | x = conv3->forward(x);
49 | x = bn3->forward(x);
50 | }
51 |
52 | if (!downsample->is_empty()) {
53 | residual = downsample->forward(residual);
54 | }
55 |
56 | x += residual;
57 | x = torch::relu(x);
58 |
59 | return x;
60 | }
61 |
62 | ResNetImpl::ResNetImpl(std::vector layers, int num_classes, std::string _model_type, int _groups, int _width_per_group)
63 | {
64 | model_type = _model_type;
65 | if (model_type != "resnet18" && model_type != "resnet34")
66 | {
67 | expansion = 4;
68 | is_basic = false;
69 | }
70 | if (model_type == "resnext50_32x4d") {
71 | groups = 32; base_width = 4;
72 | }
73 | if (model_type == "resnext101_32x8d") {
74 | groups = 32; base_width = 8;
75 | }
76 | conv1 = torch::nn::Conv2d(conv_options(3, 64, 7, 2, 3, 1, false));
77 | bn1 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(64));
78 | layer1 = torch::nn::Sequential(_make_layer(64, layers[0]));
79 | layer2 = torch::nn::Sequential(_make_layer(128, layers[1], 2));
80 | layer3 = torch::nn::Sequential(_make_layer(256, layers[2], 2));
81 | layer4 = torch::nn::Sequential(_make_layer(512, layers[3], 2));
82 |
83 | fc = torch::nn::Linear(512 * expansion, num_classes);
84 | register_module("conv1", conv1);
85 | register_module("bn1", bn1);
86 | register_module("layer1", layer1);
87 | register_module("layer2", layer2);
88 | register_module("layer3", layer3);
89 | register_module("layer4", layer4);
90 | register_module("fc", fc);
91 | }
92 |
93 |
94 | torch::Tensor ResNetImpl::forward(torch::Tensor x) {
95 | x = conv1->forward(x);
96 | x = bn1->forward(x);
97 | x = torch::relu(x);
98 | x = torch::max_pool2d(x, 3, 2, 1);
99 |
100 | x = layer1->forward(x);
101 | x = layer2->forward(x);
102 | x = layer3->forward(x);
103 | x = layer4->forward(x);
104 |
105 | x = torch::avg_pool2d(x, 7, 1);
106 | x = x.view({ x.sizes()[0], -1 });
107 | x = fc->forward(x);
108 |
109 | return torch::log_softmax(x, 1);
110 | }
111 |
112 | std::vector ResNetImpl::get_stages() {
113 | std::vector ans;
114 | ans.push_back(this->layer1);
115 | ans.push_back(this->layer2);
116 | ans.push_back(this->layer3);
117 | ans.push_back(this->layer4);
118 | return ans;
119 | }
120 |
121 | std::vector ResNetImpl::features(torch::Tensor x, int encoder_depth){
122 | std::vector features;
123 | features.push_back(x);
124 | x = conv1->forward(x);
125 | x = bn1->forward(x);
126 | x = torch::relu(x);
127 | features.push_back(x);
128 | x = torch::max_pool2d(x, 3, 2, 1);
129 |
130 | std::vector stages = get_stages();
131 | for (int i = 0; i < encoder_depth - 1; i++) {
132 | x = stages[i]->as()->forward(x);
133 | features.push_back(x);
134 | }
135 | //x = layer1->forward(x);
136 | //features.push_back(x);
137 | //x = layer2->forward(x);
138 | //features.push_back(x);
139 | //x = layer3->forward(x);
140 | //features.push_back(x);
141 | //x = layer4->forward(x);
142 | //features.push_back(x);
143 |
144 | return features;
145 | }
146 |
147 | torch::Tensor ResNetImpl::features_at(torch::Tensor x, int stage_num) {
148 | assert(stage_num > 0 && "the stage number must in range(1,5)");
149 | x = conv1->forward(x);
150 | x = bn1->forward(x);
151 | x = torch::relu(x);
152 | if (stage_num == 1) return x;
153 | x = torch::max_pool2d(x, 3, 2, 1);
154 |
155 | x = layer1->forward(x);
156 | if (stage_num == 2) return x;
157 | x = layer2->forward(x);
158 | if (stage_num == 3) return x;
159 | x = layer3->forward(x);
160 | if (stage_num == 4) return x;
161 | x = layer4->forward(x);
162 | if (stage_num == 5) return x;
163 | return x;
164 | }
165 |
166 | void ResNetImpl::load_pretrained(std::string pretrained_path) {
167 | std::map> name2layers = getParams();
168 | ResNet net_pretrained = ResNet(name2layers[model_type], 1000, model_type, groups, base_width);
169 | torch::load(net_pretrained, pretrained_path);
170 |
171 | torch::OrderedDict pretrained_dict = net_pretrained->named_parameters();
172 | torch::OrderedDict model_dict = this->named_parameters();
173 |
174 | for (auto n = pretrained_dict.begin(); n != pretrained_dict.end(); n++)
175 | {
176 | if (strstr((*n).key().data(), "fc.")) {
177 | continue;
178 | }
179 | model_dict[(*n).key()] = (*n).value();
180 | }
181 |
182 | torch::autograd::GradMode::set_enabled(false); // make parameters copying possible
183 | auto new_params = model_dict; // implement this
184 | auto params = this->named_parameters(true /*recurse*/);
185 | auto buffers = this->named_buffers(true /*recurse*/);
186 | for (auto& val : new_params) {
187 | auto name = val.key();
188 | auto* t = params.find(name);
189 | if (t != nullptr) {
190 | t->copy_(val.value());
191 | }
192 | else {
193 | t = buffers.find(name);
194 | if (t != nullptr) {
195 | t->copy_(val.value());
196 | }
197 | }
198 | }
199 | torch::autograd::GradMode::set_enabled(true);
200 | return;
201 | }
202 |
203 | torch::nn::Sequential ResNetImpl::_make_layer(int64_t planes, int64_t blocks, int64_t stride) {
204 |
205 | torch::nn::Sequential downsample;
206 | if (stride != 1 || inplanes != planes * expansion) {
207 | downsample = torch::nn::Sequential(
208 | torch::nn::Conv2d(conv_options(inplanes, planes * expansion, 1, stride, 0, 1, false)),
209 | torch::nn::BatchNorm2d(planes * expansion)
210 | );
211 | }
212 | torch::nn::Sequential layers;
213 | layers->push_back(Block(inplanes, planes, stride, downsample, groups, base_width, is_basic));
214 | inplanes = planes * expansion;
215 | for (int64_t i = 1; i < blocks; i++) {
216 | layers->push_back(Block(inplanes, planes, 1, torch::nn::Sequential(), groups, base_width,is_basic));
217 | }
218 |
219 | return layers;
220 | }
221 |
222 | void ResNetImpl::make_dilated(std::vector stage_list, std::vector dilation_list) {
223 | if (stage_list.size() != dilation_list.size()) {
224 | std::cout << "make sure stage list len equal to dilation list len";
225 | return;
226 | }
227 | std::map stage_dict = {};
228 | stage_dict.insert(std::pair(5, this->layer4));
229 | stage_dict.insert(std::pair(4, this->layer3));
230 | stage_dict.insert(std::pair(3, this->layer2));
231 | stage_dict.insert(std::pair(2, this->layer1));
232 | for (int i = 0; i < stage_list.size(); i++) {
233 | int dilation_rate = dilation_list[i];
234 | for (auto m : stage_dict[stage_list[i]]->modules()) {
235 | if (m->name() == "torch::nn::Conv2dImpl") {
236 | m->as()->options.stride(1);
237 | m->as()->options.dilation(dilation_rate);
238 | int kernel_size = m->as()->options.kernel_size()->at(0);
239 | m->as()->options.padding((kernel_size / 2) * dilation_rate);
240 | }
241 | }
242 | }
243 | return;
244 | }
245 |
246 | ResNet resnet18(int64_t num_classes) {
247 | std::vector layers = { 2, 2, 2, 2 };
248 | ResNet model(layers, num_classes, "resnet18");
249 | return model;
250 | }
251 |
252 | ResNet resnet34(int64_t num_classes) {
253 | std::vector layers = { 3, 4, 6, 3 };
254 | ResNet model(layers, num_classes, "resnet34");
255 | return model;
256 | }
257 |
258 | ResNet resnet50(int64_t num_classes) {
259 | std::vector layers = { 3, 4, 6, 3 };
260 | ResNet model(layers, num_classes, "resnet50");
261 | return model;
262 | }
263 |
264 | ResNet resnet101(int64_t num_classes) {
265 | std::vector layers = { 3, 4, 23, 3 };
266 | ResNet model(layers, num_classes, "resnet101");
267 | return model;
268 | }
269 |
270 | ResNet pretrained_resnet(int64_t num_classes, std::string model_name, std::string weight_path){
271 | std::map> name2layers = getParams();
272 | int groups = 1;
273 | int width_per_group = 64;
274 | if (model_name == "resnext50_32x4d") {
275 | groups = 32; width_per_group = 4;
276 | }
277 | if (model_name == "resnext101_32x8d") {
278 | groups = 32; width_per_group = 8;
279 | }
280 | ResNet net_pretrained = ResNet(name2layers[model_name],1000,model_name,groups,width_per_group);
281 | torch::load(net_pretrained, weight_path);
282 | if(num_classes == 1000) return net_pretrained;
283 | ResNet module = ResNet(name2layers[model_name],num_classes,model_name);
284 |
285 | torch::OrderedDict pretrained_dict = net_pretrained->named_parameters();
286 | torch::OrderedDict model_dict = module->named_parameters();
287 |
288 | for (auto n = pretrained_dict.begin(); n != pretrained_dict.end(); n++)
289 | {
290 | if (strstr((*n).key().data(), "fc.")) {
291 | continue;
292 | }
293 | model_dict[(*n).key()] = (*n).value();
294 | }
295 |
296 | torch::autograd::GradMode::set_enabled(false); // make parameters copying possible
297 | auto new_params = model_dict; // implement this
298 | auto params = module->named_parameters(true /*recurse*/);
299 | auto buffers = module->named_buffers(true /*recurse*/);
300 | for (auto& val : new_params) {
301 | auto name = val.key();
302 | auto* t = params.find(name);
303 | if (t != nullptr) {
304 | t->copy_(val.value());
305 | }
306 | else {
307 | t = buffers.find(name);
308 | if (t != nullptr) {
309 | t->copy_(val.value());
310 | }
311 | }
312 | }
313 | torch::autograd::GradMode::set_enabled(true);
314 | return module;
315 | }
316 |
--------------------------------------------------------------------------------
/src/backbones/ResNet.h:
--------------------------------------------------------------------------------
1 | /*
2 | This libtorch implementation is writen by AllentDan.
3 | Copyright(c) AllentDan 2021,
4 | All rights reserved.
5 | */
6 | #ifndef RESNET_H
7 | #define RESNET_H
8 | #include"../utils/util.h"
9 | #include"../utils/InterFace.h"
10 |
11 | /*
12 | In this implementation, bottleneck and basicblock are merged.
13 | */
14 | class BlockImpl : public torch::nn::Module {
15 | public:
16 | BlockImpl(int64_t inplanes, int64_t planes, int64_t stride_ = 1,
17 | torch::nn::Sequential downsample_ = nullptr, int groups = 1, int base_width = 64, bool is_basic = true);
18 | torch::Tensor forward(torch::Tensor x);
19 | torch::nn::Sequential downsample{ nullptr };
20 | private:
21 | bool is_basic = true;
22 | int64_t stride = 1;
23 | torch::nn::Conv2d conv1{ nullptr };
24 | torch::nn::BatchNorm2d bn1{ nullptr };
25 | torch::nn::Conv2d conv2{ nullptr };
26 | torch::nn::BatchNorm2d bn2{ nullptr };
27 | torch::nn::Conv2d conv3{ nullptr };
28 | torch::nn::BatchNorm2d bn3{ nullptr };
29 | };
30 | TORCH_MODULE(Block);
31 |
32 |
33 | class ResNetImpl : public Backbone{
34 | public:
35 | ResNetImpl(std::vector layers, int num_classes = 1000, std::string model_type = "resnet18",
36 | int groups = 1, int width_per_group = 64);
37 | torch::Tensor forward(torch::Tensor x);
38 | torch::nn::Sequential _make_layer(int64_t planes, int64_t blocks, int64_t stride = 1);
39 | std::vector get_stages();
40 |
41 | std::vector features(torch::Tensor x, int encoder_depth = 5) override;
42 | torch::Tensor features_at(torch::Tensor x, int stage_num) override;
43 | void make_dilated(std::vector stage_list, std::vector dilation_list) override;
44 | void load_pretrained(std::string pretrained_path) override;
45 | private:
46 | std::string model_type = "resnet18";
47 | int expansion = 1; bool is_basic = true;
48 | int64_t inplanes = 64; int groups = 1; int base_width = 64;
49 | torch::nn::Conv2d conv1{ nullptr };
50 | torch::nn::BatchNorm2d bn1{ nullptr };
51 | torch::nn::Sequential layer1{ nullptr };
52 | torch::nn::Sequential layer2{ nullptr };
53 | torch::nn::Sequential layer3{ nullptr };
54 | torch::nn::Sequential layer4{ nullptr };
55 | torch::nn::Linear fc{nullptr};
56 | };
57 | TORCH_MODULE(ResNet);
58 |
59 | inline std::map> getParams(){
60 | std::map> name2layers = {};
61 | name2layers.insert(std::pair>("resnet18",{2, 2, 2, 2}));
62 | name2layers.insert(std::pair>("resnet34",{3, 4, 6, 3}));
63 | name2layers.insert(std::pair>("resnet50",{3, 4, 6, 3}));
64 | name2layers.insert(std::pair>("resnet101",{3, 4, 23, 3}));
65 | name2layers.insert(std::pair>("resnet152", { 3, 8, 36, 3 }));
66 | name2layers.insert(std::pair>("resnext50_32x4d", { 3, 4, 6, 3 }));
67 | name2layers.insert(std::pair>("resnext101_32x8d", { 3, 4, 23, 3 }));
68 |
69 | return name2layers;
70 | }
71 |
72 | ResNet resnet18(int64_t num_classes);
73 | ResNet resnet34(int64_t num_classes);
74 | ResNet resnet50(int64_t num_classes);
75 | ResNet resnet101(int64_t num_classes);
76 |
77 | ResNet pretrained_resnet(int64_t num_classes, std::string model_name, std::string weight_path);
78 | #endif // RESNET_H
79 |
--------------------------------------------------------------------------------
/src/backbones/VGG.cpp:
--------------------------------------------------------------------------------
1 | #include "VGG.h"
2 |
3 | torch::nn::Sequential make_features(std::vector &cfg, bool batch_norm) {
4 | torch::nn::Sequential features;
5 | int in_channels = 3;
6 | for (auto v : cfg) {
7 | if (v == -1) {
8 | features->push_back(torch::nn::MaxPool2d(maxpool_options(2, 2)));
9 | }
10 | else {
11 | auto conv2d = torch::nn::Conv2d(conv_options(in_channels, v, 3, 1, 1));
12 | features->push_back(conv2d);
13 | if (batch_norm) {
14 | features->push_back(torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(v)));
15 | }
16 | features->push_back(torch::nn::ReLU(torch::nn::ReLUOptions(true)));
17 | in_channels = v;
18 | }
19 | }
20 | return features;
21 | }
22 |
23 | VGGImpl::VGGImpl(std::vector _cfg, int num_classes, bool batch_norm_) {
24 | cfg = _cfg;
25 | batch_norm = batch_norm_;
26 | features_ = make_features(cfg, batch_norm);
27 | avgpool = torch::nn::AdaptiveAvgPool2d(torch::nn::AdaptiveAvgPool2dOptions(7));
28 | classifier->push_back(torch::nn::Linear(torch::nn::LinearOptions(512 * 7 * 7, 4096)));
29 | classifier->push_back(torch::nn::ReLU(torch::nn::ReLUOptions(true)));
30 | classifier->push_back(torch::nn::Dropout());
31 | classifier->push_back(torch::nn::Linear(torch::nn::LinearOptions(4096, 4096)));
32 | classifier->push_back(torch::nn::ReLU(torch::nn::ReLUOptions(true)));
33 | classifier->push_back(torch::nn::Dropout());
34 | classifier->push_back(torch::nn::Linear(torch::nn::LinearOptions(4096, num_classes)));
35 |
36 | features_ = register_module("features", features_);
37 | classifier = register_module("classifier", classifier);
38 | }
39 |
40 | torch::Tensor VGGImpl::forward(torch::Tensor x) {
41 | x = features_->forward(x);
42 | x = avgpool(x);
43 | x = torch::flatten(x, 1);
44 | x = classifier->forward(x);
45 | return torch::log_softmax(x, 1);
46 | }
47 |
48 |
49 | std::vector VGGImpl::features(torch::Tensor x, int encoder_depth) {
50 | std::vector ans;
51 |
52 | int j = 0;// layer index of features_
53 | for (int i = 0; i < cfg.size(); i++) {
54 | if (cfg[i] == -1) {
55 | ans.push_back(x);
56 | if (ans.size() == encoder_depth )
57 | {
58 | break;
59 | }
60 | x = this->features_[j++]->as()->forward(x);
61 | }
62 | else {
63 | x = this->features_[j++]->as()->forward(x);
64 | if (batch_norm) {
65 | x = this->features_[j++]->as()->forward(x);
66 | }
67 | x = this->features_[j++]->as()->forward(x);
68 | }
69 | }
70 | if (ans.size() == encoder_depth && encoder_depth==5)
71 | {
72 | x = this->features_[j++]->as()->forward(x);
73 | ans.push_back(x);
74 | }
75 | return ans;
76 | }
77 |
78 | torch::Tensor VGGImpl::features_at(torch::Tensor x, int stage_num) {
79 | assert(stage_num > 0 && stage_num <=5 && "the stage number must in range[1,5]");
80 | int j = 0;
81 | int stage_count = 0;
82 | for (int i = 0; i < cfg.size(); i++) {
83 | if (cfg[i] == -1) {
84 | x = this->features_[j++]->as()->forward(x);
85 | stage_count++;
86 | if (stage_count == stage_num)
87 | return x;
88 | }
89 | else {
90 | x = this->features_[j++]->as()->forward(x);
91 | if (batch_norm) {
92 | x = this->features_[j++]->as()->forward(x);
93 | }
94 | x = this->features_[j++]->as()->forward(x);
95 | }
96 | }
97 | return x;
98 | }
99 |
100 | void VGGImpl::load_pretrained(std::string pretrained_path) {
101 | VGG net_pretrained = VGG(cfg, 1000, batch_norm);
102 | torch::load(net_pretrained, pretrained_path);
103 |
104 | torch::OrderedDict pretrained_dict = net_pretrained->named_parameters();
105 | torch::OrderedDict model_dict = this->named_parameters();
106 |
107 | for (auto n = pretrained_dict.begin(); n != pretrained_dict.end(); n++)
108 | {
109 | if (strstr((*n).key().data(), "classifier.")) {
110 | continue;
111 | }
112 | model_dict[(*n).key()] = (*n).value();
113 | }
114 |
115 | torch::autograd::GradMode::set_enabled(false); // make parameters copying possible
116 | auto new_params = model_dict; // implement this
117 | auto params = this->named_parameters(true /*recurse*/);
118 | auto buffers = this->named_buffers(true /*recurse*/);
119 | for (auto& val : new_params) {
120 | auto name = val.key();
121 | auto* t = params.find(name);
122 | if (t != nullptr) {
123 | t->copy_(val.value());
124 | }
125 | else {
126 | t = buffers.find(name);
127 | if (t != nullptr) {
128 | t->copy_(val.value());
129 | }
130 | }
131 | }
132 | torch::autograd::GradMode::set_enabled(true);
133 | return;
134 | }
135 |
136 | void VGGImpl::make_dilated(std::vector stage_list, std::vector dilation_list) {
137 | std::cout<< "'VGG' models do not support dilated mode due to Max Pooling operations for downsampling!";
138 | return;
139 | }
140 |
--------------------------------------------------------------------------------
/src/backbones/VGG.h:
--------------------------------------------------------------------------------
1 | /*
2 | This libtorch implementation is writen by AllentDan.
3 | Copyright(c) AllentDan 2021,
4 | All rights reserved.
5 | */
6 | #pragma once
7 | #include"../utils/util.h"
8 | #include"../utils/InterFace.h"
9 | //according to make_features function in torchvisio, return torch::nn::Sequential instance
10 | torch::nn::Sequential make_features(std::vector &cfg, bool batch_norm);
11 |
12 | //declare VGG, including initialization and forward
13 | class VGGImpl : public Backbone
14 | {
15 | private:
16 | torch::nn::Sequential features_{ nullptr };
17 | torch::nn::AdaptiveAvgPool2d avgpool{ nullptr };
18 | torch::nn::Sequential classifier;
19 | std::vector cfg = {};
20 | bool batch_norm;
21 |
22 | public:
23 | VGGImpl(std::vector cfg, int num_classes = 1000, bool batch_norm = false);
24 | torch::Tensor forward(torch::Tensor x);
25 |
26 | std::vector features(torch::Tensor x, int encoder_depth = 5) override;
27 | torch::Tensor features_at(torch::Tensor x, int stage_num) override;
28 | void make_dilated(std::vector stage_list, std::vector dilation_list) override;
29 | void load_pretrained(std::string pretrained_path) override;
30 | };
31 | TORCH_MODULE(VGG);
32 |
--------------------------------------------------------------------------------
/src/utils/Augmentations.cpp:
--------------------------------------------------------------------------------
1 | #include "Augmentations.h"
2 |
3 | template
4 | T RandomNum(T _min, T _max)
5 | {
6 | T temp;
7 | if (_min > _max)
8 | {
9 | temp = _max;
10 | _max = _min;
11 | _min = temp;
12 | }
13 | return rand() / (double)RAND_MAX *(_max - _min) + _min;
14 | }
15 |
16 |
17 | cv::Mat centerCrop(cv::Mat srcImage, int width, int height) {
18 | int srcHeight = srcImage.rows;
19 | int srcWidth = srcImage.cols;
20 | int maxHeight = srcHeight > height ? srcHeight : height;
21 | int maxWidth = srcWidth > width ? srcWidth : width;
22 | cv::Mat maxImage = cv::Mat::zeros(cv::Size(maxWidth, maxHeight), srcImage.type());
23 | int h_max_start = int((maxHeight - srcHeight) / 2);
24 | int w_max_start = int((maxWidth - srcWidth) / 2);
25 | srcImage.clone().copyTo(maxImage(cv::Rect(w_max_start, h_max_start, srcWidth, srcHeight)));
26 |
27 | int h_start = int((maxHeight - height) / 2);
28 | int w_start = int((maxWidth - width) / 2);
29 | cv::Mat dstImage = maxImage(cv::Rect(w_start, h_start, width, height)).clone();
30 | return dstImage;
31 | }
32 |
33 | cv::Mat RotateImage(cv::Mat src, float angle, float scale, int interpolation, int boder_mode)
34 | {
35 | cv::Mat dst;
36 |
37 | //make output size same with input after scaling
38 | cv::Size dst_sz(src.cols, src.rows);
39 | scale = 1 + scale;
40 | cv::resize(src, src, cv::Size(int(src.cols*scale), int(src.rows*scale)));
41 | src = centerCrop(src, dst_sz.width, dst_sz.height);
42 |
43 | //center for rotating
44 | cv::Point2f center(static_cast(src.cols / 2.), static_cast(src.rows / 2.));
45 |
46 | //rotate matrix
47 | cv::Mat rot_mat = cv::getRotationMatrix2D(center, angle, 1.0);
48 |
49 | cv::warpAffine(src, dst, rot_mat, dst_sz, interpolation, boder_mode);
50 | return dst;
51 | }
52 |
53 |
54 | Data Augmentations::Resize(Data mData, int width, int height, float probability) {
55 | float rand_number = RandomNum(0, 1);
56 | if (rand_number <= probability) {
57 |
58 | float h_scale = height * 1.0 / mData.image.rows;
59 | float w_scale = width * 1.0 / mData.image.cols;
60 |
61 | cv::resize(mData.image, mData.image, cv::Size(width, height));
62 | cv::resize(mData.mask, mData.mask, cv::Size(width, height));
63 | }
64 | return mData;
65 | }
66 |
67 | Data Augmentations::HorizontalFlip(Data mData, float probability) {
68 | float rand_number = RandomNum(0, 1);
69 | if (rand_number <= probability) {
70 |
71 | cv::flip(mData.image, mData.image, 1);
72 | cv::flip(mData.mask, mData.mask, 1);
73 |
74 | }
75 | return mData;
76 | }
77 |
78 | Data Augmentations::VerticalFlip(Data mData, float probability) {
79 | float rand_number = RandomNum(0, 1);
80 | if (rand_number <= probability) {
81 |
82 | cv::flip(mData.image, mData.image, 0);
83 | cv::flip(mData.mask, mData.mask, 0);
84 |
85 | }
86 | return mData;
87 | }
88 |
89 | Data Augmentations::RandomScaleRotate(Data mData, float probability, float rotate_limit, float scale_limit, int interpolation, int boder_mode) {
90 | float rand_number = RandomNum(0, 1);
91 | if (rand_number <= probability) {
92 | float angle = RandomNum(-rotate_limit, rotate_limit);
93 | float scale = RandomNum(-scale_limit, scale_limit);
94 | mData.image = RotateImage(mData.image, angle, scale, interpolation, boder_mode);
95 | mData.mask = RotateImage(mData.mask, angle, scale, interpolation, boder_mode);
96 | return mData;
97 | }
98 | return mData;
99 | }
--------------------------------------------------------------------------------
/src/utils/Augmentations.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | //contains mask and source image
5 | struct Data {
6 | Data(cv::Mat img, cv::Mat _mask) :image(img), mask(_mask) {};
7 | cv::Mat image;
8 | cv::Mat mask;
9 | };
10 |
11 | class Augmentations
12 | {
13 | public:
14 | static Data Resize(Data mData, int width, int height, float probability);
15 | static Data HorizontalFlip(Data mData, float probability);
16 | static Data VerticalFlip(Data mData, float probability);
17 | static Data RandomScaleRotate(Data mData, float probability, float rotate_limit, \
18 | float scale_limit, int interpolation, int boder_mode);
19 | };
20 |
21 |
22 |
--------------------------------------------------------------------------------
/src/utils/InterFace.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 |
5 | class Backbone : public torch::nn::Module
6 | {
7 | public:
8 | virtual std::vector features(torch::Tensor x, int encoder_depth = 5) = 0;
9 | virtual torch::Tensor features_at(torch::Tensor x, int stage_num) = 0;
10 | virtual void load_pretrained(std::string pretrained_path)=0;
11 | virtual void make_dilated(std::vector stage_list, std::vector dilation_list)=0;
12 | virtual ~Backbone() {}
13 | };
14 |
--------------------------------------------------------------------------------
/src/utils/_dirent.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Dirent interface for Microsoft Visual Studio
3 | *
4 | * Copyright (C) 1998-2019 Toni Ronkko
5 | * This file is part of dirent. Dirent may be freely distributed
6 | * under the MIT license. For all details and documentation, see
7 | * https://github.com/tronkko/dirent
8 | */
9 | #ifndef DIRENT_H
10 | #define DIRENT_H
11 |
12 | /* Hide warnings about unreferenced local functions */
13 | #if defined(__clang__)
14 | # pragma clang diagnostic ignored "-Wunused-function"
15 | #elif defined(_MSC_VER)
16 | # pragma warning(disable:4505)
17 | #elif defined(__GNUC__)
18 | # pragma GCC diagnostic ignored "-Wunused-function"
19 | #endif
20 |
21 | /*
22 | * Include windows.h without Windows Sockets 1.1 to prevent conflicts with
23 | * Windows Sockets 2.0.
24 | */
25 | #ifndef WIN32_LEAN_AND_MEAN
26 | # define WIN32_LEAN_AND_MEAN
27 | #endif
28 | #include
29 |
30 | #include
31 | #include
32 | #include
33 | #include
34 | #include
35 | #include
36 | #include
37 | #include
38 | #include
39 | #include
40 |
41 | /* Indicates that d_type field is available in dirent structure */
42 | #define _DIRENT_HAVE_D_TYPE
43 |
44 | /* Indicates that d_namlen field is available in dirent structure */
45 | #define _DIRENT_HAVE_D_NAMLEN
46 |
47 | /* Entries missing from MSVC 6.0 */
48 | #if !defined(FILE_ATTRIBUTE_DEVICE)
49 | # define FILE_ATTRIBUTE_DEVICE 0x40
50 | #endif
51 |
52 | /* File type and permission flags for stat(), general mask */
53 | #if !defined(S_IFMT)
54 | # define S_IFMT _S_IFMT
55 | #endif
56 |
57 | /* Directory bit */
58 | #if !defined(S_IFDIR)
59 | # define S_IFDIR _S_IFDIR
60 | #endif
61 |
62 | /* Character device bit */
63 | #if !defined(S_IFCHR)
64 | # define S_IFCHR _S_IFCHR
65 | #endif
66 |
67 | /* Pipe bit */
68 | #if !defined(S_IFFIFO)
69 | # define S_IFFIFO _S_IFFIFO
70 | #endif
71 |
72 | /* Regular file bit */
73 | #if !defined(S_IFREG)
74 | # define S_IFREG _S_IFREG
75 | #endif
76 |
77 | /* Read permission */
78 | #if !defined(S_IREAD)
79 | # define S_IREAD _S_IREAD
80 | #endif
81 |
82 | /* Write permission */
83 | #if !defined(S_IWRITE)
84 | # define S_IWRITE _S_IWRITE
85 | #endif
86 |
87 | /* Execute permission */
88 | #if !defined(S_IEXEC)
89 | # define S_IEXEC _S_IEXEC
90 | #endif
91 |
92 | /* Pipe */
93 | #if !defined(S_IFIFO)
94 | # define S_IFIFO _S_IFIFO
95 | #endif
96 |
97 | /* Block device */
98 | #if !defined(S_IFBLK)
99 | # define S_IFBLK 0
100 | #endif
101 |
102 | /* Link */
103 | #if !defined(S_IFLNK)
104 | # define S_IFLNK 0
105 | #endif
106 |
107 | /* Socket */
108 | #if !defined(S_IFSOCK)
109 | # define S_IFSOCK 0
110 | #endif
111 |
112 | /* Read user permission */
113 | #if !defined(S_IRUSR)
114 | # define S_IRUSR S_IREAD
115 | #endif
116 |
117 | /* Write user permission */
118 | #if !defined(S_IWUSR)
119 | # define S_IWUSR S_IWRITE
120 | #endif
121 |
122 | /* Execute user permission */
123 | #if !defined(S_IXUSR)
124 | # define S_IXUSR 0
125 | #endif
126 |
127 | /* Read group permission */
128 | #if !defined(S_IRGRP)
129 | # define S_IRGRP 0
130 | #endif
131 |
132 | /* Write group permission */
133 | #if !defined(S_IWGRP)
134 | # define S_IWGRP 0
135 | #endif
136 |
137 | /* Execute group permission */
138 | #if !defined(S_IXGRP)
139 | # define S_IXGRP 0
140 | #endif
141 |
142 | /* Read others permission */
143 | #if !defined(S_IROTH)
144 | # define S_IROTH 0
145 | #endif
146 |
147 | /* Write others permission */
148 | #if !defined(S_IWOTH)
149 | # define S_IWOTH 0
150 | #endif
151 |
152 | /* Execute others permission */
153 | #if !defined(S_IXOTH)
154 | # define S_IXOTH 0
155 | #endif
156 |
157 | /* Maximum length of file name */
158 | #if !defined(PATH_MAX)
159 | # define PATH_MAX MAX_PATH
160 | #endif
161 | #if !defined(FILENAME_MAX)
162 | # define FILENAME_MAX MAX_PATH
163 | #endif
164 | #if !defined(NAME_MAX)
165 | # define NAME_MAX FILENAME_MAX
166 | #endif
167 |
168 | /* File type flags for d_type */
169 | #define DT_UNKNOWN 0
170 | #define DT_REG S_IFREG
171 | #define DT_DIR S_IFDIR
172 | #define DT_FIFO S_IFIFO
173 | #define DT_SOCK S_IFSOCK
174 | #define DT_CHR S_IFCHR
175 | #define DT_BLK S_IFBLK
176 | #define DT_LNK S_IFLNK
177 |
178 | /* Macros for converting between st_mode and d_type */
179 | #define IFTODT(mode) ((mode) & S_IFMT)
180 | #define DTTOIF(type) (type)
181 |
182 | /*
183 | * File type macros. Note that block devices, sockets and links cannot be
184 | * distinguished on Windows and the macros S_ISBLK, S_ISSOCK and S_ISLNK are
185 | * only defined for compatibility. These macros should always return false
186 | * on Windows.
187 | */
188 | #if !defined(S_ISFIFO)
189 | # define S_ISFIFO(mode) (((mode) & S_IFMT) == S_IFIFO)
190 | #endif
191 | #if !defined(S_ISDIR)
192 | # define S_ISDIR(mode) (((mode) & S_IFMT) == S_IFDIR)
193 | #endif
194 | #if !defined(S_ISREG)
195 | # define S_ISREG(mode) (((mode) & S_IFMT) == S_IFREG)
196 | #endif
197 | #if !defined(S_ISLNK)
198 | # define S_ISLNK(mode) (((mode) & S_IFMT) == S_IFLNK)
199 | #endif
200 | #if !defined(S_ISSOCK)
201 | # define S_ISSOCK(mode) (((mode) & S_IFMT) == S_IFSOCK)
202 | #endif
203 | #if !defined(S_ISCHR)
204 | # define S_ISCHR(mode) (((mode) & S_IFMT) == S_IFCHR)
205 | #endif
206 | #if !defined(S_ISBLK)
207 | # define S_ISBLK(mode) (((mode) & S_IFMT) == S_IFBLK)
208 | #endif
209 |
210 | /* Return the exact length of the file name without zero terminator */
211 | #define _D_EXACT_NAMLEN(p) ((p)->d_namlen)
212 |
213 | /* Return the maximum size of a file name */
214 | #define _D_ALLOC_NAMLEN(p) ((PATH_MAX)+1)
215 |
216 |
217 | #ifdef __cplusplus
218 | extern "C" {
219 | #endif
220 |
221 |
222 | /* Wide-character version */
223 | struct _wdirent {
224 | /* Always zero */
225 | long d_ino;
226 |
227 | /* File position within stream */
228 | long d_off;
229 |
230 | /* Structure size */
231 | unsigned short d_reclen;
232 |
233 | /* Length of name without \0 */
234 | size_t d_namlen;
235 |
236 | /* File type */
237 | int d_type;
238 |
239 | /* File name */
240 | wchar_t d_name[PATH_MAX + 1];
241 | };
242 | typedef struct _wdirent _wdirent;
243 |
244 | struct _WDIR {
245 | /* Current directory entry */
246 | struct _wdirent ent;
247 |
248 | /* Private file data */
249 | WIN32_FIND_DATAW data;
250 |
251 | /* True if data is valid */
252 | int cached;
253 |
254 | /* Win32 search handle */
255 | HANDLE handle;
256 |
257 | /* Initial directory name */
258 | wchar_t *patt;
259 | };
260 | typedef struct _WDIR _WDIR;
261 |
262 | /* Multi-byte character version */
263 | struct dirent {
264 | /* Always zero */
265 | long d_ino;
266 |
267 | /* File position within stream */
268 | long d_off;
269 |
270 | /* Structure size */
271 | unsigned short d_reclen;
272 |
273 | /* Length of name without \0 */
274 | size_t d_namlen;
275 |
276 | /* File type */
277 | int d_type;
278 |
279 | /* File name */
280 | char d_name[PATH_MAX + 1];
281 | };
282 | typedef struct dirent dirent;
283 |
284 | struct DIR {
285 | struct dirent ent;
286 | struct _WDIR *wdirp;
287 | };
288 | typedef struct DIR DIR;
289 |
290 |
291 | /* Dirent functions */
292 | static DIR *opendir(const char *dirname);
293 | static _WDIR *_wopendir(const wchar_t *dirname);
294 |
295 | static struct dirent *readdir(DIR *dirp);
296 | static struct _wdirent *_wreaddir(_WDIR *dirp);
297 |
298 | static int readdir_r(
299 | DIR *dirp, struct dirent *entry, struct dirent **result);
300 | static int _wreaddir_r(
301 | _WDIR *dirp, struct _wdirent *entry, struct _wdirent **result);
302 |
303 | static int closedir(DIR *dirp);
304 | static int _wclosedir(_WDIR *dirp);
305 |
306 | static void rewinddir(DIR* dirp);
307 | static void _wrewinddir(_WDIR* dirp);
308 |
309 | static int scandir(const char *dirname, struct dirent ***namelist,
310 | int(*filter)(const struct dirent*),
311 | int(*compare)(const struct dirent**, const struct dirent**));
312 |
313 | static int alphasort(const struct dirent **a, const struct dirent **b);
314 |
315 | static int versionsort(const struct dirent **a, const struct dirent **b);
316 |
317 | static int strverscmp(const char *a, const char *b);
318 |
319 | /* For compatibility with Symbian */
320 | #define wdirent _wdirent
321 | #define WDIR _WDIR
322 | #define wopendir _wopendir
323 | #define wreaddir _wreaddir
324 | #define wclosedir _wclosedir
325 | #define wrewinddir _wrewinddir
326 |
327 | /* Compatibility with older Microsoft compilers and non-Microsoft compilers */
328 | #if !defined(_MSC_VER) || _MSC_VER < 1400
329 | # define wcstombs_s dirent_wcstombs_s
330 | # define mbstowcs_s dirent_mbstowcs_s
331 | #endif
332 |
333 | /* Optimize dirent_set_errno() away on modern Microsoft compilers */
334 | #if defined(_MSC_VER) && _MSC_VER >= 1400
335 | # define dirent_set_errno _set_errno
336 | #endif
337 |
338 |
339 | /* Internal utility functions */
340 | static WIN32_FIND_DATAW *dirent_first(_WDIR *dirp);
341 | static WIN32_FIND_DATAW *dirent_next(_WDIR *dirp);
342 |
343 | #if !defined(_MSC_VER) || _MSC_VER < 1400
344 | static int dirent_mbstowcs_s(
345 | size_t *pReturnValue, wchar_t *wcstr, size_t sizeInWords,
346 | const char *mbstr, size_t count);
347 | #endif
348 |
349 | #if !defined(_MSC_VER) || _MSC_VER < 1400
350 | static int dirent_wcstombs_s(
351 | size_t *pReturnValue, char *mbstr, size_t sizeInBytes,
352 | const wchar_t *wcstr, size_t count);
353 | #endif
354 |
355 | #if !defined(_MSC_VER) || _MSC_VER < 1400
356 | static void dirent_set_errno(int error);
357 | #endif
358 |
359 |
360 | /*
361 | * Open directory stream DIRNAME for read and return a pointer to the
362 | * internal working area that is used to retrieve individual directory
363 | * entries.
364 | */
365 | static _WDIR *_wopendir(const wchar_t *dirname)
366 | {
367 | wchar_t *p;
368 |
369 | /* Must have directory name */
370 | if (dirname == NULL || dirname[0] == '\0') {
371 | dirent_set_errno(ENOENT);
372 | return NULL;
373 | }
374 |
375 | /* Allocate new _WDIR structure */
376 | _WDIR *dirp = (_WDIR*)malloc(sizeof(struct _WDIR));
377 | if (!dirp)
378 | return NULL;
379 |
380 | /* Reset _WDIR structure */
381 | dirp->handle = INVALID_HANDLE_VALUE;
382 | dirp->patt = NULL;
383 | dirp->cached = 0;
384 |
385 | /*
386 | * Compute the length of full path plus zero terminator
387 | *
388 | * Note that on WinRT there's no way to convert relative paths
389 | * into absolute paths, so just assume it is an absolute path.
390 | */
391 | #if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP)
392 | /* Desktop */
393 | DWORD n = GetFullPathNameW(dirname, 0, NULL, NULL);
394 | #else
395 | /* WinRT */
396 | size_t n = wcslen(dirname);
397 | #endif
398 |
399 | /* Allocate room for absolute directory name and search pattern */
400 | dirp->patt = (wchar_t*)malloc(sizeof(wchar_t) * n + 16);
401 | if (dirp->patt == NULL)
402 | goto exit_closedir;
403 |
404 | /*
405 | * Convert relative directory name to an absolute one. This
406 | * allows rewinddir() to function correctly even when current
407 | * working directory is changed between opendir() and rewinddir().
408 | *
409 | * Note that on WinRT there's no way to convert relative paths
410 | * into absolute paths, so just assume it is an absolute path.
411 | */
412 | #if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP)
413 | /* Desktop */
414 | n = GetFullPathNameW(dirname, n, dirp->patt, NULL);
415 | if (n <= 0)
416 | goto exit_closedir;
417 | #else
418 | /* WinRT */
419 | wcsncpy_s(dirp->patt, n + 1, dirname, n);
420 | #endif
421 |
422 | /* Append search pattern \* to the directory name */
423 | p = dirp->patt + n;
424 | switch (p[-1]) {
425 | case '\\':
426 | case '/':
427 | case ':':
428 | /* Directory ends in path separator, e.g. c:\temp\ */
429 | /*NOP*/;
430 | break;
431 |
432 | default:
433 | /* Directory name doesn't end in path separator */
434 | *p++ = '\\';
435 | }
436 | *p++ = '*';
437 | *p = '\0';
438 |
439 | /* Open directory stream and retrieve the first entry */
440 | if (!dirent_first(dirp))
441 | goto exit_closedir;
442 |
443 | /* Success */
444 | return dirp;
445 |
446 | /* Failure */
447 | exit_closedir:
448 | _wclosedir(dirp);
449 | return NULL;
450 | }
451 |
452 | /*
453 | * Read next directory entry.
454 | *
455 | * Returns pointer to static directory entry which may be overwritten by
456 | * subsequent calls to _wreaddir().
457 | */
458 | static struct _wdirent *_wreaddir(_WDIR *dirp)
459 | {
460 | /*
461 | * Read directory entry to buffer. We can safely ignore the return
462 | * value as entry will be set to NULL in case of error.
463 | */
464 | struct _wdirent *entry;
465 | (void)_wreaddir_r(dirp, &dirp->ent, &entry);
466 |
467 | /* Return pointer to statically allocated directory entry */
468 | return entry;
469 | }
470 |
471 | /*
472 | * Read next directory entry.
473 | *
474 | * Returns zero on success. If end of directory stream is reached, then sets
475 | * result to NULL and returns zero.
476 | */
477 | static int _wreaddir_r(
478 | _WDIR *dirp, struct _wdirent *entry, struct _wdirent **result)
479 | {
480 | /* Read next directory entry */
481 | WIN32_FIND_DATAW *datap = dirent_next(dirp);
482 | if (!datap) {
483 | /* Return NULL to indicate end of directory */
484 | *result = NULL;
485 | return /*OK*/0;
486 | }
487 |
488 | /*
489 | * Copy file name as wide-character string. If the file name is too
490 | * long to fit in to the destination buffer, then truncate file name
491 | * to PATH_MAX characters and zero-terminate the buffer.
492 | */
493 | size_t n = 0;
494 | while (n < PATH_MAX && datap->cFileName[n] != 0) {
495 | entry->d_name[n] = datap->cFileName[n];
496 | n++;
497 | }
498 | entry->d_name[n] = 0;
499 |
500 | /* Length of file name excluding zero terminator */
501 | entry->d_namlen = n;
502 |
503 | /* File type */
504 | DWORD attr = datap->dwFileAttributes;
505 | if ((attr & FILE_ATTRIBUTE_DEVICE) != 0)
506 | entry->d_type = DT_CHR;
507 | else if ((attr & FILE_ATTRIBUTE_DIRECTORY) != 0)
508 | entry->d_type = DT_DIR;
509 | else
510 | entry->d_type = DT_REG;
511 |
512 | /* Reset dummy fields */
513 | entry->d_ino = 0;
514 | entry->d_off = 0;
515 | entry->d_reclen = sizeof(struct _wdirent);
516 |
517 | /* Set result address */
518 | *result = entry;
519 | return /*OK*/0;
520 | }
521 |
522 | /*
523 | * Close directory stream opened by opendir() function. This invalidates the
524 | * DIR structure as well as any directory entry read previously by
525 | * _wreaddir().
526 | */
527 | static int _wclosedir(_WDIR *dirp)
528 | {
529 | if (!dirp) {
530 | dirent_set_errno(EBADF);
531 | return /*failure*/-1;
532 | }
533 |
534 | /* Release search handle */
535 | if (dirp->handle != INVALID_HANDLE_VALUE)
536 | FindClose(dirp->handle);
537 |
538 | /* Release search pattern */
539 | free(dirp->patt);
540 |
541 | /* Release directory structure */
542 | free(dirp);
543 | return /*success*/0;
544 | }
545 |
546 | /*
547 | * Rewind directory stream such that _wreaddir() returns the very first
548 | * file name again.
549 | */
550 | static void _wrewinddir(_WDIR* dirp)
551 | {
552 | if (!dirp)
553 | return;
554 |
555 | /* Release existing search handle */
556 | if (dirp->handle != INVALID_HANDLE_VALUE)
557 | FindClose(dirp->handle);
558 |
559 | /* Open new search handle */
560 | dirent_first(dirp);
561 | }
562 |
563 | /* Get first directory entry */
564 | static WIN32_FIND_DATAW *dirent_first(_WDIR *dirp)
565 | {
566 | if (!dirp)
567 | return NULL;
568 |
569 | /* Open directory and retrieve the first entry */
570 | dirp->handle = FindFirstFileExW(
571 | dirp->patt, FindExInfoStandard, &dirp->data,
572 | FindExSearchNameMatch, NULL, 0);
573 | if (dirp->handle == INVALID_HANDLE_VALUE)
574 | goto error;
575 |
576 | /* A directory entry is now waiting in memory */
577 | dirp->cached = 1;
578 | return &dirp->data;
579 |
580 | error:
581 | /* Failed to open directory: no directory entry in memory */
582 | dirp->cached = 0;
583 |
584 | /* Set error code */
585 | DWORD errorcode = GetLastError();
586 | switch (errorcode) {
587 | case ERROR_ACCESS_DENIED:
588 | /* No read access to directory */
589 | dirent_set_errno(EACCES);
590 | break;
591 |
592 | case ERROR_DIRECTORY:
593 | /* Directory name is invalid */
594 | dirent_set_errno(ENOTDIR);
595 | break;
596 |
597 | case ERROR_PATH_NOT_FOUND:
598 | default:
599 | /* Cannot find the file */
600 | dirent_set_errno(ENOENT);
601 | }
602 | return NULL;
603 | }
604 |
605 | /* Get next directory entry */
606 | static WIN32_FIND_DATAW *dirent_next(_WDIR *dirp)
607 | {
608 | /* Is the next directory entry already in cache? */
609 | if (dirp->cached) {
610 | /* Yes, a valid directory entry found in memory */
611 | dirp->cached = 0;
612 | return &dirp->data;
613 | }
614 |
615 | /* No directory entry in cache */
616 | if (dirp->handle == INVALID_HANDLE_VALUE)
617 | return NULL;
618 |
619 | /* Read the next directory entry from stream */
620 | if (FindNextFileW(dirp->handle, &dirp->data) == FALSE)
621 | goto exit_close;
622 |
623 | /* Success */
624 | return &dirp->data;
625 |
626 | /* Failure */
627 | exit_close:
628 | FindClose(dirp->handle);
629 | dirp->handle = INVALID_HANDLE_VALUE;
630 | return NULL;
631 | }
632 |
633 | /* Open directory stream using plain old C-string */
634 | static DIR *opendir(const char *dirname)
635 | {
636 | /* Must have directory name */
637 | if (dirname == NULL || dirname[0] == '\0') {
638 | dirent_set_errno(ENOENT);
639 | return NULL;
640 | }
641 |
642 | /* Allocate memory for DIR structure */
643 | struct DIR *dirp = (DIR*)malloc(sizeof(struct DIR));
644 | if (!dirp)
645 | return NULL;
646 |
647 | /* Convert directory name to wide-character string */
648 | wchar_t wname[PATH_MAX + 1];
649 | size_t n;
650 | int error = mbstowcs_s(&n, wname, PATH_MAX + 1, dirname, PATH_MAX + 1);
651 | if (error)
652 | goto exit_failure;
653 |
654 | /* Open directory stream using wide-character name */
655 | dirp->wdirp = _wopendir(wname);
656 | if (!dirp->wdirp)
657 | goto exit_failure;
658 |
659 | /* Success */
660 | return dirp;
661 |
662 | /* Failure */
663 | exit_failure:
664 | free(dirp);
665 | return NULL;
666 | }
667 |
668 | /* Read next directory entry */
669 | static struct dirent *readdir(DIR *dirp)
670 | {
671 | /*
672 | * Read directory entry to buffer. We can safely ignore the return
673 | * value as entry will be set to NULL in case of error.
674 | */
675 | struct dirent *entry;
676 | (void)readdir_r(dirp, &dirp->ent, &entry);
677 |
678 | /* Return pointer to statically allocated directory entry */
679 | return entry;
680 | }
681 |
682 | /*
683 | * Read next directory entry into called-allocated buffer.
684 | *
685 | * Returns zero on success. If the end of directory stream is reached, then
686 | * sets result to NULL and returns zero.
687 | */
688 | static int readdir_r(
689 | DIR *dirp, struct dirent *entry, struct dirent **result)
690 | {
691 | /* Read next directory entry */
692 | WIN32_FIND_DATAW *datap = dirent_next(dirp->wdirp);
693 | if (!datap) {
694 | /* No more directory entries */
695 | *result = NULL;
696 | return /*OK*/0;
697 | }
698 |
699 | /* Attempt to convert file name to multi-byte string */
700 | size_t n;
701 | int error = wcstombs_s(
702 | &n, entry->d_name, PATH_MAX + 1,
703 | datap->cFileName, PATH_MAX + 1);
704 |
705 | /*
706 | * If the file name cannot be represented by a multi-byte string, then
707 | * attempt to use old 8+3 file name. This allows the program to
708 | * access files although file names may seem unfamiliar to the user.
709 | *
710 | * Be ware that the code below cannot come up with a short file name
711 | * unless the file system provides one. At least VirtualBox shared
712 | * folders fail to do this.
713 | */
714 | if (error && datap->cAlternateFileName[0] != '\0') {
715 | error = wcstombs_s(
716 | &n, entry->d_name, PATH_MAX + 1,
717 | datap->cAlternateFileName, PATH_MAX + 1);
718 | }
719 |
720 | if (!error) {
721 | /* Length of file name excluding zero terminator */
722 | entry->d_namlen = n - 1;
723 |
724 | /* File attributes */
725 | DWORD attr = datap->dwFileAttributes;
726 | if ((attr & FILE_ATTRIBUTE_DEVICE) != 0)
727 | entry->d_type = DT_CHR;
728 | else if ((attr & FILE_ATTRIBUTE_DIRECTORY) != 0)
729 | entry->d_type = DT_DIR;
730 | else
731 | entry->d_type = DT_REG;
732 |
733 | /* Reset dummy fields */
734 | entry->d_ino = 0;
735 | entry->d_off = 0;
736 | entry->d_reclen = sizeof(struct dirent);
737 | }
738 | else {
739 | /*
740 | * Cannot convert file name to multi-byte string so construct
741 | * an erroneous directory entry and return that. Note that
742 | * we cannot return NULL as that would stop the processing
743 | * of directory entries completely.
744 | */
745 | entry->d_name[0] = '?';
746 | entry->d_name[1] = '\0';
747 | entry->d_namlen = 1;
748 | entry->d_type = DT_UNKNOWN;
749 | entry->d_ino = 0;
750 | entry->d_off = -1;
751 | entry->d_reclen = 0;
752 | }
753 |
754 | /* Return pointer to directory entry */
755 | *result = entry;
756 | return /*OK*/0;
757 | }
758 |
759 | /* Close directory stream */
760 | static int closedir(DIR *dirp)
761 | {
762 | int ok;
763 |
764 | if (!dirp)
765 | goto exit_failure;
766 |
767 | /* Close wide-character directory stream */
768 | ok = _wclosedir(dirp->wdirp);
769 | dirp->wdirp = NULL;
770 |
771 | /* Release multi-byte character version */
772 | free(dirp);
773 | return ok;
774 |
775 | exit_failure:
776 | /* Invalid directory stream */
777 | dirent_set_errno(EBADF);
778 | return /*failure*/-1;
779 | }
780 |
781 | /* Rewind directory stream to beginning */
782 | static void rewinddir(DIR* dirp)
783 | {
784 | if (!dirp)
785 | return;
786 |
787 | /* Rewind wide-character string directory stream */
788 | _wrewinddir(dirp->wdirp);
789 | }
790 |
791 | /* Scan directory for entries */
792 | static int scandir(
793 | const char *dirname, struct dirent ***namelist,
794 | int(*filter)(const struct dirent*),
795 | int(*compare)(const struct dirent**, const struct dirent**))
796 | {
797 | int result;
798 |
799 | /* Open directory stream */
800 | DIR *dir = opendir(dirname);
801 | if (!dir) {
802 | /* Cannot open directory */
803 | return /*Error*/ -1;
804 | }
805 |
806 | /* Read directory entries to memory */
807 | struct dirent *tmp = NULL;
808 | struct dirent **files = NULL;
809 | size_t size = 0;
810 | size_t allocated = 0;
811 | while (1) {
812 | /* Allocate room for a temporary directory entry */
813 | if (!tmp) {
814 | tmp = (struct dirent*) malloc(sizeof(struct dirent));
815 | if (!tmp)
816 | goto exit_failure;
817 | }
818 |
819 | /* Read directory entry to temporary area */
820 | struct dirent *entry;
821 | if (readdir_r(dir, tmp, &entry) != /*OK*/0)
822 | goto exit_failure;
823 |
824 | /* Stop if we already read the last directory entry */
825 | if (entry == NULL)
826 | goto exit_success;
827 |
828 | /* Determine whether to include the entry in results */
829 | if (filter && !filter(tmp))
830 | continue;
831 |
832 | /* Enlarge pointer table to make room for another pointer */
833 | if (size >= allocated) {
834 | /* Compute number of entries in the new table */
835 | size_t num_entries = size * 2 + 16;
836 |
837 | /* Allocate new pointer table or enlarge existing */
838 | void *p = realloc(files, sizeof(void*) * num_entries);
839 | if (!p)
840 | goto exit_failure;
841 |
842 | /* Got the memory */
843 | files = (dirent**)p;
844 | allocated = num_entries;
845 | }
846 |
847 | /* Store the temporary entry to ptr table */
848 | files[size++] = tmp;
849 | tmp = NULL;
850 | }
851 |
852 | exit_failure:
853 | /* Release allocated file entries */
854 | for (size_t i = 0; i < size; i++) {
855 | free(files[i]);
856 | }
857 |
858 | /* Release the pointer table */
859 | free(files);
860 | files = NULL;
861 |
862 | /* Exit with error code */
863 | result = /*error*/ -1;
864 | goto exit_status;
865 |
866 | exit_success:
867 | /* Sort directory entries */
868 | qsort(files, size, sizeof(void*),
869 | (int(*) (const void*, const void*)) compare);
870 |
871 | /* Pass pointer table to caller */
872 | if (namelist)
873 | *namelist = files;
874 |
875 | /* Return the number of directory entries read */
876 | result = (int)size;
877 |
878 | exit_status:
879 | /* Release temporary directory entry, if we had one */
880 | free(tmp);
881 |
882 | /* Close directory stream */
883 | closedir(dir);
884 | return result;
885 | }
886 |
887 | /* Alphabetical sorting */
888 | static int alphasort(const struct dirent **a, const struct dirent **b)
889 | {
890 | return strcoll((*a)->d_name, (*b)->d_name);
891 | }
892 |
893 | /* Sort versions */
894 | static int versionsort(const struct dirent **a, const struct dirent **b)
895 | {
896 | return strverscmp((*a)->d_name, (*b)->d_name);
897 | }
898 |
899 | /* Compare strings */
900 | static int strverscmp(const char *a, const char *b)
901 | {
902 | size_t i = 0;
903 | size_t j;
904 |
905 | /* Find first difference */
906 | while (a[i] == b[i]) {
907 | if (a[i] == '\0') {
908 | /* No difference */
909 | return 0;
910 | }
911 | ++i;
912 | }
913 |
914 | /* Count backwards and find the leftmost digit */
915 | j = i;
916 | while (j > 0 && isdigit(a[j - 1])) {
917 | --j;
918 | }
919 |
920 | /* Determine mode of comparison */
921 | if (a[j] == '0' || b[j] == '0') {
922 | /* Find the next non-zero digit */
923 | while (a[j] == '0' && a[j] == b[j]) {
924 | j++;
925 | }
926 |
927 | /* String with more digits is smaller, e.g 002 < 01 */
928 | if (isdigit(a[j])) {
929 | if (!isdigit(b[j])) {
930 | return -1;
931 | }
932 | }
933 | else if (isdigit(b[j])) {
934 | return 1;
935 | }
936 | }
937 | else if (isdigit(a[j]) && isdigit(b[j])) {
938 | /* Numeric comparison */
939 | size_t k1 = j;
940 | size_t k2 = j;
941 |
942 | /* Compute number of digits in each string */
943 | while (isdigit(a[k1])) {
944 | k1++;
945 | }
946 | while (isdigit(b[k2])) {
947 | k2++;
948 | }
949 |
950 | /* Number with more digits is bigger, e.g 999 < 1000 */
951 | if (k1 < k2)
952 | return -1;
953 | else if (k1 > k2)
954 | return 1;
955 | }
956 |
957 | /* Alphabetical comparison */
958 | return (int)((unsigned char)a[i]) - ((unsigned char)b[i]);
959 | }
960 |
961 | /* Convert multi-byte string to wide character string */
962 | #if !defined(_MSC_VER) || _MSC_VER < 1400
963 | static int dirent_mbstowcs_s(
964 | size_t *pReturnValue, wchar_t *wcstr,
965 | size_t sizeInWords, const char *mbstr, size_t count)
966 | {
967 | /* Older Visual Studio or non-Microsoft compiler */
968 | size_t n = mbstowcs(wcstr, mbstr, sizeInWords);
969 | if (wcstr && n >= count)
970 | return /*error*/ 1;
971 |
972 | /* Zero-terminate output buffer */
973 | if (wcstr && sizeInWords) {
974 | if (n >= sizeInWords)
975 | n = sizeInWords - 1;
976 | wcstr[n] = 0;
977 | }
978 |
979 | /* Length of multi-byte string with zero terminator */
980 | if (pReturnValue) {
981 | *pReturnValue = n + 1;
982 | }
983 |
984 | /* Success */
985 | return 0;
986 | }
987 | #endif
988 |
989 | /* Convert wide-character string to multi-byte string */
990 | #if !defined(_MSC_VER) || _MSC_VER < 1400
991 | static int dirent_wcstombs_s(
992 | size_t *pReturnValue, char *mbstr,
993 | size_t sizeInBytes, const wchar_t *wcstr, size_t count)
994 | {
995 | /* Older Visual Studio or non-Microsoft compiler */
996 | size_t n = wcstombs(mbstr, wcstr, sizeInBytes);
997 | if (mbstr && n >= count)
998 | return /*error*/1;
999 |
1000 | /* Zero-terminate output buffer */
1001 | if (mbstr && sizeInBytes) {
1002 | if (n >= sizeInBytes) {
1003 | n = sizeInBytes - 1;
1004 | }
1005 | mbstr[n] = '\0';
1006 | }
1007 |
1008 | /* Length of resulting multi-bytes string WITH zero-terminator */
1009 | if (pReturnValue) {
1010 | *pReturnValue = n + 1;
1011 | }
1012 |
1013 | /* Success */
1014 | return 0;
1015 | }
1016 | #endif
1017 |
1018 | /* Set errno variable */
1019 | #if !defined(_MSC_VER) || _MSC_VER < 1400
1020 | static void dirent_set_errno(int error)
1021 | {
1022 | /* Non-Microsoft compiler or older Microsoft compiler */
1023 | errno = error;
1024 | }
1025 | #endif
1026 |
1027 | #ifdef __cplusplus
1028 | }
1029 | #endif
1030 | #endif /*DIRENT_H*/
--------------------------------------------------------------------------------
/src/utils/loss.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | //prediction [NCHW], a tensor after softmax activation at C dim
5 | //target [N1HW], a tensor refer to label
6 | //num_class: int, equal to C, refer to class numbers, including background
7 | torch::Tensor DiceLoss(torch::Tensor prediction, torch::Tensor target, int num_class) {
8 | auto target_onehot = torch::zeros_like(prediction); // N x C x H x W
9 | target_onehot.scatter_(1, target, 1);
10 |
11 | auto prediction_roi = prediction.slice(1, 1, num_class, 1);
12 | auto target_roi = target_onehot.slice(1, 1, num_class, 1);
13 | auto intersection = (prediction_roi*target_roi).sum();
14 | auto union_ = prediction_roi.sum() + target_roi.sum() - intersection;
15 | auto dice = (intersection + 0.0001) / (union_ + 0.0001);
16 | //cout << "prediction_roi: " << prediction_roi.sizes() << "\t" << "target roi: " << target_roi.sizes() << endl;
17 | //cout << "intersection: " << intersection << "\t" << "union: " << union_ << endl;
18 | //target_onehot.scatter()
19 | return 1 - dice;
20 | }
21 |
22 | //prediction [NCHW], target [NHW]
23 | torch::Tensor CELoss(torch::Tensor prediction, torch::Tensor target) {
24 | return torch::nll_loss2d(torch::log_softmax(prediction, /*dim=*/1), target);
25 | }
--------------------------------------------------------------------------------
/src/utils/readfile.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by root on 2021/4/9.
3 | //
4 | #include
5 | #include
6 | #ifdef _WIN32
7 | #include "_dirent.h"
8 | #else
9 | #include
10 | #endif
11 | #include
12 |
13 | #ifndef READ_FILE_READFILE_H
14 | #define READ_FILE_READFILE_H
15 |
16 | //#define throw_if(expression) if(expression)throw "error"
17 | // 判断是否是文件夹
18 | inline bool is_folder(const char* dir_name){
19 | if (nullptr==dir_name)
20 | std::cout<< "dir_name is nullprt";
21 | //throw_if(nullptr==dir_name);
22 | auto dir =opendir(dir_name);
23 | if(dir){
24 | closedir(dir);
25 | return true;
26 | }
27 | return false;
28 | }
29 | #ifdef _WIN32
30 | inline char file_sepator(){
31 | return '\\';
32 | }
33 | #else
34 | inline char file_sepator(){
35 | return '/';
36 | }
37 | #endif
38 | // 判断是否是文件夹
39 | inline bool is_folder(const std::string &dir_name){
40 | if (dir_name.empty())
41 | std::cout<< "dir_name is empty";
42 | return is_folder(dir_name.data());
43 | }
44 | using file_filter_type=std::function;
45 | /*
46 | * 列出指定目录的所有文件(不包含目录)执行,对每个文件执行filter过滤器,
47 | * filter返回true时将文件名全路径加入std::vector
48 | * sub为true时为目录递归
49 | * 返回每个文件的全路径名
50 | */
51 | static std::vector for_each_file(const std::string&dir_name,file_filter_type filter,bool sub=false){
52 | std::vector v;
53 | auto dir =opendir(dir_name.data());
54 | struct dirent *ent;
55 | if(dir){
56 | while ((ent = readdir (dir)) != nullptr) {
57 | auto p = std::string(dir_name).append({ file_sepator() }).append(ent->d_name);
58 | if(sub){
59 | if ( 0== strcmp (ent->d_name, "..") || 0 == strcmp (ent->d_name, ".")){
60 | continue;
61 | }else if(is_folder(p)){
62 | auto r= for_each_file(p,filter,sub);
63 | v.insert(v.end(),r.begin(),r.end());
64 | continue;
65 | }
66 | }
67 | if (sub||!is_folder(p))//如果是文件,则调用过滤器filter
68 | if(filter(dir_name.data(),ent->d_name))
69 | v.emplace_back(p);
70 | }
71 | closedir(dir);
72 | }
73 | return v;
74 | }
75 | //字符串大小写转换
76 | inline std::string tolower1(const std::string&src){
77 | auto dst= src;
78 | std::transform(src.begin(),src.end(),dst.begin(),::tolower);
79 | return dst;
80 | }
81 | // 判断src是否以指定的字符串(suffix)结尾
82 | inline bool end_with(const std::string&src,const std::string &suffix){
83 | return src.substr(src.size()-suffix.size())==suffix;
84 | }
85 |
86 | #endif //READ_FILE_READFILE_H
87 |
--------------------------------------------------------------------------------
/src/utils/util.cpp:
--------------------------------------------------------------------------------
1 | #include "util.h"
2 |
3 | SegmentationHeadImpl::SegmentationHeadImpl(int in_channels, int out_channels, int kernel_size, double _upsampling){
4 | conv2d = torch::nn::Conv2d(conv_options(in_channels, out_channels, kernel_size, 1, kernel_size / 2));
5 | upsampling = torch::nn::Upsample(upsample_options(std::vector{_upsampling,_upsampling}));
6 | register_module("conv2d",conv2d);
7 | }
8 | torch::Tensor SegmentationHeadImpl::forward(torch::Tensor x){
9 | x = conv2d->forward(x);
10 | x = upsampling->forward(x);
11 | return x;
12 | }
13 |
14 | std::string replace_all_distinct(std::string str, const std::string old_value, const std::string new_value)
15 | {
16 | for (std::string::size_type pos(0); pos != std::string::npos; pos += new_value.length())
17 | {
18 | if ((pos = str.find(old_value, pos)) != std::string::npos)
19 | {
20 | str.replace(pos, old_value.length(), new_value);
21 | }
22 | else { break; }
23 | }
24 | return str;
25 | }
26 |
27 | //遍历该目录下的.xml文件,并且找到对应的
28 | void load_seg_data_from_folder(std::string folder, std::string image_type,
29 | std::vector &list_images, std::vector &list_labels)
30 | {
31 | for_each_file(folder,
32 | [&](const char*path,const char* name){
33 | auto full_path=std::string(path).append({file_sepator()}).append(name);
34 | std::string lower_name=tolower1(name);
35 |
36 | if(end_with(lower_name,".json")){
37 | list_labels.push_back(full_path);
38 | std::string image_path = replace_all_distinct(full_path, ".json", image_type);
39 | list_images.push_back(image_path);
40 | }
41 | return false;
42 | }
43 | ,true
44 | );
45 | }
46 |
47 |
48 | nlohmann::json encoder_params() {
49 | nlohmann::json params = {
50 | {"resnet18", {
51 | {"class_type", "resnet"},
52 | {"out_channels", {3, 64, 64, 128, 256, 512}},
53 | {"layers" , {2, 2, 2, 2}},
54 | },
55 | },
56 | {"resnet34", {
57 | {"class_type", "resnet"},
58 | {"out_channels", {3, 64, 64, 128, 256, 512}},
59 | {"layers" , {3, 4, 6, 3}},
60 | },
61 | },
62 | {"resnet50", {
63 | {"class_type", "resnet"},
64 | {"out_channels", {3, 64, 256, 512, 1024, 2048}},
65 | {"layers" , {3, 4, 6, 3}},
66 | },
67 | },
68 | {"resnet101", {
69 | {"class_type", "resnet"},
70 | {"out_channels", {3, 64, 256, 512, 1024, 2048}},
71 | {"layers" , {3, 4, 23, 3}},
72 | },
73 | },
74 | {"resnet101", {
75 | {"class_type", "resnet"},
76 | {"out_channels", {3, 64, 256, 512, 1024, 2048}},
77 | {"layers" , {3, 8, 36, 3}},
78 | },
79 | },
80 | {"resnext50_32x4d", {
81 | {"class_type", "resnet"},
82 | {"out_channels", {3, 64, 256, 512, 1024, 2048}},
83 | {"layers" , {3, 4, 6, 3}},
84 | },
85 | },
86 | {"resnext101_32x8d", {
87 | {"class_type", "resnet"},
88 | {"out_channels", {3, 64, 256, 512, 1024, 2048}},
89 | {"layers" , {3, 4, 23, 3}},
90 | },
91 | },
92 | {"vgg11", {
93 | {"class_type", "vgg"},
94 | {"out_channels", {64, 128, 256, 512, 512, 512}},
95 | {"cfg",{64, -1, 128, -1, 256, 256, -1, 512, 512, -1, 512, 512, -1}},
96 | {"batch_norm" , false},
97 | },
98 | },
99 | {"vgg11_bn", {
100 | {"class_type", "vgg"},
101 | {"out_channels", {64, 128, 256, 512, 512, 512}},
102 | {"cfg",{64, -1, 128, -1, 256, 256, -1, 512, 512, -1, 512, 512, -1}},
103 | {"batch_norm" , true},
104 | },
105 | },
106 | {"vgg13", {
107 | {"class_type", "vgg"},
108 | {"out_channels", {64, 128, 256, 512, 512, 512}},
109 | {"cfg",{64, 64, -1, 128, 128, -1, 256, 256, -1, 512, 512, -1, 512, 512, -1}},
110 | {"batch_norm" , false},
111 | },
112 | },
113 | {"vgg13_bn", {
114 | {"class_type", "vgg"},
115 | {"out_channels", {64, 128, 256, 512, 512, 512}},
116 | {"cfg",{64, 64, -1, 128, 128, -1, 256, 256, -1, 512, 512, -1, 512, 512, -1}},
117 | {"batch_norm" , true},
118 | },
119 | },
120 | {"vgg16", {
121 | {"class_type", "vgg"},
122 | {"out_channels", {64, 128, 256, 512, 512, 512}},
123 | {"cfg",{64, 64, -1, 128, 128, -1, 256, 256, 256, -1, 512, 512, 512, -1, 512, 512, 512, -1}},
124 | {"batch_norm" , false},
125 | },
126 | },
127 | {"vgg16_bn", {
128 | {"class_type", "vgg"},
129 | {"out_channels", {64, 128, 256, 512, 512, 512}},
130 | {"cfg",{64, 64, -1, 128, 128, -1, 256, 256, 256, -1, 512, 512, 512, -1, 512, 512, 512, -1}},
131 | {"batch_norm" , true},
132 | },
133 | },
134 | {"vgg19", {
135 | {"class_type", "vgg"},
136 | {"out_channels", {64, 128, 256, 512, 512, 512}},
137 | {"cfg",{64, 64, -1, 128, 128, -1, 256, 256, 256, 256, -1, 512, 512, 512, 512, -1, 512, 512, 512, 512, -1}},
138 | {"batch_norm" , false},
139 | },
140 | },
141 | {"vgg19_bn", {
142 | {"class_type", "vgg"},
143 | {"out_channels", {64, 128, 256, 512, 512, 512}},
144 | {"cfg",{64, 64, -1, 128, 128, -1, 256, 256, 256, 256, -1, 512, 512, 512, 512, -1, 512, 512, 512, 512, -1}},
145 | {"batch_norm" , true},
146 | },
147 | },
148 | };
149 | return params;
150 | }
151 |
--------------------------------------------------------------------------------
/src/utils/util.h:
--------------------------------------------------------------------------------
1 | #ifndef UTIL_H
2 | #define UTIL_H
3 | #undef slots
4 | #include
5 | #include