├── .gitignore ├── INSTALL.md ├── LICENSE ├── QUICK_START.md ├── README.md ├── feelvos ├── __init__.py ├── data │ ├── test_folder_list.txt │ └── train_folder_list.txt ├── dataset.py ├── loss.py ├── metric.py ├── models │ ├── Backbone.py │ ├── DynamicSegmentationHead.py │ ├── Embeddings.py │ ├── FEELVOS.py │ ├── Matching.py │ ├── __init__.py │ └── correlation_package │ │ ├── __init__.py │ │ ├── correlation.py │ │ ├── correlation_cuda.cc │ │ ├── correlation_cuda_kernel.cu │ │ ├── correlation_cuda_kernel.cuh │ │ └── setup.py ├── run.py ├── test.py ├── train.py ├── trainer.py ├── transform.py └── util │ ├── __init__.py │ └── toTensor.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/pydev,putty,python,pycharm,pycharm+iml,pycharm+all,visualstudio,jupyternotebooks,visualstudiocode 3 | # Edit at https://www.gitignore.io/?templates=pydev,putty,python,pycharm,pycharm+iml,pycharm+all,visualstudio,jupyternotebooks,visualstudiocode 4 | 5 | ### JupyterNotebooks ### 6 | # gitignore template for Jupyter Notebooks 7 | # website: http://jupyter.org/ 8 | 9 | .ipynb_checkpoints 10 | */.ipynb_checkpoints/* 11 | 12 | # IPython 13 | profile_default/ 14 | ipython_config.py 15 | 16 | # Remove previous ipynb_checkpoints 17 | # git rm -r .ipynb_checkpoints/ 18 | 19 | ### PuTTY ### 20 | # Private key 21 | *.ppk 22 | 23 | ### PyCharm ### 24 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 25 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 26 | 27 | # User-specific stuff 28 | .idea/**/workspace.xml 29 | .idea/**/tasks.xml 30 | .idea/**/usage.statistics.xml 31 | .idea/**/dictionaries 32 | .idea/**/shelf 33 | 34 | # Generated files 35 | .idea/**/contentModel.xml 36 | 37 | # Sensitive or high-churn files 38 | .idea/**/dataSources/ 39 | .idea/**/dataSources.ids 40 | .idea/**/dataSources.local.xml 41 | .idea/**/sqlDataSources.xml 42 | .idea/**/dynamic.xml 43 | .idea/**/uiDesigner.xml 44 | .idea/**/dbnavigator.xml 45 | 46 | # Gradle 47 | .idea/**/gradle.xml 48 | .idea/**/libraries 49 | 50 | # Gradle and Maven with auto-import 51 | # When using Gradle or Maven with auto-import, you should exclude module files, 52 | # since they will be recreated, and may cause churn. Uncomment if using 53 | # auto-import. 54 | # .idea/modules.xml 55 | # .idea/*.iml 56 | # .idea/modules 57 | # *.iml 58 | # *.ipr 59 | 60 | # CMake 61 | cmake-build-*/ 62 | 63 | # Mongo Explorer plugin 64 | .idea/**/mongoSettings.xml 65 | 66 | # File-based project format 67 | *.iws 68 | 69 | # IntelliJ 70 | out/ 71 | 72 | # mpeltonen/sbt-idea plugin 73 | .idea_modules/ 74 | 75 | # JIRA plugin 76 | atlassian-ide-plugin.xml 77 | 78 | # Cursive Clojure plugin 79 | .idea/replstate.xml 80 | 81 | # Crashlytics plugin (for Android Studio and IntelliJ) 82 | com_crashlytics_export_strings.xml 83 | crashlytics.properties 84 | crashlytics-build.properties 85 | fabric.properties 86 | 87 | # Editor-based Rest Client 88 | .idea/httpRequests 89 | 90 | # Android studio 3.1+ serialized cache file 91 | .idea/caches/build_file_checksums.ser 92 | 93 | ### PyCharm Patch ### 94 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 95 | 96 | # *.iml 97 | # modules.xml 98 | # .idea/misc.xml 99 | # *.ipr 100 | 101 | # Sonarlint plugin 102 | .idea/**/sonarlint/ 103 | 104 | # SonarQube Plugin 105 | .idea/**/sonarIssues.xml 106 | 107 | # Markdown Navigator plugin 108 | .idea/**/markdown-navigator.xml 109 | .idea/**/markdown-navigator/ 110 | 111 | ### PyCharm+all ### 112 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 113 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 114 | 115 | # User-specific stuff 116 | 117 | # Generated files 118 | 119 | # Sensitive or high-churn files 120 | 121 | # Gradle 122 | 123 | # Gradle and Maven with auto-import 124 | # When using Gradle or Maven with auto-import, you should exclude module files, 125 | # since they will be recreated, and may cause churn. Uncomment if using 126 | # auto-import. 127 | # .idea/modules.xml 128 | # .idea/*.iml 129 | # .idea/modules 130 | # *.iml 131 | # *.ipr 132 | 133 | # CMake 134 | 135 | # Mongo Explorer plugin 136 | 137 | # File-based project format 138 | 139 | # IntelliJ 140 | 141 | # mpeltonen/sbt-idea plugin 142 | 143 | # JIRA plugin 144 | 145 | # Cursive Clojure plugin 146 | 147 | # Crashlytics plugin (for Android Studio and IntelliJ) 148 | 149 | # Editor-based Rest Client 150 | 151 | # Android studio 3.1+ serialized cache file 152 | 153 | ### PyCharm+all Patch ### 154 | # Ignores the whole .idea folder and all .iml files 155 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360 156 | 157 | .idea/ 158 | 159 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023 160 | 161 | *.iml 162 | modules.xml 163 | .idea/misc.xml 164 | *.ipr 165 | 166 | # Sonarlint plugin 167 | .idea/sonarlint 168 | 169 | ### PyCharm+iml ### 170 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 171 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 172 | 173 | # User-specific stuff 174 | 175 | # Generated files 176 | 177 | # Sensitive or high-churn files 178 | 179 | # Gradle 180 | 181 | # Gradle and Maven with auto-import 182 | # When using Gradle or Maven with auto-import, you should exclude module files, 183 | # since they will be recreated, and may cause churn. Uncomment if using 184 | # auto-import. 185 | # .idea/modules.xml 186 | # .idea/*.iml 187 | # .idea/modules 188 | # *.iml 189 | # *.ipr 190 | 191 | # CMake 192 | 193 | # Mongo Explorer plugin 194 | 195 | # File-based project format 196 | 197 | # IntelliJ 198 | 199 | # mpeltonen/sbt-idea plugin 200 | 201 | # JIRA plugin 202 | 203 | # Cursive Clojure plugin 204 | 205 | # Crashlytics plugin (for Android Studio and IntelliJ) 206 | 207 | # Editor-based Rest Client 208 | 209 | # Android studio 3.1+ serialized cache file 210 | 211 | ### PyCharm+iml Patch ### 212 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023 213 | 214 | 215 | ### pydev ### 216 | .pydevproject 217 | 218 | ### Python ### 219 | # Byte-compiled / optimized / DLL files 220 | __pycache__/ 221 | *.py[cod] 222 | *$py.class 223 | 224 | # C extensions 225 | *.so 226 | 227 | # Distribution / packaging 228 | .Python 229 | build/ 230 | develop-eggs/ 231 | dist/ 232 | downloads/ 233 | eggs/ 234 | .eggs/ 235 | lib/ 236 | lib64/ 237 | parts/ 238 | sdist/ 239 | var/ 240 | wheels/ 241 | pip-wheel-metadata/ 242 | share/python-wheels/ 243 | *.egg-info/ 244 | .installed.cfg 245 | *.egg 246 | MANIFEST 247 | 248 | # PyInstaller 249 | # Usually these files are written by a python script from a template 250 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 251 | *.manifest 252 | *.spec 253 | 254 | # Installer logs 255 | pip-log.txt 256 | pip-delete-this-directory.txt 257 | 258 | # Unit test / coverage reports 259 | htmlcov/ 260 | .tox/ 261 | .nox/ 262 | .coverage 263 | .coverage.* 264 | .cache 265 | nosetests.xml 266 | coverage.xml 267 | *.cover 268 | .hypothesis/ 269 | .pytest_cache/ 270 | 271 | # Translations 272 | *.mo 273 | *.pot 274 | 275 | # Scrapy stuff: 276 | .scrapy 277 | 278 | # Sphinx documentation 279 | docs/_build/ 280 | 281 | # PyBuilder 282 | target/ 283 | 284 | # pyenv 285 | .python-version 286 | 287 | # pipenv 288 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 289 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 290 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 291 | # install all needed dependencies. 292 | #Pipfile.lock 293 | 294 | # celery beat schedule file 295 | celerybeat-schedule 296 | 297 | # SageMath parsed files 298 | *.sage.py 299 | 300 | # Spyder project settings 301 | .spyderproject 302 | .spyproject 303 | 304 | # Rope project settings 305 | .ropeproject 306 | 307 | # Mr Developer 308 | .mr.developer.cfg 309 | .project 310 | 311 | # mkdocs documentation 312 | /site 313 | 314 | # mypy 315 | .mypy_cache/ 316 | .dmypy.json 317 | dmypy.json 318 | 319 | # Pyre type checker 320 | .pyre/ 321 | 322 | ### VisualStudioCode ### 323 | .vscode/* 324 | !.vscode/settings.json 325 | !.vscode/tasks.json 326 | !.vscode/launch.json 327 | !.vscode/extensions.json 328 | 329 | ### VisualStudioCode Patch ### 330 | # Ignore all local history of files 331 | .history 332 | 333 | ### VisualStudio ### 334 | ## Ignore Visual Studio temporary files, build results, and 335 | ## files generated by popular Visual Studio add-ons. 336 | ## 337 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore 338 | 339 | # User-specific files 340 | *.rsuser 341 | *.suo 342 | *.user 343 | *.userosscache 344 | *.sln.docstates 345 | 346 | # User-specific files (MonoDevelop/Xamarin Studio) 347 | *.userprefs 348 | 349 | # Mono auto generated files 350 | mono_crash.* 351 | 352 | # Build results 353 | [Dd]ebug/ 354 | [Dd]ebugPublic/ 355 | [Rr]elease/ 356 | [Rr]eleases/ 357 | x64/ 358 | x86/ 359 | [Aa][Rr][Mm]/ 360 | [Aa][Rr][Mm]64/ 361 | bld/ 362 | [Bb]in/ 363 | [Oo]bj/ 364 | [Ll]og/ 365 | 366 | # Visual Studio 2015/2017 cache/options directory 367 | .vs/ 368 | # Uncomment if you have tasks that create the project's static files in wwwroot 369 | #wwwroot/ 370 | 371 | # Visual Studio 2017 auto generated files 372 | Generated\ Files/ 373 | 374 | # MSTest test Results 375 | [Tt]est[Rr]esult*/ 376 | [Bb]uild[Ll]og.* 377 | 378 | # NUnit 379 | *.VisualState.xml 380 | TestResult.xml 381 | nunit-*.xml 382 | 383 | # Build Results of an ATL Project 384 | [Dd]ebugPS/ 385 | [Rr]eleasePS/ 386 | dlldata.c 387 | 388 | # Benchmark Results 389 | BenchmarkDotNet.Artifacts/ 390 | 391 | # .NET Core 392 | project.lock.json 393 | project.fragment.lock.json 394 | artifacts/ 395 | 396 | # StyleCop 397 | StyleCopReport.xml 398 | 399 | # Files built by Visual Studio 400 | *_i.c 401 | *_p.c 402 | *_h.h 403 | *.ilk 404 | *.obj 405 | *.iobj 406 | *.pch 407 | *.pdb 408 | *.ipdb 409 | *.pgc 410 | *.pgd 411 | *.rsp 412 | *.sbr 413 | *.tlb 414 | *.tli 415 | *.tlh 416 | *.tmp 417 | *.tmp_proj 418 | *_wpftmp.csproj 419 | *.log 420 | *.vspscc 421 | *.vssscc 422 | .builds 423 | *.pidb 424 | *.svclog 425 | *.scc 426 | 427 | # Chutzpah Test files 428 | _Chutzpah* 429 | 430 | # Visual C++ cache files 431 | ipch/ 432 | *.aps 433 | *.ncb 434 | *.opendb 435 | *.opensdf 436 | *.sdf 437 | *.cachefile 438 | *.VC.db 439 | *.VC.VC.opendb 440 | 441 | # Visual Studio profiler 442 | *.psess 443 | *.vsp 444 | *.vspx 445 | *.sap 446 | 447 | # Visual Studio Trace Files 448 | *.e2e 449 | 450 | # TFS 2012 Local Workspace 451 | $tf/ 452 | 453 | # Guidance Automation Toolkit 454 | *.gpState 455 | 456 | # ReSharper is a .NET coding add-in 457 | _ReSharper*/ 458 | *.[Rr]e[Ss]harper 459 | *.DotSettings.user 460 | 461 | # JustCode is a .NET coding add-in 462 | .JustCode 463 | 464 | # TeamCity is a build add-in 465 | _TeamCity* 466 | 467 | # DotCover is a Code Coverage Tool 468 | *.dotCover 469 | 470 | # AxoCover is a Code Coverage Tool 471 | .axoCover/* 472 | !.axoCover/settings.json 473 | 474 | # Visual Studio code coverage results 475 | *.coverage 476 | *.coveragexml 477 | 478 | # NCrunch 479 | _NCrunch_* 480 | .*crunch*.local.xml 481 | nCrunchTemp_* 482 | 483 | # MightyMoose 484 | *.mm.* 485 | AutoTest.Net/ 486 | 487 | # Web workbench (sass) 488 | .sass-cache/ 489 | 490 | # Installshield output folder 491 | [Ee]xpress/ 492 | 493 | # DocProject is a documentation generator add-in 494 | DocProject/buildhelp/ 495 | DocProject/Help/*.HxT 496 | DocProject/Help/*.HxC 497 | DocProject/Help/*.hhc 498 | DocProject/Help/*.hhk 499 | DocProject/Help/*.hhp 500 | DocProject/Help/Html2 501 | DocProject/Help/html 502 | 503 | # Click-Once directory 504 | publish/ 505 | 506 | # Publish Web Output 507 | *.[Pp]ublish.xml 508 | *.azurePubxml 509 | # Note: Comment the next line if you want to checkin your web deploy settings, 510 | # but database connection strings (with potential passwords) will be unencrypted 511 | *.pubxml 512 | *.publishproj 513 | 514 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 515 | # checkin your Azure Web App publish settings, but sensitive information contained 516 | # in these scripts will be unencrypted 517 | PublishScripts/ 518 | 519 | # NuGet Packages 520 | *.nupkg 521 | # NuGet Symbol Packages 522 | *.snupkg 523 | # The packages folder can be ignored because of Package Restore 524 | **/[Pp]ackages/* 525 | # except build/, which is used as an MSBuild target. 526 | !**/[Pp]ackages/build/ 527 | # Uncomment if necessary however generally it will be regenerated when needed 528 | #!**/[Pp]ackages/repositories.config 529 | # NuGet v3's project.json files produces more ignorable files 530 | *.nuget.props 531 | *.nuget.targets 532 | 533 | # Microsoft Azure Build Output 534 | csx/ 535 | *.build.csdef 536 | 537 | # Microsoft Azure Emulator 538 | ecf/ 539 | rcf/ 540 | 541 | # Windows Store app package directories and files 542 | AppPackages/ 543 | BundleArtifacts/ 544 | Package.StoreAssociation.xml 545 | _pkginfo.txt 546 | *.appx 547 | *.appxbundle 548 | *.appxupload 549 | 550 | # Visual Studio cache files 551 | # files ending in .cache can be ignored 552 | *.[Cc]ache 553 | # but keep track of directories ending in .cache 554 | !?*.[Cc]ache/ 555 | 556 | # Others 557 | ClientBin/ 558 | ~$* 559 | *~ 560 | *.dbmdl 561 | *.dbproj.schemaview 562 | *.jfm 563 | *.pfx 564 | *.publishsettings 565 | orleans.codegen.cs 566 | 567 | # Including strong name files can present a security risk 568 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 569 | #*.snk 570 | 571 | # Since there are multiple workflows, uncomment next line to ignore bower_components 572 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 573 | #bower_components/ 574 | 575 | # RIA/Silverlight projects 576 | Generated_Code/ 577 | 578 | # Backup & report files from converting an old project file 579 | # to a newer Visual Studio version. Backup files are not needed, 580 | # because we have git ;-) 581 | _UpgradeReport_Files/ 582 | Backup*/ 583 | UpgradeLog*.XML 584 | UpgradeLog*.htm 585 | ServiceFabricBackup/ 586 | *.rptproj.bak 587 | 588 | # SQL Server files 589 | *.mdf 590 | *.ldf 591 | *.ndf 592 | 593 | # Business Intelligence projects 594 | *.rdl.data 595 | *.bim.layout 596 | *.bim_*.settings 597 | *.rptproj.rsuser 598 | *- [Bb]ackup.rdl 599 | *- [Bb]ackup ([0-9]).rdl 600 | *- [Bb]ackup ([0-9][0-9]).rdl 601 | 602 | # Microsoft Fakes 603 | FakesAssemblies/ 604 | 605 | # GhostDoc plugin setting file 606 | *.GhostDoc.xml 607 | 608 | # Node.js Tools for Visual Studio 609 | .ntvs_analysis.dat 610 | node_modules/ 611 | 612 | # Visual Studio 6 build log 613 | *.plg 614 | 615 | # Visual Studio 6 workspace options file 616 | *.opt 617 | 618 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 619 | *.vbw 620 | 621 | # Visual Studio LightSwitch build output 622 | **/*.HTMLClient/GeneratedArtifacts 623 | **/*.DesktopClient/GeneratedArtifacts 624 | **/*.DesktopClient/ModelManifest.xml 625 | **/*.Server/GeneratedArtifacts 626 | **/*.Server/ModelManifest.xml 627 | _Pvt_Extensions 628 | 629 | # Paket dependency manager 630 | .paket/paket.exe 631 | paket-files/ 632 | 633 | # FAKE - F# Make 634 | .fake/ 635 | 636 | # CodeRush personal settings 637 | .cr/personal 638 | 639 | # Python Tools for Visual Studio (PTVS) 640 | *.pyc 641 | 642 | # Cake - Uncomment if you are using it 643 | # tools/** 644 | # !tools/packages.config 645 | 646 | # Tabs Studio 647 | *.tss 648 | 649 | # Telerik's JustMock configuration file 650 | *.jmconfig 651 | 652 | # BizTalk build output 653 | *.btp.cs 654 | *.btm.cs 655 | *.odx.cs 656 | *.xsd.cs 657 | 658 | # OpenCover UI analysis results 659 | OpenCover/ 660 | 661 | # Azure Stream Analytics local run output 662 | ASALocalRun/ 663 | 664 | # MSBuild Binary and Structured Log 665 | *.binlog 666 | 667 | # NVidia Nsight GPU debugger configuration file 668 | *.nvuser 669 | 670 | # MFractors (Xamarin productivity tool) working folder 671 | .mfractor/ 672 | 673 | # Local History for Visual Studio 674 | .localhistory/ 675 | 676 | # BeatPulse healthcheck temp database 677 | healthchecksdb 678 | 679 | # Backup folder for Package Reference Convert tool in Visual Studio 2017 680 | MigrationBackup/ 681 | 682 | # End of https://www.gitignore.io/api/pydev,putty,python,pycharm,pycharm+iml,pycharm+all,visualstudio,jupyternotebooks,visualstudiocode 683 | 684 | ## Custom .gitignore 685 | .vscode/ 686 | 687 | feelvos/models/example/x1.png 688 | feelvos/models/example/x2.png 689 | feelvos/models/example/x3.png 690 | 691 | feelvos/data/image/ 692 | feelvos/data/mask/ 693 | 694 | runs 695 | 696 | unet 697 | *.pt 698 | 699 | example 700 | 701 | save 702 | save2 -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | ## Requirements 4 | 5 | Stated in [requirements.txt](./requirements.txt). 6 | 7 | ## Util Install 8 | 9 | ```plain 10 | python setup.py sdist 11 | python setup.py install # if using non root, python setup.py install --user 12 | ``` 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 김영한 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /QUICK_START.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/measaverb/FEELVOS/e9f497753ea52a91c6180e9f45fb87810898d309/QUICK_START.md -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FEELVOS 2 | 3 | FEELVOS implementation in PyTorch; FEELVOS: Fast End-to-End Embedding Learning for Video Object Segmentation 4 | 5 | ![FEELVOS Architecture](https://user-images.githubusercontent.com/40779417/77501200-9dd19b00-6e9a-11ea-8734-060b28735b21.png) 6 | 7 | ## LICENSE 8 | 9 | Check [LICENSE](LICENSE). 10 | 11 | ## Installation 12 | 13 | Please check Installation procedure in [INSTALL.md](INSTALL.md). 14 | 15 | ## Quick Start 16 | 17 | Follow the instructions in [QUICK_START.md](QUICK_START.md). 18 | 19 | ## Acknowledgments 20 | 21 | ```plain 22 | @article{DBLP:journals/corr/abs-1902-09513, 23 | author = {Paul Voigtlaender and 24 | Yuning Chai and 25 | Florian Schroff and 26 | Hartwig Adam and 27 | Bastian Leibe and 28 | Liang{-}Chieh Chen}, 29 | title = {{FEELVOS:} Fast End-to-End Embedding Learning for Video Object Segmentation}, 30 | journal = {CoRR}, 31 | volume = {abs/1902.09513}, 32 | year = {2019}, 33 | url = {http://arxiv.org/abs/1902.09513}, 34 | archivePrefix = {arXiv}, 35 | eprint = {1902.09513}, 36 | timestamp = {Tue, 21 May 2019 18:03:37 +0200}, 37 | biburl = {https://dblp.org/rec/journals/corr/abs-1902-09513.bib}, 38 | bibsource = {dblp computer science bibliography, https://dblp.org} 39 | } 40 | ``` 41 | -------------------------------------------------------------------------------- /feelvos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/measaverb/FEELVOS/e9f497753ea52a91c6180e9f45fb87810898d309/feelvos/__init__.py -------------------------------------------------------------------------------- /feelvos/data/test_folder_list.txt: -------------------------------------------------------------------------------- 1 | US_2020022506_m_17_170_70_004 2 | US_2020022506_m_17_170_70_008 3 | US_2020022506_m_17_170_70_010 4 | US_2020022507_m_25_170_70_002 5 | US_2020022507_m_25_170_70_004 6 | -------------------------------------------------------------------------------- /feelvos/data/train_folder_list.txt: -------------------------------------------------------------------------------- 1 | US_2020022502_m_29_170_72_002 2 | US_2020022502_m_29_170_72_004 3 | US_2020022502_m_29_170_72_008 4 | US_2020022502_m_29_170_72_010 5 | US_2020022504_m_25_174_62_002 6 | US_2020022504_m_25_174_62_004 7 | US_2020022504_m_25_174_62_008 8 | US_2020022504_m_25_174_62_010 9 | US_2020022505_m_26_179_68_002 10 | US_2020022505_m_26_179_68_004 11 | US_2020022505_m_26_179_68_008 12 | US_2020022505_m_26_179_68_010 13 | US_2020022506_m_17_170_70_002 14 | -------------------------------------------------------------------------------- /feelvos/dataset.py: -------------------------------------------------------------------------------- 1 | from itertools import combinations 2 | from cv2 import cv2 3 | import os 4 | import natsort 5 | import pandas as pd 6 | import numpy as np 7 | 8 | import torch 9 | import torchvision 10 | 11 | from torch.utils.data import Dataset, DataLoader 12 | from torchvision.transforms import ToPILImage 13 | from torchvision import transforms, utils 14 | 15 | from feelvos.transform import preprocessing 16 | 17 | 18 | class FEELVOSTriple(Dataset): 19 | def __init__(self, root='./data/', split='train', transform=None): 20 | super().__init__() 21 | self.root = root 22 | self.split = split 23 | self.transform = transform 24 | self.folder_list = [] 25 | self.items = [] 26 | 27 | folder_f = open(os.path.join(root, self.split+"_folder_list.txt"), "r") 28 | for x in folder_f: 29 | self.folder_list.append(x[:-1]) 30 | 31 | for i in range(len(self.folder_list)): 32 | tmp_list = natsort.natsorted(os.listdir(os.path.join(root, 'image', self.folder_list[i]))) 33 | for j in range(len(tmp_list) - 2): 34 | first = tmp_list[j] 35 | for k in range(len(tmp_list[j+1:])-1): 36 | comb_1 = tmp_list[k+1] 37 | comb_2 = tmp_list[k+2] 38 | self.items.append((os.path.join(self.root, 'image', self.folder_list[i], first), os.path.join(self.root, 'image', self.folder_list[i], comb_1), os.path.join(self.root, 'image', self.folder_list[i], comb_2))) 39 | 40 | def __getitem__(self, index): 41 | src = [] 42 | mask = [] 43 | seltem = self.items[index] 44 | for i in range(3): 45 | src.append(cv2.imread(seltem[i])) 46 | mask.append(cv2.imread(os.path.join(seltem[i].split('/')[1], 'mask', seltem[i].split('/')[3], seltem[i].split('/')[4]))) 47 | sample = (src, mask) 48 | if self.transform is None: 49 | pass 50 | else: 51 | sample = self.transform(*sample) 52 | if self.split == 'train': 53 | sample[0][0] = sample[1][0] 54 | sample[0][1] = sample[1][1] 55 | return sample 56 | 57 | def __len__(self): 58 | return len(self.items) 59 | 60 | 61 | if __name__ == "__main__": 62 | ds_train = FEELVOSTriple(root='./data/', split='train', transform=preprocessing) 63 | ds_test = FEELVOSTriple(root='./data/', split='test', transform=preprocessing) 64 | print("DATA LOADED") 65 | -------------------------------------------------------------------------------- /feelvos/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def dice_loss(pred, target, epsilon=1e-7, use_sigmoid=True): 4 | pred = pred.contiguous() 5 | if use_sigmoid: 6 | pred = torch.nn.Sigmoid()(pred) 7 | target = target.contiguous() 8 | intersection = (pred * target).sum(dim=2).sum(dim=2) 9 | loss = (1 - ((2. * intersection + epsilon) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + epsilon))) 10 | return loss.mean() 11 | 12 | def tversky_loss(pred, target, epsilon=1e-7): 13 | pred = pred.contiguous() 14 | target = target.contiguous() 15 | intersection = (pred * target).sum(dim=2).sum(dim=2) 16 | false_neg = (target * (1 - pred)).sum(dim=2).sum(dim=2) 17 | false_pos = ((1 - target) * pred).sum(dim=2).sum(dim=2) 18 | alpha = 0.7 19 | loss = (1 - ((intersection + epsilon)/(intersection + alpha * false_neg + (1-alpha)*false_pos + epsilon))) 20 | return loss.mean() -------------------------------------------------------------------------------- /feelvos/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def dice_coeff(pred, target, threshold=0.5, epsilon=1e-6, use_sigmoid = True): 4 | pred = pred.contiguous() 5 | if use_sigmoid: 6 | pred = torch.nn.Sigmoid()(pred) 7 | target = target.contiguous() 8 | pred = (pred > threshold).float() 9 | intersection = (pred * target).sum(dim=2).sum(dim=2) 10 | dice = (2. * intersection + epsilon) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + epsilon) 11 | return dice.mean() 12 | 13 | def tversky(pred, target, epsilon=1e-7): 14 | pred = pred.contiguous() 15 | target = target.contiguous() 16 | intersection = (pred * target).sum(dim=2).sum(dim=2) 17 | false_neg = (target * (1 - pred)).sum(dim=2).sum(dim=2) 18 | false_pos = ((1 - target) * pred).sum(dim=2).sum(dim=2) 19 | alpha = 0.7 20 | tversky = ((intersection + epsilon)/(intersection + alpha * false_neg + (1-alpha)*false_pos + epsilon)) 21 | return tversky.mean() -------------------------------------------------------------------------------- /feelvos/models/Backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class UNet(nn.Module): 7 | def __init__(self, n_ch, n_classes, bilinear=True): 8 | super(UNet, self).__init__() 9 | self.n_ch = n_ch 10 | self.n_classes = n_classes 11 | self.bilinear = bilinear 12 | 13 | self.inc = Double33Conv(n_ch, 64) 14 | 15 | self.down1 = Down(64, 128) 16 | self.down2 = Down(128, 256) 17 | self.down3 = Down(256, 512) 18 | self.down4 = Down(512, 512) 19 | self.up1 = Up(1024, 256, bilinear) 20 | self.up2 = Up(512, 128, bilinear) 21 | self.up3 = Up(256, 64, bilinear) 22 | self.up4 = Up(128, 64, bilinear) 23 | self.out = Out(64, n_classes) 24 | 25 | def forward(self, x): 26 | x1 = self.inc(x) 27 | x2 = self.down1(x1) 28 | x3 = self.down2(x2) 29 | x4 = self.down3(x3) 30 | x5 = self.down4(x4) 31 | x = self.up1(x5, x4) 32 | x = self.up2(x, x3) 33 | x = self.up3(x, x2) 34 | x = self.up4(x, x1) 35 | y = self.out(x) 36 | return y 37 | 38 | 39 | class Double33Conv(nn.Module): 40 | def __init__(self, in_ch, out_ch): 41 | super().__init__() 42 | self.double33conv = nn.Sequential( 43 | nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), 44 | nn.BatchNorm2d(out_ch), 45 | nn.ReLU(inplace=True), 46 | nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), 47 | nn.BatchNorm2d(out_ch), 48 | nn.ReLU(inplace=True) 49 | ) 50 | 51 | def forward(self, x): 52 | return self.double33conv(x) 53 | 54 | 55 | class Down(nn.Module): 56 | def __init__(self, in_ch, out_ch): 57 | super().__init__() 58 | self.down = nn.Sequential( 59 | nn.MaxPool2d(2), 60 | Double33Conv(in_ch, out_ch) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.down(x) 65 | 66 | 67 | class Up(nn.Module): 68 | def __init__(self, in_ch, out_ch, bilinear=True): 69 | super().__init__() 70 | if bilinear: 71 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 72 | else: 73 | self.up = nn.ConvTranspose2d(in_ch//2, out_ch//2, kernel_size=2, stride=2) 74 | 75 | self.conv = Double33Conv(in_ch, out_ch) 76 | 77 | 78 | def forward(self, x1, x2): 79 | x1 = self.up(x1) 80 | 81 | dy = x2.size()[2]-x1.size()[2] 82 | dx = x2.size()[3]-x1.size()[3] 83 | """ Caution: Padding dimension 84 | N, C, H, W, dx=diffence of W-value 85 | pad=(w_left,w_right,h_top,h_bottom) 86 | """ 87 | x1 = F.pad(input=x1, pad=(dx//2, dx-dx//2, dy//2, dy-dy//2)) 88 | # print('sizes',x1.size(),x2.size(),dx // 2, dx - dx//2, dy // 2, dy - dy//2) 89 | x = torch.cat([x2, x1], dim=1) 90 | return self.conv(x) 91 | 92 | 93 | class Out(nn.Module): 94 | def __init__(self, in_ch, out_ch): 95 | super().__init__() 96 | self.out = nn.Conv2d(in_ch, out_ch, kernel_size=1) 97 | 98 | def forward(self, x): 99 | return self.out(x) -------------------------------------------------------------------------------- /feelvos/models/DynamicSegmentationHead.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from feelvos.models.Embeddings import DepthwiseSeparableConv2D 4 | 5 | 6 | class DynamicSegmentationHead(nn.Module): 7 | def __init__(self, cin, cout): 8 | super(DynamicSegmentationHead, self).__init__() 9 | self.depthwise_l = DepthwiseSeparableConv2D(cin, 256, 7) 10 | self.depthwise_r = DepthwiseSeparableConv2D(256, 256, 7) 11 | self.conv = nn.Conv2d(256, cout, 1) 12 | 13 | def forward(self, x): 14 | x = self.depthwise_l(x) 15 | x = self.depthwise_r(x) 16 | x = self.depthwise_r(x) 17 | x = self.depthwise_r(x) 18 | x = nn.ReLU(inplace=True)(x) 19 | x = self.conv(x) 20 | x = nn.Softmax2d()(x) 21 | 22 | return x -------------------------------------------------------------------------------- /feelvos/models/Embeddings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from modelsummary import summary 5 | 6 | 7 | class DepthwiseSeparableConv2D(nn.Module): 8 | def __init__(self, c_in, c_out, kernel_size=1, stride=1, padding=0, dilation=1, bias=False): 9 | super(DepthwiseSeparableConv2D,self).__init__() 10 | self.conv1 = nn.Conv2d(c_in, c_in, kernel_size, stride, padding, dilation, groups=c_in, bias=bias) 11 | self.pointwise = nn.Conv2d(c_in, c_out, 1, 1, 0, 1, 1, bias=bias) 12 | 13 | def forward(self, x): 14 | x = self.conv1(x) 15 | x = self.pointwise(x) 16 | return x 17 | 18 | 19 | class PixelwiseEmbedding(nn.Module): 20 | def __init__(self, c_in, c_out_1, c_out_2): 21 | super(PixelwiseEmbedding, self).__init__() 22 | self.separable = DepthwiseSeparableConv2D(c_in=c_in, c_out=c_out_1, kernel_size=3, stride=1, padding=1) 23 | self.conv1 = nn.Conv2d(c_out_1, c_out_2, kernel_size=1, stride=1, padding=0) 24 | 25 | def forward(self, x): 26 | x = self.separable(x) 27 | x = self.conv1(x) 28 | return x 29 | -------------------------------------------------------------------------------- /feelvos/models/FEELVOS.py: -------------------------------------------------------------------------------- 1 | from cv2 import cv2 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | from modelsummary import summary 7 | from feelvos.models.Backbone import UNet 8 | from feelvos.models.Embeddings import PixelwiseEmbedding 9 | from feelvos.models.DynamicSegmentationHead import DynamicSegmentationHead 10 | from feelvos.models.Matching import global_matching, local_matching 11 | 12 | 13 | class FEELVOS(nn.Module): 14 | def __init__(self, c_in, n_classes, use_gt=True, pretrained=None): 15 | super(FEELVOS, self).__init__() 16 | self.n_classes = n_classes 17 | self.use_gt = use_gt 18 | self.backbone = None 19 | if pretrained is not None and self.backbone is None: 20 | self.backbone = UNet(c_in, n_classes) 21 | self.backbone.load_state_dict(torch.load(pretrained)) 22 | self.backbone.eval() 23 | self.embedding = PixelwiseEmbedding(n_classes, n_classes, 100) 24 | self.dsh = DynamicSegmentationHead(n_classes+1+1+1, 1) 25 | 26 | def forward(self, x_list): 27 | x1 = x_list[0] 28 | x2 = x_list[1] 29 | x3 = x_list[2] 30 | 31 | if self.use_gt == False: 32 | with torch.no_grad(): 33 | x1 = self.backbone(x1) 34 | x2 = self.backbone(x2) 35 | with torch.no_grad(): 36 | x3 = self.backbone(x3) 37 | 38 | x1_l = []; x1_e = [] 39 | x2_l = []; x2_e = [] 40 | x3_l = []; x3_e = [] 41 | gm = []; lm = [] 42 | logits = [] 43 | 44 | x1 = F.interpolate(x1, 32) 45 | x2 = F.interpolate(x2, 32) 46 | x3 = F.interpolate(x3, 32) 47 | 48 | for i in range(self.n_classes): 49 | x1_l.append(x1[:, i, :, :].unsqueeze(1)) 50 | x1_e.append(self.embedding(x1_l[i])) 51 | x2_l.append(x2[:, i, :, :].unsqueeze(1)) 52 | x2_e.append(self.embedding(x2_l[i])) 53 | x3_l.append(x3[:, i, :, :].unsqueeze(1)) 54 | x3_e.append(self.embedding(x3_l[i])) 55 | with torch.no_grad(): 56 | gm.append(global_matching(x1_e[i], x3_e[i])) 57 | lm.append(global_matching(x2_e[i], x3_e[i])) 58 | x_t = torch.cat((x3, gm[i].cuda(), lm[i].cuda(), x2_l[i]), dim=1) 59 | logits.append(self.dsh(x_t)) 60 | x = None 61 | for i in range(self.n_classes): 62 | if i == 0: 63 | x = logits[i] 64 | else: 65 | x = torch.cat((logits[i-1], logits[i]), dim=1) 66 | return x 67 | 68 | 69 | if __name__ == "__main__": 70 | device = torch.device("cuda:0") 71 | model = FEELVOS(3, 1, use_gt=False).cuda(device=device) 72 | 73 | # summary(model, torch.zeros((1, 3, 512, 512)).cuda(), show_input=True) 74 | # summary(model, torch.zeros((1, 3, 512, 512)).cuda(), show_input=False) 75 | 76 | x1 = cv2.imread('example/x2.png') 77 | x2 = cv2.imread('example/x3.png') 78 | 79 | x1 = cv2.resize(x1, dsize=(256, 256)) 80 | x1 = torchvision.transforms.ToTensor()(x1) 81 | x1 = x1.unsqueeze(0).to(device=device) 82 | 83 | x2 = cv2.resize(x2, dsize=(256, 256)) 84 | x2 = torchvision.transforms.ToTensor()(x2) 85 | x2 = x2.unsqueeze(0).to(device=device) 86 | 87 | x = torch.cat((x1, x2), dim=0) 88 | y = model(x, x, x) 89 | print(y) 90 | -------------------------------------------------------------------------------- /feelvos/models/Matching.py: -------------------------------------------------------------------------------- 1 | from cv2 import cv2 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | from torch.autograd.variable import Variable 6 | from .correlation_package.correlation import Correlation 7 | 8 | 9 | def distance(p, q): 10 | ps = torch.sum(p * p) 11 | qs = torch.sum(q * q) 12 | norm = torch.norm(ps-qs, p=2, dim=-1) 13 | res = 1 - (2 / (1 + torch.exp(norm))) 14 | return res 15 | 16 | def global_matching(x, y): 17 | output = torch.zeros(x.size(0), 1, x.size(2), x.size(3)) 18 | for i in range(x.size(0)): 19 | for j in range(x.size(2)): 20 | for k in range(x.size(3)): 21 | output[i, :, j, k] = distance(x[i, :, j, k], y[i, :, j, k]) 22 | return output 23 | 24 | def local_matching(x, y, window): 25 | output = torch.zeros(x.size(0), 1, x.size(2), x.size(3)) 26 | # out_corr = Correlation(pad_size=6, kernel_size=window, max_displacement=0, stride1=1, stride2=1, corr_multiply=1)(x, y) 27 | 28 | return output 29 | -------------------------------------------------------------------------------- /feelvos/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/measaverb/FEELVOS/e9f497753ea52a91c6180e9f45fb87810898d309/feelvos/models/__init__.py -------------------------------------------------------------------------------- /feelvos/models/correlation_package/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/measaverb/FEELVOS/e9f497753ea52a91c6180e9f45fb87810898d309/feelvos/models/correlation_package/__init__.py -------------------------------------------------------------------------------- /feelvos/models/correlation_package/correlation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules.module import Module 3 | from torch.autograd import Function 4 | import correlation_cuda 5 | 6 | class CorrelationFunction(Function): 7 | 8 | def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1): 9 | super(CorrelationFunction, self).__init__() 10 | self.pad_size = pad_size 11 | self.kernel_size = kernel_size 12 | self.max_displacement = max_displacement 13 | self.stride1 = stride1 14 | self.stride2 = stride2 15 | self.corr_multiply = corr_multiply 16 | # self.out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1) 17 | 18 | def forward(self, input1, input2): 19 | self.save_for_backward(input1, input2) 20 | 21 | with torch.cuda.device_of(input1): 22 | rbot1 = input1.new() 23 | rbot2 = input2.new() 24 | output = input1.new() 25 | 26 | correlation_cuda.forward(input1, input2, rbot1, rbot2, output, 27 | self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply) 28 | 29 | return output 30 | 31 | def backward(self, grad_output): 32 | input1, input2 = self.saved_tensors 33 | 34 | with torch.cuda.device_of(input1): 35 | rbot1 = input1.new() 36 | rbot2 = input2.new() 37 | 38 | grad_input1 = input1.new() 39 | grad_input2 = input2.new() 40 | 41 | correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2, 42 | self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply) 43 | 44 | return grad_input1, grad_input2 45 | 46 | 47 | class Correlation(Module): 48 | def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1): 49 | super(Correlation, self).__init__() 50 | self.pad_size = pad_size 51 | self.kernel_size = kernel_size 52 | self.max_displacement = max_displacement 53 | self.stride1 = stride1 54 | self.stride2 = stride2 55 | self.corr_multiply = corr_multiply 56 | 57 | def forward(self, input1, input2): 58 | 59 | result = CorrelationFunction(self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)(input1, input2) 60 | 61 | return result 62 | 63 | -------------------------------------------------------------------------------- /feelvos/models/correlation_package/correlation_cuda.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "correlation_cuda_kernel.cuh" 9 | 10 | int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output, 11 | int pad_size, 12 | int kernel_size, 13 | int max_displacement, 14 | int stride1, 15 | int stride2, 16 | int corr_type_multiply) 17 | { 18 | 19 | int batchSize = input1.size(0); 20 | 21 | int nInputChannels = input1.size(1); 22 | int inputHeight = input1.size(2); 23 | int inputWidth = input1.size(3); 24 | 25 | int kernel_radius = (kernel_size - 1) / 2; 26 | int border_radius = kernel_radius + max_displacement; 27 | 28 | int paddedInputHeight = inputHeight + 2 * pad_size; 29 | int paddedInputWidth = inputWidth + 2 * pad_size; 30 | 31 | int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1); 32 | 33 | int outputHeight = ceil(static_cast(paddedInputHeight - 2 * border_radius) / static_cast(stride1)); 34 | int outputwidth = ceil(static_cast(paddedInputWidth - 2 * border_radius) / static_cast(stride1)); 35 | 36 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 37 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 38 | output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth}); 39 | 40 | rInput1.fill_(0); 41 | rInput2.fill_(0); 42 | output.fill_(0); 43 | 44 | int success = correlation_forward_cuda_kernel( 45 | output, 46 | output.size(0), 47 | output.size(1), 48 | output.size(2), 49 | output.size(3), 50 | output.stride(0), 51 | output.stride(1), 52 | output.stride(2), 53 | output.stride(3), 54 | input1, 55 | input1.size(1), 56 | input1.size(2), 57 | input1.size(3), 58 | input1.stride(0), 59 | input1.stride(1), 60 | input1.stride(2), 61 | input1.stride(3), 62 | input2, 63 | input2.size(1), 64 | input2.stride(0), 65 | input2.stride(1), 66 | input2.stride(2), 67 | input2.stride(3), 68 | rInput1, 69 | rInput2, 70 | pad_size, 71 | kernel_size, 72 | max_displacement, 73 | stride1, 74 | stride2, 75 | corr_type_multiply, 76 | at::cuda::getCurrentCUDAStream() 77 | //at::globalContext().getCurrentCUDAStream() 78 | ); 79 | 80 | //check for errors 81 | if (!success) { 82 | AT_ERROR("CUDA call failed"); 83 | } 84 | 85 | return 1; 86 | 87 | } 88 | 89 | int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput, 90 | at::Tensor& gradInput1, at::Tensor& gradInput2, 91 | int pad_size, 92 | int kernel_size, 93 | int max_displacement, 94 | int stride1, 95 | int stride2, 96 | int corr_type_multiply) 97 | { 98 | 99 | int batchSize = input1.size(0); 100 | int nInputChannels = input1.size(1); 101 | int paddedInputHeight = input1.size(2)+ 2 * pad_size; 102 | int paddedInputWidth = input1.size(3)+ 2 * pad_size; 103 | 104 | int height = input1.size(2); 105 | int width = input1.size(3); 106 | 107 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 108 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 109 | gradInput1.resize_({batchSize, nInputChannels, height, width}); 110 | gradInput2.resize_({batchSize, nInputChannels, height, width}); 111 | 112 | rInput1.fill_(0); 113 | rInput2.fill_(0); 114 | gradInput1.fill_(0); 115 | gradInput2.fill_(0); 116 | 117 | int success = correlation_backward_cuda_kernel(gradOutput, 118 | gradOutput.size(0), 119 | gradOutput.size(1), 120 | gradOutput.size(2), 121 | gradOutput.size(3), 122 | gradOutput.stride(0), 123 | gradOutput.stride(1), 124 | gradOutput.stride(2), 125 | gradOutput.stride(3), 126 | input1, 127 | input1.size(1), 128 | input1.size(2), 129 | input1.size(3), 130 | input1.stride(0), 131 | input1.stride(1), 132 | input1.stride(2), 133 | input1.stride(3), 134 | input2, 135 | input2.stride(0), 136 | input2.stride(1), 137 | input2.stride(2), 138 | input2.stride(3), 139 | gradInput1, 140 | gradInput1.stride(0), 141 | gradInput1.stride(1), 142 | gradInput1.stride(2), 143 | gradInput1.stride(3), 144 | gradInput2, 145 | gradInput2.size(1), 146 | gradInput2.stride(0), 147 | gradInput2.stride(1), 148 | gradInput2.stride(2), 149 | gradInput2.stride(3), 150 | rInput1, 151 | rInput2, 152 | pad_size, 153 | kernel_size, 154 | max_displacement, 155 | stride1, 156 | stride2, 157 | corr_type_multiply, 158 | at::cuda::getCurrentCUDAStream() 159 | //at::globalContext().getCurrentCUDAStream() 160 | ); 161 | 162 | if (!success) { 163 | AT_ERROR("CUDA call failed"); 164 | } 165 | 166 | return 1; 167 | } 168 | 169 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 170 | m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)"); 171 | m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)"); 172 | } 173 | 174 | -------------------------------------------------------------------------------- /feelvos/models/correlation_package/correlation_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "correlation_cuda_kernel.cuh" 4 | 5 | #define CUDA_NUM_THREADS 1024 6 | #define THREADS_PER_BLOCK 32 7 | #define FULL_MASK 0xffffffff 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | using at::Half; 15 | 16 | template 17 | __forceinline__ __device__ scalar_t warpReduceSum(scalar_t val) { 18 | for (int offset = 16; offset > 0; offset /= 2) 19 | val += __shfl_down_sync(FULL_MASK, val, offset); 20 | return val; 21 | } 22 | 23 | template 24 | __forceinline__ __device__ scalar_t blockReduceSum(scalar_t val) { 25 | 26 | static __shared__ scalar_t shared[32]; 27 | int lane = threadIdx.x % warpSize; 28 | int wid = threadIdx.x / warpSize; 29 | 30 | val = warpReduceSum(val); 31 | 32 | if (lane == 0) 33 | shared[wid] = val; 34 | 35 | __syncthreads(); 36 | 37 | val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0; 38 | 39 | if (wid == 0) 40 | val = warpReduceSum(val); 41 | 42 | return val; 43 | } 44 | 45 | 46 | template 47 | __global__ void channels_first(const scalar_t* __restrict__ input, scalar_t* rinput, int channels, int height, int width, int pad_size) 48 | { 49 | 50 | // n (batch size), c (num of channels), y (height), x (width) 51 | int n = blockIdx.x; 52 | int y = blockIdx.y; 53 | int x = blockIdx.z; 54 | 55 | int ch_off = threadIdx.x; 56 | scalar_t value; 57 | 58 | int dimcyx = channels * height * width; 59 | int dimyx = height * width; 60 | 61 | int p_dimx = (width + 2 * pad_size); 62 | int p_dimy = (height + 2 * pad_size); 63 | int p_dimyxc = channels * p_dimy * p_dimx; 64 | int p_dimxc = p_dimx * channels; 65 | 66 | for (int c = ch_off; c < channels; c += THREADS_PER_BLOCK) { 67 | value = input[n * dimcyx + c * dimyx + y * width + x]; 68 | rinput[n * p_dimyxc + (y + pad_size) * p_dimxc + (x + pad_size) * channels + c] = value; 69 | } 70 | } 71 | 72 | 73 | template 74 | __global__ void correlation_forward(scalar_t* __restrict__ output, const int nOutputChannels, 75 | const int outputHeight, const int outputWidth, const scalar_t* __restrict__ rInput1, 76 | const int nInputChannels, const int inputHeight, const int inputWidth, 77 | const scalar_t* __restrict__ rInput2, const int pad_size, const int kernel_size, 78 | const int max_displacement, const int stride1, const int stride2) { 79 | 80 | int32_t pInputWidth = inputWidth + 2 * pad_size; 81 | int32_t pInputHeight = inputHeight + 2 * pad_size; 82 | 83 | int32_t kernel_rad = (kernel_size - 1) / 2; 84 | 85 | int32_t displacement_rad = max_displacement / stride2; 86 | 87 | int32_t displacement_size = 2 * displacement_rad + 1; 88 | 89 | int32_t n = blockIdx.x; 90 | int32_t y1 = blockIdx.y * stride1 + max_displacement; 91 | int32_t x1 = blockIdx.z * stride1 + max_displacement; 92 | int32_t c = threadIdx.x; 93 | 94 | int32_t pdimyxc = pInputHeight * pInputWidth * nInputChannels; 95 | 96 | int32_t pdimxc = pInputWidth * nInputChannels; 97 | 98 | int32_t pdimc = nInputChannels; 99 | 100 | int32_t tdimcyx = nOutputChannels * outputHeight * outputWidth; 101 | int32_t tdimyx = outputHeight * outputWidth; 102 | int32_t tdimx = outputWidth; 103 | 104 | int32_t nelems = kernel_size * kernel_size * pdimc; 105 | 106 | // element-wise product along channel axis 107 | for (int tj = -displacement_rad; tj <= displacement_rad; ++tj) { 108 | for (int ti = -displacement_rad; ti <= displacement_rad; ++ti) { 109 | int x2 = x1 + ti * stride2; 110 | int y2 = y1 + tj * stride2; 111 | 112 | float acc0 = 0.0f; 113 | 114 | for (int j = -kernel_rad; j <= kernel_rad; ++j) { 115 | for (int i = -kernel_rad; i <= kernel_rad; ++i) { 116 | // THREADS_PER_BLOCK 117 | #pragma unroll 118 | for (int ch = c; ch < pdimc; ch += blockDim.x) { 119 | 120 | int indx1 = n * pdimyxc + (y1 + j) * pdimxc 121 | + (x1 + i) * pdimc + ch; 122 | int indx2 = n * pdimyxc + (y2 + j) * pdimxc 123 | + (x2 + i) * pdimc + ch; 124 | acc0 += static_cast(rInput1[indx1] * rInput2[indx2]); 125 | } 126 | } 127 | } 128 | 129 | if (blockDim.x == warpSize) { 130 | __syncwarp(); 131 | acc0 = warpReduceSum(acc0); 132 | } else { 133 | __syncthreads(); 134 | acc0 = blockReduceSum(acc0); 135 | } 136 | 137 | if (threadIdx.x == 0) { 138 | 139 | int tc = (tj + displacement_rad) * displacement_size 140 | + (ti + displacement_rad); 141 | const int tindx = n * tdimcyx + tc * tdimyx + blockIdx.y * tdimx 142 | + blockIdx.z; 143 | output[tindx] = static_cast(acc0 / nelems); 144 | } 145 | } 146 | } 147 | } 148 | 149 | 150 | template 151 | __global__ void correlation_backward_input1(int item, scalar_t* gradInput1, int nInputChannels, int inputHeight, int inputWidth, 152 | const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth, 153 | const scalar_t* __restrict__ rInput2, 154 | int pad_size, 155 | int kernel_size, 156 | int max_displacement, 157 | int stride1, 158 | int stride2) 159 | { 160 | // n (batch size), c (num of channels), y (height), x (width) 161 | 162 | int n = item; 163 | int y = blockIdx.x * stride1 + pad_size; 164 | int x = blockIdx.y * stride1 + pad_size; 165 | int c = blockIdx.z; 166 | int tch_off = threadIdx.x; 167 | 168 | int kernel_rad = (kernel_size - 1) / 2; 169 | int displacement_rad = max_displacement / stride2; 170 | int displacement_size = 2 * displacement_rad + 1; 171 | 172 | int xmin = (x - kernel_rad - max_displacement) / stride1; 173 | int ymin = (y - kernel_rad - max_displacement) / stride1; 174 | 175 | int xmax = (x + kernel_rad - max_displacement) / stride1; 176 | int ymax = (y + kernel_rad - max_displacement) / stride1; 177 | 178 | if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) { 179 | // assumes gradInput1 is pre-allocated and zero filled 180 | return; 181 | } 182 | 183 | if (xmin > xmax || ymin > ymax) { 184 | // assumes gradInput1 is pre-allocated and zero filled 185 | return; 186 | } 187 | 188 | xmin = max(0,xmin); 189 | xmax = min(outputWidth-1,xmax); 190 | 191 | ymin = max(0,ymin); 192 | ymax = min(outputHeight-1,ymax); 193 | 194 | int pInputWidth = inputWidth + 2 * pad_size; 195 | int pInputHeight = inputHeight + 2 * pad_size; 196 | 197 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels; 198 | int pdimxc = pInputWidth * nInputChannels; 199 | int pdimc = nInputChannels; 200 | 201 | int tdimcyx = nOutputChannels * outputHeight * outputWidth; 202 | int tdimyx = outputHeight * outputWidth; 203 | int tdimx = outputWidth; 204 | 205 | int odimcyx = nInputChannels * inputHeight* inputWidth; 206 | int odimyx = inputHeight * inputWidth; 207 | int odimx = inputWidth; 208 | 209 | scalar_t nelems = kernel_size * kernel_size * nInputChannels; 210 | 211 | __shared__ scalar_t prod_sum[THREADS_PER_BLOCK]; 212 | prod_sum[tch_off] = 0; 213 | 214 | for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) { 215 | 216 | int i2 = (tc % displacement_size - displacement_rad) * stride2; 217 | int j2 = (tc / displacement_size - displacement_rad) * stride2; 218 | 219 | int indx2 = n * pdimyxc + (y + j2)* pdimxc + (x + i2) * pdimc + c; 220 | 221 | scalar_t val2 = rInput2[indx2]; 222 | 223 | for (int j = ymin; j <= ymax; ++j) { 224 | for (int i = xmin; i <= xmax; ++i) { 225 | int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i; 226 | prod_sum[tch_off] += gradOutput[tindx] * val2; 227 | } 228 | } 229 | } 230 | __syncthreads(); 231 | 232 | if(tch_off == 0) { 233 | scalar_t reduce_sum = 0; 234 | for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) { 235 | reduce_sum += prod_sum[idx]; 236 | } 237 | const int indx1 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size); 238 | gradInput1[indx1] = reduce_sum / nelems; 239 | } 240 | 241 | } 242 | 243 | template 244 | __global__ void correlation_backward_input2(int item, scalar_t* gradInput2, int nInputChannels, int inputHeight, int inputWidth, 245 | const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth, 246 | const scalar_t* __restrict__ rInput1, 247 | int pad_size, 248 | int kernel_size, 249 | int max_displacement, 250 | int stride1, 251 | int stride2) 252 | { 253 | // n (batch size), c (num of channels), y (height), x (width) 254 | 255 | int n = item; 256 | int y = blockIdx.x * stride1 + pad_size; 257 | int x = blockIdx.y * stride1 + pad_size; 258 | int c = blockIdx.z; 259 | 260 | int tch_off = threadIdx.x; 261 | 262 | int kernel_rad = (kernel_size - 1) / 2; 263 | int displacement_rad = max_displacement / stride2; 264 | int displacement_size = 2 * displacement_rad + 1; 265 | 266 | int pInputWidth = inputWidth + 2 * pad_size; 267 | int pInputHeight = inputHeight + 2 * pad_size; 268 | 269 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels; 270 | int pdimxc = pInputWidth * nInputChannels; 271 | int pdimc = nInputChannels; 272 | 273 | int tdimcyx = nOutputChannels * outputHeight * outputWidth; 274 | int tdimyx = outputHeight * outputWidth; 275 | int tdimx = outputWidth; 276 | 277 | int odimcyx = nInputChannels * inputHeight* inputWidth; 278 | int odimyx = inputHeight * inputWidth; 279 | int odimx = inputWidth; 280 | 281 | scalar_t nelems = kernel_size * kernel_size * nInputChannels; 282 | 283 | __shared__ scalar_t prod_sum[THREADS_PER_BLOCK]; 284 | prod_sum[tch_off] = 0; 285 | 286 | for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) { 287 | int i2 = (tc % displacement_size - displacement_rad) * stride2; 288 | int j2 = (tc / displacement_size - displacement_rad) * stride2; 289 | 290 | int xmin = (x - kernel_rad - max_displacement - i2) / stride1; 291 | int ymin = (y - kernel_rad - max_displacement - j2) / stride1; 292 | 293 | int xmax = (x + kernel_rad - max_displacement - i2) / stride1; 294 | int ymax = (y + kernel_rad - max_displacement - j2) / stride1; 295 | 296 | if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) { 297 | // assumes gradInput2 is pre-allocated and zero filled 298 | continue; 299 | } 300 | 301 | if (xmin > xmax || ymin > ymax) { 302 | // assumes gradInput2 is pre-allocated and zero filled 303 | continue; 304 | } 305 | 306 | xmin = max(0,xmin); 307 | xmax = min(outputWidth-1,xmax); 308 | 309 | ymin = max(0,ymin); 310 | ymax = min(outputHeight-1,ymax); 311 | 312 | int indx1 = n * pdimyxc + (y - j2)* pdimxc + (x - i2) * pdimc + c; 313 | scalar_t val1 = rInput1[indx1]; 314 | 315 | for (int j = ymin; j <= ymax; ++j) { 316 | for (int i = xmin; i <= xmax; ++i) { 317 | int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i; 318 | prod_sum[tch_off] += gradOutput[tindx] * val1; 319 | } 320 | } 321 | } 322 | 323 | __syncthreads(); 324 | 325 | if(tch_off == 0) { 326 | scalar_t reduce_sum = 0; 327 | for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) { 328 | reduce_sum += prod_sum[idx]; 329 | } 330 | const int indx2 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size); 331 | gradInput2[indx2] = reduce_sum / nelems; 332 | } 333 | 334 | } 335 | 336 | int correlation_forward_cuda_kernel(at::Tensor& output, 337 | int ob, 338 | int oc, 339 | int oh, 340 | int ow, 341 | int osb, 342 | int osc, 343 | int osh, 344 | int osw, 345 | 346 | at::Tensor& input1, 347 | int ic, 348 | int ih, 349 | int iw, 350 | int isb, 351 | int isc, 352 | int ish, 353 | int isw, 354 | 355 | at::Tensor& input2, 356 | int gc, 357 | int gsb, 358 | int gsc, 359 | int gsh, 360 | int gsw, 361 | 362 | at::Tensor& rInput1, 363 | at::Tensor& rInput2, 364 | int pad_size, 365 | int kernel_size, 366 | int max_displacement, 367 | int stride1, 368 | int stride2, 369 | int corr_type_multiply, 370 | cudaStream_t stream) 371 | { 372 | 373 | int batchSize = ob; 374 | 375 | int nInputChannels = ic; 376 | int inputWidth = iw; 377 | int inputHeight = ih; 378 | 379 | int nOutputChannels = oc; 380 | int outputWidth = ow; 381 | int outputHeight = oh; 382 | 383 | dim3 blocks_grid(batchSize, inputHeight, inputWidth); 384 | dim3 threads_block(THREADS_PER_BLOCK); 385 | 386 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channels_first_fwd_1", ([&] { 387 | 388 | channels_first<<>>( 389 | input1.data(), rInput1.data(), nInputChannels, inputHeight, inputWidth, pad_size); 390 | 391 | })); 392 | 393 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "channels_first_fwd_2", ([&] { 394 | 395 | channels_first<<>> ( 396 | input2.data(), rInput2.data(), nInputChannels, inputHeight, inputWidth, pad_size); 397 | 398 | })); 399 | 400 | dim3 threadsPerBlock(THREADS_PER_BLOCK); 401 | dim3 totalBlocksCorr(batchSize, outputHeight, outputWidth); 402 | 403 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "correlation_forward", ([&] { 404 | 405 | correlation_forward<<>> 406 | (output.data(), nOutputChannels, outputHeight, outputWidth, 407 | rInput1.data(), nInputChannels, inputHeight, inputWidth, 408 | rInput2.data(), 409 | pad_size, 410 | kernel_size, 411 | max_displacement, 412 | stride1, 413 | stride2); 414 | 415 | })); 416 | 417 | cudaError_t err = cudaGetLastError(); 418 | 419 | 420 | // check for errors 421 | if (err != cudaSuccess) { 422 | printf("error in correlation_forward_cuda_kernel: %s\n", cudaGetErrorString(err)); 423 | return 0; 424 | } 425 | 426 | return 1; 427 | } 428 | 429 | 430 | int correlation_backward_cuda_kernel( 431 | at::Tensor& gradOutput, 432 | int gob, 433 | int goc, 434 | int goh, 435 | int gow, 436 | int gosb, 437 | int gosc, 438 | int gosh, 439 | int gosw, 440 | 441 | at::Tensor& input1, 442 | int ic, 443 | int ih, 444 | int iw, 445 | int isb, 446 | int isc, 447 | int ish, 448 | int isw, 449 | 450 | at::Tensor& input2, 451 | int gsb, 452 | int gsc, 453 | int gsh, 454 | int gsw, 455 | 456 | at::Tensor& gradInput1, 457 | int gisb, 458 | int gisc, 459 | int gish, 460 | int gisw, 461 | 462 | at::Tensor& gradInput2, 463 | int ggc, 464 | int ggsb, 465 | int ggsc, 466 | int ggsh, 467 | int ggsw, 468 | 469 | at::Tensor& rInput1, 470 | at::Tensor& rInput2, 471 | int pad_size, 472 | int kernel_size, 473 | int max_displacement, 474 | int stride1, 475 | int stride2, 476 | int corr_type_multiply, 477 | cudaStream_t stream) 478 | { 479 | 480 | int batchSize = gob; 481 | int num = batchSize; 482 | 483 | int nInputChannels = ic; 484 | int inputWidth = iw; 485 | int inputHeight = ih; 486 | 487 | int nOutputChannels = goc; 488 | int outputWidth = gow; 489 | int outputHeight = goh; 490 | 491 | dim3 blocks_grid(batchSize, inputHeight, inputWidth); 492 | dim3 threads_block(THREADS_PER_BLOCK); 493 | 494 | 495 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "lltm_forward_cuda", ([&] { 496 | 497 | channels_first<<>>( 498 | input1.data(), 499 | rInput1.data(), 500 | nInputChannels, 501 | inputHeight, 502 | inputWidth, 503 | pad_size 504 | ); 505 | })); 506 | 507 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] { 508 | 509 | channels_first<<>>( 510 | input2.data(), 511 | rInput2.data(), 512 | nInputChannels, 513 | inputHeight, 514 | inputWidth, 515 | pad_size 516 | ); 517 | })); 518 | 519 | dim3 threadsPerBlock(THREADS_PER_BLOCK); 520 | dim3 totalBlocksCorr(inputHeight, inputWidth, nInputChannels); 521 | 522 | for (int n = 0; n < num; ++n) { 523 | 524 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] { 525 | 526 | 527 | correlation_backward_input1<<>> ( 528 | n, gradInput1.data(), nInputChannels, inputHeight, inputWidth, 529 | gradOutput.data(), nOutputChannels, outputHeight, outputWidth, 530 | rInput2.data(), 531 | pad_size, 532 | kernel_size, 533 | max_displacement, 534 | stride1, 535 | stride2); 536 | })); 537 | } 538 | 539 | for(int n = 0; n < batchSize; n++) { 540 | 541 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(rInput1.type(), "lltm_forward_cuda", ([&] { 542 | 543 | correlation_backward_input2<<>>( 544 | n, gradInput2.data(), nInputChannels, inputHeight, inputWidth, 545 | gradOutput.data(), nOutputChannels, outputHeight, outputWidth, 546 | rInput1.data(), 547 | pad_size, 548 | kernel_size, 549 | max_displacement, 550 | stride1, 551 | stride2); 552 | 553 | })); 554 | } 555 | 556 | // check for errors 557 | cudaError_t err = cudaGetLastError(); 558 | if (err != cudaSuccess) { 559 | printf("error in correlation_backward_cuda_kernel: %s\n", cudaGetErrorString(err)); 560 | return 0; 561 | } 562 | 563 | return 1; 564 | } 565 | -------------------------------------------------------------------------------- /feelvos/models/correlation_package/correlation_cuda_kernel.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | int correlation_forward_cuda_kernel(at::Tensor& output, 8 | int ob, 9 | int oc, 10 | int oh, 11 | int ow, 12 | int osb, 13 | int osc, 14 | int osh, 15 | int osw, 16 | 17 | at::Tensor& input1, 18 | int ic, 19 | int ih, 20 | int iw, 21 | int isb, 22 | int isc, 23 | int ish, 24 | int isw, 25 | 26 | at::Tensor& input2, 27 | int gc, 28 | int gsb, 29 | int gsc, 30 | int gsh, 31 | int gsw, 32 | 33 | at::Tensor& rInput1, 34 | at::Tensor& rInput2, 35 | int pad_size, 36 | int kernel_size, 37 | int max_displacement, 38 | int stride1, 39 | int stride2, 40 | int corr_type_multiply, 41 | cudaStream_t stream); 42 | 43 | 44 | int correlation_backward_cuda_kernel( 45 | at::Tensor& gradOutput, 46 | int gob, 47 | int goc, 48 | int goh, 49 | int gow, 50 | int gosb, 51 | int gosc, 52 | int gosh, 53 | int gosw, 54 | 55 | at::Tensor& input1, 56 | int ic, 57 | int ih, 58 | int iw, 59 | int isb, 60 | int isc, 61 | int ish, 62 | int isw, 63 | 64 | at::Tensor& input2, 65 | int gsb, 66 | int gsc, 67 | int gsh, 68 | int gsw, 69 | 70 | at::Tensor& gradInput1, 71 | int gisb, 72 | int gisc, 73 | int gish, 74 | int gisw, 75 | 76 | at::Tensor& gradInput2, 77 | int ggc, 78 | int ggsb, 79 | int ggsc, 80 | int ggsh, 81 | int ggsw, 82 | 83 | at::Tensor& rInput1, 84 | at::Tensor& rInput2, 85 | int pad_size, 86 | int kernel_size, 87 | int max_displacement, 88 | int stride1, 89 | int stride2, 90 | int corr_type_multiply, 91 | cudaStream_t stream); 92 | -------------------------------------------------------------------------------- /feelvos/models/correlation_package/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import torch 4 | 5 | from setuptools import setup, find_packages 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | cxx_args = ['-std=c++11'] 9 | 10 | nvcc_args = [ 11 | '-gencode', 'arch=compute_50,code=sm_50', 12 | '-gencode', 'arch=compute_52,code=sm_52', 13 | '-gencode', 'arch=compute_60,code=sm_60', 14 | '-gencode', 'arch=compute_61,code=sm_61', 15 | '-gencode', 'arch=compute_70,code=sm_70', 16 | '-gencode', 'arch=compute_70,code=compute_70' 17 | ] 18 | 19 | setup( 20 | name='correlation_cuda', 21 | ext_modules=[ 22 | CUDAExtension('correlation_cuda', [ 23 | 'correlation_cuda.cc', 24 | 'correlation_cuda_kernel.cu' 25 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) 26 | ], 27 | cmdclass={ 28 | 'build_ext': BuildExtension 29 | }) 30 | -------------------------------------------------------------------------------- /feelvos/run.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/measaverb/FEELVOS/e9f497753ea52a91c6180e9f45fb87810898d309/feelvos/run.py -------------------------------------------------------------------------------- /feelvos/test.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision 6 | 7 | from torch.utils.data import DataLoader 8 | import torchvision.transforms as transforms 9 | 10 | from feelvos.models.Backbone import UNet 11 | from feelvos.dataset import FEELVOSTriple 12 | from feelvos.transform import preprocessing 13 | 14 | 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | print(device) 17 | 18 | if __name__ == "__main__": 19 | target_folder = './data/' 20 | ds_test = FEELVOSTriple(root='./data/', split='test', transform=preprocessing) 21 | 22 | loc = './unet/weight010' 23 | model = UNet(3, 1) 24 | model.load_state_dict(torch.load(loc+'.pt')) 25 | model = model.to(device) 26 | model.eval() 27 | 28 | pick = [] 29 | for i in range(1): 30 | pick.append(random.randrange(0, 500, 1)) 31 | 32 | for i in pick: 33 | X, y = ds_test.__getitem__(i) 34 | torchvision.utils.save_image(X[0], './testimage/'+str(i)+'_X'+'.png') 35 | torchvision.utils.save_image(y[0], './testimage/'+str(i)+'_y'+'.png') 36 | X = X[0].view(1, 3, 256, 256).cuda() 37 | y_pred = model(X) 38 | torchvision.utils.save_image(y_pred, './testimage/'+loc.split('/')[-1]+'_'+str(i)+'_ypred'+'.png') 39 | -------------------------------------------------------------------------------- /feelvos/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from feelvos.dataset import FEELVOSTriple 4 | from feelvos.transform import preprocessing 5 | from feelvos.models.FEELVOS import FEELVOS 6 | from feelvos.loss import dice_loss 7 | from feelvos.metric import dice_coeff 8 | from feelvos.trainer import Trainer 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.utils.data import DataLoader 13 | from tensorboardX import SummaryWriter 14 | 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | '--batch_size', type=int, default=7 19 | ) 20 | parser.add_argument( 21 | '--epoch', type=int, default=40 22 | ) 23 | parser.add_argument( 24 | '--lr', type=float, default=0.001 25 | ) 26 | parser.add_argument( 27 | '--dataset', type=str, default='./data/' 28 | ) 29 | parser.add_argument( 30 | '--workers', type=int, default=4 31 | ) 32 | 33 | cfg = parser.parse_args() 34 | print(cfg) 35 | 36 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 37 | print(device) 38 | 39 | if __name__ == "__main__": 40 | ds_train = FEELVOSTriple(root='./data/', split='train', transform=preprocessing) 41 | ds_test = FEELVOSTriple(root='./data/', split='test', transform=preprocessing) 42 | dl_train = DataLoader(ds_train, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.workers) 43 | dl_test = DataLoader(ds_test, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.workers) 44 | print("DATA LOADED") 45 | 46 | model = FEELVOS(3, 1, use_gt=True, pretrained='./unet/weight010.pt') 47 | optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr) 48 | criterion = nn.BCELoss() 49 | success_metric = nn.BCELoss() 50 | summary = SummaryWriter() 51 | 52 | trainer = Trainer(model, criterion, optimizer, success_metric, device, None, False) 53 | fit = trainer.fit(dl_train, dl_test, num_epochs=cfg.epoch, checkpoints='./save2/'+model.__class__.__name__+'.pt') 54 | torch.save(model.state_dict(), './save/final_state_dict.pt') 55 | torch.save(model, './save/final.pt') 56 | 57 | loss_fn_name = "cross entropy" 58 | best_score = str(fit.best_score) 59 | print(f"Best loss score(loss function = {loss_fn_name}): {best_score}") 60 | -------------------------------------------------------------------------------- /feelvos/trainer.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import os 3 | import sys 4 | import tqdm 5 | import torch 6 | import torch.nn.functional as F 7 | import datetime 8 | 9 | from torch.utils.data import DataLoader 10 | from typing import Callable, Any 11 | from typing import NamedTuple, List 12 | from torchvision.utils import make_grid 13 | 14 | def to_np(x): 15 | return x.data.cpu().numpy() 16 | 17 | class BatchResult(NamedTuple): 18 | """ 19 | Represents the result of training for a single batch: the loss 20 | and score of the batch. 21 | """ 22 | loss: float 23 | score: float 24 | 25 | 26 | class EpochResult(NamedTuple): 27 | """ 28 | Represents the result of training for a single epoch: the loss per batch 29 | and accuracy on the dataset (train or test). 30 | """ 31 | losses: List[float] 32 | score: float 33 | 34 | 35 | class FitResult(NamedTuple): 36 | """ 37 | Represents the result of fitting a model for multiple epochs given a 38 | training and test (or validation) set. 39 | The losses are for each batch and the accuracies are per epoch. 40 | """ 41 | num_epochs: int 42 | train_loss: List[float] 43 | train_acc: List[float] 44 | test_loss: List[float] 45 | test_acc: List[float] 46 | best_score: float 47 | 48 | class Trainer: 49 | """ 50 | A class abstracting the various tasks of training models. 51 | Provides methods at multiple levels of granularity: 52 | - Multiple epochs (fit) 53 | - Single epoch (train_epoch/test_epoch) 54 | - Single batch (train_batch/test_batch) 55 | """ 56 | def __init__(self, 57 | model, 58 | loss_fn, 59 | optimizer, 60 | objective_metric, 61 | device="cuda", 62 | tensorboard_logger=None, 63 | tensorboard_log_images=True, 64 | experiment_prefix=None): 65 | """ 66 | Initialize the trainer. 67 | :param model: Instance of the model to train. 68 | :param loss_fn: The loss function to evaluate with. 69 | :param optimizer: The optimizer to train with. 70 | :param device: torch.device to run training on (CPU or GPU). 71 | :param tensorboard_logger: tensordboard logger. 72 | """ 73 | self.tensorboard_logger = tensorboard_logger 74 | 75 | if experiment_prefix is None: 76 | now = datetime.datetime.now() 77 | self.experiment_prefix = now.strftime(r"%Y-%m-%d\%H:%M:%S") 78 | else: 79 | self.experiment_prefix = experiment_prefix 80 | self.tensorboard_log_images = tensorboard_log_images 81 | self.model = model 82 | self.loss_fn = loss_fn 83 | self.optimizer = optimizer 84 | self.objective_metric = objective_metric 85 | self.device = device 86 | 87 | if self.device: 88 | model.to(self.device) 89 | 90 | def fit(self, dl_train: DataLoader, dl_test: DataLoader, 91 | num_epochs, checkpoints: str = None, 92 | early_stopping: int = None, 93 | print_every=1, **kw) -> FitResult: 94 | """ 95 | Trains the model for multiple epochs with a given training set, 96 | and calculates validation loss over a given validation set. 97 | :param dl_train: Dataloader for the training set. 98 | :param dl_test: Dataloader for the test set. 99 | :param num_epochs: Number of epochs to train for. 100 | :param checkpoints: Whether to save model to file every time the 101 | test set accuracy improves. Should be a string containing a 102 | filename without extension. 103 | :param early_stopping: Whether to stop training early if there is no 104 | test loss improvement for this number of epochs. 105 | :param print_every: Print progress every this number of epochs. 106 | :return: A FitResult object containing train and test losses per epoch. 107 | """ 108 | actual_num_epochs = 0 109 | train_loss, train_acc, test_loss, test_acc = [], [], [], [] 110 | 111 | best_score = None 112 | epochs_without_improvement = 0 113 | 114 | for epoch in range(num_epochs): 115 | verbose = False # pass this to train/test_epoch. 116 | if epoch % print_every == 0 or epoch == num_epochs-1: 117 | verbose = True 118 | self._print(f'--- EPOCH {epoch+1}/{num_epochs} ---', verbose) 119 | 120 | epoch_train_res = self.train_epoch(dl_train, verbose=verbose, **kw) 121 | train_loss.extend([float(x.item()) for x in epoch_train_res.losses]) 122 | train_acc.append(float(epoch_train_res.score)) 123 | 124 | epoch_test_res = self.test_epoch(dl_test, verbose=verbose, **kw) 125 | test_loss.extend([float(x.item()) for x in epoch_test_res.losses]) 126 | test_acc.append(float(epoch_test_res.score)) 127 | 128 | if best_score is None: 129 | best_score = epoch_test_res.score 130 | elif epoch_test_res.score > best_score: 131 | best_score = epoch_test_res.score 132 | if checkpoints is not None: 133 | #torch.save(self.model, checkpoints) 134 | print("**** Checkpoint saved ****") 135 | epochs_without_improvement = 0 136 | else: 137 | if early_stopping is not None and epochs_without_improvement >= early_stopping: 138 | print("Early stopping after %s with out improvement" % epochs_without_improvement) 139 | break 140 | epochs_without_improvement += 1 141 | 142 | torch.save(self.model, f'./epoch_save_model/{self.model.__class__.__name__}_ckpt_'+str(epoch)+'.pt') 143 | 144 | # ======================== 145 | 146 | return FitResult(actual_num_epochs, 147 | train_loss, train_acc, test_loss, test_acc, best_score) 148 | 149 | def train_epoch(self, dl_train: DataLoader, **kw) -> EpochResult: 150 | """ 151 | Train once over a training set (single epoch). 152 | :param dl_train: DataLoader for the training set. 153 | :param kw: Keyword args supported by _foreach_batch. 154 | :return: An EpochResult for the epoch. 155 | """ 156 | self.model.train() # set train mode 157 | return self._foreach_batch(dl_train, self.train_batch, **kw) 158 | 159 | def test_epoch(self, dl_test: DataLoader, **kw) -> EpochResult: 160 | """ 161 | Evaluate model once over a test set (single epoch). 162 | :param dl_test: DataLoader for the test set. 163 | :param kw: Keyword args supported by _foreach_batch. 164 | :return: An EpochResult for the epoch. 165 | """ 166 | self.model.eval() # set evaluation (test) mode 167 | return self._foreach_batch(dl_test, self.test_batch, **kw) 168 | 169 | def train_batch(self, index, batch_data) -> BatchResult: 170 | """ 171 | Runs a single batch forward through the model, calculates loss, 172 | preforms back-propagation and uses the optimizer to update weights. 173 | :param batch: A single batch of data from a data loader (might 174 | be a tuple of data and labels or anything else depending on 175 | the underlying dataset. 176 | :return: A BatchResult containing the value of the loss function and 177 | the number of correctly classified samples in the batch. 178 | """ 179 | 180 | X, y = batch_data 181 | if self.tensorboard_logger and self.tensorboard_log_images: 182 | B = torch.zeros_like(X.squeeze()) 183 | C = torch.stack([B, X.squeeze(), X.squeeze()]) 184 | C = C.unsqueeze(dim=0) 185 | images = C 186 | grid = make_grid(images, normalize=True, scale_each=True) 187 | self.tensorboard_logger.add_image("exp-%s/batch/test/images" % self.experiment_prefix, grid, index) 188 | if isinstance(X, tuple) or isinstance(X, list): 189 | X = [x.to(self.device) for x in X] 190 | else: 191 | X = X.to(self.device) 192 | pred = self.model(X) 193 | loss = self.loss_fn(pred, F.interpolate(y[2], 8).to(self.device)) 194 | self.optimizer.zero_grad() 195 | loss.backward() 196 | self.optimizer.step() 197 | score = self.objective_metric(pred, F.interpolate(y[2], 8).to(self.device)) 198 | if self.tensorboard_logger: 199 | self.tensorboard_logger.add_scalar('exp-%s/batch/train/loss' % self.experiment_prefix, loss, index) 200 | self.tensorboard_logger.add_scalar('exp-%s/batch/train/score' % self.experiment_prefix, score, index) 201 | if index % 300 == 0: 202 | for tag, value in self.model.named_parameters(): 203 | tag = tag.replace('.', '/') 204 | self.tensorboard_logger.add_histogram('exp-%s/batch/train/param/%s' % (self.experiment_prefix, tag), to_np(value), index) 205 | self.tensorboard_logger.add_histogram('exp-%s/batch/train/param/%s/grad' % (self.experiment_prefix, tag), to_np(value.grad), index) 206 | 207 | return BatchResult(loss, score) 208 | 209 | def test_batch(self, index, batch_data) -> BatchResult: 210 | """ 211 | Runs a single batch forward through the model and calculates loss. 212 | :param batch: A single batch of data from a data loader (might 213 | be a tuple of data and labels or anything else depending on 214 | the underlying dataset. 215 | :return: A BatchResult containing the value of the loss function and 216 | the number of correctly classified samples in the batch. 217 | """ 218 | with torch.no_grad(): 219 | X, y = batch_data 220 | if isinstance(X, tuple) or isinstance(X, list): 221 | X = [x.to(self.device) for x in X] 222 | else: 223 | X = X.to(self.device) 224 | pred = self.model(X) 225 | loss = self.loss_fn(pred, F.interpolate(y[2], 8).to(self.device)) 226 | score = self.objective_metric(pred, F.interpolate(y[2], 8).to(self.device)) 227 | if self.tensorboard_logger: 228 | self.tensorboard_logger.add_scalar('exp-%s/batch/test/loss' % self.experiment_prefix, loss, index) 229 | self.tensorboard_logger.add_scalar('exp-%s/batch/test/score' % self.experiment_prefix, score, index) 230 | return BatchResult(loss, score) 231 | 232 | @staticmethod 233 | def _print(message, verbose=True): 234 | """ Simple wrapper around print to make it conditional """ 235 | if verbose: 236 | print(message) 237 | 238 | @staticmethod 239 | def _foreach_batch(dl: DataLoader, 240 | forward_fn: Callable[[Any], BatchResult], 241 | verbose=True, max_batches=None) -> EpochResult: 242 | """ 243 | Evaluates the given forward-function on batches from the given 244 | dataloader, and prints progress along the way. 245 | """ 246 | losses = [] 247 | num_samples = len(dl.sampler) 248 | num_batches = len(dl.batch_sampler) 249 | 250 | if max_batches is not None: 251 | if max_batches < num_batches: 252 | num_batches = max_batches 253 | num_samples = num_batches * dl.batch_size 254 | 255 | if verbose: 256 | pbar_file = sys.stdout 257 | else: 258 | pbar_file = open(os.devnull, 'w') 259 | 260 | pbar_name = forward_fn.__name__ 261 | with tqdm.tqdm(desc=pbar_name, total=num_batches, 262 | file=pbar_file) as pbar: 263 | dl_iter = iter(dl) 264 | overall_score = overall_loss = avg_score = avg_loss = counter = 0 265 | min_loss = min_score = 1 266 | max_loss = max_score = 0 267 | for batch_idx in range(num_batches): 268 | counter += 1 269 | data = next(dl_iter) 270 | batch_res = forward_fn(batch_idx, data) 271 | if batch_res.loss > max_loss: 272 | max_loss = batch_res.loss 273 | if batch_res.score > max_score: 274 | max_score = batch_res.score 275 | 276 | if batch_res.loss < min_loss: 277 | min_loss = batch_res.loss 278 | if batch_res.score < min_score: 279 | min_score = batch_res.score 280 | overall_loss += batch_res.loss 281 | overall_score += batch_res.score 282 | losses.append(batch_res.loss) 283 | 284 | avg_loss = overall_loss / counter 285 | avg_score = overall_score / counter 286 | pbar.set_description(f'{pbar_name} (Avg. loss:{avg_loss:.3f}, Avg. score:{avg_score:.3f})') 287 | pbar.update() 288 | 289 | pbar.set_description(f'{pbar_name} ' 290 | f'(Avg. Loss {avg_loss:.3f}, Min {min_loss:.3f}, Max {max_loss:.3f}), ' 291 | f'(Avg. Score {avg_score:.4f}, Min {min_score:.4f}, Max {max_score:.4f})') 292 | 293 | return EpochResult(losses=losses, score=avg_score) -------------------------------------------------------------------------------- /feelvos/transform.py: -------------------------------------------------------------------------------- 1 | from cv2 import cv2 2 | import torchvision.transforms as transforms 3 | 4 | 5 | def preprocessing(images, masks): 6 | fin_images = [] 7 | fin_masks = [] 8 | image_transform = transforms.Compose( 9 | [ 10 | transforms.ToTensor(), 11 | ] 12 | ) 13 | for i in range(len(images)): 14 | tmp_i = cv2.resize(images[i], dsize=(256, 256), interpolation=cv2.INTER_AREA) 15 | tmp_m = cv2.resize(masks[i], dsize=(256, 256), interpolation=cv2.INTER_AREA) 16 | tmp_m = cv2.cvtColor(tmp_m, cv2.COLOR_BGR2GRAY) 17 | for x in range(tmp_m.shape[0]): 18 | for y in range(tmp_m.shape[1]): 19 | if tmp_m[y, x] == 29: 20 | tmp_m[y, x] = 255 21 | fin_images.append(image_transform(tmp_i).float()) 22 | fin_masks.append(image_transform(tmp_m).float()) 23 | 24 | return fin_images, fin_masks 25 | -------------------------------------------------------------------------------- /feelvos/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/measaverb/FEELVOS/e9f497753ea52a91c6180e9f45fb87810898d309/feelvos/util/__init__.py -------------------------------------------------------------------------------- /feelvos/util/toTensor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def list_to_tensor(t_list, x, y, device): 5 | for i in range(x): 6 | for j in range(y): 7 | t_list[i][j] = torch.from_numpy(t_list[i][j]).to(device=device) 8 | 9 | return t_list 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | opencv-python 4 | pillow 5 | torch 6 | torchvision 7 | modelsummary 8 | natsort -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'feelvos', 5 | version = '0.5', 6 | description = 'FEELVOS implementation in PyTorch; FEELVOS: Fast End-to-End Embedding Learning for Video Object Segmentation', 7 | author = 'Younghan Kim', 8 | author_email = 'godppkyh@mosqtech.com', 9 | install_requires= [], 10 | packages = find_packages(), 11 | python_requires = '>=3.6' 12 | ) 13 | --------------------------------------------------------------------------------