├── CMakeLists.txt ├── README.md ├── android_gui.jpg ├── demo └── windows │ ├── 25.jpg │ ├── doc.dll │ ├── models │ └── readme.txt │ ├── ncnn_demo.exe │ └── opencv_world450.dll ├── ncnn └── ncnn-master.zip ├── replace ├── binaryop.cpp ├── expanddims.cpp ├── expanddims.h ├── gather.cpp ├── gather.h ├── gemm.cpp ├── squeeze.cpp └── squeeze.h ├── result1.jpg ├── result2.jpg ├── src └── demo.cpp └── windows_gui.jpg /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.5) 2 | project(ncnn_doctr) 3 | set(CMAKE_BUILD_TYPE Release) 4 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -pie -fPIE -fPIC -Wall -O3") 5 | 6 | find_package(OpenCV REQUIRED) 7 | if (OpenCV_FOUND) 8 | message(STATUS "OpenCV_LIBS: ${OpenCV_LIBS}") 9 | message(STATUS "OpenCV_INCLUDE_DIRS: ${OpenCV_INCLUDE_DIRS}") 10 | else () 11 | message(FATAL_ERROR "opencv Not Found!") 12 | endif (OpenCV_FOUND) 13 | 14 | find_package(OpenMP REQUIRED) 15 | if (OPENMP_FOUND) 16 | message("OPENMP FOUND") 17 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") 18 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") 19 | set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") 20 | else () 21 | message(FATAL_ERROR "OpenMP Not Found!") 22 | endif () 23 | 24 | 25 | include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ncnn/include/ncnn) 26 | link_directories(${CMAKE_CURRENT_SOURCE_DIR}/ncnn/lib) 27 | 28 | set(EXECUTABLE_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/) 29 | add_executable(ncnn_doctr src/demo.cpp) 30 | target_link_libraries(ncnn_doctr ncnn ${OpenCV_LIBS}) 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DocTr-ncnn 2 | ncnn demo of **DocTr: Document Image Transformer for Geometric Unwarping and Illumination Correction** 3 | 4 | ## model support: 5 | **1.Document Segmentation** 6 | **2.Geometric Unwarping** 7 | **3.Illumination Correction model** 8 | All models are available in [Baidu Pan](https://pan.baidu.com/s/1lny5IuL9TMUlfAUCg_6iuw) (69c4) 9 | ### PS: 10 | **1.The newest ncnn is needed(which support 4d mat), you can replace the cpp files in replace dir to you ncnn,or use ncnn in this repo.** 11 | **2.This model use transformer,which is very unfriendly to deploy. So it may take 20~30 seconds to get the final result.** 12 | 13 | ### TODO: 14 | 1.~~Illumination Correction model~~ 15 | 2.~~windows gui demo~~ 16 | 3.support ncnn-vulkan 17 | ## Result 18 | ![](android_gui.jpg) 19 | ![](windows_gui.jpg) 20 | ![](result1.jpg) 21 | ![](result2.jpg) 22 | 23 | ## Reference 24 | 1.https://github.com/fh2019ustc/DocTr 25 | 2.https://github.com/Tencent/ncnn 26 | -------------------------------------------------------------------------------- /android_gui.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiGeChuanShu/DocTr-ncnn/b8584412038ff058e8a48acf3d39516b3dc3acdc/android_gui.jpg -------------------------------------------------------------------------------- /demo/windows/25.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiGeChuanShu/DocTr-ncnn/b8584412038ff058e8a48acf3d39516b3dc3acdc/demo/windows/25.jpg -------------------------------------------------------------------------------- /demo/windows/doc.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiGeChuanShu/DocTr-ncnn/b8584412038ff058e8a48acf3d39516b3dc3acdc/demo/windows/doc.dll -------------------------------------------------------------------------------- /demo/windows/models/readme.txt: -------------------------------------------------------------------------------- 1 | Put the ncnn models in models directory 2 | -------------------------------------------------------------------------------- /demo/windows/ncnn_demo.exe: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiGeChuanShu/DocTr-ncnn/b8584412038ff058e8a48acf3d39516b3dc3acdc/demo/windows/ncnn_demo.exe -------------------------------------------------------------------------------- /demo/windows/opencv_world450.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiGeChuanShu/DocTr-ncnn/b8584412038ff058e8a48acf3d39516b3dc3acdc/demo/windows/opencv_world450.dll -------------------------------------------------------------------------------- /ncnn/ncnn-master.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiGeChuanShu/DocTr-ncnn/b8584412038ff058e8a48acf3d39516b3dc3acdc/ncnn/ncnn-master.zip -------------------------------------------------------------------------------- /replace/binaryop.cpp: -------------------------------------------------------------------------------- 1 | // Tencent is pleased to support the open source community by making ncnn available. 2 | // 3 | // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved. 4 | // 5 | // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except 6 | // in compliance with the License. You may obtain a copy of the License at 7 | // 8 | // https://opensource.org/licenses/BSD-3-Clause 9 | // 10 | // Unless required by applicable law or agreed to in writing, software distributed 11 | // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 12 | // CONDITIONS OF ANY KIND, either express or implied. See the License for the 13 | // specific language governing permissions and limitations under the License. 14 | 15 | #include "binaryop.h" 16 | 17 | #include 18 | 19 | namespace ncnn { 20 | 21 | BinaryOp::BinaryOp() 22 | { 23 | one_blob_only = false; 24 | support_inplace = false; 25 | } 26 | 27 | int BinaryOp::load_param(const ParamDict& pd) 28 | { 29 | op_type = pd.get(0, 0); 30 | with_scalar = pd.get(1, 0); 31 | b = pd.get(2, 0.f); 32 | 33 | if (with_scalar != 0) 34 | { 35 | one_blob_only = true; 36 | support_inplace = true; 37 | } 38 | 39 | return 0; 40 | } 41 | 42 | // broadcasting rule 43 | // https://github.com/Tencent/ncnn/wiki/binaryop-broadcasting 44 | 45 | template 46 | static int binary_op(const Mat& a, const Mat& b, Mat& c, const Option& opt) 47 | { 48 | Op op; 49 | 50 | int w = a.w; 51 | int h = a.h; 52 | int d = a.d; 53 | int channels = a.c; 54 | int size = w * h * d; 55 | size_t elemsize = a.elemsize; 56 | 57 | int w1 = b.w; 58 | int h1 = b.h; 59 | int d1 = b.d; 60 | int channels1 = b.c; 61 | int size1 = w1 * h1 * d1; 62 | 63 | if (a.dims == 4) 64 | { 65 | if (b.dims == 4) 66 | { 67 | // type 29 68 | c.create(w, h, d, channels, elemsize, opt.blob_allocator); 69 | if (c.empty()) 70 | return -100; 71 | 72 | #pragma omp parallel for num_threads(opt.num_threads) 73 | for (int q = 0; q < channels; q++) 74 | { 75 | const float* ptr = a.channel(q); 76 | const float* ptr1 = b.channel(q); 77 | float* outptr = c.channel(q); 78 | 79 | for (int i = 0; i < size; i++) 80 | { 81 | outptr[i] = op(ptr[i], ptr1[i]); 82 | } 83 | } 84 | 85 | return 0; 86 | } 87 | 88 | c.create(w, h, d, channels, elemsize, opt.blob_allocator); 89 | if (c.empty()) 90 | return -100; 91 | 92 | if (b.dims == 3) 93 | { 94 | // type 28 95 | #pragma omp parallel for num_threads(opt.num_threads) 96 | for (int q = 0; q < channels; q++) 97 | { 98 | const float* ptr = a.channel(q); 99 | const float* ptr1 = b.channel(q); 100 | float* outptr = c.channel(q); 101 | 102 | for (int z = 0; z < d; z++) 103 | { 104 | for (int y = 0; y < h; y++) 105 | { 106 | const float b0 = ptr1[y]; 107 | for (int x = 0; x < w; x++) 108 | { 109 | outptr[x] = op(ptr[x], b0); 110 | } 111 | 112 | ptr += w; 113 | outptr += w; 114 | } 115 | 116 | ptr1 += h; 117 | } 118 | } 119 | 120 | return 0; 121 | } 122 | 123 | if (b.dims == 2) 124 | { 125 | // type 27 126 | #pragma omp parallel for num_threads(opt.num_threads) 127 | for (int q = 0; q < channels; q++) 128 | { 129 | const float* ptr = a.channel(q); 130 | const float* ptr1 = b.row(q); 131 | float* outptr = c.channel(q); 132 | 133 | for (int z = 0; z < d; z++) 134 | { 135 | const float b0 = ptr1[z]; 136 | for (int y = 0; y < h; y++) 137 | { 138 | for (int x = 0; x < w; x++) 139 | { 140 | outptr[x] = op(ptr[x], b0); 141 | } 142 | 143 | ptr += w; 144 | outptr += w; 145 | } 146 | } 147 | } 148 | 149 | return 0; 150 | } 151 | 152 | if (b.dims == 1) 153 | { 154 | if (b.w == 1) 155 | { 156 | // type 25 157 | const float b0 = b[0]; 158 | #pragma omp parallel for num_threads(opt.num_threads) 159 | for (int q = 0; q < channels; q++) 160 | { 161 | const float* ptr = a.channel(q); 162 | float* outptr = c.channel(q); 163 | 164 | for (int i = 0; i < size; i++) 165 | { 166 | outptr[i] = op(ptr[i], b0); 167 | } 168 | } 169 | 170 | return 0; 171 | } 172 | 173 | // type 26 174 | #pragma omp parallel for num_threads(opt.num_threads) 175 | for (int q = 0; q < channels; q++) 176 | { 177 | const float* ptr = a.channel(q); 178 | const float b0 = b[q]; 179 | float* outptr = c.channel(q); 180 | 181 | for (int i = 0; i < size; i++) 182 | { 183 | outptr[i] = op(ptr[i], b0); 184 | } 185 | } 186 | 187 | return 0; 188 | } 189 | } 190 | else if (a.dims == 3) 191 | { 192 | if (b.dims == 4) 193 | { 194 | // type 23 195 | c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); 196 | if (c.empty()) 197 | return -100; 198 | 199 | #pragma omp parallel for num_threads(opt.num_threads) 200 | for (int q = 0; q < channels1; q++) 201 | { 202 | const float* ptr = a.channel(q); 203 | const float* ptr1 = b.channel(q); 204 | float* outptr = c.channel(q); 205 | 206 | for (int z = 0; z < d1; z++) 207 | { 208 | for (int y = 0; y < h1; y++) 209 | { 210 | const float a0 = ptr[y]; 211 | for (int x = 0; x < w1; x++) 212 | { 213 | outptr[x] = op(a0, ptr1[x]); 214 | } 215 | 216 | ptr1 += w1; 217 | outptr += w1; 218 | } 219 | 220 | ptr += h1; 221 | } 222 | } 223 | 224 | return 0; 225 | } 226 | 227 | if (b.dims == 3) 228 | { 229 | if (w1 == 1 && h1 == 1 && channels1 == channels) 230 | { 231 | // special type 1 232 | c.create(w, h, channels, elemsize, opt.blob_allocator); 233 | if (c.empty()) 234 | return -100; 235 | 236 | #pragma omp parallel for num_threads(opt.num_threads) 237 | for (int q = 0; q < channels; q++) 238 | { 239 | const float* ptr = a.channel(q); 240 | const float* b0 = b.channel(q); 241 | float* outptr = c.channel(q); 242 | for (int i = 0; i < size; i++) 243 | { 244 | outptr[i] = op(ptr[i], b0[0]); 245 | } 246 | } 247 | 248 | return 0; 249 | } 250 | 251 | if (w1 == w && h1 == h && channels1 == 1) 252 | { 253 | // special type 2 254 | c.create(w, h, channels, elemsize, opt.blob_allocator); 255 | if (c.empty()) 256 | return -100; 257 | 258 | #pragma omp parallel for num_threads(opt.num_threads) 259 | for (int q = 0; q < channels; q++) 260 | { 261 | const float* ptr = a.channel(q); 262 | const float* ptr1 = b; 263 | float* outptr = c.channel(q); 264 | for (int i = 0; i < size; i++) 265 | { 266 | outptr[i] = op(ptr[i], ptr1[i]); 267 | } 268 | } 269 | 270 | return 0; 271 | } 272 | 273 | if (w == 1 && h == 1 && channels1 == channels) 274 | { 275 | // special type 3 276 | c.create(w1, h1, channels1, elemsize, opt.blob_allocator); 277 | if (c.empty()) 278 | return -100; 279 | 280 | #pragma omp parallel for num_threads(opt.num_threads) 281 | for (int q = 0; q < channels1; q++) 282 | { 283 | const float* a0 = a.channel(q); 284 | const float* ptr1 = b.channel(q); 285 | float* outptr = c.channel(q); 286 | for (int i = 0; i < size1; i++) 287 | { 288 | outptr[i] = op(a0[0], ptr1[i]); 289 | } 290 | } 291 | 292 | return 0; 293 | } 294 | 295 | if (w1 == w && h1 == h && channels == 1) 296 | { 297 | // special type 4 298 | c.create(w1, h1, channels1, elemsize, opt.blob_allocator); 299 | if (c.empty()) 300 | return -100; 301 | 302 | #pragma omp parallel for num_threads(opt.num_threads) 303 | for (int q = 0; q < channels1; q++) 304 | { 305 | const float* ptr = a; 306 | const float* ptr1 = b.channel(q); 307 | float* outptr = c.channel(q); 308 | for (int i = 0; i < size1; i++) 309 | { 310 | outptr[i] = op(ptr[i], ptr1[i]); 311 | } 312 | } 313 | 314 | return 0; 315 | } 316 | 317 | if (w != 1 && w1 == 1 && h1 == h && channels1 == channels) 318 | { 319 | // special type 5 320 | c.create(w, h, channels, elemsize, opt.blob_allocator); 321 | if (c.empty()) 322 | return -100; 323 | 324 | #pragma omp parallel for num_threads(opt.num_threads) 325 | for (int q = 0; q < channels1; q++) 326 | { 327 | const float* ptr = a.channel(q); 328 | const float* ptr1 = b.channel(q); 329 | float* outptr = c.channel(q); 330 | 331 | for (int y = 0; y < h; y++) 332 | { 333 | const float b0 = ptr1[y]; 334 | for (int x = 0; x < w; x++) 335 | { 336 | outptr[x] = op(ptr[x], b0); 337 | } 338 | 339 | ptr += w; 340 | outptr += w; 341 | } 342 | } 343 | 344 | return 0; 345 | } 346 | 347 | if (w1 == w && h != 1 && h1 == 1 && channels1 == channels) 348 | { 349 | // special type 6 350 | c.create(w, h, channels, elemsize, opt.blob_allocator); 351 | if (c.empty()) 352 | return -100; 353 | 354 | #pragma omp parallel for num_threads(opt.num_threads) 355 | for (int q = 0; q < channels1; q++) 356 | { 357 | const float* ptr = a.channel(q); 358 | const float* ptr1 = b.channel(q); 359 | float* outptr = c.channel(q); 360 | 361 | for (int y = 0; y < h; y++) 362 | { 363 | for (int x = 0; x < w; x++) 364 | { 365 | outptr[x] = op(ptr[x], ptr1[x]); 366 | } 367 | 368 | ptr += w; 369 | outptr += w; 370 | } 371 | } 372 | 373 | return 0; 374 | } 375 | 376 | if (w1 != 1 && w == 1 && h1 == h && channels1 == channels) 377 | { 378 | // special type 7 379 | c.create(w1, h1, channels1, elemsize, opt.blob_allocator); 380 | if (c.empty()) 381 | return -100; 382 | 383 | #pragma omp parallel for num_threads(opt.num_threads) 384 | for (int q = 0; q < channels1; q++) 385 | { 386 | const float* ptr = a.channel(q); 387 | const float* ptr1 = b.channel(q); 388 | float* outptr = c.channel(q); 389 | 390 | for (int y = 0; y < h1; y++) 391 | { 392 | const float a0 = ptr[y]; 393 | for (int x = 0; x < w1; x++) 394 | { 395 | outptr[x] = op(a0, ptr1[x]); 396 | } 397 | 398 | ptr1 += w1; 399 | outptr += w1; 400 | } 401 | } 402 | 403 | return 0; 404 | } 405 | 406 | if (w1 == w && h1 != 1 && h == 1 && channels1 == channels) 407 | { 408 | // special type 8 409 | c.create(w1, h1, channels1, elemsize, opt.blob_allocator); 410 | if (c.empty()) 411 | return -100; 412 | 413 | #pragma omp parallel for num_threads(opt.num_threads) 414 | for (int q = 0; q < channels1; q++) 415 | { 416 | const float* ptr = a.channel(q); 417 | const float* ptr1 = b.channel(q); 418 | float* outptr = c.channel(q); 419 | 420 | for (int y = 0; y < h1; y++) 421 | { 422 | for (int x = 0; x < w1; x++) 423 | { 424 | outptr[x] = op(ptr[x], ptr1[x]); 425 | } 426 | 427 | ptr1 += w1; 428 | outptr += w1; 429 | } 430 | } 431 | 432 | return 0; 433 | } 434 | 435 | // type 19 436 | c.create(w, h, channels, elemsize, opt.blob_allocator); 437 | if (c.empty()) 438 | return -100; 439 | 440 | #pragma omp parallel for num_threads(opt.num_threads) 441 | for (int q = 0; q < channels; q++) 442 | { 443 | const float* ptr = a.channel(q); 444 | const float* ptr1 = b.channel(q); 445 | float* outptr = c.channel(q); 446 | 447 | for (int i = 0; i < size; i++) 448 | { 449 | outptr[i] = op(ptr[i], ptr1[i]); 450 | } 451 | } 452 | 453 | return 0; 454 | } 455 | 456 | c.create(w, h, channels, elemsize, opt.blob_allocator); 457 | if (c.empty()) 458 | return -100; 459 | 460 | if (b.dims == 2) 461 | { 462 | // type 18 463 | #pragma omp parallel for num_threads(opt.num_threads) 464 | for (int q = 0; q < channels; q++) 465 | { 466 | const float* ptr = a.channel(q); 467 | const float* ptr1 = b.row(q); 468 | float* outptr = c.channel(q); 469 | 470 | for (int y = 0; y < h; y++) 471 | { 472 | const float b0 = ptr1[y]; 473 | for (int x = 0; x < w; x++) 474 | { 475 | outptr[x] = op(ptr[x], b0); 476 | } 477 | 478 | ptr += w; 479 | outptr += w; 480 | } 481 | } 482 | 483 | return 0; 484 | } 485 | 486 | if (b.dims == 1) 487 | { 488 | if (b.w == 1) 489 | { 490 | // type 16 491 | const float b0 = b[0]; 492 | #pragma omp parallel for num_threads(opt.num_threads) 493 | for (int q = 0; q < channels; q++) 494 | { 495 | const float* ptr = a.channel(q); 496 | float* outptr = c.channel(q); 497 | 498 | for (int i = 0; i < size; i++) 499 | { 500 | outptr[i] = op(ptr[i], b0); 501 | } 502 | } 503 | 504 | return 0; 505 | } 506 | 507 | // type 17 508 | #pragma omp parallel for num_threads(opt.num_threads) 509 | for (int q = 0; q < channels; q++) 510 | { 511 | const float* ptr = a.channel(q); 512 | const float b0 = b[q]; 513 | float* outptr = c.channel(q); 514 | 515 | for (int i = 0; i < size; i++) 516 | { 517 | outptr[i] = op(ptr[i], b0); 518 | } 519 | } 520 | 521 | return 0; 522 | } 523 | } 524 | else if (a.dims == 2) 525 | { 526 | if (b.dims == 4) 527 | { 528 | // type 22 529 | c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); 530 | if (c.empty()) 531 | return -100; 532 | 533 | #pragma omp parallel for num_threads(opt.num_threads) 534 | for (int q = 0; q < channels1; q++) 535 | { 536 | const float* ptr = a.row(q); 537 | const float* ptr1 = b.channel(q); 538 | float* outptr = c.channel(q); 539 | 540 | for (int z = 0; z < d1; z++) 541 | { 542 | const float a0 = ptr[z]; 543 | for (int y = 0; y < h1; y++) 544 | { 545 | for (int x = 0; x < w1; x++) 546 | { 547 | outptr[x] = op(a0, ptr1[x]); 548 | } 549 | 550 | ptr1 += w1; 551 | outptr += w1; 552 | } 553 | } 554 | } 555 | 556 | return 0; 557 | } 558 | 559 | if (b.dims == 3) 560 | { 561 | // type 14 562 | c.create(w1, h1, channels1, elemsize, opt.blob_allocator); 563 | if (c.empty()) 564 | return -100; 565 | 566 | #pragma omp parallel for num_threads(opt.num_threads) 567 | for (int q = 0; q < channels1; q++) 568 | { 569 | const float* ptr = a.row(q); 570 | const float* ptr1 = b.channel(q); 571 | float* outptr = c.channel(q); 572 | 573 | for (int y = 0; y < h1; y++) 574 | { 575 | const float a0 = ptr[y]; 576 | for (int x = 0; x < w1; x++) 577 | { 578 | outptr[x] = op(a0, ptr1[x]); 579 | } 580 | 581 | ptr1 += w1; 582 | outptr += w1; 583 | } 584 | } 585 | 586 | return 0; 587 | } 588 | 589 | c.create(w, h, elemsize, opt.blob_allocator); 590 | if (c.empty()) 591 | return -100; 592 | 593 | if (b.dims == 2) 594 | { 595 | // type 13 596 | for (int i = 0; i < size; i++) 597 | { 598 | c[i] = op(a[i], b[i]); 599 | } 600 | 601 | return 0; 602 | } 603 | 604 | if (b.dims == 1) 605 | { 606 | c.create(w, h, elemsize, opt.blob_allocator); 607 | if (c.empty()) 608 | return -100; 609 | 610 | if (b.w == 1) 611 | { 612 | // type 11 613 | const float b0 = b[0]; 614 | for (int i = 0; i < size; i++) 615 | { 616 | c[i] = op(a[i], b0); 617 | } 618 | 619 | return 0; 620 | } 621 | //====new 622 | if (b.h == 1 && b.w == a.w && b.c == a.c) 623 | { 624 | const float* ptr = a; 625 | float* outptr = c; 626 | 627 | for (int y = 0; y < h; y++) 628 | { 629 | //const float b0 = b[y]; 630 | for (int x = 0; x < w; x++) 631 | { 632 | outptr[x] = op(ptr[x], b[x]); 633 | } 634 | 635 | ptr += w; 636 | outptr += w; 637 | } 638 | 639 | return 0; 640 | } 641 | // type 12 642 | const float* ptr = a; 643 | float* outptr = c; 644 | 645 | for (int y = 0; y < h; y++) 646 | { 647 | const float b0 = b[y]; 648 | for (int x = 0; x < w; x++) 649 | { 650 | outptr[x] = op(ptr[x], b0); 651 | } 652 | 653 | ptr += w; 654 | outptr += w; 655 | } 656 | 657 | return 0; 658 | } 659 | } 660 | else if (a.dims == 1) 661 | { 662 | if (a.w == 1) 663 | { 664 | if (b.dims == 4) 665 | { 666 | // type 20 667 | c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); 668 | if (c.empty()) 669 | return -100; 670 | 671 | const float a0 = a[0]; 672 | #pragma omp parallel for num_threads(opt.num_threads) 673 | for (int q = 0; q < channels1; q++) 674 | { 675 | const float* ptr1 = b.channel(q); 676 | float* outptr = c.channel(q); 677 | 678 | for (int i = 0; i < size1; i++) 679 | { 680 | outptr[i] = op(a0, ptr1[i]); 681 | } 682 | } 683 | 684 | return 0; 685 | } 686 | 687 | if (b.dims == 3) 688 | { 689 | // type 4 690 | c.create(w1, h1, channels1, elemsize, opt.blob_allocator); 691 | if (c.empty()) 692 | return -100; 693 | 694 | const float a0 = a[0]; 695 | #pragma omp parallel for num_threads(opt.num_threads) 696 | for (int q = 0; q < channels1; q++) 697 | { 698 | const float* ptr1 = b.channel(q); 699 | float* outptr = c.channel(q); 700 | 701 | for (int i = 0; i < size1; i++) 702 | { 703 | outptr[i] = op(a0, ptr1[i]); 704 | } 705 | } 706 | 707 | return 0; 708 | } 709 | 710 | if (b.dims == 2) 711 | { 712 | // type 3 713 | c.create(w1, h1, elemsize, opt.blob_allocator); 714 | if (c.empty()) 715 | return -100; 716 | 717 | const float a0 = a[0]; 718 | for (int i = 0; i < size1; i++) 719 | { 720 | c[i] = op(a0, b[i]); 721 | } 722 | 723 | return 0; 724 | } 725 | 726 | if (b.dims == 1) 727 | { 728 | // type 2 729 | c.create(w1, elemsize, opt.blob_allocator); 730 | if (c.empty()) 731 | return -100; 732 | 733 | const float a0 = a[0]; 734 | for (int i = 0; i < w1; i++) 735 | { 736 | c[i] = op(a0, b[i]); 737 | } 738 | 739 | return 0; 740 | } 741 | } 742 | 743 | if (b.dims == 4) 744 | { 745 | // type 21 746 | c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); 747 | if (c.empty()) 748 | return -100; 749 | 750 | #pragma omp parallel for num_threads(opt.num_threads) 751 | for (int q = 0; q < channels1; q++) 752 | { 753 | const float a0 = a[q]; 754 | const float* ptr1 = b.channel(q); 755 | float* outptr = c.channel(q); 756 | 757 | for (int i = 0; i < size1; i++) 758 | { 759 | outptr[i] = op(a0, ptr1[i]); 760 | } 761 | } 762 | 763 | return 0; 764 | } 765 | 766 | if (b.dims == 3) 767 | { 768 | // type 9 769 | c.create(w1, h1, channels1, elemsize, opt.blob_allocator); 770 | if (c.empty()) 771 | return -100; 772 | 773 | #pragma omp parallel for num_threads(opt.num_threads) 774 | for (int q = 0; q < channels1; q++) 775 | { 776 | const float a0 = a[q]; 777 | const float* ptr1 = b.channel(q); 778 | float* outptr = c.channel(q); 779 | 780 | for (int i = 0; i < size1; i++) 781 | { 782 | outptr[i] = op(a0, ptr1[i]); 783 | } 784 | } 785 | 786 | return 0; 787 | } 788 | 789 | if (b.dims == 2) 790 | { 791 | // type 8 792 | c.create(w1, h1, elemsize, opt.blob_allocator); 793 | if (c.empty()) 794 | return -100; 795 | 796 | const float* ptr1 = b; 797 | float* outptr = c; 798 | 799 | for (int y = 0; y < h1; y++) 800 | { 801 | const float a0 = a[y]; 802 | for (int x = 0; x < w1; x++) 803 | { 804 | outptr[x] = op(a0, ptr1[x]); 805 | } 806 | 807 | ptr1 += w1; 808 | outptr += w1; 809 | } 810 | 811 | return 0; 812 | } 813 | 814 | if (b.dims == 1) 815 | { 816 | c.create(w, elemsize, opt.blob_allocator); 817 | if (c.empty()) 818 | return -100; 819 | 820 | if (b.w == 1) 821 | { 822 | // type 6 823 | const float b0 = b[0]; 824 | for (int i = 0; i < w; i++) 825 | { 826 | c[i] = op(a[i], b0); 827 | } 828 | 829 | return 0; 830 | } 831 | 832 | // type 7 833 | for (int i = 0; i < w; i++) 834 | { 835 | c[i] = op(a[i], b[i]); 836 | } 837 | } 838 | } 839 | 840 | return 0; 841 | } 842 | 843 | template 844 | static int binary_op_scalar_inplace(Mat& a, float b, const Option& opt) 845 | { 846 | Op op; 847 | 848 | int w = a.w; 849 | int h = a.h; 850 | int d = a.d; 851 | int channels = a.c; 852 | int size = w * h * d; 853 | 854 | #pragma omp parallel for num_threads(opt.num_threads) 855 | for (int q = 0; q < channels; q++) 856 | { 857 | float* ptr = a.channel(q); 858 | 859 | for (int i = 0; i < size; i++) 860 | { 861 | ptr[i] = op(ptr[i], b); 862 | } 863 | } 864 | 865 | return 0; 866 | } 867 | 868 | struct binary_op_add 869 | { 870 | float operator()(const float& x, const float& y) const 871 | { 872 | return x + y; 873 | } 874 | }; 875 | 876 | struct binary_op_sub 877 | { 878 | float operator()(const float& x, const float& y) const 879 | { 880 | return x - y; 881 | } 882 | }; 883 | 884 | struct binary_op_mul 885 | { 886 | float operator()(const float& x, const float& y) const 887 | { 888 | return x * y; 889 | } 890 | }; 891 | 892 | struct binary_op_div 893 | { 894 | float operator()(const float& x, const float& y) const 895 | { 896 | return x / y; 897 | } 898 | }; 899 | 900 | struct binary_op_max 901 | { 902 | float operator()(const float& x, const float& y) const 903 | { 904 | return std::max(x, y); 905 | } 906 | }; 907 | 908 | struct binary_op_min 909 | { 910 | float operator()(const float& x, const float& y) const 911 | { 912 | return std::min(x, y); 913 | } 914 | }; 915 | 916 | struct binary_op_pow 917 | { 918 | float operator()(const float& x, const float& y) const 919 | { 920 | return (float)pow(x, y); 921 | } 922 | }; 923 | 924 | struct binary_op_rsub 925 | { 926 | float operator()(const float& x, const float& y) const 927 | { 928 | return y - x; 929 | } 930 | }; 931 | 932 | struct binary_op_rdiv 933 | { 934 | float operator()(const float& x, const float& y) const 935 | { 936 | return y / x; 937 | } 938 | }; 939 | 940 | int BinaryOp::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const 941 | { 942 | const Mat& bottom_blob = bottom_blobs[0]; 943 | const Mat& bottom_blob1 = bottom_blobs[1]; 944 | 945 | Mat& top_blob = top_blobs[0]; 946 | 947 | if (op_type == Operation_ADD) 948 | return binary_op(bottom_blob, bottom_blob1, top_blob, opt); 949 | 950 | if (op_type == Operation_SUB) 951 | return binary_op(bottom_blob, bottom_blob1, top_blob, opt); 952 | 953 | if (op_type == Operation_MUL) 954 | return binary_op(bottom_blob, bottom_blob1, top_blob, opt); 955 | 956 | if (op_type == Operation_DIV) 957 | return binary_op(bottom_blob, bottom_blob1, top_blob, opt); 958 | 959 | if (op_type == Operation_MAX) 960 | return binary_op(bottom_blob, bottom_blob1, top_blob, opt); 961 | 962 | if (op_type == Operation_MIN) 963 | return binary_op(bottom_blob, bottom_blob1, top_blob, opt); 964 | 965 | if (op_type == Operation_POW) 966 | return binary_op(bottom_blob, bottom_blob1, top_blob, opt); 967 | 968 | if (op_type == Operation_RSUB) 969 | return binary_op(bottom_blob1, bottom_blob, top_blob, opt); 970 | 971 | if (op_type == Operation_RDIV) 972 | return binary_op(bottom_blob1, bottom_blob, top_blob, opt); 973 | 974 | return 0; 975 | } 976 | 977 | int BinaryOp::forward_inplace(Mat& bottom_top_blob, const Option& opt) const 978 | { 979 | if (op_type == Operation_ADD) 980 | return binary_op_scalar_inplace(bottom_top_blob, b, opt); 981 | 982 | if (op_type == Operation_SUB) 983 | return binary_op_scalar_inplace(bottom_top_blob, b, opt); 984 | 985 | if (op_type == Operation_MUL) 986 | return binary_op_scalar_inplace(bottom_top_blob, b, opt); 987 | 988 | if (op_type == Operation_DIV) 989 | return binary_op_scalar_inplace(bottom_top_blob, b, opt); 990 | 991 | if (op_type == Operation_MAX) 992 | return binary_op_scalar_inplace(bottom_top_blob, b, opt); 993 | 994 | if (op_type == Operation_MIN) 995 | return binary_op_scalar_inplace(bottom_top_blob, b, opt); 996 | 997 | if (op_type == Operation_POW) 998 | return binary_op_scalar_inplace(bottom_top_blob, b, opt); 999 | 1000 | if (op_type == Operation_RSUB) 1001 | return binary_op_scalar_inplace(bottom_top_blob, b, opt); 1002 | 1003 | if (op_type == Operation_RDIV) 1004 | return binary_op_scalar_inplace(bottom_top_blob, b, opt); 1005 | 1006 | return 0; 1007 | } 1008 | 1009 | } // namespace ncnn 1010 | -------------------------------------------------------------------------------- /replace/expanddims.cpp: -------------------------------------------------------------------------------- 1 | // Tencent is pleased to support the open source community by making ncnn available. 2 | // 3 | // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved. 4 | // 5 | // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except 6 | // in compliance with the License. You may obtain a copy of the License at 7 | // 8 | // https://opensource.org/licenses/BSD-3-Clause 9 | // 10 | // Unless required by applicable law or agreed to in writing, software distributed 11 | // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 12 | // CONDITIONS OF ANY KIND, either express or implied. See the License for the 13 | // specific language governing permissions and limitations under the License. 14 | 15 | #include "expanddims.h" 16 | 17 | namespace ncnn { 18 | 19 | ExpandDims::ExpandDims() 20 | { 21 | one_blob_only = true; 22 | support_inplace = false; 23 | } 24 | 25 | int ExpandDims::load_param(const ParamDict& pd) 26 | { 27 | expand_w = pd.get(0, 0); 28 | expand_h = pd.get(1, 0); 29 | expand_c = pd.get(2, 0); 30 | expand_d = pd.get(11, 0); 31 | 32 | axes = pd.get(3, Mat()); 33 | 34 | return 0; 35 | } 36 | 37 | int ExpandDims::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const 38 | { 39 | int w = bottom_blob.w; 40 | int h = bottom_blob.h; 41 | int c = bottom_blob.c; 42 | int dims = bottom_blob.dims; 43 | 44 | bool _expand_w = false; 45 | bool _expand_h = false; 46 | bool _expand_c = false; 47 | bool _expand_d = false; 48 | 49 | if (axes.empty()) 50 | { 51 | _expand_w = expand_w; 52 | _expand_h = expand_h; 53 | _expand_c = expand_c; 54 | _expand_d = expand_d; 55 | } 56 | else 57 | { 58 | const int* axes_ptr = axes; 59 | for (int i = 0; i < axes.w; i++) 60 | { 61 | int axis = axes_ptr[i]; 62 | if (axis < 0) 63 | axis = dims + 1 + axis; 64 | 65 | if (dims == 1 && axis == 0) 66 | { 67 | _expand_h = true; 68 | } 69 | if (dims == 1 && axis == 1) 70 | { 71 | _expand_w = true; 72 | } 73 | if (dims == 2 && axis == 0) 74 | { 75 | _expand_c = true; 76 | } 77 | if (dims == 2 && axis == 1) 78 | { 79 | _expand_h = true; 80 | } 81 | if (dims == 2 && axis == 2) 82 | { 83 | _expand_w = true; 84 | } 85 | if (dims = 3 && axis == 0) 86 | { 87 | _expand_c = true; 88 | } 89 | if (dims = 3 && axis == 1) 90 | { 91 | _expand_d = true; 92 | } 93 | if (dims = 3 && axis == 2) 94 | { 95 | _expand_h = true; 96 | } 97 | if (dims = 3 && axis == 3) 98 | { 99 | _expand_w = true; 100 | } 101 | } 102 | } 103 | 104 | top_blob = bottom_blob; 105 | 106 | if (dims == 1) 107 | { 108 | if (_expand_w && _expand_h) 109 | { 110 | top_blob = bottom_blob.reshape(1, w, 1, opt.blob_allocator); 111 | } 112 | else if (_expand_w) 113 | { 114 | top_blob = bottom_blob.reshape(1, w, opt.blob_allocator); 115 | } 116 | else if (_expand_h) 117 | { 118 | top_blob = bottom_blob.reshape(w, 1, opt.blob_allocator); 119 | } 120 | } 121 | 122 | if (dims == 2) 123 | { 124 | if (_expand_w) 125 | { 126 | top_blob = bottom_blob.reshape(1, w, h, opt.blob_allocator); 127 | } 128 | else if (_expand_h) 129 | { 130 | top_blob = bottom_blob.reshape(w, 1, h, opt.blob_allocator); 131 | } 132 | else if (_expand_c) 133 | { 134 | top_blob = bottom_blob.reshape(w, h, 1, opt.blob_allocator); 135 | } 136 | } 137 | if (dims == 3) 138 | { 139 | if (_expand_w) 140 | { 141 | top_blob = bottom_blob.reshape(1, w, h, c, opt.blob_allocator); 142 | } 143 | else if (_expand_h) 144 | { 145 | top_blob = bottom_blob.reshape(w, 1, h, c, opt.blob_allocator); 146 | } 147 | else if (_expand_c) 148 | { 149 | top_blob = bottom_blob.reshape(w, h, c, 1, opt.blob_allocator); 150 | } 151 | else if (_expand_d) 152 | { 153 | top_blob = bottom_blob.reshape(w, h, 1, c, opt.blob_allocator); 154 | } 155 | } 156 | if (top_blob.empty()) 157 | return -100; 158 | 159 | return 0; 160 | } 161 | 162 | } // namespace ncnn 163 | -------------------------------------------------------------------------------- /replace/expanddims.h: -------------------------------------------------------------------------------- 1 | // Tencent is pleased to support the open source community by making ncnn available. 2 | // 3 | // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved. 4 | // 5 | // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except 6 | // in compliance with the License. You may obtain a copy of the License at 7 | // 8 | // https://opensource.org/licenses/BSD-3-Clause 9 | // 10 | // Unless required by applicable law or agreed to in writing, software distributed 11 | // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 12 | // CONDITIONS OF ANY KIND, either express or implied. See the License for the 13 | // specific language governing permissions and limitations under the License. 14 | 15 | #ifndef LAYER_EXPANDDIMS_H 16 | #define LAYER_EXPANDDIMS_H 17 | 18 | #include "layer.h" 19 | 20 | namespace ncnn { 21 | 22 | class ExpandDims : public Layer 23 | { 24 | public: 25 | ExpandDims(); 26 | 27 | virtual int load_param(const ParamDict& pd); 28 | 29 | virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; 30 | 31 | public: 32 | int expand_w; 33 | int expand_h; 34 | int expand_c; 35 | int expand_d; 36 | Mat axes; 37 | }; 38 | 39 | } // namespace ncnn 40 | 41 | #endif // LAYER_EXPANDDIMS_H 42 | -------------------------------------------------------------------------------- /replace/gather.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) OpenMMLab. All rights reserved. 2 | #include "gather.h" 3 | 4 | 5 | namespace ncnn { 6 | 7 | Gather::Gather() 8 | { 9 | one_blob_only = false; 10 | support_inplace = false; 11 | } 12 | 13 | int Gather::load_param(const ParamDict &pd) { 14 | axis = pd.get(0, 0); 15 | indice = pd.get(1, Mat()); 16 | 17 | return 0; 18 | } 19 | 20 | // Gather only support 1-dim of indices, because the data and indices all has 21 | // implicit batch in ncnn, this will lead to wrong shape to match onnx result. 22 | // When indices dim equals to 1, after eliminating implicit batch, the indices 23 | // dim still be 1. So there is only 1 implicit batch in data, this will make 24 | // the shape match onnx result. 25 | int Gather::forward(const std::vector &bottom_blobs, std::vector &top_blobs, 26 | const Option &opt) const { 27 | const Mat &bottom_blob = bottom_blobs[0]; 28 | const Mat &indices = bottom_blobs[1]; 29 | 30 | int dims = bottom_blob.dims; 31 | int indices_dims = indices.dims; 32 | size_t elemsize = bottom_blob.elemsize; 33 | int positive_axis = axis < 0 ? dims + axis : axis; 34 | Mat &top_blob = top_blobs[0]; 35 | 36 | if(indices.dims != 1) 37 | return -100; 38 | //const float *indices_ptr = indices; 39 | const int* indices_ptr = indice; 40 | 41 | if (dims == 1 && indices_dims == 1) // positive_axis == 0 42 | { 43 | int w = indices.w; 44 | top_blob.create(w, elemsize, opt.blob_allocator); 45 | if (top_blob.empty()) { 46 | return -100; 47 | } 48 | const float *ptr = bottom_blob; 49 | float *outptr = top_blob; 50 | for (int i = 0; i < w; i++) { 51 | float indice = indices_ptr[i]; 52 | outptr[i] = ptr[(int)(indice + 0.5)]; 53 | } 54 | 55 | return 0; 56 | } 57 | 58 | if (dims == 2 && positive_axis == 0 && indices_dims == 1) { 59 | int w = bottom_blob.w; 60 | int h = bottom_blob.h; 61 | top_blob.create(w, indices.w, elemsize, opt.blob_allocator); 62 | // w -> w 63 | // h -> indices.w 64 | // h * w -> indices.w * w 65 | if (top_blob.empty()) { 66 | return -100; 67 | } 68 | const float *ptr = bottom_blob; 69 | float *outptr = top_blob; 70 | for (int i = 0; i < indices.w; i++) { 71 | const int selected = (int)(indices_ptr[i] + 0.5); 72 | memcpy(top_blob.row(i), bottom_blob.row(selected), w * elemsize); 73 | } 74 | 75 | return 0; 76 | } 77 | 78 | if (dims == 2 && positive_axis == 1 && indices_dims == 1) { 79 | int w = bottom_blob.w; 80 | int h = bottom_blob.h; 81 | top_blob.create(indices.w, h, elemsize, opt.blob_allocator); 82 | // w -> h 83 | // h -> indices.w 84 | // h * w -> indices.w * h 85 | if (top_blob.empty()) { 86 | return -100; 87 | } 88 | const float *ptr = bottom_blob; 89 | float *outptr = top_blob; 90 | for (int j = 0; j < h; j++) { 91 | for (int i = 0; i < indices.w; i++) { 92 | int selected = (int)(indices_ptr[i] + 0.5); 93 | outptr[j * indices.w + i] = ptr[j * w + selected]; 94 | } 95 | } 96 | return 0; 97 | } 98 | 99 | if (dims == 3 && positive_axis == 0 && indices_dims == 1) { 100 | int w = bottom_blob.w; 101 | int h = bottom_blob.h; 102 | int channels = bottom_blob.c; 103 | top_blob.create(w, h, indices.w, elemsize, opt.blob_allocator); 104 | 105 | if (top_blob.empty()) { 106 | return -100; 107 | } 108 | for (int i = 0; i < indices.w; i++) { 109 | int selected = (int)(indices_ptr[i] + 0.5); 110 | const unsigned char *ptr = bottom_blob.channel(selected); 111 | unsigned char *outptr = top_blob.channel(i); 112 | 113 | memcpy(outptr, ptr, w * h * elemsize); 114 | } 115 | return 0; 116 | } 117 | 118 | if (dims == 3 && positive_axis == 1 && indices_dims == 1) { 119 | int w = bottom_blob.w; 120 | int h = bottom_blob.h; 121 | int channels = bottom_blob.c; 122 | top_blob.create(w, indices.w, channels, elemsize, opt.blob_allocator); 123 | #pragma omp parallel for num_threads(opt.num_threads) 124 | // use parallel programming 125 | for (int i = 0; i < channels; i++) { 126 | float *outptr = top_blob.channel(i); 127 | const float *ptr = bottom_blob.channel(i); 128 | for (int j = 0; j < indices.w; j++) { 129 | int selected = (int)(indices_ptr[j] + 0.5); 130 | for (int k = 0; k < w; k++) { 131 | outptr[j * w + k] = ptr[selected * w + k]; 132 | } 133 | } 134 | } 135 | 136 | return 0; 137 | } 138 | 139 | if (dims == 3 && positive_axis == 2 && indices_dims == 1) { 140 | int w = bottom_blob.w; 141 | int h = bottom_blob.h; 142 | int channels = bottom_blob.c; 143 | top_blob.create(indices.w, h, channels, elemsize, opt.blob_allocator); 144 | #pragma omp parallel for num_threads(opt.num_threads) 145 | // use parallel programming 146 | for (int i = 0; i < channels; i++) { 147 | float *outptr = top_blob.channel(i); 148 | const float *ptr = bottom_blob.channel(i); 149 | for (int j = 0; j < h; j++) { 150 | for (int k = 0; k < indices.w; k++) { 151 | int selected = (int)(indices_ptr[k] + 0.5); 152 | outptr[j * indices.w + k] = ptr[j * w + selected]; 153 | } 154 | } 155 | } 156 | return 0; 157 | } 158 | 159 | if (dims == 4 && positive_axis == 0 && indices_dims == 1) 160 | { 161 | int w = bottom_blob.w; 162 | int h = bottom_blob.h; 163 | int d = bottom_blob.d; 164 | int channels = bottom_blob.c; 165 | top_blob.create(w, h, d, indice.w, elemsize, opt.blob_allocator); 166 | 167 | if (top_blob.empty()) 168 | { 169 | return -100; 170 | } 171 | for (int i = 0; i < indice.w; i++) 172 | { 173 | int selected = (int)(indices_ptr[i]); 174 | const unsigned char* ptr = bottom_blob.channel(selected); 175 | unsigned char* outptr = top_blob.channel(i); 176 | 177 | memcpy(outptr, ptr, w * h * d * elemsize); 178 | } 179 | return 0; 180 | } 181 | 182 | if (dims == 4 && positive_axis == 1 && indices_dims == 1) 183 | { 184 | int w = bottom_blob.w; 185 | int h = bottom_blob.h; 186 | int channels = bottom_blob.c; 187 | top_blob.create(w, h, indice.w, channels, elemsize, opt.blob_allocator); 188 | #pragma omp parallel for num_threads(opt.num_threads) 189 | // use parallel programming 190 | for (int i = 0; i < channels; i++) 191 | { 192 | float* outptr = top_blob.channel(i); 193 | const float* ptr = bottom_blob.channel(i); 194 | for (int j = 0; j < indice.w; j++) 195 | { 196 | int selected = (int)(indices_ptr[j] + 0.5); 197 | for (int k = 0; k < w * h; k++) 198 | { 199 | outptr[j * w * h + k] = ptr[selected * w * h + k]; 200 | } 201 | } 202 | } 203 | 204 | return 0; 205 | } 206 | 207 | if (dims == 4 && positive_axis == 2 && indices_dims == 1) 208 | { 209 | int w = bottom_blob.w; 210 | int h = bottom_blob.h; 211 | int d = bottom_blob.d; 212 | int channels = bottom_blob.c; 213 | top_blob.create(w, indice.w, d, channels, elemsize, opt.blob_allocator); 214 | #pragma omp parallel for num_threads(opt.num_threads) 215 | // use parallel programming 216 | for (int i = 0; i < channels; i++) 217 | { 218 | float* outptr = top_blob.channel(i); 219 | const float* ptr = bottom_blob.channel(i); 220 | for (int j = 0; j < d; j++) 221 | { 222 | for (int k = 0; k < indice.w; k++) 223 | { 224 | int selected = (int)(indices_ptr[k] + 0.5); 225 | for(int l = 0; l < w; l++) 226 | outptr[j * indice.w * w + l] = ptr[j * selected* w + l]; 227 | } 228 | } 229 | } 230 | return 0; 231 | } 232 | if (dims == 4 && positive_axis == 3 && indices_dims == 1) 233 | { 234 | int w = bottom_blob.w; 235 | int h = bottom_blob.h; 236 | int d = bottom_blob.d; 237 | int channels = bottom_blob.c; 238 | top_blob.create(indice.w, h, d, channels, elemsize, opt.blob_allocator); 239 | #pragma omp parallel for num_threads(opt.num_threads) 240 | // use parallel programming 241 | for (int i = 0; i < channels; i++) 242 | { 243 | float* outptr = top_blob.channel(i); 244 | const float* ptr = bottom_blob.channel(i); 245 | for (int l = 0; l < d; l++) 246 | { 247 | for (int j = 0; j < h; j++) 248 | { 249 | for (int k = 0; k < indice.w; k++) 250 | { 251 | int selected = (int)(indices_ptr[k] + 0.5); 252 | outptr[l * h * indice.w + j * indice.w + k] = ptr[l * h * w + j * w + selected]; 253 | } 254 | } 255 | } 256 | 257 | } 258 | return 0; 259 | } 260 | return 0; 261 | } 262 | 263 | } // namespace mmdeploy 264 | -------------------------------------------------------------------------------- /replace/gather.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) OpenMMLab. All rights reserved. 2 | #ifndef LAYER_GATHER_H 3 | #define LAYER_GATHER_H 4 | 5 | #include "layer.h" 6 | 7 | namespace ncnn { 8 | 9 | class Gather : public ncnn::Layer { 10 | public: 11 | Gather(); 12 | 13 | virtual int load_param(const ncnn::ParamDict& pd); 14 | 15 | virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, 16 | const ncnn::Option& opt) const; 17 | 18 | public: 19 | int axis; 20 | Mat indice; 21 | }; 22 | 23 | } // namespace ncnn 24 | 25 | #endif // LAYER_GATHER_H 26 | -------------------------------------------------------------------------------- /replace/gemm.cpp: -------------------------------------------------------------------------------- 1 | // Tencent is pleased to support the open source community by making ncnn available. 2 | // 3 | // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved. 4 | // 5 | // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except 6 | // in compliance with the License. You may obtain a copy of the License at 7 | // 8 | // https://opensource.org/licenses/BSD-3-Clause 9 | // 10 | // Unless required by applicable law or agreed to in writing, software distributed 11 | // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 12 | // CONDITIONS OF ANY KIND, either express or implied. See the License for the 13 | // specific language governing permissions and limitations under the License. 14 | 15 | #include "gemm.h" 16 | 17 | namespace ncnn { 18 | 19 | Gemm::Gemm() 20 | { 21 | one_blob_only = false; 22 | support_inplace = false; 23 | } 24 | 25 | int Gemm::load_param(const ParamDict& pd) 26 | { 27 | alpha = pd.get(0, 1.f); 28 | beta = pd.get(1, 1.f); 29 | transA = pd.get(2, 0); 30 | transB = pd.get(3, 0); 31 | 32 | return 0; 33 | } 34 | 35 | int Gemm::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const 36 | { 37 | const Mat& A0 = bottom_blobs[0]; 38 | const Mat& B0 = bottom_blobs[1]; 39 | 40 | size_t elemsize = A0.elemsize; 41 | 42 | Mat A; 43 | if (transA == 0) 44 | { 45 | A = A0; 46 | } 47 | else 48 | { 49 | // transpose A to row-major 50 | A.create(A0.h, A0.w, A0.c, elemsize, opt.workspace_allocator); 51 | for (int c = 0; c < A0.c; c++) 52 | { 53 | for (int i = 0; i < A.h; i++) 54 | { 55 | float* ptr = A.channel(c).row(i); 56 | for (int j = 0; j < A.w; j++) 57 | { 58 | ptr[j] = A0.channel(c).row(j)[i]; 59 | } 60 | } 61 | } 62 | } 63 | 64 | Mat B; 65 | if (transB == 0) 66 | { 67 | // transpose B to col-major 68 | B.create(B0.h, B0.w, B0.c, elemsize, opt.workspace_allocator); 69 | for (int c = 0; c < B0.c; c++) 70 | { 71 | for (int i = 0; i < B.h; i++) 72 | { 73 | float* ptr = B.channel(c).row(i); 74 | for (int j = 0; j < B.w; j++) 75 | { 76 | ptr[j] = B0.channel(c).row(j)[i]; 77 | } 78 | } 79 | } 80 | } 81 | else 82 | { 83 | B = B0; 84 | } 85 | 86 | int M = A.h; 87 | int K = A.w; // assert A.w == B.w 88 | int N = B.h; 89 | 90 | bool has_C = bottom_blobs.size() == 3; 91 | 92 | const float* ptrC = 0; 93 | int broadcast_type_C = 0; 94 | if (has_C) 95 | { 96 | const Mat& C = bottom_blobs[2]; 97 | 98 | ptrC = C; 99 | 100 | if (C.dims == 1 && C.w == 1) 101 | { 102 | // scalar 103 | broadcast_type_C = 0; 104 | } 105 | if (C.dims == 1 && C.w == M) 106 | { 107 | // M 108 | // auto broadcast from h to w is the ncnn-style convention 109 | broadcast_type_C = 1; 110 | } 111 | if (C.dims == 2 && C.w == 1 && C.h == M) 112 | { 113 | // Mx1 114 | broadcast_type_C = 2; 115 | } 116 | if (C.dims == 2 && C.w == N && C.h == M) 117 | { 118 | // MxN 119 | broadcast_type_C = 3; 120 | } 121 | if (C.dims == 2 && C.w == N && C.h == 1) 122 | { 123 | // 1xN 124 | broadcast_type_C = 4; 125 | } 126 | } 127 | 128 | //====new 129 | Mat& top_blob = top_blobs[0]; 130 | if (A0.c == B0.c && A0.c > 1) 131 | { 132 | top_blob.create(N, M, A0.c, elemsize, opt.blob_allocator); 133 | if (top_blob.empty()) 134 | return -100; 135 | for (int c = 0; c < top_blob.c; c++) 136 | { 137 | float* outptr = top_blob.channel(c); 138 | for (int i = 0; i < M; i++) 139 | { 140 | const float* ptrA = A.channel(c).row(i); 141 | 142 | for (int j = 0; j < N; j++) 143 | { 144 | const float* ptrB = B.channel(c).row(j); 145 | 146 | float sum = 0.f; 147 | if (has_C) 148 | { 149 | if (broadcast_type_C == 0) 150 | { 151 | sum = ptrC[0]; 152 | } 153 | if (broadcast_type_C == 1) 154 | { 155 | sum = ptrC[i]; 156 | } 157 | if (broadcast_type_C == 2) 158 | { 159 | sum = ptrC[i]; 160 | } 161 | if (broadcast_type_C == 3) 162 | { 163 | sum = ptrC[i * N + j]; 164 | } 165 | if (broadcast_type_C == 4) 166 | { 167 | sum = ptrC[j]; 168 | } 169 | 170 | sum *= beta; 171 | } 172 | 173 | for (int k = 0; k < K; k++) 174 | { 175 | sum += ptrA[k] * ptrB[k]; 176 | } 177 | 178 | *outptr++ = sum * alpha; 179 | } 180 | } 181 | } 182 | 183 | 184 | return 0; 185 | } 186 | top_blob.create(N, M, elemsize, opt.blob_allocator); 187 | if (top_blob.empty()) 188 | return -100; 189 | 190 | float* outptr = top_blob; 191 | for (int i = 0; i < M; i++) 192 | { 193 | const float* ptrA = A.row(i); 194 | 195 | for (int j = 0; j < N; j++) 196 | { 197 | const float* ptrB = B.row(j); 198 | 199 | float sum = 0.f; 200 | if (has_C) 201 | { 202 | if (broadcast_type_C == 0) 203 | { 204 | sum = ptrC[0]; 205 | } 206 | if (broadcast_type_C == 1) 207 | { 208 | sum = ptrC[i]; 209 | } 210 | if (broadcast_type_C == 2) 211 | { 212 | sum = ptrC[i]; 213 | } 214 | if (broadcast_type_C == 3) 215 | { 216 | sum = ptrC[i * N + j]; 217 | } 218 | if (broadcast_type_C == 4) 219 | { 220 | sum = ptrC[j]; 221 | } 222 | 223 | sum *= beta; 224 | } 225 | 226 | for (int k = 0; k < K; k++) 227 | { 228 | sum += ptrA[k] * ptrB[k]; 229 | } 230 | 231 | *outptr++ = sum * alpha; 232 | } 233 | } 234 | 235 | return 0; 236 | } 237 | 238 | } // namespace ncnn 239 | -------------------------------------------------------------------------------- /replace/squeeze.cpp: -------------------------------------------------------------------------------- 1 | // Tencent is pleased to support the open source community by making ncnn available. 2 | // 3 | // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved. 4 | // 5 | // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except 6 | // in compliance with the License. You may obtain a copy of the License at 7 | // 8 | // https://opensource.org/licenses/BSD-3-Clause 9 | // 10 | // Unless required by applicable law or agreed to in writing, software distributed 11 | // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 12 | // CONDITIONS OF ANY KIND, either express or implied. See the License for the 13 | // specific language governing permissions and limitations under the License. 14 | 15 | #include "squeeze.h" 16 | 17 | namespace ncnn { 18 | 19 | Squeeze::Squeeze() 20 | { 21 | one_blob_only = true; 22 | support_inplace = false; 23 | } 24 | 25 | int Squeeze::load_param(const ParamDict& pd) 26 | { 27 | squeeze_w = pd.get(0, 0); 28 | squeeze_h = pd.get(1, 0); 29 | squeeze_c = pd.get(2, 0); 30 | squeeze_d = pd.get(11, 0); 31 | axes = pd.get(3, Mat()); 32 | 33 | return 0; 34 | } 35 | 36 | int Squeeze::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const 37 | { 38 | int w = bottom_blob.w; 39 | int h = bottom_blob.h; 40 | int d = bottom_blob.d; 41 | int channels = bottom_blob.c; 42 | int dims = bottom_blob.dims; 43 | 44 | bool _squeeze_w = false; 45 | bool _squeeze_h = false; 46 | bool _squeeze_c = false; 47 | bool _squeeze_d = false; 48 | 49 | if (axes.empty()) 50 | { 51 | _squeeze_w = w == 1 && squeeze_w; 52 | _squeeze_h = h == 1 && squeeze_h; 53 | _squeeze_c = channels == 1 && squeeze_c; 54 | } 55 | else 56 | { 57 | const int* axes_ptr = axes; 58 | for (int i = 0; i < axes.w; i++) 59 | { 60 | int axis = axes_ptr[i]; 61 | if (axis < 0) 62 | axis = dims + axis; 63 | 64 | if (dims == 1 && axis == 0) 65 | { 66 | _squeeze_w = w == 1; 67 | } 68 | if (dims == 2 && axis == 0) 69 | { 70 | _squeeze_h = h == 1; 71 | } 72 | if (dims == 2 && axis == 1) 73 | { 74 | _squeeze_w = w == 1; 75 | } 76 | if (dims == 3 && axis == 0) 77 | { 78 | _squeeze_c = channels == 1; 79 | } 80 | if (dims == 3 && axis == 1) 81 | { 82 | _squeeze_h = h == 1; 83 | } 84 | if (dims == 3 && axis == 2) 85 | { 86 | _squeeze_w = w == 1; 87 | } 88 | if (dims == 4 && axis == 1) 89 | { 90 | _squeeze_d = d = 1; 91 | } 92 | } 93 | } 94 | 95 | top_blob = bottom_blob; 96 | 97 | if (dims == 1) 98 | { 99 | if (_squeeze_w) 100 | { 101 | top_blob = bottom_blob.reshape(1, opt.blob_allocator); 102 | } 103 | } 104 | 105 | if (dims == 2) 106 | { 107 | if (_squeeze_w && _squeeze_h) 108 | { 109 | top_blob = bottom_blob.reshape(1, opt.blob_allocator); 110 | } 111 | else if (_squeeze_w) 112 | { 113 | top_blob = bottom_blob.reshape(h, opt.blob_allocator); 114 | } 115 | else if (_squeeze_h) 116 | { 117 | top_blob = bottom_blob.reshape(w, opt.blob_allocator); 118 | } 119 | } 120 | 121 | if (dims == 3) 122 | { 123 | if (_squeeze_w && _squeeze_h && _squeeze_c) 124 | { 125 | top_blob = bottom_blob.reshape(1, opt.blob_allocator); 126 | } 127 | else if (_squeeze_w && _squeeze_h) 128 | { 129 | top_blob = bottom_blob.reshape(channels, opt.blob_allocator); 130 | } 131 | else if (_squeeze_h && _squeeze_c) 132 | { 133 | top_blob = bottom_blob.reshape(w, opt.blob_allocator); 134 | } 135 | else if (_squeeze_w && _squeeze_c) 136 | { 137 | top_blob = bottom_blob.reshape(h, opt.blob_allocator); 138 | } 139 | else if (_squeeze_w) 140 | { 141 | top_blob = bottom_blob.reshape(h, channels, opt.blob_allocator); 142 | } 143 | else if (_squeeze_h) 144 | { 145 | top_blob = bottom_blob.reshape(w, channels, opt.blob_allocator); 146 | } 147 | else if (_squeeze_c) 148 | { 149 | top_blob = bottom_blob.reshape(w, h, opt.blob_allocator); 150 | } 151 | } 152 | 153 | if (dims == 4) 154 | { 155 | if (_squeeze_w && _squeeze_h && _squeeze_c && _squeeze_d) 156 | { 157 | top_blob = bottom_blob.reshape(1, opt.blob_allocator); 158 | } 159 | else if (_squeeze_w && _squeeze_h && _squeeze_c) 160 | { 161 | top_blob = bottom_blob.reshape(d, opt.blob_allocator); 162 | } 163 | else if (_squeeze_w && _squeeze_h) 164 | { 165 | top_blob = bottom_blob.reshape(d,channels, opt.blob_allocator); 166 | } 167 | else if (_squeeze_h && _squeeze_c) 168 | { 169 | top_blob = bottom_blob.reshape(w,d, opt.blob_allocator); 170 | } 171 | else if (_squeeze_w && _squeeze_c) 172 | { 173 | top_blob = bottom_blob.reshape(h,d, opt.blob_allocator); 174 | } 175 | else if (_squeeze_w) 176 | { 177 | top_blob = bottom_blob.reshape(h, d, channels, opt.blob_allocator); 178 | } 179 | else if (_squeeze_h) 180 | { 181 | top_blob = bottom_blob.reshape(w, d, channels, opt.blob_allocator); 182 | } 183 | else if (_squeeze_c) 184 | { 185 | top_blob = bottom_blob.reshape(w, h, d, opt.blob_allocator); 186 | } 187 | else if (_squeeze_d) 188 | { 189 | top_blob = bottom_blob.reshape(w, h, channels, opt.blob_allocator); 190 | } 191 | } 192 | 193 | if (top_blob.empty()) 194 | return -100; 195 | 196 | return 0; 197 | } 198 | 199 | } // namespace ncnn 200 | -------------------------------------------------------------------------------- /replace/squeeze.h: -------------------------------------------------------------------------------- 1 | // Tencent is pleased to support the open source community by making ncnn available. 2 | // 3 | // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved. 4 | // 5 | // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except 6 | // in compliance with the License. You may obtain a copy of the License at 7 | // 8 | // https://opensource.org/licenses/BSD-3-Clause 9 | // 10 | // Unless required by applicable law or agreed to in writing, software distributed 11 | // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 12 | // CONDITIONS OF ANY KIND, either express or implied. See the License for the 13 | // specific language governing permissions and limitations under the License. 14 | 15 | #ifndef LAYER_SQUEEZE_H 16 | #define LAYER_SQUEEZE_H 17 | 18 | #include "layer.h" 19 | 20 | namespace ncnn { 21 | 22 | class Squeeze : public Layer 23 | { 24 | public: 25 | Squeeze(); 26 | 27 | virtual int load_param(const ParamDict& pd); 28 | 29 | virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; 30 | 31 | public: 32 | int squeeze_w; 33 | int squeeze_h; 34 | int squeeze_c; 35 | int squeeze_d; 36 | Mat axes; 37 | }; 38 | 39 | } // namespace ncnn 40 | 41 | #endif // LAYER_SQUEEZE_H 42 | -------------------------------------------------------------------------------- /result1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiGeChuanShu/DocTr-ncnn/b8584412038ff058e8a48acf3d39516b3dc3acdc/result1.jpg -------------------------------------------------------------------------------- /result2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiGeChuanShu/DocTr-ncnn/b8584412038ff058e8a48acf3d39516b3dc3acdc/result2.jpg -------------------------------------------------------------------------------- /src/demo.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "net.h" 4 | #include "cpu.h" 5 | #include 6 | static void interp(const ncnn::Mat& in, ncnn::Mat& out) 7 | { 8 | ncnn::Option opt; 9 | opt.num_threads = 4; 10 | opt.use_fp16_storage = false; 11 | opt.use_packing_layout = false; 12 | 13 | ncnn::Layer* op = ncnn::create_layer("Interp"); 14 | 15 | // set param 16 | ncnn::ParamDict pd; 17 | pd.set(0, 2);// 18 | pd.set(3, 288);// 19 | pd.set(4, 288);// 20 | op->load_param(pd); 21 | 22 | op->create_pipeline(opt); 23 | 24 | // forward 25 | op->forward(in, out, opt); 26 | 27 | op->destroy_pipeline(opt); 28 | 29 | delete op; 30 | } 31 | static void scale(const ncnn::Mat& in, const float& scale, int scale_data_size, ncnn::Mat& out) 32 | { 33 | ncnn::Option opt; 34 | opt.num_threads = 4; 35 | opt.use_fp16_storage = false; 36 | opt.use_packing_layout = false; 37 | 38 | ncnn::Layer* op = ncnn::create_layer("Scale"); 39 | 40 | // set param 41 | ncnn::ParamDict pd; 42 | pd.set(0, scale_data_size);// scale_data_size 43 | pd.set(1, 0);// 44 | 45 | op->load_param(pd); 46 | 47 | // set weights 48 | ncnn::Mat scales[1]; 49 | scales[0].create(scale_data_size);// scale_data 50 | 51 | for (int i = 0; i < scale_data_size; i++) 52 | { 53 | scales[0][i] = scale; 54 | } 55 | 56 | op->load_model(ncnn::ModelBinFromMatArray(scales)); 57 | 58 | op->create_pipeline(opt); 59 | 60 | // forward 61 | op->forward(in, out, opt); 62 | 63 | op->destroy_pipeline(opt); 64 | 65 | delete op; 66 | } 67 | static void binary_op(const ncnn::Mat& a, const ncnn::Mat& b, ncnn::Mat& c,int op_type) 68 | { 69 | ncnn::Option opt; 70 | opt.num_threads = 4; 71 | opt.use_fp16_storage = false; 72 | opt.use_packing_layout = false; 73 | 74 | ncnn::Layer* op = ncnn::create_layer("BinaryOp"); 75 | 76 | // set param 77 | ncnn::ParamDict pd; 78 | pd.set(0, op_type);// op_type 79 | 80 | op->load_param(pd); 81 | 82 | op->create_pipeline(opt); 83 | 84 | // forward 85 | std::vector bottoms(2); 86 | bottoms[0] = a; 87 | bottoms[1] = b; 88 | 89 | std::vector tops(1); 90 | op->forward(bottoms, tops, opt); 91 | 92 | c = tops[0]; 93 | 94 | op->destroy_pipeline(opt); 95 | 96 | delete op; 97 | } 98 | static void concat(const ncnn::Mat& a, const ncnn::Mat& b, ncnn::Mat& c,int axis) 99 | { 100 | ncnn::Option opt; 101 | opt.num_threads = 2; 102 | opt.use_fp16_storage = false; 103 | opt.use_packing_layout = false; 104 | 105 | ncnn::Layer* op = ncnn::create_layer("Concat"); 106 | 107 | // set param 108 | ncnn::ParamDict pd; 109 | pd.set(0, axis);// axis 110 | 111 | op->load_param(pd); 112 | 113 | op->create_pipeline(opt); 114 | 115 | // forward 116 | std::vector bottoms(2); 117 | bottoms[0] = a; 118 | bottoms[1] = b; 119 | 120 | std::vector tops(1); 121 | op->forward(bottoms, tops, opt); 122 | 123 | c = tops[0]; 124 | 125 | op->destroy_pipeline(opt); 126 | 127 | delete op; 128 | } 129 | static void transpose(const ncnn::Mat& in, ncnn::Mat& out,const int& order_type) 130 | { 131 | ncnn::Option opt; 132 | opt.num_threads = 2; 133 | opt.use_fp16_storage = false; 134 | opt.use_packing_layout = true; 135 | 136 | ncnn::Layer* op = ncnn::create_layer("Permute"); 137 | 138 | // set param 139 | ncnn::ParamDict pd; 140 | pd.set(0, order_type);// order_type 141 | 142 | op->load_param(pd); 143 | 144 | op->create_pipeline(opt); 145 | 146 | ncnn::Mat in_packed = in; 147 | { 148 | // resolve dst_elempack 149 | int dims = in.dims; 150 | int elemcount = 0; 151 | if (dims == 1) elemcount = in.elempack * in.w; 152 | if (dims == 2) elemcount = in.elempack * in.h; 153 | if (dims == 3) elemcount = in.elempack * in.c; 154 | 155 | int dst_elempack = 1; 156 | if (op->support_packing) 157 | { 158 | if (elemcount % 8 == 0 && (ncnn::cpu_support_x86_avx2() || ncnn::cpu_support_x86_avx())) 159 | dst_elempack = 8; 160 | else if (elemcount % 4 == 0) 161 | dst_elempack = 4; 162 | } 163 | 164 | if (in.elempack != dst_elempack) 165 | { 166 | convert_packing(in, in_packed, dst_elempack, opt); 167 | } 168 | } 169 | 170 | // forward 171 | op->forward(in_packed, out, opt); 172 | 173 | op->destroy_pipeline(opt); 174 | 175 | delete op; 176 | } 177 | static void reduction(const ncnn::Mat& in, ncnn::Mat& out) 178 | { 179 | ncnn::Option opt; 180 | opt.num_threads = 4; 181 | opt.use_fp16_storage = false; 182 | opt.use_packing_layout = false; 183 | 184 | ncnn::Layer* op = ncnn::create_layer("Reduction"); 185 | 186 | // set param 187 | ncnn::ParamDict pd; 188 | pd.set(0, 0);// sum 189 | pd.set(1, 0);// reduce_all 190 | pd.set(4, 1);//keepdims 191 | ncnn::Mat axes = ncnn::Mat(1); 192 | axes.fill(0); 193 | pd.set(3, axes); 194 | 195 | op->load_param(pd); 196 | 197 | op->create_pipeline(opt); 198 | 199 | // forward 200 | op->forward(in, out, opt); 201 | 202 | op->destroy_pipeline(opt); 203 | 204 | delete op; 205 | } 206 | static void threshold(ncnn::Mat& in, const float& threshold) 207 | { 208 | ncnn::Option opt; 209 | opt.num_threads = 4; 210 | opt.use_fp16_storage = false; 211 | opt.use_packing_layout = false; 212 | 213 | ncnn::Layer* op = ncnn::create_layer("Threshold"); 214 | 215 | // set param 216 | ncnn::ParamDict pd; 217 | pd.set(0, threshold);// 218 | 219 | op->load_param(pd); 220 | 221 | 222 | op->create_pipeline(opt); 223 | 224 | // forward 225 | op->forward_inplace(in, opt); 226 | 227 | op->destroy_pipeline(opt); 228 | 229 | delete op; 230 | } 231 | 232 | 233 | static float im2col_get_pixel(const float* im, int height, int width, int channels, int row, int col, int channel, int pad) 234 | { 235 | row -= pad; 236 | col -= pad; 237 | 238 | if (row < 0 || col < 0 || 239 | row >= height || col >= width) return 0; 240 | return im[col + width * (row + height * channel)]; 241 | } 242 | //from https://github.com/pjreddie/darknet/blob/master/src/im2col.c 243 | static ncnn::Mat im2col_cpu(const ncnn::Mat& data_im, int ksize, int stride, int pad) 244 | { 245 | int c, h, w; 246 | int channels = data_im.c; 247 | int height = data_im.h; 248 | int width = data_im.w; 249 | 250 | int height_col = (height + 2 * pad - ksize) / stride + 1; 251 | int width_col = (width + 2 * pad - ksize) / stride + 1;; 252 | int channels_col = channels * ksize * ksize; 253 | 254 | ncnn::Mat data_col = ncnn::Mat(channels_col * height_col * width_col, 1, 1); 255 | data_col.fill(0.0f); 256 | 257 | for (c = 0; c < channels_col; c++) 258 | { 259 | int w_offset = c % ksize; 260 | int h_offset = (c / ksize) % ksize; 261 | int c_im = c / ksize / ksize; 262 | for (h = 0; h < height_col; ++h) 263 | { 264 | for (w = 0; w < width_col; ++w) 265 | { 266 | int im_row = h_offset + h * stride; 267 | int im_col = w_offset + w * stride; 268 | int col_index = (c * height_col + h) * width_col + w; 269 | data_col.channel(0)[col_index] = im2col_get_pixel((const float*)data_im.data, height, width, channels, im_row, im_col, c_im, pad); 270 | } 271 | } 272 | } 273 | 274 | return data_col.reshape(height_col * width_col, channels_col); 275 | } 276 | 277 | static int position_embedding(ncnn::Mat& mask,int num_pos_feats,ncnn::Mat& pos) 278 | { 279 | ncnn::Mat y_embed = ncnn::Mat(mask.w, mask.h, mask.c); 280 | ncnn::Mat x_embed = ncnn::Mat(mask.w, mask.h, mask.c); 281 | 282 | for (int i = 0; i < mask.c; i++) 283 | { 284 | for (int j = 0; j < mask.h; j++) 285 | { 286 | float* mask_data = mask.channel(i).row(j); 287 | float* x_embed_data = x_embed.channel(i).row(j); 288 | for (int k = 0; k < mask.w; k++) 289 | { 290 | for (int l = k; l >= 0; l--) 291 | x_embed_data[k] += mask_data[l]; 292 | } 293 | } 294 | float* mask_data = mask.channel(i); 295 | for (int j = 0; j < mask.w; j++) 296 | { 297 | for (int k = 0; k < mask.h; k++) 298 | { 299 | float* y_embed_data = y_embed.channel(i).row(k); 300 | for (int l = k; l >= 0; l--) 301 | y_embed_data[j] += mask_data[l * mask.w]; 302 | } 303 | } 304 | } 305 | for (int i = 0; i < y_embed.c; i++) 306 | { 307 | for (int j = 0; j < y_embed.h; j++) 308 | { 309 | for (int k = 0; k < y_embed.w; k++) 310 | { 311 | y_embed[j * y_embed.w + k] = y_embed[j * y_embed.w + k]*6.283185307179586/(y_embed.row(y_embed.h-1)[k] + 0.000001); 312 | 313 | } 314 | } 315 | } 316 | for (int i = 0; i < x_embed.c; i++) 317 | { 318 | for (int j = 0; j < x_embed.h; j++) 319 | { 320 | for (int k = 0; k < x_embed.w; k++) 321 | { 322 | x_embed[j * x_embed.w + k] = x_embed[j * x_embed.w + k] * 6.283185307179586 / (x_embed[j * x_embed.w + x_embed.w - 1] + 0.000001); 323 | } 324 | } 325 | } 326 | 327 | 328 | std::vector dim_t; 329 | for (int i = 0; i < num_pos_feats; i++) 330 | dim_t.push_back(i); 331 | for (int i = 0; i < num_pos_feats; i++) 332 | { 333 | dim_t[i] = std::pow(10000.0, 2 * std::floor(dim_t[i] / 2) / num_pos_feats); 334 | } 335 | 336 | ncnn::Mat pos_x = ncnn::Mat(num_pos_feats, mask.w, mask.h); 337 | ncnn::Mat pos_y = ncnn::Mat(num_pos_feats, mask.w, mask.h); 338 | 339 | for (int i = 0; i < pos_x.c; i++) 340 | { 341 | float* pos_x_data = pos_x.channel(i); 342 | for (int j = 0; j < pos_x.h; j++) 343 | { 344 | for (int k = 0; k < pos_x.w; k++) 345 | { 346 | pos_x_data[j * pos_x.w + k] = x_embed[i * pos_x.h + j] / dim_t[k]; 347 | } 348 | } 349 | } 350 | for (int i = 0; i < pos_y.c; i++) 351 | { 352 | float* pos_y_data = pos_y.channel(i); 353 | for (int j = 0; j < pos_y.h; j++) 354 | { 355 | for (int k = 0; k < pos_y.w; k++) 356 | { 357 | pos_y_data[j * pos_y.w + k] = y_embed[i * pos_y.h + j] / dim_t[k]; 358 | } 359 | } 360 | } 361 | 362 | 363 | for (int i = 0; i < pos_x.c; i++) 364 | { 365 | float* data = pos_x.channel(i); 366 | for (int j = 0; j < pos_x.h; j++) 367 | { 368 | for (int k = 0; k < pos_x.w;) 369 | { 370 | data[j * pos_x.w + k] = std::sin(data[j * pos_x.w + k]); 371 | k += 2; 372 | } 373 | 374 | for (int k = 1; k < pos_x.w;) 375 | { 376 | data[j * pos_x.w + k] = std::cos(data[j * pos_x.w + k]); 377 | k += 2; 378 | } 379 | 380 | } 381 | } 382 | 383 | for (int i = 0; i < pos_y.c; i++) 384 | { 385 | float* data = pos_y.channel(i); 386 | for (int j = 0; j < pos_y.h; j++) 387 | { 388 | for (int k = 0; k < pos_y.w;) 389 | { 390 | data[j * pos_y.w + k] = std::sin(data[j * pos_y.w + k]); 391 | k += 2; 392 | } 393 | 394 | for (int k = 1; k < pos_y.w;) 395 | { 396 | data[j * pos_y.w + k] = std::cos(data[j * pos_y.w + k]); 397 | k += 2; 398 | } 399 | 400 | } 401 | } 402 | 403 | concat(pos_y, pos_x, pos,2); 404 | transpose(pos, pos, 4); 405 | 406 | return 0; 407 | 408 | } 409 | 410 | static void coords_grid(int h, int w,ncnn::Mat& coords) 411 | { 412 | coords.create(w, h, 2); 413 | float* ptr0 = coords.channel(0); 414 | for (int i = 0; i < h; i++) 415 | { 416 | for (int j = 0; j < w; j++) 417 | { 418 | ptr0[i * w + j] = j; 419 | } 420 | } 421 | float* ptr1 = coords.channel(1); 422 | for (int i = 0; i < h; i++) 423 | { 424 | for (int j = 0; j < w; j++) 425 | { 426 | ptr1[i * w + j] = i; 427 | } 428 | } 429 | } 430 | 431 | 432 | static float within_bounds_2d(const ncnn::Mat& data, int x, int y, int c, int H, int W) 433 | { 434 | if (y >= 0 && y < H && x >= 0 && x < W) 435 | return data.channel(c)[y * W + x]; 436 | else 437 | return 0; 438 | } 439 | //from https://github.com/open-mmlab/mmdeploy/blob/master/csrc/backend_ops/onnxruntime/grid_sample/grid_sample.cpp 440 | static ncnn::Mat grid_sample(const ncnn::Mat& input,const ncnn::Mat& grid) 441 | { 442 | int channel = input.c; 443 | int input_height = input.h; 444 | int input_width = input.w; 445 | int output_height = grid.c; 446 | int output_width = grid.h; 447 | 448 | ncnn::Mat out(input.w, input.h, input.c); 449 | out.fill(0.0f); 450 | 451 | for (int h = 0; h < output_height; h++) 452 | { 453 | for (int w = 0; w < output_width; w++) 454 | { 455 | float x = grid.channel(h)[w * 2 + 0]; 456 | float y = grid.channel(h)[w * 2 + 1]; 457 | 458 | float ix = (x + 1) * input_width * 0.5 - 0.5; 459 | float iy = (y + 1) * input_height * 0.5 - 0.5; 460 | 461 | int ix_nw = static_cast(std::floor(ix)); 462 | int iy_nw = static_cast(std::floor(iy)); 463 | int ix_ne = ix_nw + 1; 464 | int iy_ne = iy_nw; 465 | 466 | int ix_sw = ix_nw; 467 | int iy_sw = iy_nw + 1; 468 | 469 | int ix_se = ix_nw + 1; 470 | int iy_se = iy_nw + 1; 471 | 472 | float nw = (ix_se - ix) * (iy_se - iy); 473 | float ne = (ix - ix_sw) * (iy_sw - iy); 474 | float sw = (ix_ne - ix) * (iy - iy_ne); 475 | float se = (ix - ix_nw) * (iy - iy_nw); 476 | 477 | 478 | for (int c = 0; c < channel; c++) 479 | { 480 | float nw_res = within_bounds_2d(input, ix_nw, iy_nw, c, input_height, input_width); 481 | float ne_res = within_bounds_2d(input, ix_ne, iy_ne, c, input_height, input_width); 482 | float sw_res = within_bounds_2d(input, ix_sw, iy_sw, c, input_height, input_width); 483 | float se_res = within_bounds_2d(input, ix_se, iy_se, c, input_height, input_width); 484 | out.channel(c)[h * input_width + w] = nw_res * nw + ne_res * ne + sw_res * sw + se_res * se; 485 | } 486 | } 487 | } 488 | return out; 489 | } 490 | 491 | static void to_ocv(const ncnn::Mat& result, cv::Mat& out) 492 | { 493 | cv::Mat cv_result_32F = cv::Mat::zeros(cv::Size(result.w, result.h), CV_32FC3); 494 | for (int i = 0; i < result.h; i++) 495 | { 496 | for (int j = 0; j < result.w; j++) 497 | { 498 | cv_result_32F.at(i, j)[0] = result.channel(0)[i * result.w + j]; 499 | cv_result_32F.at(i, j)[1] = result.channel(1)[i * result.w + j]; 500 | cv_result_32F.at(i, j)[2] = result.channel(2)[i * result.w + j]; 501 | } 502 | } 503 | 504 | cv::Mat cv_result_8U; 505 | cv_result_32F.convertTo(cv_result_8U, CV_8UC3, 255.0, 0); 506 | 507 | cv_result_8U.copyTo(out); 508 | 509 | } 510 | 511 | ncnn::Mat seg(cv::Mat& img) 512 | { 513 | ncnn::Net seg_net; 514 | seg_net.load_param("./models/seg.param"); 515 | seg_net.load_model("./models/seg.bin"); 516 | 517 | 518 | ncnn::Mat seg_in = ncnn::Mat::from_pixels_resize(img.data, ncnn::Mat::PIXEL_BGR2RGB, img.cols, img.rows, 288, 288); 519 | const float norm_vals[3] = { 1 / 255.f, 1 / 255.f, 1 / 255.f }; 520 | seg_in.substract_mean_normalize(0, norm_vals); 521 | ncnn::Extractor ex0 = seg_net.create_extractor(); 522 | 523 | ex0.input("input", seg_in); 524 | ncnn::Mat seg_out; 525 | ex0.extract("out", seg_out); 526 | threshold(seg_out, 0.5f); 527 | 528 | binary_op(seg_in, seg_out, seg_in, 2); 529 | 530 | cv::Mat seg_result = cv::Mat(cv::Size(seg_out.w, seg_out.h), CV_32FC1, (float*)seg_out.data); 531 | cv::Mat seg_result_8u; 532 | seg_result.convertTo(seg_result_8u, CV_8UC1, 255.0, 0); 533 | cv::imwrite("seg_result.jpg", seg_result_8u); 534 | 535 | //interp(seg_in, seg_out); 536 | 537 | 538 | //cv::imshow("seg_result", seg_result_8u); 539 | //cv::waitKey(); 540 | 541 | return seg_in; 542 | } 543 | cv::Mat warp_image(const ncnn::Mat& lbl, const cv::Mat& img) 544 | { 545 | ncnn::Mat im_ori = ncnn::Mat::from_pixels(img.data, ncnn::Mat::PIXEL_BGR, img.cols, img.rows); 546 | const float norm_vals[3] = { 1 / 255.f, 1 / 255.f, 1 / 255.f }; 547 | im_ori.substract_mean_normalize(0, norm_vals); 548 | ncnn::Mat out; 549 | out = grid_sample(im_ori, lbl); 550 | 551 | cv::Mat cv_out; 552 | to_ocv(out, cv_out); 553 | 554 | return cv_out; 555 | } 556 | ncnn::Mat geo(const ncnn::Mat& seg_out,const cv::Mat& img) 557 | { 558 | ncnn::Mat coords0, coords1, coodslar; 559 | coords_grid(288, 288, coodslar); 560 | coords_grid(36, 36, coords0); 561 | coords_grid(36, 36, coords1); 562 | 563 | ncnn::Net decoder_net; 564 | decoder_net.opt.use_packing_layout = false;//there is some bug in packing layout 565 | decoder_net.load_param("./models/decoder.param"); 566 | decoder_net.load_model("./models/decoder.bin"); 567 | 568 | ncnn::Net fbnet_net; 569 | fbnet_net.load_param("./models/fbnet.param"); 570 | fbnet_net.load_model("./models/fbnet.bin"); 571 | 572 | ncnn::Net encoder_net; 573 | encoder_net.load_param("./models/encoder.param"); 574 | encoder_net.load_model("./models/encoder.bin"); 575 | 576 | ncnn::Net update_block; 577 | update_block.load_param("./models/update_block.param"); 578 | update_block.load_model("./models/update_block.bin"); 579 | 580 | ncnn::Mat posf; 581 | ncnn::Mat mask = ncnn::Mat(36, 36, 1); 582 | mask.fill(1.0f); 583 | position_embedding(mask, 128, posf); 584 | 585 | ncnn::Extractor ex1 = fbnet_net.create_extractor(); 586 | ex1.input("input", seg_out); 587 | ncnn::Mat fmap1; 588 | ex1.extract("out", fmap1); 589 | 590 | //encoder 591 | ncnn::Extractor ex2 = encoder_net.create_extractor(); 592 | ex2.input("imgf", fmap1); 593 | ex2.input("pos", posf); 594 | 595 | ncnn::Mat fmap2; 596 | ex2.extract("out", fmap2); 597 | //decoder 598 | ncnn::Extractor ex3 = decoder_net.create_extractor(); 599 | ex3.input("imgf", fmap2); 600 | ex3.input("pos", posf); 601 | 602 | ncnn::Mat fmap3; 603 | ex3.extract("out", fmap3); 604 | 605 | ncnn::Extractor ex4 = update_block.create_extractor(); 606 | ex4.input("imgf", fmap3); 607 | ex4.input("coords", coords1); 608 | 609 | ncnn::Mat fmask, coords1_out; 610 | ex4.extract("mask", fmask); 611 | ex4.extract("coords1", coords1_out); 612 | 613 | ncnn::Mat coords; 614 | binary_op(coords1_out, coords0, coords, 1);//sub 615 | scale(coords, 8.0, coords.c, coords); 616 | ncnn::Mat up_flow = im2col_cpu(coords, 3, 1, 1); 617 | 618 | ncnn::Mat up_flow1 = up_flow.reshape(1296, 1, 9, 2); 619 | 620 | ncnn::Mat up_flow11 = up_flow1.channel(0); 621 | ncnn::Mat up_flow12 = up_flow1.channel(1); 622 | ncnn::Mat fmask_up_flow11; 623 | binary_op(fmask, up_flow11, fmask_up_flow11, 2);//mul 624 | ncnn::Mat fmask_up_flow12; 625 | binary_op(fmask, up_flow12, fmask_up_flow12, 2);//mul 626 | 627 | ncnn::Mat fmask_up_flow11_sum, fmask_up_flow12_sum; 628 | reduction(fmask_up_flow11, fmask_up_flow11_sum); 629 | reduction(fmask_up_flow12, fmask_up_flow12_sum); 630 | 631 | fmask_up_flow11_sum = fmask_up_flow11_sum.reshape(1296, 64, 1, 1); 632 | ncnn::Mat fmask_up_flow11_sum_t; 633 | transpose(fmask_up_flow11_sum, fmask_up_flow11_sum_t, 2); 634 | fmask_up_flow11_sum_t = fmask_up_flow11_sum_t.reshape(36, 36, 64, 1); 635 | transpose(fmask_up_flow11_sum_t, fmask_up_flow11_sum_t, 6); 636 | fmask_up_flow11_sum_t = fmask_up_flow11_sum_t.reshape(36, 36, 8, 8); 637 | 638 | fmask_up_flow12_sum = fmask_up_flow12_sum.reshape(1296, 64, 1, 1); 639 | ncnn::Mat fmask_up_flow12_sum_t; 640 | transpose(fmask_up_flow12_sum, fmask_up_flow12_sum_t, 2); 641 | fmask_up_flow12_sum_t = fmask_up_flow12_sum_t.reshape(36, 36, 64, 1); 642 | transpose(fmask_up_flow12_sum_t, fmask_up_flow12_sum_t, 6); 643 | fmask_up_flow12_sum_t = fmask_up_flow12_sum_t.reshape(36, 36, 8, 8); 644 | 645 | transpose(fmask_up_flow11_sum_t, fmask_up_flow11_sum_t, 13); 646 | transpose(fmask_up_flow12_sum_t, fmask_up_flow12_sum_t, 13); 647 | 648 | fmask_up_flow11_sum_t = fmask_up_flow11_sum_t.reshape(288, 288, 1); 649 | fmask_up_flow12_sum_t = fmask_up_flow12_sum_t.reshape(288, 288, 1); 650 | concat(fmask_up_flow11_sum_t, fmask_up_flow12_sum_t, up_flow, 0); 651 | ncnn::Mat bm_up; 652 | binary_op(coodslar, up_flow, bm_up, 0);//add 653 | const float mean[2] = { 286.8f / 2, 286.8f / 2 }; 654 | const float norm[2] = { 2 * 0.99 / 286.8f, 2 * 0.99 / 286.8f }; 655 | bm_up.substract_mean_normalize(mean, norm); 656 | 657 | cv::Mat cv_bm0 = cv::Mat(cv::Size(bm_up.w, bm_up.h), CV_32FC1, bm_up.channel(0)); 658 | cv::Mat cv_bm1 = cv::Mat(cv::Size(bm_up.w, bm_up.h), CV_32FC1, bm_up.channel(1)); 659 | 660 | cv::resize(cv_bm0, cv_bm0, img.size(), 0, 0, 1); 661 | cv::resize(cv_bm1, cv_bm1, img.size(), 0, 0, 1); 662 | cv::blur(cv_bm0, cv_bm0, cv::Size(3, 3)); 663 | cv::blur(cv_bm1, cv_bm1, cv::Size(3, 3)); 664 | 665 | ncnn::Mat bm0 = ncnn::Mat(cv_bm0.cols * cv_bm0.rows, (void*)cv_bm0.data).reshape(cv_bm0.cols, cv_bm0.rows, 1); 666 | ncnn::Mat bm1 = ncnn::Mat(cv_bm0.cols * cv_bm0.rows, (void*)cv_bm1.data).reshape(cv_bm0.cols, cv_bm0.rows, 1); 667 | 668 | ncnn::Mat lbl; 669 | concat(bm0, bm1, lbl, 0); 670 | transpose(lbl, lbl, 3); 671 | 672 | return lbl; 673 | } 674 | 675 | int preprocess(const cv::Mat& img, cv::Mat& pad_img, int& img_pad_h, int& img_pad_w) 676 | { 677 | if (img.cols % 2 != 0) 678 | { 679 | img_pad_w = (2 - img.cols % 2); 680 | } 681 | if (img.rows % 2 != 0) 682 | { 683 | img_pad_h = (2 - img.rows % 2); 684 | } 685 | cv::copyMakeBorder(img, pad_img, 0, img_pad_h, 0, img_pad_w, cv::BORDER_CONSTANT, cv::Scalar(0)); 686 | 687 | return 0; 688 | } 689 | static ncnn::Mat depatch_embed(const ncnn::Mat& x) 690 | { 691 | int num_patches = 1024; 692 | ncnn::Mat out(128, 128, 16); 693 | int p = 4; 694 | int i = 0, j = 0; 695 | for (int k = 0; k < num_patches; k++)//1,1024,256 696 | { 697 | if (i + p > out.w) 698 | { 699 | i = 0; 700 | j += p; 701 | } 702 | ncnn::Mat temp = ncnn::Mat(256, (void*)x.row(k)).reshape(p, p, out.c);//16,4,4=256 703 | 704 | for (int c = 0; c < 16; c++)//16 705 | { 706 | float* outptr = out.channel(c); 707 | float* ptr = temp.channel(c); 708 | for (int h = i; h < i + p; h++)//4 709 | { 710 | for (int w = j; w < j + p; w++)//4 711 | { 712 | outptr[h * out.w + w] = *ptr; 713 | ptr++; 714 | } 715 | } 716 | } 717 | i += p; 718 | } 719 | 720 | return out; 721 | } 722 | static ncnn::Mat patch_embed(const ncnn::Mat& x) 723 | { 724 | ncnn::Mat out(256, 1024/*, 1*/); 725 | 726 | int p = 4; 727 | int i = 0, j = 0; 728 | for (int k = 0; k < 1024; k++) 729 | { 730 | if (i + p > x.w) 731 | { 732 | i = 0; 733 | j += p; 734 | } 735 | float* ptr = out/*.channel(0)*/.row(k); 736 | for (int c = 0; c < 16; c++)//16 737 | { 738 | for (int h = i; h < i + p; h++)//4 739 | { 740 | for (int w = j; w < j + p; w++)//4 741 | { 742 | *ptr = x.channel(c)[h * x.w + w]; 743 | ptr++; 744 | } 745 | } 746 | 747 | } 748 | i += p; 749 | } 750 | 751 | return out; 752 | } 753 | static void to_ocv(const ncnn::Mat& result, cv::Mat& out, int img_h, int img_w) 754 | { 755 | cv::Mat cv_result_32F = cv::Mat::zeros(cv::Size(result.w, result.h), CV_32FC3); 756 | for (int i = 0; i < result.h; i++) 757 | { 758 | for (int j = 0; j < result.w; j++) 759 | { 760 | cv_result_32F.at(i, j)[2] = result.channel(0)[i * result.w + j]; 761 | cv_result_32F.at(i, j)[1] = result.channel(1)[i * result.w + j]; 762 | cv_result_32F.at(i, j)[0] = result.channel(2)[i * result.w + j]; 763 | } 764 | } 765 | 766 | cv::Mat cv_result_8U; 767 | cv_result_32F.convertTo(cv_result_8U, CV_8UC3, 255.0, 0); 768 | cv::resize(cv_result_8U, out, cv::Size(img_w, img_h), 0, 0, cv::INTER_LANCZOS4); 769 | //cv_result_8U.copyTo(out); 770 | 771 | } 772 | int ill_inference(const ncnn::Net & head_net, const ncnn::Net & encoder_net, const ncnn::Net & decoder_net, 773 | const cv::Mat & img, cv::Mat & ill_result, int img_h, int img_w) 774 | { 775 | cv::Mat img1 = img.clone(); 776 | ncnn::Mat in = ncnn::Mat::from_pixels_resize(img1.data, ncnn::Mat::PIXEL_BGR2RGB, img1.cols, img1.rows, 128, 128); 777 | const float norm_vals1[3] = { 1 / 255.f, 1 / 255.f, 1 / 255.f }; 778 | in.substract_mean_normalize(0, norm_vals1); 779 | 780 | ncnn::Extractor ex0 = head_net.create_extractor(); 781 | ex0.input("ill_head_in", in); 782 | ncnn::Mat head_out; 783 | ex0.extract("ill_head_out", head_out); 784 | 785 | ncnn::Mat patch_embed_out = patch_embed(head_out); 786 | 787 | ex0.input("pos_embed_in", patch_embed_out); 788 | ncnn::Mat pos_embed_out; 789 | ex0.extract("pos_embed_out", pos_embed_out); 790 | 791 | ncnn::Extractor ex1 = encoder_net.create_extractor(); 792 | 793 | ex1.input("0", pos_embed_out); 794 | ncnn::Mat encoder_out; 795 | ex1.extract("722", encoder_out); 796 | 797 | ncnn::Extractor ex2 = decoder_net.create_extractor(); 798 | ex2.input("ill_decoder_in", encoder_out); 799 | ncnn::Mat decoder_out; 800 | ex2.extract("ill_decoder_out", decoder_out); 801 | 802 | ncnn::Mat depatch_embed_out = depatch_embed(decoder_out); 803 | 804 | ncnn::Mat out; 805 | ex2.input("ill_tail_in", depatch_embed_out); 806 | ex2.extract("ill_tail_out", out); 807 | 808 | //cv::Mat ill_result; 809 | to_ocv(out, ill_result, img_h, img_w); 810 | 811 | return 0; 812 | } 813 | 814 | int tile_process(const cv::Mat & inimage, cv::Mat & outimage) 815 | { 816 | ncnn::Net head_net; 817 | head_net.load_param("./models/ill_head.param"); 818 | head_net.load_model("./models/ill_head.bin"); 819 | 820 | ncnn::Net encoder_net; 821 | encoder_net.opt.use_packing_layout = false; 822 | encoder_net.load_param("./models/ill_encoder.param"); 823 | encoder_net.load_model("./models/ill_encoder.bin"); 824 | 825 | ncnn::Net decoder_net; 826 | decoder_net.opt.use_packing_layout = false; 827 | decoder_net.load_param("./models/ill_decoder.param"); 828 | decoder_net.load_model("./models/ill_decoder.bin"); 829 | 830 | const int tile_size = 256; 831 | const int tile_pad = 10; 832 | const int scale = 1; 833 | cv::Mat pad_inimage; 834 | int img_pad_w = 0, img_pad_h = 0; 835 | preprocess(inimage, pad_inimage, img_pad_w, img_pad_h); 836 | 837 | int tiles_x = std::ceil((float)inimage.cols / tile_size); 838 | int tiles_y = std::ceil((float)inimage.rows / tile_size); 839 | 840 | cv::Mat out = cv::Mat(cv::Size(pad_inimage.cols, pad_inimage.rows), CV_8UC3); 841 | std::vector tile_imgs; 842 | int num_threads = ncnn::get_cpu_count(); 843 | #pragma omp parallel for num_threads(num_threads) 844 | for (int i = 0; i < tiles_y; i++) 845 | { 846 | for (int j = 0; j < tiles_x; j++) 847 | { 848 | int ofs_x = j * tile_size; 849 | int ofs_y = i * tile_size; 850 | 851 | int input_start_x = ofs_x; 852 | int input_end_x = std::min(ofs_x + tile_size, pad_inimage.cols); 853 | int input_start_y = ofs_y; 854 | int input_end_y = std::min(ofs_y + tile_size, pad_inimage.rows); 855 | 856 | int input_start_x_pad = std::max(input_start_x - tile_pad, 0); 857 | int input_end_x_pad = std::min(input_end_x + tile_pad, pad_inimage.cols); 858 | int input_start_y_pad = std::max(input_start_y - tile_pad, 0); 859 | int input_end_y_pad = std::min(input_end_y + tile_pad, pad_inimage.rows); 860 | 861 | int input_tile_width = input_end_x - input_start_x; 862 | int input_tile_height = input_end_y - input_start_y; 863 | 864 | cv::Mat input_tile = pad_inimage(cv::Rect(input_start_x_pad, input_start_y_pad, input_end_x_pad - input_start_x_pad, input_end_y_pad - input_start_y_pad)).clone(); 865 | 866 | cv::Mat out_tile; 867 | ill_inference(head_net, encoder_net, decoder_net, input_tile, out_tile, input_tile.rows, input_tile.cols); 868 | //to mat 869 | 870 | int output_start_x = input_start_x * scale; 871 | int output_end_x = input_end_x * scale; 872 | int output_start_y = input_start_y * scale; 873 | int output_end_y = input_end_y * scale; 874 | 875 | int output_start_x_tile = (input_start_x - input_start_x_pad) * scale; 876 | int output_end_x_tile = output_start_x_tile + input_tile_width * scale; 877 | int output_start_y_tile = (input_start_y - input_start_y_pad) * scale; 878 | int output_end_y_tile = output_start_y_tile + input_tile_height * scale; 879 | cv::Rect tile_roi = cv::Rect(output_start_x_tile, output_start_y_tile, 880 | output_end_x_tile - output_start_x_tile, 881 | output_end_y_tile - output_start_y_tile); 882 | cv::Rect out_roi = cv::Rect(output_start_x, output_start_y, 883 | output_end_x - output_start_x, output_end_y - output_start_y); 884 | out_tile(tile_roi).copyTo(out(out_roi)); 885 | } 886 | } 887 | 888 | out(cv::Rect(0, 0, inimage.cols, inimage.rows)).copyTo(outimage); 889 | return 0; 890 | 891 | } 892 | 893 | int main(int argc, char** argv) 894 | { 895 | if (argc != 2) 896 | { 897 | fprintf(stderr, "Usage: %s [imagepath]\n", argv[0]); 898 | return -1; 899 | } 900 | 901 | const char* imagepath = argv[1]; 902 | 903 | cv::Mat img = cv::imread(imagepath, 1); 904 | if (img.empty()) 905 | { 906 | fprintf(stderr, "cv::imread %s failed\n", imagepath); 907 | return -1; 908 | } 909 | 910 | ncnn::Mat seg_out = seg(img); 911 | ncnn::Mat lbl = geo(seg_out, img); 912 | cv::Mat warp_result = warp_image(lbl, img); 913 | cv::imwrite("warp_result.jpg", warp_result); 914 | 915 | cv::Mat ill_result; 916 | tile_process(warp_result, ill_result); 917 | cv::imwrite("ill_result.jpg", ill_result); 918 | 919 | //cv::imshow("warp_result", warp_result); 920 | //cv::waitKey(); 921 | 922 | 923 | return 0; 924 | } 925 | -------------------------------------------------------------------------------- /windows_gui.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiGeChuanShu/DocTr-ncnn/b8584412038ff058e8a48acf3d39516b3dc3acdc/windows_gui.jpg --------------------------------------------------------------------------------