├── LICENSE ├── README.md ├── cuda ├── CMakeLists.txt ├── README.md ├── cws.cpp ├── deps │ └── argtable3 │ │ ├── argtable3.c │ │ └── argtable3.h ├── setup.py └── src │ ├── IO_helper_functions.h │ ├── addKernel.cuh │ ├── common.h │ ├── computeDCTweightsKernel.cuh │ ├── cws_A_phi.cu │ ├── cws_A_phi.h │ ├── medianfilteringKernel.cuh │ ├── prepare_cufft_warmup.cu │ ├── prepare_cufft_warmup.h │ ├── prepare_precomputations.cu │ ├── prepare_precomputations.h │ ├── prox_gKernel.cuh │ └── x_updateKernel.cuh ├── data ├── HeLa │ ├── cap.tif │ └── ref.tif ├── MCF-7 │ ├── cap.tif │ └── ref.tif ├── MLA-150-7AR-M │ ├── Zygo.dat │ ├── cap.tif │ └── ref.tif ├── blood │ ├── cap_1.tif │ ├── cap_2.tif │ ├── ref_1.tif │ └── ref_2.tif └── cheek │ ├── cap.tif │ └── ref.tif ├── matlab ├── cpu_gpu_comparison.m ├── cws.m ├── cws_gpu_wrapper.m ├── main_wavefront_solver.m ├── speckle_pattern_baseline.m └── utils │ ├── LoadMetroProData.m │ ├── poisson_solver.m │ └── tilt_removal.m └── scripts ├── Figure2.m ├── Figure3.m └── Figure4.m /LICENSE: -------------------------------------------------------------------------------- 1 | ## creative commons 2 | 3 | # Attribution-NonCommercial-ShareAlike 4.0 International 4 | 5 | Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible. 6 | 7 | ### Using Creative Commons Public Licenses 8 | 9 | Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses. 10 | 11 | * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors). 12 | 13 | * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees). 14 | 15 | ## Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License 16 | 17 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. 18 | 19 | ### Section 1 – Definitions. 20 | 21 | a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image. 22 | 23 | b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License. 24 | 25 | c. __BY-NC-SA Compatible License__ means a license listed at [creativecommons.org/compatiblelicenses](http://creativecommons.org/compatiblelicenses), approved by Creative Commons as essentially the equivalent of this Public License. 26 | 27 | d. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights. 28 | 29 | e. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements. 30 | 31 | f. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material. 32 | 33 | g. __License Elements__ means the license attributes listed in the name of a Creative Commons Public License. The License Elements of this Public License are Attribution, NonCommercial, and ShareAlike. 34 | 35 | h. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License. 36 | 37 | i. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license. 38 | 39 | h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License. 40 | 41 | i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange. 42 | 43 | j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them. 44 | 45 | k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world. 46 | 47 | l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning. 48 | 49 | ### Section 2 – Scope. 50 | 51 | a. ___License grant.___ 52 | 53 | 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: 54 | 55 | A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and 56 | 57 | B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only. 58 | 59 | 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions. 60 | 61 | 3. __Term.__ The term of this Public License is specified in Section 6(a). 62 | 63 | 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material. 64 | 65 | 5. __Downstream recipients.__ 66 | 67 | A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. 68 | 69 | B. __Additional offer from the Licensor – Adapted Material.__ Every recipient of Adapted Material from You automatically receives an offer from the Licensor to exercise the Licensed Rights in the Adapted Material under the conditions of the Adapter’s License You apply. 70 | 71 | C. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material. 72 | 73 | 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). 74 | 75 | b. ___Other rights.___ 76 | 77 | 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise. 78 | 79 | 2. Patent and trademark rights are not licensed under this Public License. 80 | 81 | 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes. 82 | 83 | ### Section 3 – License Conditions. 84 | 85 | Your exercise of the Licensed Rights is expressly made subject to the following conditions. 86 | 87 | a. ___Attribution.___ 88 | 89 | 1. If You Share the Licensed Material (including in modified form), You must: 90 | 91 | A. retain the following if it is supplied by the Licensor with the Licensed Material: 92 | 93 | i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated); 94 | 95 | ii. a copyright notice; 96 | 97 | iii. a notice that refers to this Public License; 98 | 99 | iv. a notice that refers to the disclaimer of warranties; 100 | 101 | v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable; 102 | 103 | B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and 104 | 105 | C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License. 106 | 107 | 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information. 108 | 109 | 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable. 110 | 111 | b. ___ShareAlike.___ 112 | 113 | In addition to the conditions in Section 3(a), if You Share Adapted Material You produce, the following conditions also apply. 114 | 115 | 1. The Adapter’s License You apply must be a Creative Commons license with the same License Elements, this version or later, or a BY-NC-SA Compatible License. 116 | 117 | 2. You must include the text of, or the URI or hyperlink to, the Adapter's License You apply. You may satisfy this condition in any reasonable manner based on the medium, means, and context in which You Share Adapted Material. 118 | 119 | 3. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, Adapted Material that restrict exercise of the rights granted under the Adapter's License You apply. 120 | 121 | ### Section 4 – Sui Generis Database Rights. 122 | 123 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: 124 | 125 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only; 126 | 127 | b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material, including for purposes of Section 3(b); and 128 | 129 | c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database. 130 | 131 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights. 132 | 133 | ### Section 5 – Disclaimer of Warranties and Limitation of Liability. 134 | 135 | a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__ 136 | 137 | b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__ 138 | 139 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. 140 | 141 | ### Section 6 – Term and Termination. 142 | 143 | a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically. 144 | 145 | b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: 146 | 147 | 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or 148 | 149 | 2. upon express reinstatement by the Licensor. 150 | 151 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License. 152 | 153 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License. 154 | 155 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License. 156 | 157 | ### Section 7 – Other Terms and Conditions. 158 | 159 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. 160 | 161 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License. 162 | 163 | ### Section 8 – Interpretation. 164 | 165 | a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License. 166 | 167 | b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions. 168 | 169 | c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. 170 | 171 | d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. 172 | 173 | > Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses. 174 | > 175 | > Creative Commons may be contacted at creativecommons.org -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Quantitative Phase and Intensity Microscope 2 | This is the open source repository for our paper in Scientific Reports: 3 | 4 | [**Quantitative Phase and Intensity Microscopy Using Snapshot White Light Wavefront Sensing**]() 5 | 6 | [Congli Wang](https://congliwang.github.io), [Qiang Fu](http://vccimaging.org/People/fuq/), [Xiong Dun](http://vccimaging.org/People/dunx/), and [Wolfgang Heidrich](http://vccimaging.org/People/heidriw/) 7 | 8 | King Abdullah University of Science and Technology (KAUST) 9 | 10 | ### Overview 11 | 12 | This repository contains: 13 | 14 | - An improved version of the wavefront solver in [1], implemented in MATLAB and CUDA: 15 | - For MATLAB code, simply plug & play. 16 | - For the CUDA solver, you need an NVIDIA graphics card with CUDA to compile and run. Also refer to [`./cuda/README.md`](./cuda/README.md) for how to compile the code. 17 | - Other solvers [2, 3] (our implementation). 18 | - Scripts and data for generating Figure 2, Figure 3, and Figure 4 in the paper. 19 | 20 | Our solver is not multi-scale because microscopy wavefronts are small; however based on our solver it is simple to implement such a pyramid scheme. 21 | 22 | ### Related 23 | 24 | - Sensor principle: [The Coded Wavefront Sensor](https://vccimaging.org/Publications/Wang2017CWS/) (Optics Express 2017). 25 | - An adaptive optics application using this sensor: [Megapixel Adaptive Optics]() (SIGGRAPH 2018). 26 | - Sensor simulation or old solvers, refer to repository . 27 | 28 | 29 | ### Citation 30 | 31 | ```bibtex 32 | @article{wang2019quantitative, 33 | title = {Quantitative Phase and Intensity Microscopy Using Snapshot White Light Wavefront Sensing}, 34 | author = {Wang, Congli and Fu, Qiang and Dun, Xiong and Heidrich, Wolfgang}, 35 | journal = {Scientific Reports}, 36 | volume = {9}, 37 | pages = {13795}, 38 | year = {2019}, 39 | publisher = {Nature Publishing Group} 40 | } 41 | ``` 42 | 43 | ### Contact 44 | 45 | We welcome any questions or comments. Please either open up an issue, or email to congli.wang@kaust.edu.sa. 46 | 47 | ### References 48 | 49 | [1] Congli Wang, Qiang Fu, Xiong Dun, and Wolfgang Heidrich. "Ultra-high resolution coded wavefront sensor." *Optics Express* 25.12 (2017): 13736-13746. 50 | 51 | [2] Pascal Berto, Hervé Rigneault, and Marc Guillon. "Wavefront sensing with a thin diffuser." *Optics Letters* 42.24 (2017): 5117-5120. 52 | 53 | [3] Sebastien Berujon and Eric Ziegler. "Near-field speckle-scanning-based X-ray imaging." *Physical Review A* 92.1 (2015): 013837. 54 | -------------------------------------------------------------------------------- /cuda/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.8) 2 | project(CWS LANGUAGES C CXX CUDA) 3 | 4 | # check requirements 5 | if (UNIX) 6 | find_package(OpenCV 3.0 REQUIRED) 7 | endif (UNIX) 8 | if (WIN32) 9 | find_package(OpenCV 3.0 REQUIRED PATHS C:/Program\ Files/OpenCV) 10 | endif (WIN32) 11 | 12 | # set include directories 13 | if (UNIX) 14 | set(CUDA_TARGET_INC ${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include) 15 | set(CUDA_SAMPLE_INC ${CUDA_TOOLKIT_ROOT_DIR}/samples/common/inc) 16 | endif (UNIX) 17 | if (WIN32) 18 | set(CUDA_SAMPLE_INC C:/ProgramData/NVIDIA\ Corporation/CUDA\ Samples/v10.0/common/inc) 19 | endif (WIN32) 20 | include_directories(deps src ${CUDA_TARGET_INC} ${CUDA_SAMPLE_INC} ${OpenCV_INCLUDE_DIRS}) 21 | 22 | # set link directories 23 | if (UNIX) 24 | link_directories(${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/lib) 25 | endif (UNIX) 26 | 27 | # build our static library 28 | set(STATIC_LIB_CWS_A_PHI cws_A_phi) 29 | add_library(${STATIC_LIB_CWS_A_PHI} STATIC 30 | src/cws_A_phi.cu 31 | src/prepare_cufft_warmup.cu 32 | ) 33 | 34 | # build cws 35 | set(APP cws) 36 | add_executable(${APP} cws.cpp ${CMAKE_SOURCE_DIR}/deps/argtable3/argtable3.c) 37 | target_link_libraries(${APP} PRIVATE ${STATIC_LIB_CWS_A_PHI} cufft ${OpenCV_LIBS}) 38 | install(TARGETS ${APP} RUNTIME DESTINATION ${CMAKE_SOURCE_DIR}/bin) 39 | 40 | # install 41 | install(TARGETS ${STATIC_LIB_CWS_A_PHI} 42 | ARCHIVE DESTINATION ${CMAKE_SOURCE_DIR}/lib 43 | LIBRARY DESTINATION ${CMAKE_SOURCE_DIR}/lib) 44 | install(DIRECTORY ${CMAKE_SOURCE_DIR}/src 45 | DESTINATION ${CMAKE_SOURCE_DIR}/include 46 | FILES_MATCHING PATTERN "*.h") 47 | # NEED: Copy opencv_imgcodecs343.dll, opencv_imgproc343.dll and opencv_core343.dll to ${CMAKE_SOURCE_DIR}/bin, or make the folder contains these files be in PATH (as suggested by CMake warning). 48 | -------------------------------------------------------------------------------- /cuda/README.md: -------------------------------------------------------------------------------- 1 | ### Installation 2 | 3 | The command line solver `cws` is written with C++ and CUDA, with code hierarchy managed by CMake. 4 | 5 | After successful compilation, you may call the MATLAB wrapper (`../matlab/cws_gpu_wrapper.m`) for this command line solver. See `../matlab/cpu_gpu_comparison.m` for a simple demo. Due to floating point errors, the CPU (`cws.m`) and GPU solvers (`cws`) do not produce exactly the same result, but the difference is small. 6 | 7 | #### Prerequisites 8 | 9 | - CUDA Toolkit 8.0 or above (We used CUDA 10.0); 10 | - CMake 2.8 or above (We used CMake 3.13); 11 | - OpenCV 3.0 or above (We used OpenCV 3.4.3). 12 | 13 | Make sure you complete above dependencies before proceeding. 14 | 15 | #### Linux 16 | 17 | Additional prerequisites: 18 | 19 | - GNU make; 20 | - GNU Compiler Collection (GCC). 21 | 22 | Simply run `setup.py` to get it done. You can also build it manually, for example: 23 | 24 | ```shell 25 | cd 26 | mkdir build 27 | cd build 28 | cmake .. 29 | make -j8 30 | sudo make install 31 | ``` 32 | 33 | #### Windows 34 | 35 | Here we only provide one compilation approach. You can do any modifications for your convenience at will. 36 | 37 | It has been tested using Visual Studio 2017 and CUDA 10.0 on Windows 10. 38 | 39 | Steps: 40 | 41 | - Add OpenCV to the system `PATH` variable. The specific path depends on your installation option. 42 | 43 | - Run `cmake-gui` and configure your Visual Studio Solution. If CMake failed to auto-detect the paths: 44 | 45 | - Manually find the path (depends on specific installation option), and modify `./CMakeLists.txt`; 46 | - Then, try to configure and generate the project again. 47 | 48 | If success, you will see `CWS.sln` be generated in your build folder. 49 | 50 | - In Visual Studio 2017 (opened as Administrator), build the solution. To install, in Visual Studio: 51 | 52 | - Build -> Configuration Manager -> click on Build option for INSTALL; 53 | - Then build the solution. 54 | 55 | If success, you will see new folders `bin`, `include` and `lib` appear. 56 | 57 | ### Usage 58 | 59 | Run `./cws --help` in command line prompt to see all parameters: 60 | 61 | ``` 62 | cws version 1.0 (Nov 2018) 63 | Simultaneous intensity and wavefront recovery GPU solver for the coded wavefront sensor. Solve for: 64 | 65 | min || i(x+\nabla phi) - A i_0(x) ||_2^2 + 66 | A,phi alpha || \nabla phi ||_1 + 67 | beta ( || \nabla phi ||_2^2 + || \nabla^2 phi ||_2^2 ) + 68 | gamma ( || \nabla A ||_1 + || \nabla^2 A ||_1 ) + 69 | tau ( || \nabla A ||_2^2 + || \nabla^2 A ||_2^2 ). 70 | 71 | Inputs : i_0 (reference), i (measure). 72 | Outputs: A (intensity), phi (phase). 73 | 74 | by Congli Wang, VCC Imaging @ KAUST. 75 | 76 | Usage: cws [-v] [--help] [--version] [-p ]... [-i ]... [-m ] [-m ] [-t ] [-s ] [-s ] [-l ] [-l ] [-o <.flo>] [-f ] [-f ] 77 | --help display this help and exit 78 | --version display version info and exit 79 | -p, --priors= prior weights {alpha,beta,gamma,beta} (default {0.1,0.1,100,5}) 80 | -i, --iter= iteartions {total alternating iter, A-update iter, phi-update iter} (default {3,20,20}) 81 | -m, --mu= ADMM parameters {mu_A,mu_phi} (default {0.1,100}) 82 | -t, --tol_phi= phi-update tolerance stopping criteria (default 0.05) 83 | -v, --verbose verbose output (default 0) 84 | -s, --size= output size {width,height} (default input size) 85 | -l, --L= padding size {pad_width,pad_height} (default nearest power of 2 of out_size, each in range [2, 32]) 86 | -o, --output=<.flo> save output file (intensity A & wavefront phi) as *.flo file (default "./out.flo") 87 | -f, --files= input file names (reference & measurement) 88 | ``` 89 | -------------------------------------------------------------------------------- /cuda/cws.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "argtable3/argtable3.h" 3 | #include "common.h" 4 | #include "cws_A_phi.h" 5 | #include "IO_helper_functions.h" 6 | 7 | // default: disable demo 8 | bool isdemo = false; 9 | 10 | bool read_rep_images(const char *image_name, cv::Mat &img, int *size = NULL, double max = 255.0, double min = 0.0) 11 | { 12 | cv::Mat temp = cv::imread(image_name, CV_LOAD_IMAGE_GRAYSCALE); 13 | if (temp.empty()) 14 | { 15 | printf(image_name, "%s is not found; Program exits.\n"); 16 | return true; 17 | } 18 | 19 | // set region of interest 20 | if (size != NULL) 21 | { 22 | int L_w = (temp.cols - size[0]) / 2; 23 | int L_h = (temp.rows - size[1]) / 2; 24 | temp = temp(cv::Rect(L_w, L_h, size[0], size[1])); 25 | } 26 | temp.convertTo(img, CV_32F); 27 | 28 | // convert to approximately [0 255] 29 | img.convertTo(img, CV_32F, 255.0f / (max-min)); 30 | 31 | return false; 32 | } 33 | 34 | 35 | bool is_size_invalid(int s, int s_ref) 36 | { 37 | if (0 < s && s <= s_ref && (s % 2 == 0)) return false; 38 | else return true; 39 | } 40 | 41 | 42 | int cws(double *priors, int *iter, double *mu, double phi_tol, int v, 43 | int *out_size, int nsize, int *L, int Lsize, const char *outfile, const char **infiles, int ninfiles) 44 | { 45 | // remove CUDA timing latency 46 | cudaFree(0); 47 | cudaSetDevice(0); 48 | 49 | // images 50 | cv::Mat img_ref, img_cap; 51 | char image_name[80]; 52 | 53 | // read images 54 | if (read_rep_images(infiles[0], img_ref, out_size)) return -1; 55 | if (read_rep_images(infiles[1], img_cap, out_size)) return -1; 56 | 57 | // output size check 58 | if (is_size_invalid(out_size[0], img_ref.cols) || is_size_invalid(out_size[1], img_ref.rows)) 59 | { 60 | out_size[0] = img_ref.cols; 61 | out_size[1] = img_ref.rows; 62 | printf("out_size wrong or unspecified; Set as input image size = [%d, %d].\n\n", out_size[0], out_size[1]); 63 | } 64 | 65 | // define contatiner variables 66 | cv::Mat A, phi; 67 | opt_algo para_algo; 68 | 69 | // tradeoff parameters 70 | para_algo.alpha = priors[0]; 71 | para_algo.beta = priors[1]; 72 | para_algo.gamma = priors[2]; 73 | para_algo.tau = priors[3]; 74 | 75 | // if verbose for sub-update energy report 76 | para_algo.isverbose = v > 0 ? true : false; 77 | 78 | // alternating iterations 79 | para_algo.iter = iter[0]; 80 | para_algo.A_iter = iter[1]; 81 | para_algo.phi_iter = iter[2]; 82 | 83 | // ADMM parameters 84 | para_algo.mu_A = mu[0]; 85 | para_algo.mu_phi = mu[1]; 86 | 87 | // tolerance for incremental phase 88 | para_algo.phi_tol = phi_tol; 89 | 90 | // set L 91 | para_algo.L.width = L[0]; 92 | para_algo.L.height = L[1]; 93 | 94 | // run the solver 95 | cws_A_phi(img_ref, img_cap, A, phi, para_algo); 96 | 97 | // save result 98 | WriteFloFile(outfile, img_ref.cols, img_ref.rows, A.ptr(0), phi.ptr(0)); 99 | 100 | // return 101 | cudaDeviceReset(); 102 | return 0; 103 | } 104 | 105 | 106 | void print_solver_info(char *progname) 107 | { 108 | printf("%s version 1.0 (Nov 2018) \n", progname); 109 | printf("Simultaneous intensity and wavefront recovery GPU solver for the coded wavefront sensor. Solve for:\n\n"); 110 | printf(" min || i(x+\\nabla phi) - A i_0(x) ||_2^2 +\n"); 111 | printf("A,phi alpha || \\nabla phi ||_1 +\n"); 112 | printf(" beta ( || \\nabla phi ||_2^2 + || \\nabla^2 phi ||_2^2 ) +\n"); 113 | printf(" gamma ( || \\nabla A ||_1 + || \\nabla^2 A ||_1 ) +\n"); 114 | printf(" tau ( || \\nabla A ||_2^2 + || \\nabla^2 A ||_2^2 ).\n\n"); 115 | printf("Inputs : i_0 (reference), i (measure).\n"); 116 | printf("Outputs: A (intensity), phi (phase).\n"); 117 | printf("\n"); 118 | printf("by Congli Wang, VCC Imaging @ KAUST.\n"); 119 | } 120 | 121 | 122 | int main(int argc, char **argv) 123 | { 124 | // help & version info 125 | struct arg_lit *help = arg_litn(NULL, "help", 0, 1, "display this help and exit"); 126 | struct arg_lit *version = arg_litn(NULL, "version", 0, 1, "display version info and exit"); 127 | 128 | // model: tradeoff parameters 129 | struct arg_dbl *priors = arg_dbln("p", "priors", "", 0, 4, "prior weights {alpha,beta,gamma,beta} (default {0.1,0.1,100,5})"); 130 | 131 | // algorithm: ADMM parameters 132 | struct arg_int *iter = arg_intn("i", "iter", "", 0, 3, "iteartions {total alternating iter, A-update iter, phi-update iter} (default {3,20,20})"); 133 | struct arg_dbl *mu = arg_dbln("m", "mu", "", 0, 2, "ADMM parameters {mu_A,mu_phi} (default {0.1,100})"); 134 | struct arg_dbl *phi_tol = arg_dbl0("t", "tol_phi","", "phi-update tolerance stopping criteria (default 0.05)"); 135 | struct arg_lit *verbose = arg_litn("v", "verbose", 0, 1, "verbose output (default 0)"); 136 | struct arg_int *out_size = arg_intn("s", "size", "", 0, 2, "output size {width,height} (default input size)"); 137 | struct arg_int *L = arg_intn("l", "L", "", 0, 2, "padding size {pad_width,pad_height} (default nearest power of 2 of out_size, each in range [2, 32])"); 138 | 139 | // input & output files 140 | struct arg_file *out = arg_filen("o", "output", "<.flo>", 0, 1, "save output file (intensity A & wavefront phi) as *.flo file (default \"./out.flo\")"); 141 | struct arg_file *file = arg_filen("f", "files", "", 0, 2, "input file names (reference & measurement)"); 142 | struct arg_end *end = arg_end(20); 143 | 144 | /* the global arg_xxx structs are initialised within the argtable */ 145 | void *argtable[] = {help, version, priors, iter, mu, phi_tol, verbose, out_size, L, out, file, end}; 146 | 147 | int exitcode = 0; 148 | char progname[] = "cws"; 149 | 150 | /* set any command line default values prior to parsing */ 151 | priors ->dval[0] = 0.1; // alpha 152 | priors ->dval[1] = 0.1; // beta 153 | priors ->dval[2] = 100.0; // gamma 154 | priors ->dval[3] = 5.0; // tau 155 | iter ->ival[0] = 3; // total alternating iterations 156 | iter ->ival[1] = 20; // ADMM iterations for A-update 157 | iter ->ival[2] = 20; // ADMM iterations for phi-update 158 | mu ->dval[0] = 0.1; // ADMM parameter mu for A-update 159 | mu ->dval[1] = 100.0; // ADMM parameter mu for phi-update 160 | phi_tol->dval[0] = 0.05; // tolerance stopping criteria for phi-update 161 | verbose->count = 0; 162 | L ->ival[0] = 1; // L pad size for width 163 | L ->ival[1] = 1; // L pad size for height 164 | out->filename[0] = "./out.flo"; 165 | 166 | // parse arguments 167 | int nerrors; 168 | nerrors = arg_parse(argc,argv,argtable); 169 | 170 | /* special case: '--help' takes precedence over error reporting */ 171 | if (help->count > 0) 172 | { 173 | print_solver_info(progname); 174 | printf("\n"); 175 | printf("Usage: %s", progname); 176 | arg_print_syntax(stdout, argtable, "\n"); 177 | arg_print_glossary(stdout, argtable, " %-25s %s \n"); 178 | exitcode = 0; 179 | goto exit; 180 | } 181 | 182 | /* special case: '--version' takes precedence error reporting */ 183 | if (version->count > 0) 184 | { 185 | print_solver_info(progname); 186 | exitcode = 0; 187 | goto exit; 188 | } 189 | 190 | /* If the parser returned any errors then display them and exit */ 191 | if (nerrors > 0) 192 | { 193 | /* Display the error details contained in the arg_end struct.*/ 194 | arg_print_errors(stdout, end, progname); 195 | #if defined(_WIN32) || defined(_WIN64) 196 | printf("Try '%s --help' for more information.\n", progname); 197 | #elif defined(__unix) 198 | printf("Try './%s --help' for more information.\n", progname); 199 | #endif 200 | exitcode = 1; 201 | goto exit; 202 | } 203 | 204 | /* if no input files specified, use the test data */ 205 | if (file->count < 2) 206 | { 207 | printf("Number of input files < 2; use test data instead; Set out_size as [992 992]\n\n"); 208 | out_size->ival[0] = 992; 209 | out_size->ival[1] = 992; 210 | file->filename[0] = "../tests/solver_accuracy/data/img_reference.png"; 211 | file->filename[1] = "../tests/solver_accuracy/data/img_capture.png"; 212 | isdemo = true; 213 | exitcode = 1; 214 | } 215 | 216 | /* Command line parsing is complete, do the main processing */ 217 | exitcode = cws(priors->dval, iter->ival, mu->dval, phi_tol->dval[0], verbose->count, 218 | out_size->ival, out_size->count, L->ival, L->count, out->filename[0], file->filename, file->count); 219 | 220 | exit: 221 | /* deallocate each non-null entry in argtable[] */ 222 | arg_freetable(argtable, sizeof(argtable) / sizeof(argtable[0])); 223 | 224 | return exitcode; 225 | } -------------------------------------------------------------------------------- /cuda/deps/argtable3/argtable3.h: -------------------------------------------------------------------------------- 1 | /******************************************************************************* 2 | * This file is part of the argtable3 library. 3 | * 4 | * Copyright (C) 1998-2001,2003-2011,2013 Stewart Heitmann 5 | * 6 | * All rights reserved. 7 | * 8 | * Redistribution and use in source and binary forms, with or without 9 | * modification, are permitted provided that the following conditions are met: 10 | * * Redistributions of source code must retain the above copyright 11 | * notice, this list of conditions and the following disclaimer. 12 | * * Redistributions in binary form must reproduce the above copyright 13 | * notice, this list of conditions and the following disclaimer in the 14 | * documentation and/or other materials provided with the distribution. 15 | * * Neither the name of STEWART HEITMANN nor the names of its contributors 16 | * may be used to endorse or promote products derived from this software 17 | * without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 22 | * ARE DISCLAIMED. IN NO EVENT SHALL STEWART HEITMANN BE LIABLE FOR ANY DIRECT, 23 | * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 24 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 25 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 26 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 28 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | ******************************************************************************/ 30 | 31 | #ifndef ARGTABLE3 32 | #define ARGTABLE3 33 | 34 | #include /* FILE */ 35 | #include /* struct tm */ 36 | 37 | #ifdef __cplusplus 38 | extern "C" { 39 | #endif 40 | 41 | #define ARG_REX_ICASE 1 42 | 43 | /* bit masks for arg_hdr.flag */ 44 | enum 45 | { 46 | ARG_TERMINATOR=0x1, 47 | ARG_HASVALUE=0x2, 48 | ARG_HASOPTVALUE=0x4 49 | }; 50 | 51 | typedef void (arg_resetfn)(void *parent); 52 | typedef int (arg_scanfn)(void *parent, const char *argval); 53 | typedef int (arg_checkfn)(void *parent); 54 | typedef void (arg_errorfn)(void *parent, FILE *fp, int error, const char *argval, const char *progname); 55 | 56 | 57 | /* 58 | * The arg_hdr struct defines properties that are common to all arg_xxx structs. 59 | * The argtable library requires each arg_xxx struct to have an arg_hdr 60 | * struct as its first data member. 61 | * The argtable library functions then use this data to identify the 62 | * properties of the command line option, such as its option tags, 63 | * datatype string, and glossary strings, and so on. 64 | * Moreover, the arg_hdr struct contains pointers to custom functions that 65 | * are provided by each arg_xxx struct which perform the tasks of parsing 66 | * that particular arg_xxx arguments, performing post-parse checks, and 67 | * reporting errors. 68 | * These functions are private to the individual arg_xxx source code 69 | * and are the pointer to them are initiliased by that arg_xxx struct's 70 | * constructor function. The user could alter them after construction 71 | * if desired, but the original intention is for them to be set by the 72 | * constructor and left unaltered. 73 | */ 74 | struct arg_hdr 75 | { 76 | char flag; /* Modifier flags: ARG_TERMINATOR, ARG_HASVALUE. */ 77 | const char *shortopts; /* String defining the short options */ 78 | const char *longopts; /* String defiing the long options */ 79 | const char *datatype; /* Description of the argument data type */ 80 | const char *glossary; /* Description of the option as shown by arg_print_glossary function */ 81 | int mincount; /* Minimum number of occurences of this option accepted */ 82 | int maxcount; /* Maximum number of occurences if this option accepted */ 83 | void *parent; /* Pointer to parent arg_xxx struct */ 84 | arg_resetfn *resetfn; /* Pointer to parent arg_xxx reset function */ 85 | arg_scanfn *scanfn; /* Pointer to parent arg_xxx scan function */ 86 | arg_checkfn *checkfn; /* Pointer to parent arg_xxx check function */ 87 | arg_errorfn *errorfn; /* Pointer to parent arg_xxx error function */ 88 | void *priv; /* Pointer to private header data for use by arg_xxx functions */ 89 | }; 90 | 91 | struct arg_rem 92 | { 93 | struct arg_hdr hdr; /* The mandatory argtable header struct */ 94 | }; 95 | 96 | struct arg_lit 97 | { 98 | struct arg_hdr hdr; /* The mandatory argtable header struct */ 99 | int count; /* Number of matching command line args */ 100 | }; 101 | 102 | struct arg_int 103 | { 104 | struct arg_hdr hdr; /* The mandatory argtable header struct */ 105 | int count; /* Number of matching command line args */ 106 | int *ival; /* Array of parsed argument values */ 107 | }; 108 | 109 | struct arg_dbl 110 | { 111 | struct arg_hdr hdr; /* The mandatory argtable header struct */ 112 | int count; /* Number of matching command line args */ 113 | double *dval; /* Array of parsed argument values */ 114 | }; 115 | 116 | struct arg_str 117 | { 118 | struct arg_hdr hdr; /* The mandatory argtable header struct */ 119 | int count; /* Number of matching command line args */ 120 | const char **sval; /* Array of parsed argument values */ 121 | }; 122 | 123 | struct arg_rex 124 | { 125 | struct arg_hdr hdr; /* The mandatory argtable header struct */ 126 | int count; /* Number of matching command line args */ 127 | const char **sval; /* Array of parsed argument values */ 128 | }; 129 | 130 | struct arg_file 131 | { 132 | struct arg_hdr hdr; /* The mandatory argtable header struct */ 133 | int count; /* Number of matching command line args*/ 134 | const char **filename; /* Array of parsed filenames (eg: /home/foo.bar) */ 135 | const char **basename; /* Array of parsed basenames (eg: foo.bar) */ 136 | const char **extension; /* Array of parsed extensions (eg: .bar) */ 137 | }; 138 | 139 | struct arg_date 140 | { 141 | struct arg_hdr hdr; /* The mandatory argtable header struct */ 142 | const char *format; /* strptime format string used to parse the date */ 143 | int count; /* Number of matching command line args */ 144 | struct tm *tmval; /* Array of parsed time values */ 145 | }; 146 | 147 | enum {ARG_ELIMIT=1, ARG_EMALLOC, ARG_ENOMATCH, ARG_ELONGOPT, ARG_EMISSARG}; 148 | struct arg_end 149 | { 150 | struct arg_hdr hdr; /* The mandatory argtable header struct */ 151 | int count; /* Number of errors encountered */ 152 | int *error; /* Array of error codes */ 153 | void **parent; /* Array of pointers to offending arg_xxx struct */ 154 | const char **argval; /* Array of pointers to offending argv[] string */ 155 | }; 156 | 157 | 158 | /**** arg_xxx constructor functions *********************************/ 159 | 160 | struct arg_rem* arg_rem(const char* datatype, const char* glossary); 161 | 162 | struct arg_lit* arg_lit0(const char* shortopts, 163 | const char* longopts, 164 | const char* glossary); 165 | struct arg_lit* arg_lit1(const char* shortopts, 166 | const char* longopts, 167 | const char *glossary); 168 | struct arg_lit* arg_litn(const char* shortopts, 169 | const char* longopts, 170 | int mincount, 171 | int maxcount, 172 | const char *glossary); 173 | 174 | struct arg_key* arg_key0(const char* keyword, 175 | int flags, 176 | const char* glossary); 177 | struct arg_key* arg_key1(const char* keyword, 178 | int flags, 179 | const char* glossary); 180 | struct arg_key* arg_keyn(const char* keyword, 181 | int flags, 182 | int mincount, 183 | int maxcount, 184 | const char* glossary); 185 | 186 | struct arg_int* arg_int0(const char* shortopts, 187 | const char* longopts, 188 | const char* datatype, 189 | const char* glossary); 190 | struct arg_int* arg_int1(const char* shortopts, 191 | const char* longopts, 192 | const char* datatype, 193 | const char *glossary); 194 | struct arg_int* arg_intn(const char* shortopts, 195 | const char* longopts, 196 | const char *datatype, 197 | int mincount, 198 | int maxcount, 199 | const char *glossary); 200 | 201 | struct arg_dbl* arg_dbl0(const char* shortopts, 202 | const char* longopts, 203 | const char* datatype, 204 | const char* glossary); 205 | struct arg_dbl* arg_dbl1(const char* shortopts, 206 | const char* longopts, 207 | const char* datatype, 208 | const char *glossary); 209 | struct arg_dbl* arg_dbln(const char* shortopts, 210 | const char* longopts, 211 | const char *datatype, 212 | int mincount, 213 | int maxcount, 214 | const char *glossary); 215 | 216 | struct arg_str* arg_str0(const char* shortopts, 217 | const char* longopts, 218 | const char* datatype, 219 | const char* glossary); 220 | struct arg_str* arg_str1(const char* shortopts, 221 | const char* longopts, 222 | const char* datatype, 223 | const char *glossary); 224 | struct arg_str* arg_strn(const char* shortopts, 225 | const char* longopts, 226 | const char* datatype, 227 | int mincount, 228 | int maxcount, 229 | const char *glossary); 230 | 231 | struct arg_rex* arg_rex0(const char* shortopts, 232 | const char* longopts, 233 | const char* pattern, 234 | const char* datatype, 235 | int flags, 236 | const char* glossary); 237 | struct arg_rex* arg_rex1(const char* shortopts, 238 | const char* longopts, 239 | const char* pattern, 240 | const char* datatype, 241 | int flags, 242 | const char *glossary); 243 | struct arg_rex* arg_rexn(const char* shortopts, 244 | const char* longopts, 245 | const char* pattern, 246 | const char* datatype, 247 | int mincount, 248 | int maxcount, 249 | int flags, 250 | const char *glossary); 251 | 252 | struct arg_file* arg_file0(const char* shortopts, 253 | const char* longopts, 254 | const char* datatype, 255 | const char* glossary); 256 | struct arg_file* arg_file1(const char* shortopts, 257 | const char* longopts, 258 | const char* datatype, 259 | const char *glossary); 260 | struct arg_file* arg_filen(const char* shortopts, 261 | const char* longopts, 262 | const char* datatype, 263 | int mincount, 264 | int maxcount, 265 | const char *glossary); 266 | 267 | struct arg_date* arg_date0(const char* shortopts, 268 | const char* longopts, 269 | const char* format, 270 | const char* datatype, 271 | const char* glossary); 272 | struct arg_date* arg_date1(const char* shortopts, 273 | const char* longopts, 274 | const char* format, 275 | const char* datatype, 276 | const char *glossary); 277 | struct arg_date* arg_daten(const char* shortopts, 278 | const char* longopts, 279 | const char* format, 280 | const char* datatype, 281 | int mincount, 282 | int maxcount, 283 | const char *glossary); 284 | 285 | struct arg_end* arg_end(int maxerrors); 286 | 287 | 288 | /**** other functions *******************************************/ 289 | int arg_nullcheck(void **argtable); 290 | int arg_parse(int argc, char **argv, void **argtable); 291 | void arg_print_option(FILE *fp, const char *shortopts, const char *longopts, const char *datatype, const char *suffix); 292 | void arg_print_syntax(FILE *fp, void **argtable, const char *suffix); 293 | void arg_print_syntaxv(FILE *fp, void **argtable, const char *suffix); 294 | void arg_print_glossary(FILE *fp, void **argtable, const char *format); 295 | void arg_print_glossary_gnu(FILE *fp, void **argtable); 296 | void arg_print_errors(FILE* fp, struct arg_end* end, const char* progname); 297 | void arg_freetable(void **argtable, size_t n); 298 | 299 | /**** deprecated functions, for back-compatibility only ********/ 300 | void arg_free(void **argtable); 301 | 302 | #ifdef __cplusplus 303 | } 304 | #endif 305 | #endif 306 | -------------------------------------------------------------------------------- /cuda/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys 4 | import os 5 | 6 | def print_y(s): 7 | print('\033[1;33m' + s + '\033[m') 8 | return 9 | 10 | def print_bw(s): 11 | print('\033[1;1m' + s + '\033[m') 12 | return 13 | 14 | def errorcheck_sys(error): 15 | if error: 16 | print_y("Error occurs. Please run setup.py first to compile the binary files.") 17 | sys.exit() 18 | return 19 | 20 | def errorcheck_prog(error): 21 | if error: 22 | print_y("Binary error occurs.") 23 | sys.exit() 24 | return 25 | 26 | print_bw("Compiling AO_CWS code ...") 27 | 28 | if not os.path.exists("build"): 29 | errorcheck_sys(os.mkdir("build")) 30 | errorcheck_sys(os.chdir("build")) 31 | 32 | if os.name == 'nt': 33 | errorcheck_sys(os.system('cmake -DCMAKE_GENERATOR_PLATFORM=x64 --target install ..')) 34 | # errorcheck_sys(os.system("msbuild CWS_and_AO.sln /p:Configuration=Release")) // need to manually build the solution 35 | elif os.name == 'posix': 36 | errorcheck_sys(os.system("cmake ..")) 37 | errorcheck_sys(os.system("make -j8")) 38 | errorcheck_sys(os.system("make install")) 39 | else: 40 | print_y("Unknown platform. Compilation ends.") 41 | 42 | print_bw("You are done. AO_CWS CUDA code compilation is finished. Program exits.") 43 | -------------------------------------------------------------------------------- /cuda/src/IO_helper_functions.h: -------------------------------------------------------------------------------- 1 | #ifndef IO_HELPER_FUNCTIONS_H 2 | #define IO_HELPER_FUNCTIONS_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | // use fopen under Linux 12 | #ifdef __unix 13 | #define fopen_s(pFile,filename,mode) ((*(pFile))=fopen((filename), (mode)))==NULL 14 | #endif 15 | 16 | /////////////////////////////////////////////////////////////////////////////// 17 | /// \brief save optical flow in format described on vision.middlebury.edu/flow 18 | /// \param[in] name output file name 19 | /// \param[in] w optical flow field width 20 | /// \param[in] h optical flow field height 21 | /// \param[in] u horizontal displacement 22 | /// \param[in] v vertical displacement 23 | /////////////////////////////////////////////////////////////////////////////// 24 | void WriteFloFile(const char *name, int w, int h, const float *u, const float *v) 25 | { 26 | FILE *stream; 27 | 28 | if ((fopen_s(&stream, name, "wb")) != 0) 29 | { 30 | printf("Could not save flow to \"%s\"\n", name); 31 | return; 32 | } 33 | 34 | float data = 202021.25f; 35 | fwrite(&data, sizeof(float), 1, stream); 36 | fwrite(&w, sizeof(w), 1, stream); 37 | fwrite(&h, sizeof(h), 1, stream); 38 | 39 | for (int i = 0; i < h; ++i) 40 | { 41 | for (int j = 0; j < w; ++j) 42 | { 43 | const int pos = j + i * w; 44 | fwrite(u + pos, sizeof(float), 1, stream); 45 | fwrite(v + pos, sizeof(float), 1, stream); 46 | } 47 | } 48 | 49 | fclose(stream); 50 | } 51 | 52 | 53 | std::string readFile(const char *filePath) { 54 | std::string content; 55 | std::ifstream fileStream(filePath, std::ios::in); 56 | 57 | if(!fileStream.is_open()) { 58 | std::cerr << "Could not read file " << filePath << ". File does not exist." << std::endl; 59 | return ""; 60 | } 61 | 62 | std::string line = ""; 63 | while(!fileStream.eof()) { 64 | std::getline(fileStream, line); 65 | content.append(line + "\n"); 66 | } 67 | 68 | fileStream.close(); 69 | return content; 70 | } 71 | 72 | 73 | template 74 | void readtxt(const char *dataPath, T *data, int N) 75 | { 76 | FILE *file = fopen(dataPath,"r"); 77 | for(int i = 0; i < N; i++) 78 | fscanf(file, "%f\n", &data[i]); 79 | 80 | // fflush(file); 81 | fclose(file); 82 | 83 | printf("%s data loaded\n", dataPath); 84 | } 85 | 86 | template 87 | void savetxt(const char *dataPath, T *data, int N) 88 | { 89 | std::ofstream file(dataPath); 90 | if (file.is_open()){ 91 | for (int i = 0; i < N; i++) 92 | file << std::setprecision(std::numeric_limits::digits10 + 1) << data[i] << "\n"; 93 | file.close(); 94 | } 95 | printf("%s data saved\n", dataPath); 96 | } 97 | 98 | 99 | #endif 100 | 101 | -------------------------------------------------------------------------------- /cuda/src/addKernel.cuh: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | 3 | /////////////////////////////////////////////////////////////////////////////// 4 | /// \brief add two vectors of size _count_ 5 | /// 6 | /// CUDA kernel 7 | /// \param[in] op1 term one 8 | /// \param[in] op2 term two 9 | /// \param[in] count vector size 10 | /// \param[out] sum result 11 | /////////////////////////////////////////////////////////////////////////////// 12 | __global__ 13 | void AddKernel(const float *op1, const float *op2, int count, float *sum) 14 | { 15 | const int pos = threadIdx.x + blockIdx.x * blockDim.x; 16 | 17 | if (pos >= count) return; 18 | 19 | sum[pos] = op1[pos] + op2[pos]; 20 | } 21 | 22 | /////////////////////////////////////////////////////////////////////////////// 23 | /// \brief add two vectors of size _count_ 24 | /// \param[in] op1 term one 25 | /// \param[in] op2 term two 26 | /// \param[in] count vector size 27 | /// \param[out] sum result 28 | /////////////////////////////////////////////////////////////////////////////// 29 | static 30 | void Add(const float *op1, const float *op2, int count, float *sum) 31 | { 32 | dim3 threads(256); 33 | dim3 blocks(iDivUp(count, threads.x)); 34 | 35 | AddKernel<<>>(op1, op2, count, sum); 36 | } 37 | -------------------------------------------------------------------------------- /cuda/src/common.h: -------------------------------------------------------------------------------- 1 | /////////////////////////////////////////////////////////////////////////////// 2 | // Header for common includes and utility functions 3 | /////////////////////////////////////////////////////////////////////////////// 4 | 5 | #ifndef COMMON_H 6 | #define COMMON_H 7 | 8 | 9 | /////////////////////////////////////////////////////////////////////////////// 10 | // Common includes 11 | /////////////////////////////////////////////////////////////////////////////// 12 | 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | #include 20 | #include 21 | #include 22 | 23 | #include 24 | #include 25 | 26 | /////////////////////////////////////////////////////////////////////////////// 27 | // Error checking functions 28 | /////////////////////////////////////////////////////////////////////////////// 29 | //static const char *_cudaGetErrorEnum(cufftResult error) 30 | //{ 31 | // switch (error) 32 | // { 33 | // case CUFFT_SUCCESS: 34 | // return "CUFFT_SUCCESS"; 35 | 36 | // case CUFFT_INVALID_PLAN: 37 | // return "CUFFT_INVALID_PLAN"; 38 | 39 | // case CUFFT_ALLOC_FAILED: 40 | // return "CUFFT_ALLOC_FAILED"; 41 | 42 | // case CUFFT_INVALID_TYPE: 43 | // return "CUFFT_INVALID_TYPE"; 44 | 45 | // case CUFFT_INVALID_VALUE: 46 | // return "CUFFT_INVALID_VALUE"; 47 | 48 | // case CUFFT_INTERNAL_ERROR: 49 | // return "CUFFT_INTERNAL_ERROR"; 50 | 51 | // case CUFFT_EXEC_FAILED: 52 | // return "CUFFT_EXEC_FAILED"; 53 | 54 | // case CUFFT_SETUP_FAILED: 55 | // return "CUFFT_SETUP_FAILED"; 56 | 57 | // case CUFFT_INVALID_SIZE: 58 | // return "CUFFT_INVALID_SIZE"; 59 | 60 | // case CUFFT_UNALIGNED_DATA: 61 | // return "CUFFT_UNALIGNED_DATA"; 62 | // } 63 | 64 | // return ""; 65 | //} 66 | 67 | //#define cufftSafeCall(err) __cufftSafeCall(err, __FILE__, __LINE__) 68 | //inline void __cufftSafeCall(cufftResult err, const char *file, const int line) 69 | //{ 70 | // if( CUFFT_SUCCESS != err) { 71 | // fprintf(stderr, "CUFFT error in file '%s', line %d\n %s\nerror %d: %s\nterminating!\n",__FILE__, __LINE__,err, \ 72 | // _cudaGetErrorEnum(err)); \ 73 | // cudaDeviceReset(); assert(0); \ 74 | // } 75 | //} 76 | 77 | 78 | /////////////////////////////////////////////////////////////////////////////// 79 | // Common constants 80 | /////////////////////////////////////////////////////////////////////////////// 81 | const int StrideAlignment = 32; 82 | 83 | // #ifdef DOUBLE_PRECISION 84 | // typedef double real; 85 | // typedef cufftDoubleComplex complex; 86 | // #else 87 | // typedef float real; 88 | // typedef cufftComplex complex; 89 | // #endif 90 | 91 | typedef cufftComplex complex; 92 | 93 | #define PI 3.1415926535897932384626433832795028841971693993751 94 | 95 | // block size for shared memory 96 | #define BLOCK_X 32 97 | #define BLOCK_Y 32 98 | 99 | 100 | /////////////////////////////////////////////////////////////////////////////// 101 | // Common functions 102 | /////////////////////////////////////////////////////////////////////////////// 103 | 104 | // A GPU timer 105 | struct GpuTimer 106 | { 107 | cudaEvent_t start; 108 | cudaEvent_t stop; 109 | 110 | GpuTimer() 111 | { 112 | cudaEventCreate(&start); 113 | cudaEventCreate(&stop); 114 | } 115 | 116 | ~GpuTimer() 117 | { 118 | cudaEventDestroy(start); 119 | cudaEventDestroy(stop); 120 | } 121 | 122 | void Start() 123 | { 124 | cudaEventRecord(start, 0); 125 | } 126 | 127 | void Stop() 128 | { 129 | cudaEventRecord(stop, 0); 130 | } 131 | 132 | float Elapsed() 133 | { 134 | float elapsed; 135 | cudaEventSynchronize(stop); 136 | cudaEventElapsedTime(&elapsed, start, stop); 137 | return elapsed; 138 | } 139 | }; 140 | 141 | // Align up n to the nearest multiple of m 142 | inline int iAlignUp(int n, int m = StrideAlignment) 143 | { 144 | int mod = n % m; 145 | 146 | if (mod) 147 | return n + m - mod; 148 | else 149 | return n; 150 | } 151 | 152 | // round up n/m 153 | inline int iDivUp(int n, int m) 154 | { 155 | return (n + m - 1) / m; 156 | } 157 | 158 | // swap two values 159 | template 160 | inline void Swap(T &a, T &b) 161 | { 162 | T t = a; 163 | a = b; 164 | b = t; 165 | } 166 | 167 | // Wrap to [-0.5 0.5], then add 0.5 to [0 1] for final phase show 168 | __host__ 169 | __device__ 170 | inline float wrap(float x) 171 | { 172 | return x - floor(x + 0.5f) + 0.5f; 173 | } 174 | 175 | // nearest integer of power of 2 176 | __host__ 177 | inline int nearest_power_of_two(int x) 178 | { 179 | return 2 << (static_cast(std::log2(x) + 1.0) - 1); 180 | } 181 | #endif 182 | -------------------------------------------------------------------------------- /cuda/src/computeDCTweightsKernel.cuh: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | 3 | /////////////////////////////////////////////////////////////////////////////// 4 | /// \brief compute mat_x_hat 5 | /// 6 | /// CUDA kernel 7 | /// \param[in] mu proximal parameter 8 | /// \param[in] alpha regularization parameter 9 | /// \param[in] M_width unknown width 10 | /// \param[in] M_height unknown height 11 | /// \param[out] mat_x_hat result 12 | /////////////////////////////////////////////////////////////////////////////// 13 | __global__ 14 | void computeDCTweightsKernel(int N_height, complex *out) 15 | { 16 | const int pos = threadIdx.x + blockIdx.x * blockDim.x; 17 | 18 | if (pos >= N_height) return; 19 | 20 | // we use double to enhance accuracy 21 | out[pos].x = (float) (sqrt((double)(2*N_height)) * cos(pos*PI/(2*N_height))); 22 | out[pos].y = (float) (sqrt((double)(2*N_height)) * sin(pos*PI/(2*N_height))); 23 | } 24 | 25 | /////////////////////////////////////////////////////////////////////////////// 26 | /// \brief compute DCT weights 27 | /// \param[in] mu proximal parameter 28 | /// \param[in] alpha regularization parameter 29 | /// \param[in] M_width unknown width 30 | /// \param[in] M_height unknown height 31 | /// \param[out] mat_x_hat result 32 | /////////////////////////////////////////////////////////////////////////////// 33 | static 34 | void computeDCTweights(int N_width, int N_height, complex *ww_1, complex *ww_2) 35 | { 36 | dim3 threads(256); 37 | dim3 blocks1(iDivUp(N_height, threads.x)); 38 | dim3 blocks2(iDivUp(N_width, threads.x)); 39 | 40 | computeDCTweightsKernel<<>>(N_height, ww_1); 41 | computeDCTweightsKernel<<>>(N_width, ww_2); 42 | } 43 | -------------------------------------------------------------------------------- /cuda/src/cws_A_phi.cu: -------------------------------------------------------------------------------- 1 | // Please keep this include order to make sure a successful compilation! 2 | #include 3 | #include 4 | #include // for profiling 5 | #include 6 | 7 | // OpenCV 8 | #include 9 | 10 | // include kernels 11 | #include "common.h" 12 | #include "prepare_cufft_warmup.h" 13 | #include "computeDCTweightsKernel.cuh" 14 | #include "x_updateKernel.cuh" 15 | #include "prox_gKernel.cuh" 16 | #include "addKernel.cuh" 17 | #include "medianfilteringKernel.cuh" 18 | 19 | // thrust headers 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | #include "cws_A_phi.h" 27 | 28 | 29 | // ============================================================= 30 | // thrust definitions 31 | // ============================================================= 32 | // thrust operators 33 | template 34 | struct absolute{ 35 | __host__ __device__ 36 | T operator()(const T& x) const { 37 | return fabs(x); 38 | } 39 | }; 40 | template 41 | struct square{ 42 | __host__ __device__ 43 | T operator()(const T& x) const { 44 | return x * x; 45 | } 46 | }; 47 | 48 | // define reduction pointer 49 | thrust::device_ptr thrust_ptr; 50 | 51 | // thrust setup arguments 52 | absolute u_op_abs; 53 | square u_op_sq; 54 | thrust::plus b_op; 55 | // ============================================================= 56 | 57 | 58 | // ============================================================= 59 | // norm & mean functions 60 | // ============================================================= 61 | template 62 | real norm(real *d_x, int n) 63 | { 64 | switch (N) 65 | { 66 | case 1: 67 | { 68 | thrust_ptr = thrust::device_pointer_cast(d_x); 69 | return thrust::transform_reduce(thrust_ptr, thrust_ptr + n, u_op_abs, 0.0, b_op); 70 | } 71 | case 2: 72 | { 73 | thrust_ptr = thrust::device_pointer_cast(d_x); 74 | return thrust::transform_reduce(thrust_ptr, thrust_ptr + n, u_op_sq, 0.0, b_op); 75 | } 76 | default: 77 | { 78 | printf("Undefined norm!\n"); 79 | return 0; 80 | } 81 | } 82 | } 83 | template 84 | real mean(real *d_x, int n) 85 | { 86 | switch (N) 87 | { 88 | case 1: 89 | return norm<1>(d_x, n) / static_cast(n); 90 | case 2: 91 | return norm<2>(d_x, n) / static_cast(n); 92 | default: 93 | { 94 | printf("Undefined mean!\n"); 95 | return 0; 96 | } 97 | } 98 | } 99 | // ============================================================= 100 | 101 | 102 | // ============================================================= 103 | // helper functions 104 | // ============================================================= 105 | // ----- set to ones ----- 106 | __global__ 107 | void setonesKernel(real *I, cv::Size N) 108 | { 109 | const int ix = threadIdx.x + blockIdx.x * blockDim.x; 110 | const int iy = threadIdx.y + blockIdx.y * blockDim.y; 111 | if (ix >= N.width || iy >= N.height) return; 112 | 113 | I[ix + iy * N.width] = 1.0f; 114 | } 115 | void setasones(real *I, cv::Size N) 116 | { 117 | dim3 threads(16, 16); 118 | dim3 blocks(iDivUp(N.width, threads.x), iDivUp(N.height, threads.y)); 119 | setonesKernel<<>>(I, N); 120 | } 121 | // ============================================================= 122 | 123 | 124 | // ============================================================= 125 | // constant memory cache 126 | // ============================================================= 127 | __constant__ real shrinkage_value[2]; 128 | __constant__ int L_SIZE[2]; 129 | // ============================================================= 130 | 131 | 132 | // ============================================================= 133 | // pre-computation: DCT basis function 134 | // ============================================================= 135 | __device__ 136 | double DCT_kernel(const int ix, const int iy, int w, int h) 137 | { 138 | return 4.0 - 2*cos(PI*(double)ix/(double)w) - 2*cos(PI*(double)iy/(double)h); 139 | } 140 | __global__ 141 | void mat_A_hatKernel(opt_A opt_A_t, cv::Size M, real *mat_A_hat) 142 | { 143 | const int ix = threadIdx.x + blockIdx.x * blockDim.x; 144 | const int iy = threadIdx.y + blockIdx.y * blockDim.y; 145 | const int pos = ix + iy * M.width; 146 | 147 | if (ix >= M.width || iy >= M.height) return; 148 | 149 | // \nabla^2 150 | double lap_mat = DCT_kernel(ix, iy, M.width, M.height); 151 | 152 | // \nabla^4 + \nabla^2 153 | double K_mat = lap_mat * (lap_mat + 1.0); 154 | 155 | // get mat_A_hat 156 | mat_A_hat[pos] = (double) ( 1.0 + (opt_A_t.tau_new + opt_A_t.mu) * K_mat ); 157 | } 158 | __global__ 159 | void mat_phi_hatKernel(opt_phi opt_phi_t, real beta, cv::Size N, real *mat_phi_hat) 160 | { 161 | const int ix = threadIdx.x + blockIdx.x * blockDim.x; 162 | const int iy = threadIdx.y + blockIdx.y * blockDim.y; 163 | const int pos = ix + iy * N.width; 164 | 165 | if (ix >= N.width || iy >= N.height) return; 166 | 167 | // \nabla^2 168 | double lap_mat = DCT_kernel(ix, iy, N.width, N.height); 169 | 170 | // get mat_phi_hat (pre-divide mu so there is no need for re-scaling phi in ADMM) 171 | if (pos == 0){ 172 | mat_phi_hat[pos] = 1.0; 173 | } 174 | else{ 175 | mat_phi_hat[pos] = (double) ( 176 | (opt_phi_t.mu + beta + beta * lap_mat) * lap_mat ); 177 | } 178 | } 179 | void compute_inverse_mat(opt_A opt_A_t, opt_phi opt_phi_t, real beta, 180 | cv::Size M, cv::Size N, real *mat_A_hat, real *mat_phi_hat) 181 | { 182 | dim3 threads(16, 16); 183 | dim3 blocks_M(iDivUp(M.width, threads.x), iDivUp(M.height, threads.y)); 184 | dim3 blocks_N(iDivUp(N.width, threads.x), iDivUp(N.height, threads.y)); 185 | mat_A_hatKernel<<>>(opt_A_t, M, mat_A_hat); 186 | mat_phi_hatKernel<<>>(opt_phi_t, beta, N, mat_phi_hat); 187 | } 188 | // ============================================================= 189 | 190 | 191 | // ============================================================= 192 | // operators 193 | // ============================================================= 194 | // ----- nabla ----- 195 | __global__ 196 | void nablaKernel(const real *__restrict__ I, cv::Size N, real *Ix, real *Iy) 197 | { 198 | const int ix = threadIdx.x + blockIdx.x * blockDim.x; 199 | const int iy = threadIdx.y + blockIdx.y * blockDim.y; 200 | const int pos = ix + iy * N.width; 201 | 202 | if (ix >= N.width || iy >= N.height) return; 203 | 204 | // replicate/symmetric boundary condition 205 | if(ix == 0){ 206 | Ix[pos] = 0.0f; 207 | } 208 | else{ 209 | Ix[pos] = I[pos] - I[(ix-1) + iy * N.width]; 210 | } 211 | if (iy == 0){ 212 | Iy[pos] = 0.0f; 213 | } 214 | else{ 215 | Iy[pos] = I[pos] - I[ix + (iy-1) * N.width]; 216 | } 217 | } 218 | void nabla(const real *I, cv::Size N, real *nabla_x, real *nabla_y) 219 | { 220 | dim3 threads(32, 32); 221 | dim3 blocks(iDivUp(N.width, threads.x), iDivUp(N.height, threads.y)); 222 | nablaKernel<<>>(I, N, nabla_x, nabla_y); 223 | } 224 | // ----------------- 225 | 226 | // ----- nablaT ----- 227 | __global__ 228 | void nablaTKernel(const real *__restrict__ Ix, const real *__restrict__ Iy, cv::Size N, real *div) 229 | { 230 | const int ix = threadIdx.x + blockIdx.x * blockDim.x; 231 | const int iy = threadIdx.y + blockIdx.y * blockDim.y; 232 | const int pos = ix + iy * N.width; 233 | 234 | real val1, val2; 235 | if (ix >= N.width || iy >= N.height) return; 236 | 237 | // replicate/symmetric boundary condition 238 | if(ix == N.width-1){ 239 | val1 = 0.0f; 240 | } 241 | else{ 242 | val1 = Ix[pos] - Ix[ix+1 + iy * N.width]; 243 | } 244 | if (iy == N.height-1){ 245 | val2 = 0.0f; 246 | } 247 | else{ 248 | val2 = Iy[pos] - Iy[ix + (iy+1) * N.width]; 249 | } 250 | div[pos] = val1 + val2; 251 | } 252 | void nablaT(const real *__restrict__ Ix, const real *__restrict__ Iy, cv::Size N, real *div) 253 | { 254 | dim3 threads(32, 32); 255 | dim3 blocks(iDivUp(N.width, threads.x), iDivUp(N.height, threads.y)); 256 | nablaTKernel<<>>(Ix, Iy, N, div); 257 | } 258 | // ----------------- 259 | 260 | // ----- nabla2 ----- 261 | __global__ 262 | void nabla2Kernel(const real *__restrict__ I, cv::Size N, real *L) 263 | { 264 | const int ix = threadIdx.x + blockIdx.x * blockDim.x; 265 | const int iy = threadIdx.y + blockIdx.y * blockDim.y; 266 | const int py = iy * N.width; 267 | const int pos = ix + py; 268 | 269 | if (ix >= N.width || iy >= N.height) return; 270 | 271 | // replicate/symmetric boundary condition (3x3 stencil) 272 | const int x_min = max(ix-1, 0); 273 | const int x_max = min(ix+1, N.width-1); 274 | const int y_min = max(iy-1, 0); 275 | const int y_max = min(iy+1, N.height-1); 276 | 277 | L[pos] = -4 * I[pos] + I[x_min + py] + I[x_max + py] + I[ix + y_min*N.width] + I[ix + y_max*N.width]; 278 | } 279 | void nabla2(const real *I, cv::Size N, real *L) 280 | { 281 | dim3 threads(32, 32); 282 | dim3 blocks(iDivUp(N.width, threads.x), iDivUp(N.height, threads.y)); 283 | nabla2Kernel<<>>(I, N, L); 284 | } 285 | // ----------------- 286 | 287 | // ----- K & KT ----- 288 | void K(const real *I, cv::Size N, real *grad_x, real *grad_y, real *L) 289 | { 290 | nabla(I, N, grad_x, grad_y); 291 | nabla2(I, N, L); 292 | } 293 | __global__ 294 | void KTKernel(const real *__restrict__ Ix, const real *__restrict__ Iy, 295 | const real *__restrict__ L, cv::Size N, real *I) 296 | { 297 | const int ix = threadIdx.x + blockIdx.x * blockDim.x; 298 | const int iy = threadIdx.y + blockIdx.y * blockDim.y; 299 | const int py = iy * N.width; 300 | const int pos = ix + py; 301 | 302 | if (ix >= N.width || iy >= N.height) return; 303 | 304 | // replicate/symmetric boundary condition (3x3 stencil) 305 | const int x_min = max(ix-1, 0); 306 | const int x_max = min(ix+1, N.width-1); 307 | const int y_min = max(iy-1, 0); 308 | const int y_max = min(iy+1, N.height-1); 309 | 310 | real Div = Ix[pos] - Ix[x_max + iy*N.width] + Iy[pos] - Iy[ix + y_max*N.width]; 311 | real Lap = -4 * L[pos] + L[x_min + py] + L[x_max + py] + L[ix + y_min*N.width] + L[ix + y_max*N.width]; 312 | I[pos] = Div + Lap; 313 | } 314 | void KT(const real *grad_x, const real *grad_y, const real *L, cv::Size N, real *I) 315 | { 316 | dim3 threads(32, 32); 317 | dim3 blocks(iDivUp(N.width, threads.x), iDivUp(N.height, threads.y)); 318 | KTKernel<<>>(grad_x, grad_y, L, N, I); 319 | } 320 | __global__ 321 | void MKernel(const real *in, cv::Size M, cv::Size N, real *out) 322 | { 323 | const int ix = threadIdx.x + blockIdx.x * blockDim.x; 324 | const int iy = threadIdx.y + blockIdx.y * blockDim.y; 325 | 326 | const int pos_N = ix + iy * N.width; 327 | const int pos_M = ix - L_SIZE[0] + (iy - L_SIZE[1]) * M.width; 328 | 329 | if (ix >= N.width || iy >= N.height) return; 330 | else if (ix >= L_SIZE[0] && ix < N.width - L_SIZE[0] && 331 | iy >= L_SIZE[1] && iy < N.height - L_SIZE[1]) { 332 | out[pos_M] = in[pos_N]; 333 | } 334 | else return; 335 | } 336 | void Mask_func(const real *in, cv::Size M, cv::Size N, real *out) 337 | { 338 | dim3 threads(32, 32); 339 | dim3 blocks(iDivUp(N.width, threads.x), iDivUp(N.height, threads.y)); 340 | MKernel<<>>(in, M, N, out); 341 | } 342 | // ----------------- 343 | // ============================================================= 344 | 345 | 346 | // ============================================================= 347 | // objective functions 348 | // ============================================================= 349 | __global__ 350 | void res1Kernel(const real *A, const real *b, const real *I0, cv::Size N, real *tmp) 351 | { 352 | const int ix = threadIdx.x + blockIdx.x * blockDim.x; 353 | const int iy = threadIdx.y + blockIdx.y * blockDim.y; 354 | const int pos = ix + iy * N.width; 355 | 356 | if (ix >= N.width || iy >= N.height) return; 357 | tmp[pos] = A[pos] - b[pos]/I0[pos]; 358 | } 359 | __global__ 360 | void res2Kernel(const real *A, const real *b, const real *I0, cv::Size M, real *tmp) 361 | { 362 | const int ix = threadIdx.x + blockIdx.x * blockDim.x; 363 | const int iy = threadIdx.y + blockIdx.y * blockDim.y; 364 | const int pos = ix + iy * M.width; 365 | 366 | if (ix >= M.width || iy >= M.height) return; 367 | tmp[pos] = I0[pos] * A[pos] - b[pos]; 368 | } 369 | __global__ 370 | void res3Kernel(real *gx, real *gy, real *gt, real *wx, real *wy, 371 | cv::Size M, cv::Size N, real *out, bool isM = true) 372 | { 373 | // This function will be used by: 374 | // - 1. Cost function of phi (function: obj_phi) 375 | // - 2. Update I_warp (function: update_I_warp) 376 | // To fulfill both goals, size of out is designed to be either M or N. 377 | const int ix = threadIdx.x + blockIdx.x * blockDim.x; 378 | const int iy = threadIdx.y + blockIdx.y * blockDim.y; 379 | 380 | const int pos_N = ix + iy * N.width; 381 | const int pos_M = ix - L_SIZE[0] + (iy - L_SIZE[1]) * M.width; 382 | const int pos_tmp = isM ? pos_M : pos_N; 383 | 384 | if (ix >= N.width || iy >= N.height) return; 385 | else if (ix >= L_SIZE[0] && ix < N.width - L_SIZE[0] && 386 | iy >= L_SIZE[1] && iy < N.height - L_SIZE[1]) { 387 | out[pos_tmp] = gx[pos_M] * wx[pos_N] + gy[pos_M] * wy[pos_N] + gt[pos_M]; 388 | } 389 | else if (!isM) { 390 | out[pos_tmp] = 0.0f; // outside area be zeros; in this case tmp is of size N 391 | } 392 | else return; 393 | } 394 | real obj_A(const real *A, const real *b, const real *I0, opt_A opt_A_t, cv::Size N, 395 | real *tmp1, real *tmp2, real *tmp3) 396 | { 397 | // compute A - b/I0, and data term 398 | dim3 threads(32, 32); 399 | dim3 blocks(iDivUp(N.width, threads.x), iDivUp(N.height, threads.y)); 400 | res1Kernel<<>>(A, b, I0, N, tmp1); 401 | real data_term = norm<2>(tmp1, N.area()); 402 | 403 | // compute KA, and priors 404 | K(A, N, tmp1, tmp2, tmp3); 405 | real L1_term = opt_A_t.gamma_new * ( norm<1>(tmp1, N.area()) + norm<1>(tmp2, N.area()) + norm<1>(tmp3, N.area()) ); 406 | real L2_term = opt_A_t.tau_new * ( norm<2>(tmp1, N.area()) + norm<2>(tmp2, N.area()) + norm<2>(tmp3, N.area()) ); 407 | 408 | return data_term + L1_term + L2_term; 409 | } 410 | real obj_phi(const real *phi, real *gx, real *gy, real *gt, 411 | real alpha, real beta, cv::Size M, cv::Size N, 412 | real *tmp1_N, real *tmp2_N, real *tmp3_N) 413 | { 414 | // compute data term 415 | dim3 threads(32, 32); 416 | dim3 blocks(iDivUp(N.width, threads.x), iDivUp(N.height, threads.y)); 417 | nabla(phi, N, tmp1_N, tmp2_N); 418 | res3Kernel<<>>(gx, gy, gt, tmp1_N, tmp2_N, M, N, tmp1_N, false); 419 | real data_term = norm<2>(tmp1_N, N.area()); 420 | 421 | // compute nabla phi, and phi priors 422 | K(phi, N, tmp1_N, tmp2_N, tmp3_N); 423 | real L1_term = alpha * ( norm<1>(tmp1_N, N.area()) + norm<1>(tmp2_N, N.area()) ); 424 | real L2_term = beta * ( norm<2>(tmp1_N, N.area()) + norm<2>(tmp2_N, N.area()) + norm<2>(tmp3_N, N.area()) ); 425 | 426 | return data_term + L1_term + L2_term; 427 | } 428 | real obj_total(const real *A, const real *phi, const real *b, const real *I0, 429 | real alpha, real beta, real gamma, real tau, cv::Size M, cv::Size N, 430 | real *tmp1_M, real *tmp2_M, real *tmp3_M, 431 | real *tmp1_N, real *tmp2_N, real *tmp3_N) 432 | { 433 | // compute I0*A - b, and data term 434 | dim3 threads(16, 16); 435 | dim3 blocks(iDivUp(M.width, threads.x), iDivUp(M.height, threads.y)); 436 | res2Kernel<<>>(A, b, I0, M, tmp1_M); 437 | real data_term = norm<2>(tmp1_M, M.area()); 438 | 439 | // compute nabla phi, and phi priors 440 | K(phi, N, tmp1_N, tmp2_N, tmp3_N); 441 | real phi_L1_term = alpha * ( norm<1>(tmp1_N, N.area()) + norm<1>(tmp2_N, N.area()) ); 442 | real phi_L2_term = beta * ( norm<2>(tmp1_N, N.area()) + norm<2>(tmp2_N, N.area()) + norm<2>(tmp3_N, N.area()) ); 443 | 444 | // compute KA, and A priors 445 | K(A, M, tmp1_M, tmp2_M, tmp3_M); 446 | real A_L1_term = gamma * ( norm<1>(tmp1_M, M.area()) + norm<1>(tmp2_M, M.area()) + norm<1>(tmp3_M, M.area()) ); 447 | real A_L2_term = tau * ( norm<2>(tmp1_M, M.area()) + norm<2>(tmp2_M, M.area()) + norm<2>(tmp3_M, M.area()) ); 448 | 449 | return data_term + phi_L1_term + phi_L2_term + A_L1_term + A_L2_term; 450 | } 451 | // ============================================================= 452 | 453 | 454 | // ============================================================= 455 | // sub functions for updates 456 | // ============================================================= 457 | __global__ 458 | void ADMM_AKernel(real *zeta_x, real *zeta_y, real *zeta_z, 459 | real *tmp_x, real *tmp_y, real *tmp_z, cv::Size N) 460 | { 461 | const int ix = threadIdx.x + blockIdx.x * blockDim.x; 462 | const int iy = threadIdx.y + blockIdx.y * blockDim.y; 463 | const int pos = ix + iy * N.width; 464 | 465 | if (ix >= N.width || iy >= N.height) return; 466 | 467 | // compute u = KA + zeta 468 | real u_x = tmp_x[pos] + zeta_x[pos]; 469 | real u_y = tmp_y[pos] + zeta_y[pos]; 470 | real u_z = tmp_z[pos] + zeta_z[pos]; 471 | 472 | // B-update 473 | real B_x = prox_l1(u_x, shrinkage_value[0]); 474 | real B_y = prox_l1(u_y, shrinkage_value[0]); 475 | real B_z = prox_l1(u_z, shrinkage_value[0]); 476 | 477 | // zeta-update 478 | zeta_x[pos] = u_x - B_x; 479 | zeta_y[pos] = u_y - B_y; 480 | zeta_z[pos] = u_z - B_z; 481 | 482 | // store B-zeta 483 | tmp_x[pos] = 2*B_x - u_x; 484 | tmp_y[pos] = 2*B_y - u_y; 485 | tmp_z[pos] = 2*B_z - u_z; 486 | } 487 | __global__ 488 | void ADMM_precomputeAKernel(real *I_warp, real *I0, real *KT_res, real mu, cv::Size N, real *A) 489 | { 490 | const int ix = threadIdx.x + blockIdx.x * blockDim.x; 491 | const int iy = threadIdx.y + blockIdx.y * blockDim.y; 492 | const int py = iy * N.width; 493 | const int pos = ix + py; 494 | 495 | if (ix >= N.width || iy >= N.height) return; 496 | 497 | // pre-compute for A-update (I_warp/I0 can be pre-computed) 498 | A[pos] = I_warp[pos] / I0[pos] + mu * KT_res[pos]; 499 | } 500 | void A_update(real *A, real *I_warp, real *I0, 501 | real *zeta_x, real *zeta_y, real *zeta_z, 502 | real *grad_x_M, real *grad_y_M, real *lap_M, 503 | complex *tmp_dct_M, real *mat_A_hat, 504 | complex *ww_1_M, complex *ww_2_M, cufftHandle plan_dct_1_M, cufftHandle plan_dct_2_M, 505 | opt_A opt_A_t, cv::Size M) 506 | { 507 | dim3 threads(32, 32); 508 | dim3 blocks(iDivUp(M.width, threads.x), iDivUp(M.height, threads.y)); 509 | 510 | // initial objective 511 | real objval; 512 | if (opt_A_t.isverbose) 513 | { 514 | printf("-- A update\n"); 515 | objval = obj_A(A, I_warp, I0, opt_A_t, M, grad_x_M, grad_y_M, lap_M); 516 | printf("---- init, obj = %.4e \n", objval); 517 | } 518 | 519 | // initialization 520 | checkCudaErrors(cudaMemset(zeta_x, 0, M.area()*sizeof(real))); 521 | checkCudaErrors(cudaMemset(zeta_y, 0, M.area()*sizeof(real))); 522 | checkCudaErrors(cudaMemset(zeta_z, 0, M.area()*sizeof(real))); 523 | checkCudaErrors(cudaMemset(grad_x_M, 0, M.area()*sizeof(real))); 524 | checkCudaErrors(cudaMemset(grad_y_M, 0, M.area()*sizeof(real))); 525 | checkCudaErrors(cudaMemset(lap_M, 0, M.area()*sizeof(real))); 526 | 527 | // ADMM loop 528 | for (signed int i = 1; i <= opt_A_t.iter; ++i) 529 | { 530 | // compute KT(B - zeta) 531 | KT(grad_x_M, grad_y_M, lap_M, M, A); 532 | 533 | // A-update inversion 534 | ADMM_precomputeAKernel<<>>(I_warp, I0, A, opt_A_t.mu, M, A); 535 | x_update(A, tmp_dct_M, mat_A_hat, ww_1_M, ww_2_M, M.width, M.height, plan_dct_1_M, plan_dct_2_M); 536 | 537 | // B-update & zeta-update 538 | K(A, M, grad_x_M, grad_y_M, lap_M); 539 | ADMM_AKernel<<>>(zeta_x, zeta_y, zeta_z, grad_x_M, grad_y_M, lap_M, M); 540 | } 541 | 542 | // show final objective 543 | if (opt_A_t.isverbose) 544 | { 545 | objval = obj_A(A, I_warp, I0, opt_A_t, M, grad_x_M, grad_y_M, lap_M); 546 | printf("---- iter = %d, obj = %.4e \n", opt_A_t.iter, objval); 547 | } 548 | } 549 | __global__ 550 | void I_warp_updateKernel(real *A, real *I0, cv::Size M, real *I_warp) 551 | { 552 | const int ix = threadIdx.x + blockIdx.x * blockDim.x; 553 | const int iy = threadIdx.y + blockIdx.y * blockDim.y; 554 | const int pos = ix + iy * M.width; 555 | 556 | if (ix >= M.width || iy >= M.height) return; 557 | I_warp[pos] = A[pos] * I0[pos]; 558 | } 559 | void I_warp_update(real *A, real *I0, cv::Size M, real *I_warp) 560 | { 561 | dim3 threads(32, 32); 562 | dim3 blocks(iDivUp(M.width, threads.x), iDivUp(M.height, threads.y)); 563 | I_warp_updateKernel<<>>(A, I0, M, I_warp); 564 | } 565 | void update_I_warp(real *I_warp, real *phi, real *gx, real *gy, 566 | real *tmp1, real *tmp2, cv::Size M, cv::Size N) 567 | { 568 | dim3 threads(32, 32); 569 | dim3 blocks(iDivUp(N.width, threads.x), iDivUp(N.height, threads.y)); 570 | nabla(phi, N, tmp1, tmp2); 571 | res3Kernel<<>>(gx, gy, I_warp, tmp1, tmp2, M, N, I_warp, true); 572 | } 573 | texture texTarget; 574 | template // template global kernel to handle all cases 575 | __global__ 576 | void ComputeDerivativesL1Kernel(const real *__restrict__ I0, cv::Size M, real mu, 577 | real *A11, real *A12, real *A22, real *a, real *b, real *c, const real *__restrict__ I = NULL) 578 | { 579 | const int ix = threadIdx.x + blockIdx.x * blockDim.x; 580 | const int iy = threadIdx.y + blockIdx.y * blockDim.y; 581 | const int pos = ix + iy * M.width; 582 | 583 | if (ix >= M.width || iy >= M.height) return; 584 | 585 | // compute the derivatives 586 | real gx, gy, gt; 587 | switch (mods) 588 | { 589 | case 0: // image size is multiple of 32; use texture memory 590 | { 591 | real dx = 1.0f / (real)M.width; 592 | real dy = 1.0f / (real)M.height; 593 | 594 | real x = ((real)ix + 0.5f) * dx; 595 | real y = ((real)iy + 0.5f) * dy; 596 | 597 | // x derivative 598 | gx = tex2D(texTarget, x - 2.0f * dx, y); 599 | gx -= tex2D(texTarget, x - 1.0f * dx, y) * 8.0f; 600 | gx += tex2D(texTarget, x + 1.0f * dx, y) * 8.0f; 601 | gx -= tex2D(texTarget, x + 2.0f * dx, y); 602 | gx /= 12.0f; 603 | 604 | // t derivative 605 | gt = tex2D(texTarget, x, y) - I0[pos]; 606 | 607 | // y derivative 608 | gy = tex2D(texTarget, x, y - 2.0f * dy); 609 | gy -= tex2D(texTarget, x, y - 1.0f * dy) * 8.0f; 610 | gy += tex2D(texTarget, x, y + 1.0f * dy) * 8.0f; 611 | gy -= tex2D(texTarget, x, y + 2.0f * dy); 612 | gy /= 12.0f; 613 | } 614 | break; 615 | 616 | case 1: // image size is not multiple of 32; use shared memory 617 | { 618 | const int tx = threadIdx.x + 2; 619 | const int ty = threadIdx.y + 2; 620 | 621 | if (1) 622 | { 623 | // save to shared memory 624 | __shared__ real smem[BLOCK_X+4][BLOCK_Y+4]; 625 | smem[tx][ty] = I[pos]; 626 | set_bd_shared_memory<5>(smem, I, tx, ty, ix, iy, M.width, M.height); 627 | __syncthreads(); 628 | 629 | // x derivative 630 | gx = smem[tx-2][ty]; 631 | gx -= smem[tx-1][ty] * 8.0f; 632 | gx += smem[tx+1][ty] * 8.0f; 633 | gx -= smem[tx+2][ty]; 634 | gx /= 12.0f; 635 | 636 | // t derivative 637 | gt = smem[tx][ty] - I0[pos]; 638 | 639 | // y derivative 640 | gy = smem[tx][ty-2]; 641 | gy -= smem[tx][ty-1] * 8.0f; 642 | gy += smem[tx][ty+1] * 8.0f; 643 | gy -= smem[tx][ty+2]; 644 | gy /= 12.0f; 645 | } 646 | else 647 | { 648 | // x derivative 649 | gx = I[max(0,ix-2) + iy * M.width]; 650 | gx -= I[max(0,ix-1) + iy * M.width] * 8.0f; 651 | gx += I[min(M.width-1,ix+1) + iy * M.width] * 8.0f; 652 | gx -= I[min(M.width-1,ix+2) + iy * M.width]; 653 | gx /= 12.0f; 654 | 655 | // t derivative 656 | gt = I[pos] - I0[pos]; 657 | 658 | // y derivative 659 | gy = I[ix + max(0,iy-2) * M.width]; 660 | gy -= I[ix + max(0,iy-1) * M.width] * 8.0f; 661 | gy += I[ix + min(M.height-1,iy+1) * M.width] * 8.0f; 662 | gy -= I[ix + min(M.height-1,iy+2) * M.width]; 663 | gy /= 12.0f; 664 | } 665 | } 666 | break; 667 | } 668 | 669 | // pre-caching 670 | real gxx = gx*gx; 671 | real gxy = gx*gy; 672 | real gyy = gy*gy; 673 | real denom = 1.0f / (mu * (gxx + gyy + mu)); 674 | 675 | // L1 + L2 676 | A11[pos] = denom * (gyy + mu); // A11 677 | A12[pos] = -gxy * denom; // A12 678 | A22[pos] = denom * (gxx + mu); // A22 679 | a[pos] = gx; 680 | b[pos] = gy; 681 | c[pos] = gt; 682 | } 683 | void ComputeDerivativesL1(const real *__restrict__ I0, const real *__restrict__ I1, 684 | cv::Size M, int s, real mu, 685 | real *A11, real *A12, real *A22, real *a, real *b, real *c) 686 | { 687 | int mods = (((M.width % 32) == 0) && ((M.height % 32) == 0)) ? 0 : 1; 688 | switch (mods) 689 | { 690 | case 0: // image size is multiple of 32; use texture memory 691 | { 692 | dim3 threads(32, 32); 693 | dim3 blocks(iDivUp(M.width, threads.x), iDivUp(M.height, threads.y)); 694 | 695 | // replicate if a coordinate value is out-of-range 696 | texTarget.addressMode[0] = cudaAddressModeClamp; 697 | texTarget.addressMode[1] = cudaAddressModeClamp; 698 | texTarget.filterMode = cudaFilterModeLinear; 699 | texTarget.normalized = true; 700 | cudaChannelFormatDesc desc = cudaCreateChannelDesc(); 701 | cudaBindTexture2D(0, texTarget, I1, M.width, M.height, s*sizeof(real)); 702 | 703 | ComputeDerivativesL1Kernel<0><<>>(I0, M, mu, A11, A12, A22, a, b, c); 704 | } 705 | break; 706 | 707 | case 1: // image size is not multiple of 32; use shared memory 708 | { 709 | dim3 threads(BLOCK_X, BLOCK_Y); 710 | dim3 blocks(iDivUp(M.width, threads.x), iDivUp(M.height, threads.y)); 711 | ComputeDerivativesL1Kernel<1><<>>(I0, M, mu, A11, A12, A22, a, b, c, I1); 712 | } 713 | break; 714 | } 715 | } 716 | void phi_update(real *phi, real *I_warp, real *I0, 717 | real alpha, real beta, 718 | real *zeta_x, real *zeta_y, real *grad_x_N, real *grad_y_N, 719 | real *A11, real *A12, real *A22, real *gx, real *gy, real *gt, 720 | complex *tmp_dct_N, real *mat_phi_hat, 721 | complex *ww_1_N, complex *ww_2_N, cufftHandle plan_dct_1_N, cufftHandle plan_dct_2_N, 722 | opt_phi opt_phi_t, cv::Size M, cv::Size N) 723 | { 724 | dim3 threads(32, 32); 725 | dim3 blocks(iDivUp(N.width, threads.x), iDivUp(N.height, threads.y)); 726 | 727 | // compute derivatives and constants 728 | ComputeDerivativesL1(I0, I_warp, M, M.width, opt_phi_t.mu, 729 | A11, A12, A22, gx, gy, gt); 730 | 731 | // initial objective 732 | real objval; 733 | if (opt_phi_t.isverbose) 734 | { 735 | printf("-- phi update\n"); 736 | objval = obj_phi(phi, gx, gy, gt, alpha, beta, M, N, grad_x_N, grad_y_N, zeta_x); 737 | printf("---- init, obj = %.4e \n", objval); 738 | } 739 | 740 | // initialization 741 | checkCudaErrors(cudaMemset(zeta_x, 0, N.area()*sizeof(real))); 742 | checkCudaErrors(cudaMemset(zeta_y, 0, N.area()*sizeof(real))); 743 | checkCudaErrors(cudaMemset(grad_x_N, 0, N.area()*sizeof(real))); 744 | checkCudaErrors(cudaMemset(grad_y_N, 0, N.area()*sizeof(real))); 745 | 746 | // ADMM loop 747 | for (signed int i = 1; i <= opt_phi_t.iter; ++i) 748 | { 749 | // compute nablaT(w - zeta) 750 | nablaT(grad_x_N, grad_y_N, N, phi); 751 | 752 | // phi-update inversion 753 | x_update(phi, tmp_dct_N, mat_phi_hat, ww_1_N, ww_2_N, N.width, N.height, plan_dct_1_N, plan_dct_2_N); 754 | 755 | // w-update & zeta-update 756 | nabla(phi, N, grad_x_N, grad_y_N); 757 | prox_gL1Kernel<<>>(zeta_x, zeta_y, grad_x_N, grad_y_N, opt_phi_t.mu, alpha, 758 | A11, A12, A22, gx, gy, gt, N.width, N.height, M.width, M.height); 759 | } 760 | 761 | // show final objective 762 | if (opt_phi_t.isverbose) 763 | { 764 | objval = obj_phi(phi, gx, gy, gt, alpha, beta, M, N, grad_x_N, grad_y_N, zeta_x); 765 | printf("---- iter = %d, obj = %.4e \n", opt_phi_t.iter, objval); 766 | } 767 | 768 | // update I_warp as: I_warp += [gx; gy] * \nabla\phi 769 | update_I_warp(I_warp, phi, gx, gy, grad_x_N, grad_y_N, M, N); 770 | } 771 | // ============================================================= 772 | 773 | 774 | CWS_A_Phi::CWS_A_Phi(cv::Mat I0, cv::Mat &h_A, cv::Mat &h_phi, opt_algo opt_algo_out) 775 | { 776 | opt_algo_t = opt_algo_out; 777 | 778 | // is verbose 779 | opt_A_t.isverbose = opt_algo_t.isverbose_sub; 780 | opt_phi_t.isverbose = opt_algo_t.isverbose_sub; 781 | 782 | // ADMM parameters 783 | opt_A_t.mu = opt_algo_t.mu_A; 784 | opt_phi_t.mu = opt_algo_t.mu_phi; 785 | 786 | // total number of alternation iterations 787 | opt_A_t.iter = opt_algo_t.A_iter; 788 | opt_phi_t.iter = opt_algo_t.phi_iter; 789 | 790 | // determine sizes 791 | M = I0.size(); 792 | 793 | // validate L 794 | if (opt_algo_t.L.width < 2 || opt_algo_t.L.height < 2) 795 | { 796 | printf("L unspecified or wrong; will use nearest power of two of M; L \in [2, 32].\n"); 797 | opt_algo_t.L.width = min(max(2, (nearest_power_of_two(M.width) - M.width) /2), 32); 798 | opt_algo_t.L.height = min(max(2, (nearest_power_of_two(M.height) - M.height)/2), 32); 799 | } 800 | 801 | // set L 802 | opt_phi_t.L.width = opt_algo_t.L.width; 803 | opt_phi_t.L.height = opt_algo_t.L.height; 804 | 805 | // set N 806 | N = M + opt_algo_t.L + opt_algo_t.L; 807 | printf("M = [%d, %d], N = [%d, %d] \n", M.width, M.height, N.width, N.height); 808 | 809 | // allocate host cv::Mat containers 810 | h_A = cv::Mat::zeros(M, CV_32F); 811 | h_phi = cv::Mat::zeros(N, CV_32F); 812 | 813 | // allocate device pointers 814 | checkCudaErrors(cudaMalloc(&d_I0, M.area()*sizeof(real))); 815 | checkCudaErrors(cudaMalloc(&d_I, M.area()*sizeof(real))); 816 | checkCudaErrors(cudaMalloc(&d_I0_tmp, M.area()*sizeof(real))); 817 | 818 | // copy from host to device 819 | checkCudaErrors(cudaMemcpy(d_I0, I0.ptr(0), M.area()*sizeof(real), cudaMemcpyHostToDevice)); 820 | checkCudaErrors(cudaMemcpy(d_I0_tmp, d_I0, M.area()*sizeof(real), cudaMemcpyDeviceToDevice)); 821 | 822 | // compute A-update parameters 823 | opt_A_t.gamma_new = opt_algo_t.gamma / mean<2>(d_I0, M.area()); 824 | opt_A_t.tau_new = opt_algo_t.tau / mean<2>(d_I0, M.area()); 825 | 826 | // set shrinkage value & cache to constant memory 827 | real shrinkage_value_array [2] = { opt_A_t.gamma_new/(2.0f*opt_algo_t.mu_A), opt_algo_t.alpha/(2.0f*opt_algo_t.mu_A) }; 828 | cudaMemcpyToSymbol(shrinkage_value, shrinkage_value_array, 2*sizeof(real)); 829 | 830 | // cache L to constant memory 831 | int L_size_array [2] = { opt_algo_t.L.width, opt_algo_t.L.height }; 832 | cudaMemcpyToSymbol(L_SIZE, L_size_array, 2*sizeof(int)); 833 | 834 | // allocate variable arrays 835 | // -- main variables 836 | checkCudaErrors(cudaMalloc(&A, M.area()*sizeof(real))); 837 | checkCudaErrors(cudaMalloc(&phi, N.area()*sizeof(real))); 838 | checkCudaErrors(cudaMalloc(&Delta_phi, N.area()*sizeof(real))); 839 | checkCudaErrors(cudaMalloc(&I_warp, M.area()*sizeof(real))); 840 | 841 | // -- update variables 842 | checkCudaErrors(cudaMalloc(&zeta_x_M, M.area()*sizeof(real))); 843 | checkCudaErrors(cudaMalloc(&zeta_y_M, M.area()*sizeof(real))); 844 | checkCudaErrors(cudaMalloc(&zeta_z_M, M.area()*sizeof(real))); 845 | checkCudaErrors(cudaMalloc(&zeta_x_N, N.area()*sizeof(real))); 846 | checkCudaErrors(cudaMalloc(&zeta_y_N, N.area()*sizeof(real))); 847 | checkCudaErrors(cudaMalloc(&zeta_z_N, N.area()*sizeof(real))); 848 | 849 | // -- temp arrays 850 | checkCudaErrors(cudaMalloc(&grad_x_M, M.area()*sizeof(real))); 851 | checkCudaErrors(cudaMalloc(&grad_y_M, M.area()*sizeof(real))); 852 | checkCudaErrors(cudaMalloc(&lap_M, M.area()*sizeof(real))); 853 | checkCudaErrors(cudaMalloc(&grad_x_N, N.area()*sizeof(real))); 854 | checkCudaErrors(cudaMalloc(&grad_y_N, N.area()*sizeof(real))); 855 | checkCudaErrors(cudaMalloc(&lap_N, N.area()*sizeof(real))); 856 | checkCudaErrors(cudaMalloc(&tmp_dct_M,M.area()*sizeof(complex))); 857 | checkCudaErrors(cudaMalloc(&tmp_dct_N,N.area()*sizeof(complex))); 858 | 859 | // -- temp variables for constant coefficients caching 860 | checkCudaErrors(cudaMalloc(&A11, M.area()*sizeof(real))); 861 | checkCudaErrors(cudaMalloc(&A12, M.area()*sizeof(real))); 862 | checkCudaErrors(cudaMalloc(&A22, M.area()*sizeof(real))); 863 | checkCudaErrors(cudaMalloc(&a, M.area()*sizeof(real))); 864 | checkCudaErrors(cudaMalloc(&b, M.area()*sizeof(real))); 865 | checkCudaErrors(cudaMalloc(&c, M.area()*sizeof(real))); 866 | 867 | // prepare FFT plans 868 | int pH_M[1] = {M.height}, pW_M[1] = {M.width}; 869 | int pH_N[1] = {N.height}, pW_N[1] = {N.width}; 870 | cufft_prepare(1, pH_M, pW_M, &plan_dct_1_M, &plan_dct_2_M, NULL, NULL, NULL, NULL, NULL, NULL); 871 | cufft_prepare(1, pH_N, pW_N, &plan_dct_1_N, &plan_dct_2_N, NULL, NULL, NULL, NULL, NULL, NULL); 872 | 873 | // prepare DCT coefficients 874 | checkCudaErrors(cudaMalloc(&ww_1_M, M.height*sizeof(complex))); 875 | checkCudaErrors(cudaMalloc(&ww_2_M, M.width *sizeof(complex))); 876 | checkCudaErrors(cudaMalloc(&ww_1_N, N.height*sizeof(complex))); 877 | checkCudaErrors(cudaMalloc(&ww_2_N, N.width *sizeof(complex))); 878 | computeDCTweights(M.width, M.height, ww_1_M, ww_2_M); 879 | computeDCTweights(N.width, N.height, ww_1_N, ww_2_N); 880 | 881 | // prepare inversion matrices 882 | checkCudaErrors(cudaMalloc(&mat_A_hat, M.area()*sizeof(real))); 883 | checkCudaErrors(cudaMalloc(&mat_phi_hat, N.area()*sizeof(real))); 884 | compute_inverse_mat(opt_A_t, opt_phi_t, opt_algo_t.beta, M, N, mat_A_hat, mat_phi_hat); 885 | } 886 | 887 | CWS_A_Phi::~CWS_A_Phi() 888 | { 889 | checkCudaErrors(cudaFree(d_I0)); 890 | checkCudaErrors(cudaFree(d_I)); 891 | checkCudaErrors(cudaFree(d_I0_tmp)); 892 | 893 | // result variables ------------- 894 | checkCudaErrors(cudaFree(A)); 895 | checkCudaErrors(cudaFree(phi)); 896 | checkCudaErrors(cudaFree(Delta_phi)); 897 | checkCudaErrors(cudaFree(I_warp)); 898 | 899 | // update variables ------------- 900 | checkCudaErrors(cudaFree(zeta_x_M)); 901 | checkCudaErrors(cudaFree(zeta_y_M)); 902 | checkCudaErrors(cudaFree(zeta_z_M)); 903 | checkCudaErrors(cudaFree(zeta_x_N)); 904 | checkCudaErrors(cudaFree(zeta_y_N)); 905 | checkCudaErrors(cudaFree(zeta_z_N)); 906 | 907 | // temp variables ------------- 908 | checkCudaErrors(cudaFree(grad_x_M)); 909 | checkCudaErrors(cudaFree(grad_y_M)); 910 | checkCudaErrors(cudaFree(lap_M)); 911 | checkCudaErrors(cudaFree(grad_x_N)); 912 | checkCudaErrors(cudaFree(grad_y_N)); 913 | checkCudaErrors(cudaFree(lap_N)); 914 | checkCudaErrors(cudaFree(tmp_dct_M)); 915 | checkCudaErrors(cudaFree(tmp_dct_N)); 916 | checkCudaErrors(cudaFree(A11)); 917 | checkCudaErrors(cudaFree(A12)); 918 | checkCudaErrors(cudaFree(A22)); 919 | checkCudaErrors(cudaFree(a)); 920 | checkCudaErrors(cudaFree(b)); 921 | checkCudaErrors(cudaFree(c)); 922 | 923 | // FFT plans ------------- 924 | cufftDestroy(plan_dct_1_M); 925 | cufftDestroy(plan_dct_2_M); 926 | cufftDestroy(plan_dct_1_N); 927 | cufftDestroy(plan_dct_2_N); 928 | 929 | // DCT coefficients ------------- 930 | checkCudaErrors(cudaFree(ww_1_M)); 931 | checkCudaErrors(cudaFree(ww_2_M)); 932 | checkCudaErrors(cudaFree(ww_1_N)); 933 | checkCudaErrors(cudaFree(ww_2_N)); 934 | 935 | // inversion matrices ------------- 936 | checkCudaErrors(cudaFree(mat_A_hat)); 937 | checkCudaErrors(cudaFree(mat_phi_hat)); 938 | } 939 | 940 | void CWS_A_Phi::setParas(opt_algo opt_algo_out) 941 | { 942 | opt_algo_t = opt_algo_out; 943 | } 944 | 945 | // ============================================================= 946 | // main algorithm 947 | // ============================================================= 948 | void CWS_A_Phi::solver(cv::Mat I) 949 | { 950 | // CWS_A_PHI Simutanous intensity and wavefront recovery. Solve for: 951 | // min || i(x+\nabla phi) - A i_0(x) ||_2^2 + ... 952 | //A,phi alpha || \nabla phi ||_1 + ... 953 | // beta ( || \nabla phi ||_2^2 + || \nabla^2 phi ||_2^2 ) + ... 954 | // gamma ( || \nabla A ||_1 + || \nabla^2 A ||_1 ) + ... 955 | // tau ( || \nabla A ||_2^2 + || \nabla^2 A ||_2^2 ) 956 | 957 | // copy from host to device 958 | checkCudaErrors(cudaMemcpy(d_I, I.ptr(0), M.area()*sizeof(real), cudaMemcpyHostToDevice)); 959 | 960 | // record variables 961 | real objval, mean_Delta_phi; 962 | 963 | // initialization 964 | setasones(A, M); 965 | checkCudaErrors(cudaMemset(phi, 0, N.area()*sizeof(real))); 966 | checkCudaErrors(cudaMemset(Delta_phi, 0, N.area()*sizeof(real))); 967 | checkCudaErrors(cudaMemcpy(I_warp, d_I, M.area()*sizeof(real), cudaMemcpyDeviceToDevice)); 968 | 969 | // initial objective 970 | if (opt_algo_t.isverbose) 971 | { 972 | objval = obj_total(A, phi, I_warp, d_I0, opt_algo_t.alpha, opt_algo_t.beta, opt_algo_t.gamma, opt_algo_t.tau, M, N, 973 | grad_x_M, grad_y_M, lap_M, grad_x_N, grad_y_N, lap_N); 974 | printf("iter = %d, obj = %.4e\n", 0, objval); 975 | } 976 | 977 | // the loop 978 | real obj_min = 0x7f800000; // = Inf in float 979 | for (signed int outer_loop = 1; outer_loop <= opt_algo_t.iter; ++outer_loop) 980 | { 981 | // === A-update === 982 | A_update(A, I_warp, d_I0, 983 | zeta_x_M, zeta_y_M, zeta_z_M, 984 | grad_x_M, grad_y_M, lap_M, 985 | tmp_dct_M, mat_A_hat, ww_1_M, ww_2_M, plan_dct_1_M, plan_dct_2_M, opt_A_t, M); 986 | 987 | // median filtering A 988 | medfilt2(A, M.width, M.height, A); 989 | 990 | // update d_I0_tmp 991 | I_warp_update(A, d_I0, M, d_I0_tmp); 992 | 993 | // === phi-update === 994 | phi_update(Delta_phi, I_warp, d_I0_tmp, opt_algo_t.alpha, opt_algo_t.beta, zeta_x_N, zeta_y_N, grad_x_N, grad_y_N, 995 | A11, A12, A22, a, b, c, tmp_dct_N, mat_phi_hat, 996 | ww_1_N, ww_2_N, plan_dct_1_N, plan_dct_2_N, opt_phi_t, M, N); 997 | 998 | // medfilt2(Delta_phi, N.width, N.height, Delta_phi); 999 | 1000 | // === accmulate phi === 1001 | mean_Delta_phi = mean<1>(Delta_phi, N.area()); 1002 | if (mean_Delta_phi < opt_algo_t.phi_tol) 1003 | { 1004 | printf("-- mean(|Delta phi|) = %.4e < %.4e = tol; Quit. \n", mean_Delta_phi, opt_algo_t.phi_tol); 1005 | break; 1006 | } 1007 | else 1008 | { 1009 | printf("-- mean(|Delta phi|) = %.4e \n", mean_Delta_phi); 1010 | Add(phi, Delta_phi, N.area(), phi); 1011 | } 1012 | 1013 | // === records === 1014 | if (opt_algo_t.isverbose) 1015 | { 1016 | objval = obj_total(A, phi, I_warp, d_I0, opt_algo_t.alpha, opt_algo_t.beta, opt_algo_t.gamma, opt_algo_t.tau, M, N, 1017 | grad_x_M, grad_y_M, lap_M, grad_x_N, grad_y_N, lap_N); 1018 | printf("iter = %d, obj = %.4e\n", outer_loop, objval); 1019 | if (objval > obj_min) 1020 | { 1021 | printf("Obj increasing; Quit. \n"); 1022 | break; 1023 | } 1024 | else 1025 | obj_min = objval; 1026 | } 1027 | } 1028 | 1029 | // median filtering phi 1030 | medfilt2(phi, N.width, N.height, phi); 1031 | } 1032 | 1033 | 1034 | // download results from GPU to host, and crop phi 1035 | void CWS_A_Phi::download(cv::Mat &h_A, cv::Mat &h_phi) 1036 | { 1037 | // dummy variable 1038 | cv::Mat phi_tmp = cv::Mat::zeros(N, CV_32F); 1039 | 1040 | // transfer result from device to host 1041 | checkCudaErrors(cudaMemcpy(h_A.ptr(0), A, M.area()*sizeof(real), cudaMemcpyDeviceToHost)); 1042 | checkCudaErrors(cudaMemcpy(phi_tmp.ptr(0), phi, N.area()*sizeof(real), cudaMemcpyDeviceToHost)); 1043 | 1044 | // crop phi to preserve only the center part 1045 | h_phi = phi_tmp(cv::Rect(opt_algo_t.L.width, opt_algo_t.L.height, M.width, M.height)).clone(); 1046 | } 1047 | 1048 | // wrapper function for cws 1049 | void cws_A_phi(cv::Mat I0, cv::Mat I, cv::Mat &h_A, cv::Mat &h_phi, opt_algo para_algo) 1050 | { 1051 | // define cws object 1052 | CWS_A_Phi cws_obj(I0, h_A, h_phi, para_algo); 1053 | 1054 | // record variables 1055 | real fps; 1056 | GpuTimer timer; 1057 | 1058 | // timing started 1059 | timer.Start(); 1060 | 1061 | // the loop 1062 | int rep_times = 1; 1063 | for (signed int i = 0; i < rep_times; ++i) 1064 | { 1065 | // run cws solver 1066 | cws_obj.solver(I); 1067 | cws_obj.download(h_A, h_phi); 1068 | 1069 | // timing ended 1070 | timer.Stop(); 1071 | fps = 1000/timer.Elapsed()*rep_times; 1072 | printf("Mean elapsed time: %.4f ms. Mean frame rate: %.4f fps. \n", timer.Elapsed()/rep_times, fps); 1073 | } 1074 | } 1075 | -------------------------------------------------------------------------------- /cuda/src/cws_A_phi.h: -------------------------------------------------------------------------------- 1 | #ifndef CWS_A_PHI_H 2 | #define CWS_A_PHI_H 3 | 4 | #include 5 | 6 | // accuracy 7 | #ifdef USE_DOUBLES 8 | typedef double real; 9 | typedef cufftDoubleComplex complex; 10 | #else 11 | typedef float real; 12 | typedef cufftComplex complex; 13 | #endif 14 | 15 | // default parameters 16 | struct opt_algo { 17 | bool isverbose_sub = false; 18 | bool isverbose = false; 19 | int iter = 3; 20 | int A_iter = 20; 21 | int phi_iter = 20; 22 | float alpha = 0.1f; 23 | float beta = 0.1f; 24 | float gamma = 100.0f; 25 | float tau = 5.0f; 26 | float phi_tol = 0.05f; 27 | float mu_A = 0.1f; 28 | float mu_phi = 100.0f; 29 | cv::Size L = cv::Size(1,1); 30 | // (1,1) for initializers; if unspecified, will be updated in CWS_A_Phi::CWS_A_Phi as L = cv::Size(2,2) 31 | }; 32 | 33 | 34 | // ============================================================= 35 | // algorithm parameter structures 36 | // ============================================================= 37 | struct opt_A { 38 | bool isverbose = false; 39 | int iter = 10; 40 | real mu = 0.1; 41 | real gamma_new; 42 | real tau_new; 43 | }; 44 | 45 | struct opt_phi { 46 | bool isverbose = false; 47 | int iter = 10; 48 | real mu = 100.0; 49 | cv::Size L = cv::Size(2,2); 50 | }; 51 | // ============================================================= 52 | 53 | 54 | class CWS_A_Phi 55 | { 56 | // algorithm parameters 57 | opt_A opt_A_t; 58 | opt_phi opt_phi_t; 59 | 60 | // device pointers 61 | real *d_I0, *d_I, *d_I0_tmp; 62 | 63 | // update variables ------------- 64 | real *zeta_x_M, *zeta_y_M, *zeta_z_M; 65 | real *zeta_x_N, *zeta_y_N, *zeta_z_N; 66 | 67 | // temp variables ------------- 68 | real *grad_x_M, *grad_y_M, *lap_M; 69 | real *grad_x_N, *grad_y_N, *lap_N; 70 | complex *tmp_dct_M, *tmp_dct_N; 71 | real *A11, *A12, *A22, *a, *b, *c; 72 | 73 | // FFT plans ------------- 74 | cufftHandle plan_dct_1_M, plan_dct_2_M; 75 | cufftHandle plan_dct_1_N, plan_dct_2_N; 76 | 77 | // DCT coefficients ------------- 78 | complex *ww_1_M, *ww_2_M, *ww_1_N, *ww_2_N; 79 | 80 | // inversion matrices ------------- 81 | real *mat_A_hat, *mat_phi_hat; 82 | 83 | public: 84 | // dimensions 85 | cv::Size M, N; 86 | 87 | // algorithm parameters 88 | opt_algo opt_algo_t; 89 | 90 | // result variables ------------- 91 | real *A, *phi, *Delta_phi, *I_warp; 92 | 93 | CWS_A_Phi(cv::Mat I0, cv::Mat &h_A, cv::Mat &h_phi, opt_algo opt_algo_out); 94 | ~CWS_A_Phi(); 95 | void setParas(opt_algo opt_algo_out); 96 | void solver(cv::Mat I); 97 | void download(cv::Mat &h_A, cv::Mat &h_phi); 98 | }; 99 | 100 | void cws_A_phi(cv::Mat I0, cv::Mat I, cv::Mat &A, cv::Mat &phi, opt_algo para_algo); 101 | 102 | #endif 103 | -------------------------------------------------------------------------------- /cuda/src/medianfilteringKernel.cuh: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | 3 | // replicate boundary condition (for stencil size of 3x3 or 5x5) 4 | template 5 | __device__ 6 | void set_bd_shared_memory(volatile float (*smem)[BLOCK_Y+stencil_size-1], const float *d_in, 7 | const int tx, const int ty, const int ix, const int iy, int w, int h) 8 | { 9 | const int py = iy*w; 10 | const int x_min = max(ix-1, 0); 11 | const int x_max = min(ix+1, w-1); 12 | const int y_min = max(iy-1, 0) *w; 13 | const int y_max = min(iy+1, h-1)*w; 14 | switch (stencil_size) 15 | { 16 | case 3: 17 | { 18 | if (tx == 1) 19 | { 20 | smem[0][ty] = d_in[x_min + py]; 21 | if (ty == 1) 22 | { 23 | smem[0][0] = d_in[x_min + y_min]; 24 | smem[0][BLOCK_Y+1] = d_in[x_min + y_max]; 25 | } 26 | } 27 | if (tx == BLOCK_X) 28 | { 29 | smem[BLOCK_X+1][ty] = d_in[x_max + py]; 30 | if (ty == 1) 31 | { 32 | smem[BLOCK_X+1][0] = d_in[x_max + y_min]; 33 | smem[BLOCK_X+1][BLOCK_Y+1] = d_in[x_max + y_max]; 34 | } 35 | } 36 | if (ty == 1) 37 | { 38 | smem[tx][0] = d_in[ix + y_min]; 39 | } 40 | if (ty == BLOCK_Y) 41 | { 42 | smem[tx][BLOCK_Y+1] = d_in[ix + y_max]; 43 | } 44 | } 45 | break; 46 | 47 | case 5: 48 | { 49 | // boundary guards 50 | const int x_min2 = max(ix-2, 0); 51 | const int x_max2 = min(ix+2, w-1); 52 | const int y_min2 = max(iy-2, 0) *w; 53 | const int y_max2 = min(iy+2, h-1)*w; 54 | 55 | if (tx == 2) 56 | { 57 | smem[0][ty] = d_in[x_min2 + py]; 58 | smem[1][ty] = d_in[x_min + py]; 59 | if (ty == 2) 60 | { 61 | smem[0][0] = d_in[x_min2 + y_min2]; 62 | smem[1][0] = d_in[x_min + y_min2]; 63 | smem[0][1] = d_in[x_min2 + y_min ]; 64 | smem[1][1] = d_in[x_min + y_min ]; 65 | } 66 | } 67 | if (tx == BLOCK_X+1) 68 | { 69 | smem[BLOCK_X+2][ty] = d_in[x_max + py]; 70 | smem[BLOCK_X+3][ty] = d_in[x_max2 + py]; 71 | if (ty == BLOCK_Y+1) 72 | { 73 | smem[BLOCK_X+2][BLOCK_Y+2] = d_in[x_max + y_max ]; 74 | smem[BLOCK_X+3][BLOCK_Y+2] = d_in[x_max2 + y_max ]; 75 | smem[BLOCK_X+2][BLOCK_Y+3] = d_in[x_max + y_max2]; 76 | smem[BLOCK_X+3][BLOCK_Y+3] = d_in[x_max2 + y_max2]; 77 | } 78 | } 79 | if (ty == 2) 80 | { 81 | smem[tx][0] = d_in[ix + y_min2]; 82 | smem[tx][1] = d_in[ix + y_min]; 83 | if (tx == BLOCK_X+1) 84 | { 85 | smem[BLOCK_X+2][0] = d_in[x_max + y_min2]; 86 | smem[BLOCK_X+3][0] = d_in[x_max2 + y_min2]; 87 | smem[BLOCK_X+2][1] = d_in[x_max + y_min ]; 88 | smem[BLOCK_X+3][1] = d_in[x_max2 + y_min ]; 89 | } 90 | } 91 | if (ty == BLOCK_Y+1) 92 | { 93 | smem[tx][BLOCK_Y+2] = d_in[ix + y_max]; 94 | smem[tx][BLOCK_Y+3] = d_in[ix + y_max2]; 95 | if (tx == 2) 96 | { 97 | smem[0][BLOCK_X+2] = d_in[x_min2 + y_max ]; 98 | smem[0][BLOCK_X+3] = d_in[x_min2 + y_max2]; 99 | smem[1][BLOCK_X+2] = d_in[x_min + y_max ]; 100 | smem[1][BLOCK_X+3] = d_in[x_min + y_max2]; 101 | } 102 | } 103 | } 104 | break; 105 | } 106 | } 107 | 108 | // Exchange trick: Morgan McGuire, ShaderX 2008 109 | #define s2(a,b) { float tmp = a; a = min(a,b); b = max(tmp,b); } 110 | #define mn3(a,b,c) s2(a,b); s2(a,c); 111 | #define mx3(a,b,c) s2(b,c); s2(a,c); 112 | 113 | #define mnmx3(a,b,c) mx3(a,b,c); s2(a,b); // 3 exchanges 114 | #define mnmx4(a,b,c,d) s2(a,b); s2(c,d); s2(a,c); s2(b,d); // 4 exchanges 115 | #define mnmx5(a,b,c,d,e) s2(a,b); s2(c,d); mn3(a,c,e); mx3(b,d,e); // 6 exchanges 116 | #define mnmx6(a,b,c,d,e,f) s2(a,d); s2(b,e); s2(c,f); mn3(a,b,c); mx3(d,e,f); // 7 exchanges 117 | 118 | __global__ void medfilt2_exch(int width, int height, float *d_out, float *d_in) 119 | { 120 | const int tx = threadIdx.x + 1; 121 | const int ty = threadIdx.y + 1; 122 | const int ix = threadIdx.x + blockIdx.x * blockDim.x; 123 | const int iy = threadIdx.y + blockIdx.y * blockDim.y; 124 | const int pos = ix + iy*width; 125 | 126 | if (ix >= width || iy >= height) return; 127 | 128 | // shared memory for 3x3 stencil 129 | __shared__ float smem[BLOCK_X+2][BLOCK_Y+2]; 130 | 131 | // load in shared memory 132 | smem[tx][ty] = d_in[pos]; 133 | set_bd_shared_memory<3>(smem, d_in, tx, ty, ix, iy, width, height); 134 | __syncthreads(); 135 | 136 | // pull top six from shared memory 137 | float v[6] = { smem[tx-1][ty-1], smem[tx][ty-1], smem[tx+1][ty-1], 138 | smem[tx-1][ty ], smem[tx][ty ], smem[tx+1][ty ]}; 139 | 140 | // with each pass, remove min and max values and add new value 141 | mnmx6(v[0], v[1], v[2], v[3], v[4], v[5]); 142 | v[5] = smem[tx-1][ty+1]; // add new contestant 143 | mnmx5(v[1], v[2], v[3], v[4], v[5]); 144 | v[5] = smem[tx][ty+1]; 145 | mnmx4(v[2], v[3], v[4], v[5]); 146 | v[5] = smem[tx+1][ty+1]; 147 | mnmx3(v[3], v[4], v[5]); 148 | 149 | // pick the middle one 150 | d_out[pos] = v[4]; 151 | } 152 | 153 | 154 | /////////////////////////////////////////////////////////////////////////////// 155 | /// \brief 3-by-3 median filtering of an image 156 | /// \param[in] img_in input image 157 | /// \param[in] w width of input image 158 | /// \param[in] h height of input image 159 | /// \param[out] img_out output image 160 | /////////////////////////////////////////////////////////////////////////////// 161 | static 162 | void medfilt2(float *img_in, int w, int h, float *img_out) 163 | { 164 | dim3 threads(BLOCK_X, BLOCK_Y); 165 | dim3 blocks(iDivUp(w, threads.x), iDivUp(h, threads.y)); 166 | 167 | medfilt2_exch<<>>(w, h, img_out, img_in); 168 | 169 | } 170 | -------------------------------------------------------------------------------- /cuda/src/prepare_cufft_warmup.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "assert.h" 7 | 8 | typedef cufftComplex complex; 9 | 10 | void cufft_warper(complex *h_in, int n, int m, cufftHandle plan, complex *h_out) 11 | { 12 | const int data_size = n*m*sizeof(complex); 13 | 14 | // device memory allocation 15 | complex *d_temp; 16 | checkCudaErrors(cudaMalloc(&d_temp, data_size)); 17 | 18 | // transfer data from host to device 19 | checkCudaErrors(cudaMemcpy(d_temp, h_in, data_size, cudaMemcpyHostToDevice)); 20 | 21 | // Compute the FFT 22 | cufftExecC2C(plan, d_temp, d_temp, CUFFT_FORWARD); 23 | 24 | // transfer result from device to host 25 | checkCudaErrors(cudaMemcpy(h_out, d_temp, data_size, cudaMemcpyDeviceToHost)); 26 | 27 | // cleanup 28 | checkCudaErrors(cudaFree(d_temp)); 29 | } 30 | 31 | void cufft_prepare(int nLevels, int *pH_N, int *pW_N, 32 | cufftHandle *plan_dct_1, cufftHandle *plan_dct_2, 33 | cufftHandle *plan_dct_3, cufftHandle *plan_dct_4, 34 | cufftHandle *plan_dct_5, cufftHandle *plan_dct_6, 35 | cufftHandle *plan_dct_7, cufftHandle *plan_dct_8) 36 | { 37 | // prepare cufft plans & warmup 38 | printf("Preparing CuFFT plans and warmups ... "); 39 | 40 | int Length1[1], Length2[1]; 41 | if (nLevels >= 1) 42 | { 43 | Length1[0] = pH_N[0]; // for each FFT, the Length1 is N_height 44 | Length2[0] = pW_N[0]; // for each FFT, the Length2 is N_width 45 | cufftPlanMany(plan_dct_1, 1, Length1, 46 | Length1, pW_N[0], 1, 47 | Length1, pW_N[0], 1, 48 | CUFFT_C2C, pW_N[0]); 49 | cufftPlanMany(plan_dct_2, 1, Length2, 50 | Length2, pH_N[0], 1, 51 | Length2, pH_N[0], 1, 52 | CUFFT_C2C, pH_N[0]); 53 | } 54 | else 55 | { 56 | printf("No CuFFT plans prepared; out ... \n"); 57 | } 58 | 59 | if (nLevels >= 2) 60 | { 61 | Length1[0] = pH_N[1]; // for each FFT, the Length1 is N_height 62 | Length2[0] = pW_N[1]; // for each FFT, the Length2 is N_width 63 | cufftPlanMany(plan_dct_3, 1, Length1, 64 | Length1, pW_N[1], 1, 65 | Length1, pW_N[1], 1, 66 | CUFFT_C2C, pW_N[1]); 67 | cufftPlanMany(plan_dct_4, 1, Length2, 68 | Length2, pH_N[1], 1, 69 | Length2, pH_N[1], 1, 70 | CUFFT_C2C, pH_N[1]); 71 | } 72 | 73 | if (nLevels >= 3) 74 | { 75 | Length1[0] = pH_N[2]; // for each FFT, the Length1 is N_height 76 | Length2[0] = pW_N[2]; // for each FFT, the Length2 is N_width 77 | cufftPlanMany(plan_dct_5, 1, Length1, 78 | Length1, pW_N[2], 1, 79 | Length1, pW_N[2], 1, 80 | CUFFT_C2C, pW_N[2]); 81 | cufftPlanMany(plan_dct_6, 1, Length2, 82 | Length2, pH_N[2], 1, 83 | Length2, pH_N[2], 1, 84 | CUFFT_C2C, pH_N[2]); 85 | } 86 | 87 | if (nLevels >= 4) 88 | { 89 | Length1[0] = pH_N[3]; // for each FFT, the Length1 is N_height 90 | Length2[0] = pW_N[3]; // for each FFT, the Length2 is N_width 91 | cufftPlanMany(plan_dct_7, 1, Length2, 92 | Length1, pW_N[3], 1, 93 | Length1, pW_N[3], 1, 94 | CUFFT_C2C, pW_N[3]); 95 | cufftPlanMany(plan_dct_8, 1, Length2, 96 | Length2, pH_N[3], 1, 97 | Length2, pH_N[3], 1, 98 | CUFFT_C2C, pH_N[3]); 99 | } 100 | 101 | // cufft warmup 102 | int N_width = pW_N[0]; 103 | int N_height = pH_N[0]; 104 | complex *h_warmup_in = new complex[N_width * N_height]; 105 | complex *h_warmup_out = new complex[N_width * N_height]; 106 | cufft_warper(h_warmup_in, N_width, N_height, *plan_dct_1, h_warmup_out); 107 | cufft_warper(h_warmup_in, N_width, N_height, *plan_dct_2, h_warmup_out); 108 | delete[] h_warmup_in; 109 | delete[] h_warmup_out; 110 | printf("Done.\n"); 111 | } -------------------------------------------------------------------------------- /cuda/src/prepare_cufft_warmup.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | typedef cufftComplex complex; 7 | 8 | #ifndef PREPARE_CUFFT_WARMUP_H 9 | #define PREPARE_CUFFT_WARMUP_H 10 | 11 | void cufft_warper(complex *h_in, int n, int m, cufftHandle plan, complex *h_out); 12 | 13 | void cufft_prepare(int nLevels, int *pH_N, int *pW_N, 14 | cufftHandle *plan_dct_1, cufftHandle *plan_dct_2, 15 | cufftHandle *plan_dct_3, cufftHandle *plan_dct_4, 16 | cufftHandle *plan_dct_5, cufftHandle *plan_dct_6, 17 | cufftHandle *plan_dct_7, cufftHandle *plan_dct_8); 18 | 19 | #endif 20 | -------------------------------------------------------------------------------- /cuda/src/prepare_precomputations.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "common.h" 3 | 4 | #include 5 | 6 | // include kernels 7 | #include "computemat_x_hatKernel.cuh" 8 | #include "computeDCTweightsKernel.cuh" 9 | 10 | // for thrust functions 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | /////////////////////////////////////////////////////////////////////////////// 18 | /// texture references 19 | /////////////////////////////////////////////////////////////////////////////// 20 | 21 | /// image to downscale 22 | texture texFine; 23 | 24 | 25 | static 26 | __global__ 27 | void transposeKernel(float *in, int width, int height, int stride1, int stride2, float *out) 28 | { 29 | const int ix = threadIdx.x + blockIdx.x * blockDim.x; 30 | const int iy = threadIdx.y + blockIdx.y * blockDim.y; 31 | 32 | if (ix >= width || iy >= height) return; 33 | 34 | out[iy + ix * stride2] = in[ix + iy * stride1]; 35 | } 36 | 37 | 38 | static 39 | __device__ void tridisolve(float *x, float *b, float *mu, int m) 40 | { 41 | // forward 42 | for (int i = 0; i < m-1; i++) 43 | x[i+1] -= mu[i] * x[i]; 44 | 45 | // backward 46 | x[m-1] /= b[m-1]; 47 | for (int i = m-2; i >= 0; i--) 48 | x[i] = (x[i] - x[i+1]) / b[i]; 49 | 50 | // maginify by 6 51 | for (int i = 0; i < m; i++) 52 | x[i] = 6.0f * x[i]; 53 | } 54 | 55 | 56 | static 57 | __global__ 58 | void tridisolve_parallel(float *img, float *b_h, float *mu_h, 59 | int width, int height, int stride2) 60 | { 61 | const int iy = threadIdx.x + blockIdx.x * blockDim.x; 62 | 63 | if (iy >= width) return; 64 | 65 | tridisolve(img + iy * stride2, b_h, mu_h, height); 66 | } 67 | 68 | 69 | void cbanal(float *in, float *out, int width, int height, int stride1, int stride2) 70 | { 71 | float *b = new float [height]; 72 | float *mu = new float [height-1]; 73 | 74 | // set b 75 | for (int j = 0; j < height; j++) 76 | b[j] = 4.0f; 77 | 78 | // pre-computations 79 | for (int j = 0; j < height-1; j++) 80 | { 81 | mu[j] = 1.0f / b[j]; 82 | b[j+1] -= mu[j]; 83 | } 84 | 85 | float *d_b, *d_mu; 86 | checkCudaErrors(cudaMalloc(&d_b, height*sizeof(float))); 87 | checkCudaErrors(cudaMalloc(&d_mu, (height-1)*sizeof(float))); 88 | 89 | checkCudaErrors(cudaMemcpy(d_b, b, height*sizeof(float), cudaMemcpyHostToDevice)); 90 | checkCudaErrors(cudaMemcpy(d_mu, mu, (height-1)*sizeof(float), cudaMemcpyHostToDevice)); 91 | 92 | dim3 threads_1D(256); 93 | dim3 blocks_1D(iDivUp(width, threads_1D.x)); 94 | 95 | dim3 threads_2D(32, 6); 96 | dim3 blocks_2D(iDivUp(width, threads_2D.x), iDivUp(height, threads_2D.y)); 97 | 98 | transposeKernel<<>>(in, width, height, stride1, stride2, out); 99 | tridisolve_parallel<<>>(out, d_b, d_mu, width, height, stride2); 100 | 101 | // cleanup 102 | checkCudaErrors(cudaFree(d_b)); 103 | checkCudaErrors(cudaFree(d_mu)); 104 | delete [] b; 105 | delete [] mu; 106 | } 107 | 108 | 109 | 110 | void cbanal2D(float *img, int width, int height, int stride1, int stride2) 111 | { 112 | float *temp; 113 | checkCudaErrors(cudaMalloc(&temp, height*width*sizeof(float))); 114 | 115 | // compute the cubic B-spline coefficients 116 | cbanal(img, temp, width, height, stride1, stride2); 117 | cbanal(temp, img, height, width, stride2, stride1); 118 | 119 | // cleanup 120 | checkCudaErrors(cudaFree(temp)); 121 | } 122 | 123 | 124 | 125 | static 126 | __global__ void anti_weight_x_Kernel(int width, int height, int stride, float *out) 127 | { 128 | const int ix = threadIdx.x + blockIdx.x * blockDim.x; 129 | const int iy = threadIdx.y + blockIdx.y * blockDim.y; 130 | 131 | if (ix >= width || iy >= height) 132 | return; 133 | 134 | float dx = 1.0f/(float)width; 135 | float dy = 1.0f/(float)height; 136 | 137 | float x = ((float)ix + 0.5f) * dx; 138 | float y = ((float)iy + 0.5f) * dy; 139 | 140 | x -= 0.25*dx; 141 | y -= 0.25*dy; 142 | 143 | dx /= 2.0f; 144 | dy /= 2.0f; 145 | 146 | // magic number, in MATLAB: conv2([0.125 0.375 0.375 0.125],[0.125 0.375 0.375 0.125]') 147 | out[ix + iy * stride] = 148 | 0.015625f*tex2D(texFine, x-dx, y-1*dy) + 0.046875f*tex2D(texFine, x, y-1*dy) + 0.046875f*tex2D(texFine, x+dx, y-1*dy) + 0.015625f*tex2D(texFine, x+2*dx, y-1*dy) + 149 | 0.046875f*tex2D(texFine, x-dx, y ) + 0.140625f*tex2D(texFine, x, y ) + 0.140625f*tex2D(texFine, x+dx, y ) + 0.046875f*tex2D(texFine, x+2*dx, y ) + 150 | 0.046875f*tex2D(texFine, x-dx, y+1*dy) + 0.140625f*tex2D(texFine, x, y+1*dy) + 0.140625f*tex2D(texFine, x+dx, y+1*dy) + 0.046875f*tex2D(texFine, x+2*dx, y+1*dy) + 151 | 0.015625f*tex2D(texFine, x-dx, y+2*dy) + 0.046875f*tex2D(texFine, x, y+2*dy) + 0.046875f*tex2D(texFine, x+dx, y+2*dy) + 0.015625f*tex2D(texFine, x+2*dx, y+2*dy); 152 | } 153 | 154 | 155 | extern 156 | void Downscale_Anti(const float *src, int width, int height, int stride, 157 | int newWidth, int newHeight, int newStride, float *out) 158 | { 159 | dim3 threads(32, 8); 160 | dim3 blocks(iDivUp(newWidth, threads.x), iDivUp(newHeight, threads.y)); 161 | 162 | // mirror if a coordinate value is out-of-range 163 | texFine.addressMode[0] = cudaAddressModeMirror; 164 | texFine.addressMode[1] = cudaAddressModeMirror; 165 | texFine.filterMode = cudaFilterModeLinear; 166 | texFine.normalized = true; 167 | 168 | cudaChannelFormatDesc desc = cudaCreateChannelDesc(); 169 | 170 | checkCudaErrors(cudaBindTexture2D(0, texFine, src, width, height, stride * sizeof(float))); 171 | 172 | anti_weight_x_Kernel<<>>(newWidth, newHeight, newStride, out); 173 | } 174 | 175 | 176 | 177 | 178 | /////////////////////////////////////////////////////////////////////////////// 179 | /// \brief method logic 180 | /// 181 | /// handles memory allocations, control flow 182 | /// \param[in] N_W unknown phase width 183 | /// \param[in] N_H unknown phase height 184 | /// \param[in] alpha degree of displacement field smoothness 185 | /// \param[in] mu proximal parameter 186 | /// \param[out] h_mat_x_hat pre-computed mat_x_hat 187 | /// \param[out] ww_1 ww_1 coefficient 188 | /// \param[out] ww_2 ww_2 coefficient 189 | /////////////////////////////////////////////////////////////////////////////// 190 | void prepare_precomputations(int nLevels, int L_width, int L_height, 191 | int N_width, int N_height, 192 | int M_width, int M_height, int M_stride, 193 | int N_W, int N_H, int currentLevel, 194 | int *pW_N, int *pH_N, int *pS_N, 195 | int *pW_M, int *pH_M, int *pS_M, 196 | int *pW_L, int *pH_L, 197 | float **pI0, float **pI1, float **d_I0_coeff, 198 | float alpha, float *mu, const float **mat_x_hat, 199 | const complex **ww_1, const complex **ww_2) 200 | { 201 | // determine L value for best performance (we use mod 2 here) 202 | for (int i = nLevels - 1; i >= 0; i--) { 203 | pW_L[i] = L_width >> (nLevels - 1 - i); 204 | pH_L[i] = L_height >> (nLevels - 1 - i); 205 | if (pW_L[i] < 2) pW_L[i] = 2; 206 | if (pH_L[i] < 2) pH_L[i] = 2; 207 | } 208 | 209 | pW_N[nLevels - 1] = M_width + 2 * pW_L[nLevels - 1]; 210 | pH_N[nLevels - 1] = M_height + 2 * pH_L[nLevels - 1]; 211 | pS_N[nLevels - 1] = iAlignUp(pW_N[nLevels - 1]); 212 | pW_M[nLevels - 1] = M_width; 213 | pH_M[nLevels - 1] = M_height; 214 | pS_M[nLevels - 1] = M_stride; 215 | 216 | printf("initial sizes: W_N = %d, H_N = %d, S_N = %d\n", N_width, N_height, iAlignUp(pW_N[nLevels - 1])); 217 | printf("initial sizes: W_M = %d, H_M = %d, S_M = %d\n", M_width, M_height, M_stride); 218 | 219 | 220 | printf("Pre-compute the variables on GPU...\n"); 221 | 222 | if (currentLevel != 0){ // prepare pyramid 223 | for (; currentLevel > 0; currentLevel--) 224 | { 225 | int nw = pW_M[currentLevel] / 2; 226 | int nh = pH_M[currentLevel] / 2; 227 | int ns = iAlignUp(nw); 228 | 229 | pW_N[currentLevel - 1] = nw + 2*pW_L[currentLevel-1]; 230 | pH_N[currentLevel - 1] = nh + 2*pH_L[currentLevel-1]; 231 | pS_N[currentLevel - 1] = iAlignUp(pW_N[currentLevel - 1]); 232 | 233 | // pre-calculate mat_x_hat and store it 234 | checkCudaErrors(cudaMalloc(mat_x_hat + currentLevel, 235 | pW_N[currentLevel] * pH_N[currentLevel] * sizeof(float))); 236 | computemat_x_hat(mu[currentLevel], alpha, pW_N[currentLevel], pH_N[currentLevel], 237 | (float *)mat_x_hat[currentLevel]); 238 | 239 | // pre-compute DCT weights and store them in device 240 | checkCudaErrors(cudaMalloc(ww_1 + currentLevel, pH_N[currentLevel]*sizeof(complex))); 241 | checkCudaErrors(cudaMalloc(ww_2 + currentLevel, pW_N[currentLevel]*sizeof(complex))); 242 | computeDCTweights(pW_N[currentLevel], pH_N[currentLevel], 243 | (complex *)ww_1[currentLevel], (complex *)ww_2[currentLevel]); 244 | 245 | checkCudaErrors(cudaMalloc(pI0 + currentLevel - 1, ns * nh * sizeof(float))); 246 | checkCudaErrors(cudaMalloc(pI1 + currentLevel - 1, ns * nh * sizeof(float))); 247 | checkCudaErrors(cudaMalloc(d_I0_coeff + currentLevel - 1, ns * nh * sizeof(float))); 248 | 249 | // downscale 250 | // Downscale(pI0[currentLevel], pW_M[currentLevel], pH_M[currentLevel], pS_M[currentLevel], 251 | // nw, nh, ns, (float *)pI0[currentLevel - 1]); 252 | // Downscale(d_I0_coeff[currentLevel], pW_M[currentLevel], pH_M[currentLevel], pS_M[currentLevel], 253 | // nw, nh, ns, (float *)d_I0_coeff[currentLevel - 1]); 254 | Downscale_Anti(pI0[currentLevel], pW_M[currentLevel], pH_M[currentLevel], pS_M[currentLevel], 255 | nw, nh, ns, (float *)pI0[currentLevel - 1]); 256 | Downscale_Anti(d_I0_coeff[currentLevel], pW_M[currentLevel], pH_M[currentLevel], pS_M[currentLevel], 257 | nw, nh, ns, (float *)d_I0_coeff[currentLevel - 1]); 258 | 259 | // pre-compute cubic coefficients and store them in device 260 | cbanal2D(d_I0_coeff[currentLevel], pW_M[currentLevel], pH_M[currentLevel], 261 | pS_M[currentLevel], pH_M[currentLevel]); 262 | 263 | pW_M[currentLevel - 1] = nw; 264 | pH_M[currentLevel - 1] = nh; 265 | pS_M[currentLevel - 1] = ns; 266 | 267 | printf("pW_M[%d] = %d, pH_M[%d] = %d, pS_M[%d] = %d\n", 268 | currentLevel, pW_M[currentLevel], 269 | currentLevel, pH_M[currentLevel], 270 | currentLevel, pS_M[currentLevel]); 271 | printf("pW_N[%d] = %d, pH_N[%d] = %d, pS_N[%d] = %d\n", 272 | currentLevel, pW_N[currentLevel], 273 | currentLevel, pH_N[currentLevel], 274 | currentLevel, pS_N[currentLevel]); 275 | printf("pW_L[%d] = %d, pH_L[%d] = %d\n", 276 | currentLevel, pW_L[currentLevel], 277 | currentLevel, pH_L[currentLevel]); 278 | } 279 | } 280 | 281 | // pre-calculate mat_x_hat and store it 282 | checkCudaErrors(cudaMalloc(mat_x_hat + currentLevel, 283 | pW_N[currentLevel] * pH_N[currentLevel] * sizeof(float))); 284 | computemat_x_hat(mu[currentLevel], alpha, pW_N[currentLevel], pH_N[currentLevel], (float *)mat_x_hat[currentLevel]); 285 | 286 | // pre-compute DCT weights and store them in device 287 | checkCudaErrors(cudaMalloc(ww_1 + currentLevel, pH_N[currentLevel]*sizeof(complex))); 288 | checkCudaErrors(cudaMalloc(ww_2 + currentLevel, pW_N[currentLevel]*sizeof(complex))); 289 | computeDCTweights(pW_N[currentLevel], pH_N[currentLevel], 290 | (complex *)ww_1[currentLevel], (complex *)ww_2[currentLevel]); 291 | 292 | // pre-compute cubic coefficients and store them in device 293 | cbanal2D(d_I0_coeff[currentLevel], pW_M[currentLevel], pH_M[currentLevel], 294 | pS_M[currentLevel], pH_M[currentLevel]); 295 | 296 | printf("pW_M[%d] = %d, pH_M[%d] = %d, pS_M[%d] = %d\n", 297 | currentLevel, pW_M[currentLevel], 298 | currentLevel, pH_M[currentLevel], 299 | currentLevel, pS_M[currentLevel]); 300 | printf("pW_N[%d] = %d, pH_N[%d] = %d, pS_N[%d] = %d\n", 301 | currentLevel, pW_N[currentLevel], 302 | currentLevel, pH_N[currentLevel], 303 | currentLevel, pS_N[currentLevel]); 304 | printf("pW_L[%d] = %d, pH_L[%d] = %d\n", 305 | currentLevel, pW_L[currentLevel], 306 | currentLevel, pH_L[currentLevel]); 307 | } 308 | 309 | 310 | -------------------------------------------------------------------------------- /cuda/src/prepare_precomputations.h: -------------------------------------------------------------------------------- 1 | #ifndef PREPARE_PRECOMPUTATIONS_H 2 | #define PREPARE_PRECOMPUTATIONS_H 3 | 4 | void prepare_precomputations(int nLevels, 5 | int L_width, int L_height, 6 | int N_width, int N_height, 7 | int M_width, int M_height, int M_stride, 8 | int N_W, // unknown phase width 9 | int N_H, // unknown phase height 10 | int currentLevel, 11 | int *pW_N, int *pH_N, int *pS_N, 12 | int *pW_M, int *pH_M, int *pS_M, 13 | int *pW_L, int *pH_L, 14 | float **pI0, 15 | float **pI1, 16 | float **d_I0_coeff, 17 | float alpha, // smoothness coefficient 18 | float *mu, // proximal parameter 19 | const float **mat_x_hat, // pre-computed mat_x_hat 20 | const complex **ww_1, // ww_1 coefficient 21 | const complex **ww_2); // ww_2 coefficient 22 | #endif 23 | -------------------------------------------------------------------------------- /cuda/src/prox_gKernel.cuh: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | 3 | /////////////////////////////////////////////////////////////////////////////// 4 | /////////////////////////////////////////////////////////////////////////////// 5 | /////////////////////////////////////////////////////////////////////////////// 6 | ////////////////////////////// L1+L2 Flow //////////////////////////////// 7 | /////////////////////////////////////////////////////////////////////////////// 8 | /////////////////////////////////////////////////////////////////////////////// 9 | /////////////////////////////////////////////////////////////////////////////// 10 | 11 | template 12 | __host__ 13 | __device__ 14 | float sign(T val) 15 | { 16 | return val > T(0) ? 1.0f : -1.0f; // 0 is seldom 17 | } 18 | // proximal operator for L1 operator 19 | __host__ 20 | __device__ 21 | float prox_l1(float u, float tau) 22 | { 23 | return sign(u) * max(fabsf(u) - tau, 0.0f); 24 | } 25 | 26 | // LASSO in R2 (simple) 27 | __device__ 28 | void lassoR2_simple(float A11, float A12, float A22, 29 | float a, float b, float c, 30 | float ux, float uy, float mu, float alpha, 31 | float& ours_x, float& ours_y) 32 | { 33 | // 1. get (x0, y0) 34 | // -- temp 35 | float tx = mu*ux - a*c; 36 | float ty = mu*uy - b*c; 37 | 38 | // -- l2 minimum 39 | float x0_x = A11*tx + A12*ty; 40 | float x0_y = A12*tx + A22*ty; 41 | 42 | // 2. get the sign of optimal 43 | float sign_x = sign(x0_x); 44 | float sign_y = sign(x0_y); 45 | 46 | // 3. get the optimal 47 | ours_x = x0_x - alpha/2.0f * (A11 * sign_x + A12 * sign_y); 48 | ours_y = x0_y - alpha/2.0f * (A12 * sign_x + A22 * sign_y); 49 | 50 | // 4. check sign and map to the range 51 | ours_x = sign(ours_x) == sign_x ? ours_x : 0.0f; 52 | ours_y = sign(ours_y) == sign_y ? ours_y : 0.0f; 53 | } 54 | 55 | // LASSO in R2 (complete) 56 | __device__ 57 | void lassoR2_complete(float A11, float A12, float A22, 58 | float a, float b, float c, 59 | float ux, float uy, float mu, float alpha, 60 | float& ours_x, float& ours_y) 61 | { 62 | // temp 63 | float tx = mu*ux - a*c; 64 | float ty = mu*uy - b*c; 65 | 66 | // l2 minimum 67 | float x0_x = A11*tx + A12*ty; 68 | float x0_y = A12*tx + A22*ty; 69 | 70 | // x1 (R^n) [x1 0] & x2 (R^n) [0 x2] 71 | float x1 = (tx > 0) ? ((tx - alpha/2.0f) / (a*a + mu)) : ((tx + alpha/2.0f) / (a*a + mu)); 72 | float x2 = (ty > 0) ? ((ty - alpha/2.0f) / (b*b + mu)) : ((ty + alpha/2.0f) / (b*b + mu)); 73 | 74 | // x3 (R^2n) 75 | tx = alpha/2.0f * (A11 + A12); 76 | ty = alpha/2.0f * (A12 + A22); 77 | bool tb = (tx*x0_x + ty*x0_y) > 0; 78 | float x3_x = tb ? (x0_x - tx) : (x0_x + tx); 79 | float x3_y = tb ? (x0_y - ty) : (x0_y + ty); 80 | 81 | // x4 (R^2n) 82 | tx = alpha/2.0f * (A11 - A12); 83 | ty = alpha/2.0f * (A12 - A22); 84 | tb = (tx*x0_x + ty*x0_y) > 0; 85 | float x4_x = tb ? (x0_x - tx) : (x0_x + tx); 86 | float x4_y = tb ? (x0_y - ty) : (x0_y + ty); 87 | 88 | // cost functions 89 | float cost[4]; 90 | 91 | // temp 92 | float uxx = ux*ux; 93 | float uyy = uy*uy; 94 | 95 | // cost function of x1 96 | float t1 = a*x1 + c; 97 | tx = x1 - ux; 98 | cost[0] = t1*t1 + mu*(tx*tx + uyy) + alpha*fabsf(x1); 99 | 100 | // cost function of x2 101 | t1 = b*x2 + c; 102 | ty = x2 - uy; 103 | cost[1] = t1*t1 + mu*(uxx + ty*ty) + alpha*fabsf(x2); 104 | 105 | // cost function of x3 106 | t1 = a*x3_x + b*x3_y + c; 107 | tx = x3_x - ux; 108 | ty = x3_y - uy; 109 | cost[2] = t1*t1 + mu*(tx*tx + ty*ty) + alpha*(fabsf(x3_x) + fabsf(x3_y)); 110 | 111 | // cost function of x4 112 | t1 = a*x4_x + b*x4_y + c; 113 | tx = x4_x - ux; 114 | ty = x4_y - uy; 115 | cost[3] = t1*t1 + mu*(tx*tx + ty*ty) + alpha*(fabsf(x4_x) + fabsf(x4_y)); 116 | 117 | // cost function of x5 118 | float cost_min = c*c + mu*(uxx + uyy); // [0 0] solution 119 | 120 | // find minimum 121 | signed int I = 5; 122 | for (signed int i = 0; i < 4; ++i) 123 | { 124 | if (cost[i] < cost_min) 125 | { 126 | cost_min = cost[i]; 127 | I = i + 1; 128 | } 129 | } 130 | 131 | // set final solution 132 | switch (I) 133 | { 134 | case 1: 135 | ours_x = x1; 136 | ours_y = 0.0f; 137 | break; 138 | case 2: 139 | ours_x = 0.0f; 140 | ours_y = x2; 141 | break; 142 | case 3: 143 | ours_x = x3_x; 144 | ours_y = x3_y; 145 | break; 146 | case 4: 147 | ours_x = x4_x; 148 | ours_y = x4_y; 149 | break; 150 | case 5: 151 | ours_x = 0.0f; 152 | ours_y = 0.0f; 153 | break; 154 | } 155 | } 156 | 157 | 158 | __global__ void prox_gL1Kernel(float *zeta_x, float *zeta_y, 159 | float *temp_x, float *temp_y, float mu, float alpha, 160 | float *A11, float *A12, float *A22, float *a, float *b, float *c, 161 | int N_width, int N_height, int M_width, int M_height) 162 | { 163 | const int L_width = (N_width - M_width) / 2; 164 | const int L_height = (N_height - M_height) / 2; 165 | 166 | const int ix = threadIdx.x + blockIdx.x * blockDim.x; 167 | const int iy = threadIdx.y + blockIdx.y * blockDim.y; 168 | 169 | const int pos_N = ix + iy * N_width; 170 | const int pos_M = ix - L_width + (iy - L_height) * M_width; 171 | 172 | float temp_w_x = temp_x[pos_N] + zeta_x[pos_N]; 173 | float temp_w_y = temp_y[pos_N] + zeta_y[pos_N]; 174 | float val_x, val_y; 175 | 176 | // w-update 177 | if (ix >= N_width || iy >= N_height) return; 178 | else if (ix >= L_width && ix < N_width - L_width && 179 | iy >= L_height && iy < N_height - L_height) 180 | { // update interior flow 181 | // lassoR2_simple(A11[pos_M], A12[pos_M], A22[pos_M], 182 | // a[pos_M], b[pos_M], c[pos_M], 183 | // temp_w_x, temp_w_y, mu, alpha, val_x, val_y); 184 | lassoR2_complete(A11[pos_M], A12[pos_M], A22[pos_M], 185 | a[pos_M], b[pos_M], c[pos_M], 186 | temp_w_x, temp_w_y, mu, alpha, val_x, val_y); 187 | } 188 | else { // keep exterior flow unchanged 189 | val_x = temp_w_x; 190 | val_y = temp_w_y; 191 | } 192 | 193 | // zeta-update 194 | zeta_x[pos_N] = temp_w_x - val_x; 195 | zeta_y[pos_N] = temp_w_y - val_y; 196 | 197 | // pre-store value of (w - zeta) 198 | temp_x[pos_N] = mu * (2 * val_x - temp_w_x); 199 | temp_y[pos_N] = mu * (2 * val_y - temp_w_y); 200 | } 201 | 202 | static 203 | void prox_gL1(float *zeta_x, float *zeta_y, 204 | float *temp_x, float *temp_y, float mu, float alpha, 205 | float *A11, float *A12, float *A22, 206 | float *a, float *b, float *c, 207 | int N_width, int N_height, int M_width, int M_height) 208 | { 209 | dim3 threads(32, 32); 210 | dim3 blocks(iDivUp(N_width, threads.x), iDivUp(N_height, threads.y)); 211 | 212 | prox_gL1Kernel<<>>(zeta_x, zeta_y, temp_x, temp_y, 213 | mu, alpha, A11, A12, A22, a, b, c, 214 | N_width, N_height, M_width, M_height); 215 | } 216 | 217 | 218 | /////////////////////////////////////////////////////////////////////////////// 219 | /////////////////////////////////////////////////////////////////////////////// 220 | /////////////////////////////////////////////////////////////////////////////// 221 | ////////////////////////////// L2 Flow //////////////////////////////// 222 | /////////////////////////////////////////////////////////////////////////////// 223 | /////////////////////////////////////////////////////////////////////////////// 224 | /////////////////////////////////////////////////////////////////////////////// 225 | 226 | 227 | /////////////////////////////////////////////////////////////////////////////// 228 | /// \brief compute proximal operator of g 229 | /// 230 | /// \param[in/out] w_x flow along x 231 | /// \param[in/out] w_y flow along y 232 | /// \param[in/out] zeta_x dual variable zeta along x 233 | /// \param[in/out] zeta_y dual variable zeta along y 234 | /// \param[in/out] temp_x temp variable (either for nabla(x) or w-zeta) along x 235 | /// \param[in/out] temp_y temp variable (either for nabla(x) or w-zeta) along y 236 | /// \param[in] mu proximal parameter 237 | /// \param[in] w11, w12_or_w22, w13, w21, w23 238 | /// pre-computed weights 239 | /// \param[in] N_width unknown width 240 | /// \param[in] N_height unknown height 241 | /// \param[in] M_width image width 242 | /// \param[in] M_height image height 243 | /////////////////////////////////////////////////////////////////////////////// 244 | __global__ void prox_gKernel(float *w_x, float *w_y, float *zeta_x, float *zeta_y, 245 | float *temp_x, float *temp_y, float mu, 246 | const float *w11, const float *w12_or_w22, 247 | const float *w13, const float *w21, const float *w23, 248 | int N_width, int N_height, 249 | int M_width, int M_height) 250 | { 251 | const int L_width = (N_width - M_width)/2; 252 | const int L_height = (N_height - M_height)/2; 253 | 254 | const int ix = threadIdx.x + blockIdx.x * blockDim.x; 255 | const int iy = threadIdx.y + blockIdx.y * blockDim.y; 256 | 257 | const int pos_N = ix + iy * N_width; 258 | const int pos_M = (ix-L_width) + (iy-L_height) * M_width; 259 | 260 | float temp_w_x = temp_x[pos_N] + zeta_x[pos_N]; 261 | float temp_w_y = temp_y[pos_N] + zeta_y[pos_N]; 262 | float val_x, val_y; 263 | 264 | // w-update 265 | if (ix >= N_width || iy >= N_height) return; 266 | else if (ix >= L_width && ix < N_width -L_width && 267 | iy >= L_height && iy < N_height-L_height){ // update interior flow 268 | val_x = w11[pos_M] * temp_w_x + w12_or_w22[pos_M] * temp_w_y + w13[pos_M]; 269 | val_y = w21[pos_M] * temp_w_y + w12_or_w22[pos_M] * temp_w_x + w23[pos_M]; 270 | } 271 | else { // keep exterior flow unchanged 272 | val_x = temp_w_x; 273 | val_y = temp_w_y; 274 | } 275 | w_x[pos_N] = val_x; 276 | w_y[pos_N] = val_y; 277 | 278 | // zeta-update 279 | zeta_x[pos_N] = temp_w_x - val_x; 280 | zeta_y[pos_N] = temp_w_y - val_y; 281 | 282 | // pre-store value of (w - zeta) 283 | temp_x[pos_N] = mu * (2*val_x - temp_w_x); 284 | temp_y[pos_N] = mu * (2*val_y - temp_w_y); 285 | } 286 | 287 | 288 | 289 | /////////////////////////////////////////////////////////////////////////////// 290 | /// \brief compute proximal operator of g 291 | /// 292 | /// \param[in/out] w_x flow along x 293 | /// \param[in/out] w_y flow along y 294 | /// \param[in/out] zeta_x dual variable zeta along x 295 | /// \param[in/out] zeta_y dual variable zeta along y 296 | /// \param[in/out] temp_x temp variable (either for nabla(x) or w-zeta) along x 297 | /// \param[in/out] temp_y temp variable (either for nabla(x) or w-zeta) along y 298 | /// \param[in] mu proximal parameter 299 | /// \param[in] w11, w12_or_w22, w13, w21, w23 300 | /// pre-computed weights 301 | /// \param[in] N_width unknown width 302 | /// \param[in] N_height unknown height 303 | /// \param[in] M_width image width 304 | /// \param[in] M_height image height 305 | /////////////////////////////////////////////////////////////////////////////// 306 | static 307 | void prox_g(float *w_x, float *w_y, float *zeta_x, float *zeta_y, 308 | float *temp_x, float *temp_y, float mu, 309 | const float *w11, const float *w12_or_w22, 310 | const float *w13, const float *w21, const float *w23, 311 | int N_width, int N_height, int M_width, int M_height) 312 | { 313 | dim3 threads(32, 6); 314 | dim3 blocks(iDivUp(N_width, threads.x), iDivUp(N_height, threads.y)); 315 | 316 | prox_gKernel<<>>(w_x, w_y, zeta_x, zeta_y, temp_x, temp_y, 317 | mu, w11, w12_or_w22, w13, w21, w23, 318 | N_width, N_height, M_width, M_height); 319 | } 320 | -------------------------------------------------------------------------------- /cuda/src/x_updateKernel.cuh: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "common.h" 7 | 8 | template 9 | __global__ 10 | void dct_ReorderEvenKernel(T *phi, int N_width, int N_height, W *y) 11 | { 12 | const int ix = threadIdx.x + blockIdx.x * blockDim.x; 13 | const int iy = threadIdx.y + blockIdx.y * blockDim.y; 14 | 15 | if (ix >= N_width || iy >= N_height) return; 16 | 17 | const int pos = ix + iy * N_width; 18 | 19 | y[pos].y = 0.0f; 20 | if (iy < N_height/2){ 21 | y[pos].x = phi[ix + 2*iy*N_width]; 22 | } 23 | else{ 24 | y[pos].x = phi[ix + (2*(N_height-iy)-1)*N_width]; 25 | } 26 | } 27 | 28 | template 29 | __global__ 30 | void dct_MultiplyFFTWeightsKernel(int N_width, int N_height, W *y, T *b, const W *ww) 31 | { 32 | const int ix = threadIdx.x + blockIdx.x * blockDim.x; 33 | const int iy = threadIdx.y + blockIdx.y * blockDim.y; 34 | if (ix >= N_width || iy >= N_height) return; 35 | 36 | const int pos = ix + iy * N_width; 37 | const int pos_tran = iy + ix * N_height; // transpose on the output b 38 | 39 | b[pos_tran] = (ww[iy].x * y[pos].x + ww[iy].y * y[pos].y) / (T)N_height; 40 | if (iy == 0) 41 | b[pos_tran] /= sqrtf(2); 42 | } 43 | 44 | template 45 | __global__ 46 | void divide_mat_x_hatKernel(T *phi, const T *mat_x_hat, int N_width, int N_height) 47 | { 48 | const int ix = threadIdx.x + blockIdx.x * blockDim.x; 49 | const int iy = threadIdx.y + blockIdx.y * blockDim.y; 50 | if (ix >= N_width || iy >= N_height) return; 51 | 52 | const int pos = ix + iy * N_width; 53 | 54 | phi[pos] /= mat_x_hat[pos]; 55 | } 56 | 57 | template 58 | __global__ 59 | void idct_MultiplyFFTWeightsKernel(int N_width, int N_height, T *b, W *y, const W *ww) 60 | { 61 | const int ix = threadIdx.x + blockIdx.x * blockDim.x; 62 | const int iy = threadIdx.y + blockIdx.y * blockDim.y; 63 | if (ix >= N_width || iy >= N_height) return; 64 | 65 | const int pos = ix + iy * N_width; 66 | 67 | y[pos].x = ww[iy].x * b[pos]; 68 | y[pos].y = ww[iy].y * b[pos]; 69 | if (iy == 0){ 70 | y[pos].x /= sqrtf(2); 71 | y[pos].y /= sqrtf(2); 72 | } 73 | } 74 | 75 | template 76 | __global__ 77 | void idct_ReorderEvenKernel(W *y, int N_width, int N_height, T *phi) 78 | { 79 | const int ix = threadIdx.x + blockIdx.x * blockDim.x; 80 | const int iy = threadIdx.y + blockIdx.y * blockDim.y; 81 | 82 | if (ix >= N_width || iy >= N_height) return; 83 | 84 | // const int pos = ix + iy * N_width; 85 | const int pos = iy + ix * N_height; 86 | 87 | if ((iy & 1) == 0){ // iy is even 88 | phi[pos] = y[ix + iy/2*N_width].x / (T) N_height; 89 | } 90 | else{ // iy is odd 91 | phi[pos] = y[ix + (N_height-(iy+1)/2)*N_width].x / (T) N_height; 92 | } 93 | } 94 | 95 | 96 | template 97 | void x_update(T *phi, W *y, const T *mat_x_hat, 98 | const W *ww_1, const W *ww_2, 99 | int N_width, int N_height, cufftHandle plan_dct_1, cufftHandle plan_dct_2) 100 | { 101 | dim3 threads(32, 32); 102 | dim3 blocks_1(iDivUp(N_width, threads.x), iDivUp(N_height, threads.y)); 103 | dim3 blocks_2(iDivUp(N_height, threads.x), iDivUp(N_width, threads.y)); 104 | 105 | // first DCT 106 | dct_ReorderEvenKernel<<>>(phi, N_width, N_height, y); 107 | // cufftSafeCall(cufftExecC2C(plan_dct_1, y, y, CUFFT_FORWARD)); 108 | cufftExecC2C(plan_dct_1, y, y, CUFFT_FORWARD); 109 | dct_MultiplyFFTWeightsKernel<<>>(N_width, N_height, y, phi, ww_1); 110 | 111 | // second DCT 112 | dct_ReorderEvenKernel<<>>(phi, N_height, N_width, y); 113 | cufftExecC2C(plan_dct_2, y, y, CUFFT_FORWARD); 114 | dct_MultiplyFFTWeightsKernel<<>>(N_height, N_width, y, phi, ww_2); 115 | 116 | // divided by mat_x_hat 117 | divide_mat_x_hatKernel<<>>(phi, mat_x_hat, N_width, N_height); 118 | 119 | // first IDCT 120 | idct_MultiplyFFTWeightsKernel<<>>(N_width, N_height, phi, y, ww_1); 121 | cufftExecC2C(plan_dct_1, y, y, CUFFT_INVERSE); 122 | idct_ReorderEvenKernel<<>>(y, N_width, N_height, phi); 123 | 124 | // second IDCT 125 | idct_MultiplyFFTWeightsKernel<<>>(N_height, N_width, phi, y, ww_2); 126 | cufftExecC2C(plan_dct_2, y, y, CUFFT_INVERSE); 127 | idct_ReorderEvenKernel<<>>(y, N_height, N_width, phi); 128 | } 129 | -------------------------------------------------------------------------------- /data/HeLa/cap.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/PhaseIntensityMicroscope/007be7a30a7cbe782027a01dae91b11b95b16909/data/HeLa/cap.tif -------------------------------------------------------------------------------- /data/HeLa/ref.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/PhaseIntensityMicroscope/007be7a30a7cbe782027a01dae91b11b95b16909/data/HeLa/ref.tif -------------------------------------------------------------------------------- /data/MCF-7/cap.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/PhaseIntensityMicroscope/007be7a30a7cbe782027a01dae91b11b95b16909/data/MCF-7/cap.tif -------------------------------------------------------------------------------- /data/MCF-7/ref.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/PhaseIntensityMicroscope/007be7a30a7cbe782027a01dae91b11b95b16909/data/MCF-7/ref.tif -------------------------------------------------------------------------------- /data/MLA-150-7AR-M/Zygo.dat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/PhaseIntensityMicroscope/007be7a30a7cbe782027a01dae91b11b95b16909/data/MLA-150-7AR-M/Zygo.dat -------------------------------------------------------------------------------- /data/MLA-150-7AR-M/cap.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/PhaseIntensityMicroscope/007be7a30a7cbe782027a01dae91b11b95b16909/data/MLA-150-7AR-M/cap.tif -------------------------------------------------------------------------------- /data/MLA-150-7AR-M/ref.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/PhaseIntensityMicroscope/007be7a30a7cbe782027a01dae91b11b95b16909/data/MLA-150-7AR-M/ref.tif -------------------------------------------------------------------------------- /data/blood/cap_1.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/PhaseIntensityMicroscope/007be7a30a7cbe782027a01dae91b11b95b16909/data/blood/cap_1.tif -------------------------------------------------------------------------------- /data/blood/cap_2.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/PhaseIntensityMicroscope/007be7a30a7cbe782027a01dae91b11b95b16909/data/blood/cap_2.tif -------------------------------------------------------------------------------- /data/blood/ref_1.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/PhaseIntensityMicroscope/007be7a30a7cbe782027a01dae91b11b95b16909/data/blood/ref_1.tif -------------------------------------------------------------------------------- /data/blood/ref_2.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/PhaseIntensityMicroscope/007be7a30a7cbe782027a01dae91b11b95b16909/data/blood/ref_2.tif -------------------------------------------------------------------------------- /data/cheek/cap.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/PhaseIntensityMicroscope/007be7a30a7cbe782027a01dae91b11b95b16909/data/cheek/cap.tif -------------------------------------------------------------------------------- /data/cheek/ref.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/PhaseIntensityMicroscope/007be7a30a7cbe782027a01dae91b11b95b16909/data/cheek/ref.tif -------------------------------------------------------------------------------- /matlab/cpu_gpu_comparison.m: -------------------------------------------------------------------------------- 1 | clc;clear;close all; 2 | addpath('./utils/'); 3 | 4 | % datapath 5 | datapath = '../data/MLA-150-7AR-M/'; 6 | 7 | % read image 8 | ref = double(imread([datapath 'ref.tif'])); 9 | cap = double(imread([datapath 'cap.tif'])); 10 | 11 | % set active area 12 | M = [992 992]; 13 | 14 | % crop size 15 | S = (size(ref)-M)/2; 16 | 17 | % crop data 18 | ref = ref(1+S(1):end-S(1),1+S(2):end-S(2)); 19 | cap = cap(1+S(1):end-S(1),1+S(2):end-S(2)); 20 | 21 | % normalize images 22 | norm_img = @(x) double(uint8( 255 * x ./ max( max(ref(:)), max(cap(:)) ) )); 23 | ref = norm_img(ref); 24 | cap = norm_img(cap); 25 | 26 | % set gpu parameters 27 | opt.priors = [0.1 0.1 100 5]; 28 | opt.iter = [3 10 20]; 29 | opt.mu = [0.1 100]; 30 | opt.tol = 0.05; 31 | opt.L = min((2.^ceil(log2(M)) - M) / 2, 256); 32 | opt.size = M; 33 | opt.isverbose = 1; 34 | 35 | % run cpu algorithm 36 | tic; 37 | [A_cpu, phi_cpu, ~, ~, ~] = cws(ref, cap, opt); 38 | toc 39 | 40 | % run gpu algorithm 41 | [A_gpu, phi_gpu] = cws_gpu_wrapper(ref, cap, opt); 42 | 43 | % mean normalized phase 44 | phi_gpu = phi_gpu - mean2(phi_gpu); 45 | 46 | % show results 47 | figure; imshow([A_cpu A_gpu A_cpu-A_gpu],[]); 48 | title('A: CPU / GPU / Difference'); 49 | disp(['A: max diff = ' num2str(norm(A_gpu(:) - A_cpu(:),'inf'))]); 50 | figure; imshow([phi_cpu phi_gpu phi_cpu-phi_gpu],[]); 51 | title('phi: CPU / GPU / Difference'); 52 | disp(['phi: max diff = ' num2str(norm(phi_gpu(:) - phi_cpu(:),'inf'))]); 53 | figure;mesh(phi_cpu - phi_gpu) 54 | title('phi: CPU / GPU / Difference'); 55 | -------------------------------------------------------------------------------- /matlab/cws.m: -------------------------------------------------------------------------------- 1 | function [A, phi, wavefront_lap, I_warp, objvaltotal] = cws(I0, I, opt) 2 | % Simutanous intensity and wavefront recovery. Solve for: 3 | % 4 | % min_{A,phi} || i(x+\nabla phi) - A i_0(x) ||_2^2 + ... 5 | % alpha || \nabla phi ||_1 + ... 6 | % beta ( || \nabla phi ||_2^2 + || \nabla^2 phi ||_2^2 ) + ... 7 | % gamma ( || \nabla A ||_1 + || \nabla^2 A ||_1 ) + ... 8 | % tau ( || \nabla A ||_2^2 + || \nabla^2 A ||_2^2 ) 9 | % 10 | % Inputs: 11 | % - I0: reference image 12 | % - I: measurement image 13 | % - opt: solver options 14 | % - priors: prior weights [alpha beta gamma beta] 15 | % (default = [0.1,0.1,100,5]) 16 | % - iter: [total_alternating_iter A-update_iter phi-update_iter] 17 | % (default = [3 10 20]) 18 | % - mu: ADMM parameters (default = [0.1 100]) 19 | % - tol: phi-update tolerance stopping criteria (default = 0.05) 20 | % - L: padding size [pad_width pad_height] 21 | % (default nearest power of 2 of out_size, each in range [2, 256]) 22 | % - isverbose: output verbose [A_verbose phi_verbose] 23 | % (default = [0 0]) 24 | % Outputs: 25 | % - A: intensity 26 | % - phi: wavefront 27 | % - wavefront_lap: wavefront Laplacian 28 | % - I_warp: warped image of I 29 | % - objvaltotal: objective function 30 | % 31 | % See also `cws_gpu_wrapper.m` for its GPU version. 32 | 33 | % check images 34 | if ~isa(I0,'double') 35 | I0 = double(I0); 36 | end 37 | if ~isa(I,'double') 38 | I = double(I); 39 | end 40 | 41 | if nargin < 3 42 | disp('Using default parameter settings ...') 43 | 44 | % tradeoff parameters 45 | alpha = 0.1; 46 | beta = 0.1; 47 | gamma = 100; 48 | tau = 5; 49 | 50 | % total number of alternation iterations 51 | iter = 3; 52 | 53 | % A-update parameters 54 | opt_A.isverbose = 0; 55 | opt_A.iter = 10; 56 | opt_A.mu_A = 1e-1; 57 | 58 | % phi-update parameters 59 | opt_phase.isverbose = 0; 60 | opt_phase.L = min((2.^ceil(log2(size(I0))) - size(I0)) / 2, 256); % suit to [1024 1024] 61 | opt_phase.mu = 100; 62 | opt_phase.iter = 20; 63 | 64 | % tolerance 65 | tol = 0.05; 66 | else 67 | % check option, if not existed, use default parameters 68 | if ~isfield(opt,'priors') 69 | opt.priors = [0.1 0.1 100 5]; 70 | end 71 | if ~isfield(opt,'iter') 72 | opt.iter = [3 10 20]; 73 | end 74 | if ~isfield(opt,'mu') 75 | opt.mu = [0.1 100]; 76 | end 77 | if ~isfield(opt,'tol') 78 | opt.tol = 0.05; 79 | end 80 | if ~isfield(opt,'L') 81 | opt.L = min((2.^ceil(log2(size(I0))) - size(I0)) / 2, 256); % suit to [1024 1024] 82 | end 83 | if ~isfield(opt,'isverbose') 84 | opt.isverbose = [0 0]; 85 | end 86 | 87 | % parameters checkers 88 | opt.priors = max(0, opt.priors); 89 | if length(opt.priors) ~= 4 90 | error('length of opt.priors must equal to 4!'); 91 | end 92 | opt.iter = round(opt.iter); 93 | if length(opt.iter) ~= 3 94 | error('length of opt.iter must equal to 3!'); 95 | end 96 | opt.mu = max(0, opt.mu); 97 | if length(opt.mu) ~= 2 98 | error('length of opt.mu must equal to 2!'); 99 | end 100 | opt.tol = max(0, opt.tol); 101 | if length(opt.tol) ~= 1 102 | error('length of opt.tol must equal to 1!'); 103 | end 104 | opt.L = max(0, opt.L); 105 | if length(opt.L) ~= 2 106 | error('length of opt.L must equal to 2!'); 107 | end 108 | 109 | % tradeoff parameters 110 | alpha = opt.priors(1); 111 | beta = opt.priors(2); 112 | gamma = opt.priors(3); 113 | tau = opt.priors(4); 114 | 115 | % total number of alternation iterations 116 | iter = opt.iter(1); 117 | 118 | % A-update parameters 119 | opt_A.isverbose = 0; 120 | opt_A.iter = opt.iter(2); 121 | opt_A.mu_A = opt.mu(1); 122 | 123 | % phi-update parameters 124 | opt_phase.isverbose = 0; 125 | opt_phase.L = opt.L; 126 | opt_phase.mu = opt.mu(2); 127 | opt_phase.iter = opt.iter(3); 128 | 129 | % tolerance 130 | tol = opt.tol; 131 | end 132 | 133 | % compute A-update parameters 134 | gamma_new = gamma / mean2(abs(I0).^2); 135 | tau_new = tau / mean2(abs(I0).^2); 136 | 137 | % define operators in spatial domain 138 | x_k = [0 0 0; -1 1 0; 0 0 0]; 139 | y_k = [0 -1 0; 0 1 0; 0 0 0]; 140 | l_k = [0 1 0; 1 -4 1; 0 1 0]; 141 | 142 | % define sizes 143 | M = size(I0); 144 | L = opt_phase.L; 145 | N = M + 2*L; 146 | 147 | % boundary mask 148 | M1 = @(u) cat(3, u(L(1)+1:end-L(1),L(2)+1:end-L(2),1), ... 149 | u(L(1)+1:end-L(1),L(2)+1:end-L(2),2)); 150 | 151 | % specify boundary conditions 152 | bc = 'symmetric'; % use DCT 153 | 154 | % define operators 155 | nabla = @(phi) cat(3, imfilter(phi, x_k, bc), imfilter(phi, y_k, bc)); 156 | nabla2 = @(phi) imfilter(phi, l_k, bc); 157 | nablaT = @(adj_phi) imfilter(adj_phi(:,:,1), rot90(x_k,2), bc) + ... 158 | imfilter(adj_phi(:,:,2), rot90(y_k,2), bc); 159 | K = @(x) cat(3, nabla(x), nabla2(x)); 160 | KT = @(u) nablaT(u(:,:,1:2)) + nabla2(u(:,:,3)); 161 | 162 | % pre-calculated variables 163 | [~, K_mat] = prepare_DCT_basis(M); 164 | 165 | % inversion basis in DCT domain for A-update and phi-update 166 | mat_A_hat = 1 + (tau_new + opt_A.mu_A) * K_mat; 167 | 168 | % define proximal operator of A 169 | prox_A = @(u) sign(u) .* max(abs(u) - gamma_new/(2*opt_A.mu_A), 0); 170 | 171 | % define norms 172 | norm2 = @(x) sum(abs(x(:)).^2); 173 | norm1 = @(x) sum(abs(x(:))); 174 | 175 | % define objective functions 176 | obj_A = @(A, b) norm2(A - b./I0) + gamma_new * norm1(K(A)) + tau_new * norm2(K(A)); 177 | obj_total = @(A, b, phi) norm2(I0.*A - b) + ... 178 | alpha * norm1(nabla(phi)) + beta * norm2(K(phi)) + ... 179 | gamma * norm1(K(A)) + tau * norm2(K(A)); 180 | 181 | % define obj handle 182 | disp_obj = @(s,i,objval) disp([s num2str(i) ', obj = ' num2str(objval,'%.4e')]); 183 | 184 | % initialization 185 | A = ones(M); 186 | phi = zeros(N); 187 | I_warp = I; 188 | 189 | % main loop 190 | objvaltotal = zeros(iter,1); 191 | disp_obj('iter = ', 0, obj_total(A, I_warp, phi)); 192 | for outer_loop = 1:iter 193 | % === A-update === 194 | B = zeros([size(A) 3]); 195 | zeta = zeros([size(A) 3]); 196 | disp('-- A-update') 197 | for i_A = 1:opt_A.iter 198 | % A-update 199 | A = idct(idct(dct(dct( I_warp./I0 + opt_A.mu_A*KT(B-zeta) ).').' ./ mat_A_hat).').'; 200 | 201 | % pre-cache 202 | u = K(A) + zeta; 203 | 204 | % B-update 205 | B = prox_A(u); 206 | 207 | % zeta-update 208 | zeta = u - B; 209 | 210 | % show objective function 211 | if opt_A.isverbose 212 | disp_obj('---- iter = ', i_A, obj_A(A,I_warp)); 213 | end 214 | end 215 | % median filter A 216 | A = medfilt2(A, [3 3], 'symmetric'); 217 | 218 | % === phi-update === 219 | disp('-- phi-update') 220 | img = cat(3, A.*I0, I_warp); 221 | [~, Delta_phi, I_warp] = phase_update_ADMM(img, alpha, beta, opt_phase); 222 | 223 | disp(['-- mean(|\Delta\phi|) = ' num2str(mean(abs(Delta_phi(:))),'%.3e')]) 224 | if mean(abs(Delta_phi(:))) < tol 225 | disp('-- mean(|\Delta\phi|) too small; quit'); 226 | break; 227 | end 228 | phi = phi + Delta_phi; 229 | 230 | % === records === 231 | objvaltotal(outer_loop) = obj_total(A, I_warp, phi); 232 | disp_obj('iter = ', outer_loop, objvaltotal(outer_loop)); 233 | end 234 | 235 | % compute wavefront Laplacian 236 | wavefront_lap = nabla2(phi); 237 | wavefront_lap = M1(cat(3,wavefront_lap,wavefront_lap)); 238 | wavefront_lap = wavefront_lap(:,:,1); 239 | 240 | % median filtering phi 241 | phi = medfilt2(phi, [3 3], 'symmetric'); 242 | 243 | % return phi 244 | phi = M1(cat(3,phi,phi)); 245 | phi = phi(:,:,1); 246 | phi = phi - mean2(phi); 247 | 248 | return; 249 | 250 | function [x, x_full, I_warp] = phase_update_ADMM(img, alpha, beta, opt) 251 | 252 | M = [size(img,1) size(img,2)]; 253 | L = opt.L; 254 | N = M + 2*L; 255 | 256 | % boundary mask 257 | M1 = @(u) cat(3, u(L(1)+1:end-L(1),L(2)+1:end-L(2),1), ... 258 | u(L(1)+1:end-L(1),L(2)+1:end-L(2),2)); 259 | [lap_mat, ~] = prepare_DCT_basis(N); 260 | mat_x_hat = (opt.mu+beta)*lap_mat + beta*lap_mat.^2; 261 | mat_x_hat(1) = 1; 262 | 263 | % initialization 264 | x = zeros(N); 265 | u = zeros([N 2]); 266 | zeta = zeros([N 2]); 267 | 268 | % get the matrices 269 | [gt, gx, gy] = partial_deriv(img); 270 | 271 | % pre-compute 272 | mu = opt.mu; 273 | gxy = gx.*gy; 274 | gxx = gx.^2; 275 | gyy = gy.^2; 276 | 277 | % store in memory at run-time 278 | denom = mu*(gxx + gyy + mu); 279 | A11 = (gyy + mu) ./ denom; 280 | A12 = -gxy ./ denom; 281 | A22 = (gxx + mu) ./ denom; 282 | 283 | % proximal algorithm 284 | objval = zeros(opt.iter,1); 285 | time = zeros(opt.iter,1); 286 | 287 | % the loop 288 | tic; 289 | for k = 1:opt.iter 290 | 291 | % x-update 292 | if k > 1 293 | x = idct(idct(dct(dct(mu*nablaT(u-zeta)).').' ./ mat_x_hat).').'; 294 | end 295 | 296 | % pre-compute nabla_x 297 | nabla_x = nabla(x); 298 | 299 | % u-update 300 | u = nabla_x + zeta; 301 | u_temp = M1(u); 302 | 303 | % R2 LASSO (in practice, difference between the two solutions 304 | % are subtle; for speed's sake, the simple version is 305 | % recommended) 306 | [w_opt_x,w_opt_y] = R2LASSO(gx,gy,gt,u_temp(:,:,1),u_temp(:,:,2),alpha,mu,A11,A12,A22,'complete'); 307 | 308 | % update u 309 | u(L(1)+1:end-L(1), L(2)+1:end-L(2), 1) = w_opt_x; 310 | u(L(1)+1:end-L(1), L(2)+1:end-L(2), 2) = w_opt_y; 311 | 312 | % zeta-update 313 | zeta = zeta + nabla_x - u; 314 | 315 | % record 316 | if opt.isverbose 317 | if k == 1 318 | % define operators 319 | G = @(u) gx.*u(:,:,1) + gy.*u(:,:,2); 320 | 321 | % define objective function 322 | obj = @(phi) norm2(G(M1(nabla(phi)))+gt) + ... 323 | alpha * norm1(nabla(phi)) + ... 324 | beta * (norm2(nabla(phi)) + norm2(nabla2(phi))); 325 | end 326 | objval(k) = obj(x); 327 | time(k) = toc; 328 | disp(['---- ADMM iter: ' num2str(k) ', obj = ' num2str(objval(k),'%.4e')]) 329 | end 330 | end 331 | 332 | % compute the warped image by Taylor expansion 333 | w = M1(nabla(x)); 334 | I_warp = gx.*w(:,:,1) + gy.*w(:,:,2) + img(:,:,2); 335 | 336 | % return masked x 337 | x = x - mean2(x); 338 | x_full = x; 339 | x = M1(cat(3,x,x)); 340 | x = x(:,:,1); 341 | end 342 | end 343 | 344 | 345 | function [It,Ix,Iy] = partial_deriv(images) 346 | 347 | % derivative kernel 348 | h = [1 -8 0 8 -1]/12; 349 | 350 | % temporal gradient 351 | It = images(:,:,2) - images(:,:,1); 352 | 353 | % First compute derivative then warp 354 | Ix = imfilter(images(:,:,2), h, 'replicate'); 355 | Iy = imfilter(images(:,:,2), h', 'replicate'); 356 | 357 | end 358 | 359 | 360 | function [lap_mat, K_mat] = prepare_DCT_basis(M) 361 | % Prepare DCT basis of size M: 362 | % lap_mat: \nabla^2 363 | % K_mat: \nabla^2 + \nabla^4 364 | 365 | H = M(1); 366 | W = M(2); 367 | [x_coord,y_coord] = meshgrid(0:W-1,0:H-1); 368 | lap_mat = 4 - 2*cos(pi*x_coord/W) - 2*cos(pi*y_coord/H); 369 | K_mat = lap_mat .* (lap_mat + 1); 370 | 371 | end 372 | 373 | 374 | function [x, y] = R2LASSO(a,b,c,ux,uy,alpha,mu,A11,A12,A22,option) 375 | switch option 376 | case 'simple' 377 | [x, y] = R2LASSO_simple(a,b,c,ux,uy,alpha,mu,A11,A12,A22); 378 | case 'complete' 379 | [x, y] = R2LASSO_complete(a,b,c,ux,uy,alpha,mu,A11,A12,A22); 380 | otherwise 381 | error('invalid option for R2LASSO solver!') 382 | end 383 | end 384 | 385 | 386 | function [x, y] = R2LASSO_simple(a,b,c,ux,uy,alpha,mu,A11,A12,A22) 387 | % This function attempts to solve the R2 LASSO problem, in a fast but not accurate sense: 388 | % min (a x + b y + c)^2 + \mu [(x - ux)^2 + (y - uy)^2] + \alpha (|x| + |y|) 389 | % x,y 390 | % 391 | % See also R2LASSO_complete for a more accurate but slower one. 392 | % 393 | % A11, A12 and A22 can be computed as: 394 | % A_tmp = 1/(mu*(a*a + b*b + mu)) .* [b*b + mu -a*b; -a*b a*a + mu]; 395 | % A11 = A_tmp(1); 396 | % A12 = A_tmp(2); 397 | % A22 = A_tmp(4); 398 | 399 | if ~exist('A11','var') || ~exist('A12','var') || ~exist('A22','var') 400 | denom = 1 ./ ( mu * (a.*a + b.*b + mu) ); 401 | A11 = denom * (b.*b + mu); 402 | A12 = -a.*b .* denom; 403 | A22 = denom .* (a.*a + mu); 404 | end 405 | 406 | % 1. get (x0, y0) 407 | 408 | % temp 409 | tx = mu*ux - a.*c; 410 | ty = mu*uy - b.*c; 411 | 412 | % l2 minimum 413 | x0_x = A11.*tx + A12.*ty; 414 | x0_y = A12.*tx + A22.*ty; 415 | 416 | % 2. get the sign of optimal 417 | sign_x = sign(x0_x); 418 | sign_y = sign(x0_y); 419 | 420 | % 3. get the optimal 421 | x = x0_x - alpha/2 * (A11 .* sign_x + A12 .* sign_y); 422 | y = x0_y - alpha/2 * (A12 .* sign_x + A22 .* sign_y); 423 | 424 | % 4. check sign and map to the range 425 | x( sign(x) ~= sign_x ) = 0; 426 | y( sign(y) ~= sign_y ) = 0; 427 | 428 | end 429 | 430 | 431 | function [x, y] = R2LASSO_complete(a,b,c,ux,uy,alpha,mu,A11,A12,A22) 432 | % This function attempts to solve the R2 LASSO problem in a most accurate sense: 433 | % min (a x + b y + c)^2 + \mu [(x - ux)^2 + (y - uy)^2] + \alpha (|x| + |y|) 434 | % x,y 435 | % 436 | % See also R2LASSO_simple for a faster but less accurate one. 437 | % 438 | % A11, A12 and A22 can be computed as: 439 | % A_tmp = 1/(mu*(a*a + b*b + mu)) .* [b*b + mu -a*b; -a*b a*a + mu]; 440 | % A11 = A_tmp(1); 441 | % A12 = A_tmp(2); 442 | % A22 = A_tmp(4); 443 | 444 | if ~exist('A11','var') || ~exist('A12','var') || ~exist('A22','var') 445 | denom = 1 ./ ( mu * (a.*a + b.*b + mu) ); 446 | A11 = denom .* (b.*b + mu); 447 | A12 = -a.*b .* denom; 448 | A22 = denom .* (a.*a + mu); 449 | end 450 | 451 | % temp 452 | tx = mu*ux - a.*c; 453 | ty = mu*uy - b.*c; 454 | 455 | % l2 minimum 456 | x0_x = A11.*tx + A12.*ty; 457 | x0_y = A12.*tx + A22.*ty; 458 | 459 | % x1 (R^n) [x1 0] & x2 (R^n) [0 x2] 460 | % if (tx > 0) 461 | % x1 = (tx - alpha/2) ./ (a.*a + mu); 462 | % else 463 | % x1 = (tx + alpha/2) ./ (a.*a + mu); 464 | % end 465 | % if (ty > 0) 466 | % x2 = (ty - alpha/2) ./ (b.*b + mu); 467 | % else 468 | % x2 = (ty + alpha/2) ./ (b.*b + mu); 469 | % end 470 | x1 = (tx + alpha/2) ./ (a.*a + mu); 471 | tmp = (tx - alpha/2) ./ (a.*a + mu); 472 | ind = tx > 0; 473 | x1(ind) = tmp(ind); 474 | 475 | x2 = (ty + alpha/2) ./ (b.*b + mu); 476 | tmp = (ty - alpha/2) ./ (b.*b + mu); 477 | ind = ty > 0; 478 | x2(ind) = tmp(ind); 479 | 480 | % x3 (R^2n) 481 | % tx = alpha/2 * (A11 + A12); 482 | % ty = alpha/2 * (A12 + A22); 483 | % if (tx*x0_x + ty*x0_y) > 0 484 | % x3_x = x0_x - tx; 485 | % x3_y = x0_y - ty; 486 | % else 487 | % x3_x = x0_x + tx; 488 | % x3_y = x0_y + ty; 489 | % end 490 | tx = alpha/2 * (A11 + A12); 491 | ty = alpha/2 * (A12 + A22); 492 | x3_x = x0_x + tx; 493 | x3_y = x0_y + ty; 494 | tmp_x = x0_x - tx; 495 | tmp_y = x0_y - ty; 496 | ind = (tx.*x0_x + ty.*x0_y) > 0; 497 | x3_x(ind) = tmp_x(ind); 498 | x3_y(ind) = tmp_y(ind); 499 | 500 | % x4 (R^2n) 501 | % tx = alpha/2 * (A11 - A12); 502 | % ty = alpha/2 * (A12 - A22); 503 | % if (tx*x0_x + ty*x0_y) > 0 504 | % x4_x = x0_x - tx; 505 | % x4_y = x0_y - ty; 506 | % else 507 | % x4_x = x0_x + tx; 508 | % x4_y = x0_y + ty; 509 | % end 510 | tx = alpha/2 * (A11 - A12); 511 | ty = alpha/2 * (A12 - A22); 512 | x4_x = x0_x + tx; 513 | x4_y = x0_y + ty; 514 | tmp_x = x0_x - tx; 515 | tmp_y = x0_y - ty; 516 | ind = (tx.*x0_x + ty.*x0_y) > 0; 517 | x4_x(ind) = tmp_x(ind); 518 | x4_y(ind) = tmp_y(ind); 519 | 520 | % cost functions 521 | cost = inf([size(a) 5]); 522 | 523 | % x1 524 | t1 = a.*x1 + c; 525 | tx = x1 - ux; 526 | cost(:,:,1) = t1.*t1 + mu*(tx.*tx + uy.*uy) + alpha*abs(x1); 527 | 528 | % x2 529 | t1 = b.*x2 + c; 530 | ty = x2 - uy; 531 | cost(:,:,2) = t1.*t1 + mu*(ux.*ux + ty.*ty) + alpha*abs(x2); 532 | 533 | % x3 534 | t1 = a.*x3_x + b.*x3_y + c; 535 | tx = x3_x - ux; 536 | ty = x3_y - uy; 537 | cost(:,:,3) = t1.*t1 + mu*(tx.*tx + ty.*ty) + alpha*(abs(x3_x) + abs(x3_y)); 538 | 539 | % x4 540 | t1 = a.*x4_x + b.*x4_y + c; 541 | tx = x4_x - ux; 542 | ty = x4_y - uy; 543 | cost(:,:,4) = t1.*t1 + mu*(tx.*tx + ty.*ty) + alpha*(abs(x4_x) + abs(x4_y)); 544 | 545 | % x5 546 | cost(:,:,5) = c.*c + mu*(ux.*ux + uy.*uy); % [0 0] solution 547 | 548 | % find minimum 549 | cost_min = min(cost, [], 3); 550 | x_can = cat(3, x1, zeros(size(a)), x3_x, x4_x, zeros(size(a))); 551 | y_can = cat(3, zeros(size(a)), x2, x3_y, x4_y, zeros(size(a))); 552 | ind = cost == cost_min; 553 | 554 | % return 555 | x = sum(ind .* x_can, 3); 556 | y = sum(ind .* y_can, 3); 557 | 558 | end 559 | -------------------------------------------------------------------------------- /matlab/cws_gpu_wrapper.m: -------------------------------------------------------------------------------- 1 | function [A, phi] = cws_gpu_wrapper(I0, I, opt) 2 | % This is a MATLAB wrapper for GPU solver `cws`. 3 | % 4 | % Inputs: 5 | % - I0: reference image 6 | % - I: measurement image 7 | % - opt: solver options 8 | % - priors: prior weights [alpha beta gamma beta] 9 | % (default = [0.1,0.1,100,5]) 10 | % - iter: [total_alternating_iter A-update_iter phi-update_iter] 11 | % (default = [3 10 20]) 12 | % - mu: ADMM parameters (default = [0.1 100]) 13 | % - tol: phi-update tolerance stopping criteria (default = 0.05) 14 | % - L: padding size [pad_width pad_height] 15 | % (default nearest power of 2 of out_size, each in range [2, 256]) 16 | % - isverbose: output verbose [A_verbose phi_verbose] 17 | % (default = [0 0]) 18 | % - size: output size (this is unavailable in the CPU version) 19 | % 20 | % Outputs: 21 | % - A: intensity 22 | % - phi: wavefront 23 | % 24 | % See also `cws.m` for its CPU version, and `cpu_gpu_comparison` as demo. 25 | 26 | % get cws solver 27 | tpath = mfilename('fullpath'); 28 | if ismac || isunix 29 | cws_path = [tpath(1:end-23) '/cuda/bin/cws']; 30 | elseif ispc 31 | cws_path = ['"' tpath(1:end-23) '\cuda\bin\cws' '"']; 32 | else 33 | disp('Platform not supported.'); 34 | end 35 | 36 | % check options 37 | if nargin == 3 38 | % check option, if not existed, use default parameters 39 | if ~isfield(opt,'priors') 40 | opt.priors = [0.1 0.1 100 5]; 41 | end 42 | if ~isfield(opt,'iter') 43 | opt.iter = [3 10 20]; 44 | end 45 | if ~isfield(opt,'mu') 46 | opt.mu = [0.1 100]; 47 | end 48 | if ~isfield(opt,'tol') 49 | opt.tol = 0.05; 50 | end 51 | if ~isfield(opt,'L') 52 | opt.L = min((2.^ceil(log2(size(I0))) - size(I0)) / 2, 256); % suit to [1024 1024] 53 | end 54 | if ~isfield(opt,'isverbose') 55 | opt.isverbose = 0; 56 | end 57 | 58 | % parameters checkers 59 | opt.priors = max(0, opt.priors); 60 | if length(opt.priors) ~= 4 61 | error('length of opt.priors must equal to 4!'); 62 | end 63 | opt.iter = round(opt.iter); 64 | if length(opt.iter) ~= 3 65 | error('length of opt.iter must equal to 3!'); 66 | end 67 | opt.mu = max(0, opt.mu); 68 | if length(opt.mu) ~= 2 69 | error('length of opt.mu must equal to 2!'); 70 | end 71 | opt.tol = max(0, opt.tol); 72 | if length(opt.tol) ~= 1 73 | error('length of opt.tol must equal to 1!'); 74 | end 75 | opt.L = max(0, opt.L); 76 | if length(opt.L) ~= 2 77 | error('length of opt.L must equal to 2!'); 78 | end 79 | if length(opt.isverbose) ~= 1 80 | error('length of opt.isverbose must equal to 1!'); 81 | end 82 | else % use default parameters 83 | opt.priors = [0.1 0.1 100 5]; 84 | opt.iter = [3 10 20]; 85 | opt.mu = [0.1 100]; 86 | opt.tol = 0.05; 87 | opt.L = min((2.^ceil(log2(size(I0))) - size(I0)) / 2, 256); % suit to [1024 1024] 88 | opt.isverbose = 0; 89 | end 90 | 91 | if isfield(opt,'size') 92 | opt.size = max(0, 2*round(opt.size/2)); 93 | issize = [' -s ' num2str(opt.size(1)) ' -s ' num2str(opt.size(2))]; 94 | else % default the same size as inputs 95 | issize = [' -s ' num2str(size(I0,1)) ' -s ' num2str(size(I0,2))]; 96 | end 97 | 98 | if abs(sign(opt.isverbose)) 99 | isverbose_str = ' --verbose'; 100 | else 101 | isverbose_str = ' '; 102 | end 103 | 104 | % write input data to disk 105 | J = cat(3, I0, I); 106 | if ~isa(J,'uint8') 107 | J = uint8(255 * J / max(J(:))); 108 | end 109 | I_path = {['I1_' char(datetime('today')) '.png']; ... 110 | ['I2_' char(datetime('today')) '.png']}; 111 | arrayfun(@(i) imwrite(J(:,:,i), I_path{i}), 1:2); 112 | 113 | % output name 114 | out_name = ['out_tmp_' char(datetime('today')) '.flo']; 115 | 116 | % cat solver info 117 | cws_run = [cws_path ... 118 | ' -p ' num2str(opt.priors(1)) ' -p ' num2str(opt.priors(2)) ... 119 | ' -p ' num2str(opt.priors(3)) ' -p ' num2str(opt.priors(4)) ... 120 | ' -i ' num2str(opt.iter(1)) ' -i ' num2str(opt.iter(2)) ... 121 | ' -i ' num2str(opt.iter(3)) ... 122 | ' -m ' num2str(opt.mu(1)) ' -m ' num2str(opt.mu(2)) ... 123 | ' -t ' num2str(opt.tol) isverbose_str issize ... 124 | ' -l ' num2str(opt.L(1)) ' -l ' num2str(opt.L(2))... 125 | ' -o ' out_name ' -f "' I_path{1} '" -f "' I_path{2} '"']; 126 | 127 | % run GPU solver 128 | system(cws_run); 129 | 130 | % read and load the results 131 | test = readFlowFile(out_name); 132 | A = test(:,:,1); 133 | phi = test(:,:,2); 134 | 135 | % remove output file 136 | if ismac || isunix 137 | system(['rm ' out_name]); 138 | arrayfun(@(i) system(['rm ' I_path{i}]), 1:numel(I_path)); 139 | elseif ispc 140 | system(['del ' out_name]); 141 | arrayfun(@(i) system(['del ' I_path{i}]), 1:numel(I_path)); 142 | else 143 | disp('Platform not supported.'); 144 | end 145 | 146 | end 147 | 148 | 149 | function img = readFlowFile(filename) 150 | 151 | % readFlowFile read a flow file FILENAME into 2-band image IMG 152 | 153 | % According to the c++ source code of Daniel Scharstein 154 | % Contact: schar@middlebury.edu 155 | 156 | % Author: Deqing Sun, Department of Computer Science, Brown University 157 | % Contact: dqsun@cs.brown.edu 158 | % $Date: 2007-10-31 16:45:40 (Wed, 31 Oct 2006) $ 159 | 160 | % Copyright 2007, Deqing Sun. 161 | % 162 | % All Rights Reserved 163 | % 164 | % Permission to use, copy, modify, and distribute this software and its 165 | % documentation for any purpose other than its incorporation into a 166 | % commercial product is hereby granted without fee, provided that the 167 | % above copyright notice appear in all copies and that both that 168 | % copyright notice and this permission notice appear in supporting 169 | % documentation, and that the name of the author and Brown University not be used in 170 | % advertising or publicity pertaining to distribution of the software 171 | % without specific, written prior permission. 172 | % 173 | % THE AUTHOR AND BROWN UNIVERSITY DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, 174 | % INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY 175 | % PARTICULAR PURPOSE. IN NO EVENT SHALL THE AUTHOR OR BROWN UNIVERSITY BE LIABLE FOR 176 | % ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 177 | % WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 178 | % ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 179 | % OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 180 | 181 | TAG_FLOAT = 202021.25; % check for this when READING the file 182 | 183 | % sanity check 184 | if isempty(filename) == 1 185 | error('readFlowFile: empty filename'); 186 | end 187 | 188 | idx = strfind(filename, '.'); 189 | idx = idx(end); 190 | 191 | if length(filename(idx:end)) == 1 192 | error('readFlowFile: extension required in filename %s', filename); 193 | end 194 | 195 | if strcmp(filename(idx:end), '.flo') ~= 1 196 | error('readFlowFile: filename %s should have extension ''.flo''', filename); 197 | end 198 | 199 | fid = fopen(filename, 'r'); 200 | if (fid < 0) 201 | error('readFlowFile: could not open %s', filename); 202 | end 203 | 204 | tag = fread(fid, 1, 'float32'); 205 | width = fread(fid, 1, 'int32'); 206 | height = fread(fid, 1, 'int32'); 207 | 208 | % sanity check 209 | 210 | if (tag ~= TAG_FLOAT) 211 | error('readFlowFile(%s): wrong tag (possibly due to big-endian machine?)', filename); 212 | end 213 | 214 | if (width < 1 || width > 99999) 215 | error('readFlowFile(%s): illegal width %d', filename, width); 216 | end 217 | 218 | if (height < 1 || height > 99999) 219 | error('readFlowFile(%s): illegal height %d', filename, height); 220 | end 221 | 222 | nBands = 2; 223 | 224 | % arrange into matrix form 225 | tmp = fread(fid, inf, 'float32'); 226 | tmp = reshape(tmp, [width*nBands, height]); 227 | tmp = tmp'; 228 | img(:,:,1) = tmp(:, (1:width)*nBands-1); 229 | img(:,:,2) = tmp(:, (1:width)*nBands); 230 | 231 | fclose(fid); 232 | 233 | end 234 | 235 | -------------------------------------------------------------------------------- /matlab/main_wavefront_solver.m: -------------------------------------------------------------------------------- 1 | function phi = main_wavefront_solver(img, beta, opt, warping_iter) 2 | % [in] img the stacked reference and captured image pair 3 | % [in] beta smoothness paramter \beta 4 | % [in] opt solver options: 5 | % opt.mu [vector] proximal paramter at each pyramid level 6 | % opt.L [cell] additional boundary unknowns at each pyramid level 7 | % (better larger than [2 2] for each cell element) 8 | % opt.isverbose [bool] is the solver verbose or not 9 | % opt.ls [char] which linear solver to use; 'ADMM' or 'CG' 10 | % opt.iter [double] linear solver iterations 11 | % warping_iter [vector] warping iterations at each pyramid level 12 | % [out] phi the recovered wavefront 13 | % 14 | % From https://github.com/vccimaging/MegapixelAO. 15 | 16 | % set default values 17 | if ~isfield(opt,'mu') 18 | opt.mu = 100; 19 | end 20 | if ~isfield(opt,'L') 21 | opt.L = repmat({[2 2]}, [1 numel(opt.mu)]); 22 | end 23 | if ~isfield(opt,'isverbose') 24 | opt.ls = 0; 25 | end 26 | if ~isfield(opt,'ls') || ((opt.ls ~= 'ADMM') && (opt.ls ~= 'CG')) 27 | opt.ls = 'ADMM'; 28 | opt.iter = 10; 29 | end 30 | if ~isfield(opt,'iter') 31 | switch opt.ls 32 | case 'ADMM' 33 | opt.iter = 10; 34 | case 'CG' 35 | opt.iter = 1000; 36 | end 37 | end 38 | if ~exist('warping_iter','var') 39 | warping_iter = repmat(2, [1 numel(opt.mu)]); 40 | end 41 | 42 | % check parameter length 43 | if ~((numel(warping_iter) == numel(opt.L)) && (numel(opt.L) == numel(opt.mu))) 44 | error('numel(warping_iter) == numel(opt.L) == numel(opt.mu) not satisfied!'); 45 | end 46 | 47 | % get number of pyramid levels 48 | pyramid_level = numel(warping_iter)-1; 49 | 50 | % define scale for x and y dimensions 51 | scale = [2 2]; 52 | dim_org = [size(img,1) size(img,2)]; 53 | 54 | % check size of L 55 | if length(opt.L) ~= pyramid_level + 1 56 | error('length(opt.L) ~= pyramid_level + 1!'); 57 | end 58 | 59 | % the warping scheme 60 | for i = pyramid_level:-1:0 % (pyramid warping) 61 | img_pyramid = imresize(img, dim_org./scale.^i, 'bilinear'); 62 | dim = size(img_pyramid(:,:,2)); 63 | [x, y] = meshgrid(1:dim(2),1:dim(1)); 64 | 65 | % pre-cache cubic spline coefficients 66 | c = cbanal(cbanal(img_pyramid(:,:,1)).').'; 67 | 68 | % options 69 | opt_small = opt; 70 | opt_small.L = opt.L{i+1}; 71 | opt_small.mu = opt.mu(i+1); 72 | 73 | % Matrix M 74 | M = @(x) x(1+opt_small.L(1):end-opt_small.L(1), ... 75 | 1+opt_small.L(2):end-opt_small.L(2)); 76 | 77 | for j = 1:warping_iter(i+1) % (in-level warping) 78 | if exist('phi','var') 79 | wx = imfilter(phi, [-1 1 0], 'replicate'); 80 | wy = imfilter(phi, [-1 1 0]', 'replicate'); 81 | x_new = -M(wx) + x; 82 | y_new = -M(wy) + y; 83 | temp_img = cat(3, cbinterp(c,y_new,x_new), img_pyramid(:,:,2)); 84 | else 85 | temp_img = img_pyramid; 86 | end 87 | switch opt_small.ls % linear solver 88 | case 'ADMM' 89 | [~, phi_delta] = main_ADMM_fast(temp_img, beta, [], opt_small); 90 | case 'CG' 91 | [~, phi_delta] = main_cg_fast(temp_img, beta, opt_small); 92 | end 93 | 94 | % check if the mean of phi_delta is too small; for early termination 95 | mean_phi = mean(abs(phi_delta(:))); 96 | if mean_phi < 0 % 0.1/prod(scale)^i % ( [0.314 0.628 1.257 2.094] = 2*pi./[20 10 5 3] ) 97 | disp(['Pyr ' num2str(i) ', Warp ' num2str(j) ... 98 | ', Mean of delta phi = ' num2str(mean_phi) ' < eps: ' ... 99 | 'Early termination']) 100 | if i == pyramid_level && j == 1 101 | phi_delta = zeros(dim+2*opt_small.L); 102 | phi = phi_delta; 103 | end 104 | break; 105 | else 106 | disp(['Pyr ' num2str(i) ', Warp ' num2str(j) ... 107 | ', Mean of delta phi = ' num2str(mean_phi)]) 108 | end 109 | 110 | % update phi 111 | if i == pyramid_level && j == 1 112 | phi = phi_delta; 113 | else 114 | phi = phi_delta + phi; 115 | end 116 | end 117 | 118 | if i > 0 119 | dim_next = scale.*dim + 2*opt.L{i}; 120 | phi = prod( dim_next./(size(phi)) ) * imresize(phi, dim_next, 'bilinear'); 121 | end 122 | end 123 | 124 | % singular case 125 | if nnz(phi) == 0 126 | disp('All estimations were rejected; set output to zero') 127 | phi = zeros(size(phi_delta)); 128 | end 129 | 130 | % get center part 131 | phi = phi(1+opt.L{1}(1):end-opt.L{1}(1), 1+opt.L{1}(2):end-opt.L{1}(2)); 132 | 133 | end 134 | 135 | 136 | % ADMM linear solver 137 | function [x, x_full] = main_ADMM_fast(img, beta, x, opt) 138 | 139 | %% define constants and operators 140 | 141 | % define operators in spatial domain 142 | nabla_x_kern = [0 0 0; -1 1 0; 0 0 0]; 143 | nabla_y_kern = [0 -1 0; 0 1 0; 0 0 0]; 144 | nabla2_kern = [0 1 0; 1 -4 1; 0 1 0]; 145 | 146 | % define sizes 147 | M = size(img(:,:,1)); 148 | L1 = (size(nabla_x_kern) - 1) / 2; L1 = L1.*opt.L; 149 | N = M + 2*L1; 150 | 151 | % boundary mask 152 | M1 = @(u) cat(3, u(L1(1)+1:end-L1(1),L1(2)+1:end-L1(2),1), ... 153 | u(L1(1)+1:end-L1(1),L1(2)+1:end-L1(2),2)); 154 | 155 | 156 | %% formulate the problem 157 | 158 | % specify boundary conditions 159 | bc = 'symmetric'; % use DCT 160 | 161 | % define forward operators 162 | nabla = @(phi) cat(3, imfilter(phi, nabla_x_kern, bc), ... 163 | imfilter(phi, nabla_y_kern, bc)); 164 | 165 | % define adjoint operators 166 | nablaT = @(grad_phi) imfilter(grad_phi(:,:,1), rot90(nabla_x_kern,2), bc) + ... 167 | imfilter(grad_phi(:,:,2), rot90(nabla_y_kern,2), bc); 168 | 169 | 170 | %% run proximal algorithm 171 | 172 | % initialization 173 | if isempty(x) 174 | x = zeros(N); 175 | end 176 | zeta = zeros([N 2]); 177 | 178 | % get the matrices 179 | [gt, gx, gy] = partial_deriv(img); 180 | 181 | % pre-compute 182 | mu = opt.mu; 183 | gxy = gx.*gy; 184 | gxx = gx.^2; 185 | gyy = gy.^2; 186 | denom = gxx + gyy + mu/2; 187 | 188 | % store in memory at run-time 189 | w11 = (mu/2 + gyy) ./ denom; 190 | w12 = - gxy ./ denom; 191 | w13 = - gx.*gt ./ denom; 192 | w21 = (mu/2 + gxx) ./ denom; 193 | w22 = - gxy ./ denom; 194 | w23 = - gy.*gt ./ denom; 195 | 196 | % proximal algorithm 197 | if opt.isverbose 198 | disp('start ADMM iteration ...') 199 | end 200 | objval = zeros(opt.iter,1); 201 | res = zeros(opt.iter,1); 202 | time = zeros(opt.iter,1); 203 | 204 | tic; 205 | for k = 1:opt.iter 206 | % x-update 207 | if k == 1 208 | H = N(1); 209 | W = N(2); 210 | [x_coord,y_coord] = meshgrid(0:W-1,0:H-1); 211 | mat_x_hat = - (mu+2*beta) * ... 212 | (2*cos(pi*x_coord/W) + 2*cos(pi*y_coord/H) - 4); 213 | mat_x_hat(1) = 1; 214 | else 215 | x = idct(idct(dct(dct(mu*nablaT(u-zeta)).').' ./ mat_x_hat).').'; 216 | end 217 | 218 | % pre-compute nabla_x 219 | nabla_x = nabla(x); 220 | 221 | % u-update 222 | u = nabla_x + zeta; 223 | u_temp = M1(u); 224 | Mu_x = u_temp(:,:,1); 225 | Mu_y = u_temp(:,:,2); 226 | 227 | % update u 228 | u(L1(1)+1:end-L1(1), L1(2)+1:end-L1(2), 1) = ... 229 | w11.*Mu_x + w12.*Mu_y + w13; 230 | u(L1(1)+1:end-L1(1), L1(2)+1:end-L1(2), 2) = ... 231 | w21.*Mu_y + w22.*Mu_x + w23; 232 | 233 | % zeta-update 234 | zeta = zeta + nabla_x - u; 235 | 236 | % record 237 | if opt.isverbose 238 | if k == 1 239 | % define operators 240 | G = @(u) gx.*u(:,:,1) + gy.*u(:,:,2); 241 | nabla2 = @(phi) imfilter(phi, nabla2_kern, bc); 242 | GT = @(Gu) cat(3, gx.*Gu, gy.*Gu); 243 | M1T = @(Mu) padarray(Mu, L1); 244 | 245 | % define objective function 246 | obj = @(phi) sum(sum(abs(G(M1(nabla(phi)))+gt).^2)) + ... 247 | beta * sum(sum(sum(abs(nabla(phi)).^2))); 248 | 249 | % define its gradient 250 | grad = @(phi) nablaT(M1T(GT(G(M1(nabla(phi)))+gt))) + ... 251 | 2*beta*nabla2(phi); 252 | end 253 | objval(k) = obj(x); 254 | res(k) = sum(sum(abs(grad(x)).^2)); 255 | time(k) = toc; 256 | disp(['ADMM iter: ' num2str(k) ... 257 | ', obj = ' num2str(objval(k),'%e') ... 258 | ', res = ' num2str(res(k),'%e')]) 259 | end 260 | end 261 | 262 | % do median filtering at the last step 263 | temp = cat(3, medfilt2( u(:,:,1) - zeta(:,:,1), [3 3] ), ... 264 | medfilt2( u(:,:,2) - zeta(:,:,2), [3 3] )); 265 | x = idct(idct( dct(dct( mu*nablaT(temp) ).').' ./ mat_x_hat ).').'; 266 | if opt.isverbose 267 | disp(['final objective: ' num2str(obj(x),'%e')]) 268 | end 269 | toc 270 | 271 | % return masked x 272 | x_full = x; 273 | x = M1(cat(3,x,x)); 274 | x = x(:,:,1); 275 | 276 | end 277 | 278 | 279 | % CG linear solver 280 | function [x, x_full] = main_cg_fast(img, beta, opt) 281 | 282 | %% define consts and operators 283 | 284 | % define operators in spatial domain 285 | nabla_x_kern = [0 0 0; 1 -1 0; 0 0 0]; 286 | nabla_y_kern = [0 1 0; 0 -1 0; 0 0 0]; 287 | nabla2_kern = -[0 1 0; 1 -4 1; 0 1 0]; 288 | 289 | % define sizes 290 | M = size(img(:,:,1)); 291 | L1 = (size(nabla_x_kern) - 1) / 2; L1 = L1.*opt.L; 292 | N = M + 2*L1; 293 | 294 | 295 | %% formulate the problem 296 | 297 | % specify boundary conditions 298 | boundary_cond = 'symmetric'; % use DCT 299 | 300 | % get the matrices 301 | [gt, gx, gy] = partial_deriv(img); 302 | 303 | % define forward operators 304 | nabla = @(phi) cat(3, imfilter(phi, nabla_x_kern, boundary_cond), ... 305 | imfilter(phi, nabla_y_kern, boundary_cond)); 306 | 307 | % define adjoint operators 308 | nablaT = @(grad_phi) imfilter(grad_phi(:,:,1), rot90(nabla_x_kern,2), boundary_cond) + ... 309 | imfilter(grad_phi(:,:,2), rot90(nabla_y_kern,2), boundary_cond); 310 | 311 | % define Laplacian operator 312 | nabla2 = @(phi) imfilter(phi, nabla2_kern, boundary_cond); 313 | 314 | % boundary masks 315 | M1 = @(u) cat(3, u(L1(1)+1:end-L1(1),L1(2)+1:end-L1(2),1), ... 316 | u(L1(1)+1:end-L1(1),L1(2)+1:end-L1(2),2)); 317 | MT = @(Mu) padarray(Mu, L1); 318 | 319 | 320 | %% run CG 321 | 322 | % define G and GT 323 | G = @(x) gx.*x(:,:,1) + gy.*x(:,:,2); 324 | GT = @(x) cat(3, gx.*x, gy.*x); 325 | 326 | % define A and b 327 | A = @(x) nablaT(MT(GT(G(M1(nabla(x)))))) + beta*nabla2(x); 328 | b = -nablaT(MT(GT(gt))); 329 | 330 | % initialize 331 | x = zeros(N); 332 | 333 | % get initial r and p 334 | r = b - A(x); 335 | p = r; 336 | 337 | % GGT 338 | disp('start CG iteration ...') 339 | objval = zeros(opt.iter,1); 340 | res = zeros(opt.iter,1); 341 | time = zeros(opt.iter,1); 342 | 343 | tic; 344 | for k = 1:opt.iter 345 | alpha = sum(abs(r(:)).^2) / sum(sum(conj(p).*A(p))); 346 | x = x + alpha*p; 347 | r_new = r - alpha*A(p); 348 | beta = sum(abs(r_new(:)).^2) / sum(abs(r(:)).^2); 349 | r = r_new; 350 | p = r + beta*p; 351 | 352 | if opt.isverbose 353 | if k == 1 354 | % define objective function 355 | obj = @(phi) sum(sum(abs(G(M1(nabla(phi)))+gt).^2)) + ... 356 | beta * sum(sum(sum(abs(nabla(phi)).^2))); 357 | end 358 | objval(k) = obj(x); 359 | res(k) = sum(abs(r(:))); 360 | time(k) = toc; 361 | disp(['CG iter: ' num2str(k) ... 362 | ', obj = ' num2str(objval(k),'%e') ... 363 | ', res = ' num2str(res(k),'%e')]) 364 | else 365 | if ~mod(k,100) 366 | disp(['CG iter: ' num2str(k)]) 367 | end 368 | end 369 | end 370 | 371 | % record final objective 372 | obj = @(phi) sum(sum(abs(G(M1(nabla(phi)))+gt).^2)) + ... 373 | beta * sum(sum(sum(abs(nabla(phi)).^2))); 374 | disp(['final objective: ' num2str(obj(x),'%e')]) 375 | toc 376 | 377 | % return masked x 378 | x_full = x; 379 | x = M1(cat(3,x,x)); 380 | x = x(:,:,1); 381 | 382 | end 383 | 384 | 385 | % function to get gx, gy and gt 386 | function [It,Ix,Iy] = partial_deriv(images) 387 | 388 | % derivative kernel 389 | h = [1 -8 0 8 -1]/12; 390 | 391 | % blending ratio 392 | b = 0.5; 393 | 394 | % get images 395 | img1 = images(:,:,1); 396 | img2 = images(:,:,2); 397 | 398 | % get gt 399 | It = images(:,:,2) - images(:,:,1); 400 | 401 | % first compute derivative then warp 402 | I2x = imfilter(img2, h, 'replicate'); 403 | I2y = imfilter(img2, h', 'replicate'); 404 | 405 | % temporal average 406 | I1x = imfilter(img1, h, 'replicate'); 407 | I1y = imfilter(img1, h', 'replicate'); 408 | 409 | % get gx and gy 410 | Ix = b*I2x + (1-b)*I1x; 411 | Iy = b*I2y + (1-b)*I1y; 412 | 413 | end 414 | 415 | 416 | % cubic spline interpolation function (backward: coeffcient -> image) 417 | function out = cbinterp(c,x,y) 418 | 419 | % calculate movements 420 | px = floor(x); 421 | fx = x - px; 422 | py = floor(y); 423 | fy = y - py; 424 | 425 | % define device functions 426 | w0 = @(a) (1/6)*(a.*(a.*(-a + 3) - 3) + 1); 427 | w1 = @(a) (1/6)*(a.*a.*(3*a - 6) + 4); 428 | w2 = @(a) (1/6)*(a.*(a.*(-3*a + 3) + 3) + 1); 429 | w3 = @(a) (1/6)*(a.*a.*a); 430 | r = @(x,c0,c1,c2,c3) c0.*w0(x) + c1.*w1(x) + c2.*w2(x) + c3.*w3(x); 431 | 432 | % define texture function 433 | tex = @(x,y) interp2(c,y,x,'nearest',0); 434 | 435 | % elementwise lookup 436 | out = r(fy, ... 437 | r(fx, tex(px-1,py-1), tex(px,py-1), tex(px+1,py-1), tex(px+2,py-1)), ... 438 | r(fx, tex(px-1,py ), tex(px,py ), tex(px+1,py ), tex(px+2,py )), ... 439 | r(fx, tex(px-1,py+1), tex(px,py+1), tex(px+1,py+1), tex(px+2,py+1)), ... 440 | r(fx, tex(px-1,py+2), tex(px,py+2), tex(px+1,py+2), tex(px+2,py+2))); 441 | end 442 | 443 | 444 | % cubic spline interpolation function (forward: image -> coeffcient) 445 | function c = cbanal(img) 446 | 447 | [m,n] = size(img); 448 | 449 | % A = toeplitz([4 1 zeros(1,m-2)]) / 6; 450 | 451 | c = zeros(m,n); 452 | for i = 1:n 453 | % c(:,i) = A \ y(:,i); 454 | c(:,i) = 6 * tridisolve(img(:,i)); 455 | end 456 | 457 | end 458 | 459 | 460 | % triangular linear solver 461 | function x = tridisolve(d) 462 | 463 | % initialize 464 | x = d; 465 | 466 | % get length 467 | m = length(x); 468 | 469 | % define a, b and c 470 | b = 4*ones(m,1); 471 | 472 | % forward 473 | for j = 1:m-1 474 | mu = 1/b(j); 475 | b(j+1) = b(j+1) - mu; 476 | x(j+1) = x(j+1) - mu*x(j); 477 | end 478 | 479 | % backward 480 | x(m) = x(m)/b(m); 481 | for j = m-1:-1:1 482 | x(j) = (x(j)-x(j+1))/b(j); 483 | end 484 | 485 | end 486 | -------------------------------------------------------------------------------- /matlab/speckle_pattern_baseline.m: -------------------------------------------------------------------------------- 1 | function [A, phi, D] = speckle_pattern_baseline(I0, I) 2 | % A: amplitude 3 | % phi: phase image 4 | % D: dark-field scattering 5 | 6 | % window size 7 | n = [3 3]; 8 | h = ones(n); 9 | 10 | % calculate A2 11 | A = meanfilt(I,h) ./ meanfilt(I0,h); 12 | % A = medfilt2(A, n, 'symmetric'); 13 | 14 | % calculate D 15 | D = stdfilt(I,h) ./ stdfilt(I0,h); 16 | D = D ./ A; 17 | 18 | % calculate local pixel shifts (wavefront slopes) 19 | [w, ~] = imregdemons(I, A.*I0); 20 | 21 | % integrate phi from w 22 | phi = poisson_solver(w(:,:,1), w(:,:,2)); 23 | phi = phi - mean2(phi); 24 | 25 | % square to get amplitude 26 | A = sqrt(A); 27 | 28 | end 29 | 30 | 31 | function J = meanfilt(I, h) 32 | %MEANFILT Local mean of image. 33 | 34 | J = imfilter(I, h/sum(h(:)) , 'replicate'); 35 | 36 | end 37 | -------------------------------------------------------------------------------- /matlab/utils/LoadMetroProData.m: -------------------------------------------------------------------------------- 1 | % 2 | % function [ dat, xl, yl ] = LoadMetroProData( filename ) 3 | % 4 | % Read MetroPro binary data file and returns in dat. If failed, dat is empty. 5 | % 6 | % x and y coordinates of data points are calculaed as (see NB below) 7 | % xl = dxy * ( -(size(dat,2)-1)/2 : (size(dat,2)-1)/2 ); 8 | % yl = dxy * ( -(size(dat,1)-1)/2 : (size(dat,1)-1)/2 ); 9 | % 10 | % dat(iy,jx) is the value at the location (x,y) = ( xl(jx), yl(iy) ) 11 | % which is the convention of matlab 12 | % 13 | % The following commands will show data 14 | % plot( xl, dat(floor(end/2), :), yl, dat(:, floor(end/2) ) 15 | % mesh( xl, yl, dat*1e9 ) 16 | % 17 | % NB) aboslute values of coordinate is not useful, 18 | % because the mesurement origin (0,0) is located at the top-left corner 19 | % and not at the center of the mirror. 20 | % 21 | function [ dat, xl, yl ] = LoadMetroProData( filename ) 22 | 23 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 24 | % 25 | % March 20, 2014 26 | % Zernike term/tilt handling has been removed, and this function works without other functions 27 | % 28 | % September 17, 2013 29 | % Third output argument, xl, added 30 | % 31 | % September 9, 2013 HY 32 | % zernike term removal added 33 | % 34 | % August 30, 2013 HY 35 | % 1) piston/tilt removal is added 36 | % 2) matrxi transversed to follow matlab convention 37 | % 38 | % MetroPro format stores data row major, starting from top-left 39 | % I.e., in units of spacing dxy, data at the following locations are stored in sequence 40 | % (x,y) = (1,1), (2,1), (3,1)... 41 | % When this series of data is read in into matlab array dat, the data at (i,j) is loaded at dat(i,j) 42 | % To use the matlab convention, i.e., x axis horizontal or second column index, matrix is transvered. 43 | % Now dat(iy, jx) is the value at phyical location (jx, iy)*dxy. 44 | % 45 | % When using a nominal x-y convention, positive directions of x and y axises are right and up. 46 | % The data need to be revered in the y-direction. Because ... 47 | % The first data at top left is at (x,y) = (1,1) and the last data at bot-right is at (N,N). 48 | % For x axis, this is the correct ordering (1 to N from left to right), 49 | % but, for y, it is opposite (1 to N from top to bottom). 50 | % To change the y ordering, data(iy, jx) needs to be data(N-iy+1, jx) or flipud. 51 | % This makes the phyical location of y for dat(1,j) is below of dat(2,j). 52 | % 53 | % August 19, 2013 Hiro Yamamoto 54 | % Based on MetroPro Reference Guide, 55 | % Section 12 "Data Format and Conversion" which covers 56 | % header format 2 and 3 and phase_res format 0, 1 and 2. 57 | % 58 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 59 | 60 | % try to open the file 61 | fid = fopen( filename, 'r', 'b' ); 62 | if fid == -1 63 | error('File ''%s'' does not exist', filename); 64 | end; 65 | 66 | % read the header information 67 | hData = readerHeader( fid ); 68 | if hData.hFormat < 0 69 | fclose(fid); 70 | error('Format unknown'); 71 | end; 72 | 73 | % read the phasemap data 74 | % skip the header and intensity data 75 | fseek( fid, hData.hSize + hData.IntNBytes, -1 ); 76 | [dat, count] = fread( fid, hData.XN*hData.YN, 'int32' ); 77 | fclose( fid ); 78 | 79 | if count ~= hData.XN*hData.YN 80 | error('data could not fully read'); 81 | end 82 | 83 | % mark unmeasured data as NaN 84 | dat( dat >= hData.invalid ) = NaN; 85 | % scale data to unit of meter 86 | dat = dat * hData.convFactor; 87 | % reshape data to XN x YN matrux 88 | dat = reshape( dat, hData.XN, hData.YN ); 89 | % transpose to make the matrix (NY, NX) 90 | dat = dat'; 91 | % change the y-axis diretion 92 | dat = flipud( dat ); 93 | 94 | % auxiliary data to return 95 | dxy = hData.CameraRes; 96 | 97 | if nargout >= 2 98 | xl = dxy * ( -(size(dat,2)-1)/2 : (size(dat,2)-1)/2 ); 99 | end 100 | if nargout >= 3 101 | yl = dxy * ( -(size(dat,1)-1)/2 : (size(dat,1)-1)/2 ); 102 | end 103 | 104 | return 105 | 106 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 107 | 108 | % read header information, and hData.hFormat = -1 when failed. 109 | % information of the header segment is from MetroPro Reference Guide 110 | % The manual covers format 1, 2 and 3, and this function fails if the data has unknown format. 111 | 112 | function hData = readerHeader( fid ) 113 | 114 | % first, check the format information to make sure this is MetroPro binary data 115 | hData.hFormat = -1; 116 | [magicNum, count] = fread( fid, 1, 'uint32' ); 117 | if count == 0 118 | return; 119 | end; 120 | 121 | [hData.hFormat, count] = readI16( fid ); 122 | if count == 0 123 | return; 124 | end; 125 | 126 | [hData.hSize, count] = readI32( fid ); 127 | if count == 0 128 | return; 129 | end; 130 | 131 | % check if the magic string and format are known ones 132 | if (hData.hFormat >= 1) && (hData.hFormat<=3) && (magicNum-hData.hFormat == hex2dec('881B036E')) 133 | % sprintf('MetroPro format %d with header size %d', hData.hFormat, hData.hSize ) 134 | else 135 | % sprintf('====> warning format unknown : %d\n', hData.hFormat); 136 | hData.hFormat = -1; 137 | return; 138 | end 139 | 140 | % read necessary data 141 | hData.invalid = hex2dec('7FFFFFF8'); 142 | 143 | % intensitity data, which we will skip over 144 | hData.IntNBytes = readI32( fid, 61-1 ); 145 | 146 | % top-left coordinate, which are useless 147 | hData.X0 = readI16( fid, 65-1 ); 148 | hData.Y0 = readI16( fid, 67-1 ); 149 | 150 | % number of data points along x and y 151 | hData.XN = readI16( fid, 69-1 ); 152 | hData.YN = readI16( fid, 71-1 ); 153 | 154 | % total data, 4 * XN * YN 155 | hData.PhaNBytes = readI32( fid, 73-1 ); 156 | 157 | % scale factor is determined by phase resolution tag 158 | phaseResTag = readI16( fid, 219-1 ); 159 | switch phaseResTag 160 | case 0, 161 | phaseResVal = 4096; 162 | 163 | case 1, 164 | phaseResVal = 32768; 165 | 166 | case 2, 167 | phaseResVal = 131072; 168 | 169 | otherwise 170 | phaseResVal = 0; 171 | end 172 | 173 | hData.waveLength = readReal( fid, 169-1 ); 174 | % Eq. in p12-6 in MetroPro Reference Guide 175 | hData.convFactor = readReal( fid, 165-1 ) * readReal( fid, 177-1 ) * hData.waveLength / phaseResVal; 176 | 177 | % bin size of each measurement 178 | hData.CameraRes = readReal( fid, 185-1 ); 179 | 180 | return 181 | 182 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 183 | 184 | % utility to read data, which are stored in big-endian format 185 | function [val, count] = readI16( fid, offset ) 186 | if nargin == 2 187 | fseek( fid, offset, -1 ); 188 | end 189 | [val, count] = fread( fid, 1, 'int16' ); 190 | return 191 | 192 | function [val, count] = readI32( fid, offset ) 193 | if nargin == 2 194 | fseek( fid, offset, -1 ); 195 | end 196 | [val, count] = fread( fid, 1, 'int32' ); 197 | return 198 | 199 | function [val, count] = readReal( fid, offset ) 200 | if nargin == 2 201 | fseek( fid, offset, -1 ); 202 | end 203 | [val, count] = fread( fid, 1, 'float' ); 204 | return -------------------------------------------------------------------------------- /matlab/utils/poisson_solver.m: -------------------------------------------------------------------------------- 1 | function rec = poisson_solver(gx,gy) 2 | % A DCT-based Poisson solver to integrate surface from the gradients. 3 | % 4 | % Inputs: 5 | % - gx: gradient on x 6 | % - gy: gradient on y 7 | % 8 | % Output: 9 | % - rec: reconstructed surface 10 | 11 | % pad size 12 | wid = 1; 13 | gx = padarray(gx,[wid wid]); 14 | gy = padarray(gy,[wid wid]); 15 | 16 | % define operators in spatial domain 17 | nabla_x_kern = [0 0 0; -1 1 0; 0 0 0]; 18 | nabla_y_kern = [0 -1 0; 0 1 0; 0 0 0]; 19 | 20 | % specify boundary conditions 21 | bc = 'symmetric'; % use DCT 22 | 23 | % define adjoint operator 24 | nablaT = @(gx,gy) imfilter(gx, rot90(nabla_x_kern,2), bc) + ... 25 | imfilter(gy, rot90(nabla_y_kern,2), bc); 26 | 27 | % genereate inverse kernel 28 | [H,W] = size(gx); 29 | [x_coord,y_coord] = meshgrid(0:W-1,0:H-1); 30 | mat_x_hat = 2*cos(pi*x_coord/W) + 2*cos(pi*y_coord/H) - 4; 31 | mat_x_hat(1) = 1; 32 | 33 | % do inverse filtering 34 | rec = idct2( dct2(nablaT(gx,gy)) ./ -mat_x_hat ); 35 | rec = rec(1+wid:end-wid,1+wid:end-wid); 36 | 37 | % redeem on boundary 38 | rec(1,:) = []; 39 | rec(end,:) = []; 40 | rec(:,1) = []; 41 | rec(:,end) = []; 42 | rec = padarray(rec, [1 1], 'replicate'); 43 | 44 | % zero-normalize 45 | rec = rec - mean2(rec); 46 | -------------------------------------------------------------------------------- /matlab/utils/tilt_removal.m: -------------------------------------------------------------------------------- 1 | function phi = tilt_removal(phi) 2 | 3 | phi = phi - min(phi(:)); 4 | [x,y] = meshgrid(1:size(phi,2),1:size(phi,1)); 5 | [~,~,~,S] = affine_fit(x,y,phi); 6 | phi = phi - S; 7 | phi = phi - min(phi(:)); 8 | 9 | end 10 | 11 | 12 | function [n,V,p,S] = affine_fit(x,y,z) 13 | %Computes the plane that fits best (lest square of the normal distance 14 | %to the plane) a set of sample points. 15 | %INPUTS: 16 | %x: 2D x coordinates 17 | %y: 2D y coordinates 18 | %z: input 2D array to be fit 19 | % (org) X: a N by 3 matrix where each line is a sample point 20 | % 21 | %OUTPUTS: 22 | % 23 | %n : a unit (column) vector normal to the plane 24 | %V : a 3 by 2 matrix. The columns of V form an orthonormal basis of the 25 | %plane 26 | %p : a point belonging to the plane 27 | %S : the fitted affine plane 28 | % 29 | %NB: this code actually works in any dimension (2,3,4,...) 30 | %Author: Adrien Leygue 31 | %Date: August 30 2013 32 | 33 | X = [x(:) y(:) z(:)]; 34 | 35 | %the mean of the samples belongs to the plane 36 | p = mean(X,1); 37 | 38 | %The samples are reduced: 39 | R = bsxfun(@minus,X,p); 40 | %Computation of the principal directions if the samples cloud 41 | [V,~] = eig(R'*R); 42 | %Extract the output from the eigenvectors 43 | n = V(:,1); 44 | V = V(:,2:end); 45 | 46 | % the fitted plane 47 | S = - (n(1)/n(3)*x + n(2)/n(3)*y - dot(n,p)/n(3)); 48 | end 49 | -------------------------------------------------------------------------------- /scripts/Figure2.m: -------------------------------------------------------------------------------- 1 | clc;clear;close all; 2 | addpath('../matlab/'); 3 | addpath('../matlab/utils/'); 4 | 5 | % input data path 6 | dpath = '../data/MLA-150-7AR-M/'; 7 | r = imread([dpath 'ref.tif']); 8 | s = imread([dpath 'cap.tif']); 9 | n = 1.46; % refractive index 10 | 11 | % parameters 12 | pixel_size = 6.45; % [um] 13 | z = 1.43e3; % [um] 14 | scale_factor = pixel_size^2/z; 15 | map = jet(256); 16 | 17 | % read data 18 | r = double(r)/2^14 * 255; 19 | s = double(s)/2^14 * 255; 20 | 21 | 22 | %% Methods 23 | 24 | %%% Slope-tracking 25 | [w, ~] = imregdemons(s, r, 200); 26 | phi_tracking = poisson_solver(w(:,:,1), w(:,:,2)); 27 | phi_tracking = phi_tracking - mean2(phi_tracking); 28 | phi_tracking = tilt_removal(phi_tracking/(n-1)*scale_factor); 29 | 30 | 31 | %%% Wang et al. 32 | opt.isverbose = 0; 33 | opt.L = {[20 20]}; 34 | opt.mu = 100; 35 | opt.iter = 30; 36 | warping_iter = 2; 37 | beta = 1; 38 | phi_wang = main_wavefront_solver(cat(3, r, s), beta, opt, warping_iter); 39 | phi_wang = tilt_removal(phi_wang/(n-1)*scale_factor); 40 | 41 | 42 | %%% Baseline 43 | [A_base, phi_base, D_base] = speckle_pattern_baseline(r, s); 44 | phi_base = tilt_removal(phi_base/(n-1)*scale_factor); 45 | 46 | 47 | %%% Ours 48 | opt_cws.priors = [0.5 0.5 100 5]; 49 | [A_ours, phi, wavefront_lap, I_warp] = cws(r, s, opt_cws); 50 | A_ours = A_ours .* (1 + pixel_size/z*wavefront_lap); 51 | A_ours = sqrt(A_ours); 52 | I = A_ours; % amplitude 53 | phi = tilt_removal(phi/(n-1)*scale_factor); 54 | 55 | % denoise a little bit ... 56 | phi = medfilt2(phi, [3 3], 'symmetric'); 57 | 58 | 59 | %% Show results 60 | 61 | % normalize to start from 0 62 | phi_tracking = phi_tracking - min(phi_tracking(:)); 63 | phi_wang = phi_wang - min(phi_wang(:)); 64 | phi_base = phi_base - min(phi_base(:)); 65 | phi = phi - min(phi(:)); 66 | 67 | % show results 68 | figure;imshow(phi_tracking, []); 69 | axis tight ij;colormap(map);pause(0.2); 70 | title('Berto et al. 2017'); 71 | 72 | figure;imshow(phi_wang, []); 73 | axis tight ij;colormap(map);pause(0.2) 74 | title('Wang et al. 2017'); 75 | 76 | figure;imshow(phi_base, []); 77 | axis tight ij;colormap(map);pause(0.2) 78 | title('Berujon et al. 2015'); 79 | 80 | figure;imshow(phi, []); 81 | axis tight ij;colormap(map);pause(0.2) 82 | title('Ours'); 83 | 84 | 85 | %% Cross-sections 86 | 87 | % for MLA 88 | % set parameters 89 | x1=253; y1=360; 90 | x2=670; y2=585; 91 | x3=1082; y3=800; 92 | 93 | % set two points for cross-sectioning 94 | xx = [x2 x3]; 95 | yy = [y2 y3]; 96 | 97 | % get the indexes 98 | a = (yy(2)-yy(1))/(xx(2)-xx(1)); 99 | b = yy(1) - a*xx(1); 100 | x = xx(1):xx(2); 101 | y = round(a*x+b); 102 | ind = sub2ind(size(phi),y,x); 103 | 104 | % show the points 105 | figure; imshow(phi, []); axis tight; 106 | hold on; line([xx(1) xx(2)],[yy(1) yy(2)],'Color',[1 0 0]); 107 | 108 | % get cross-section 109 | get_c = @(phi) phi(ind)'; 110 | 111 | % get the maxs 112 | max(phi_tracking(:)) 113 | max(phi_wang(:)) 114 | max(phi_base(:)) 115 | max(phi(:)) 116 | max_phi = round(max(phi(:)),2); 117 | 118 | % get cross-sections 119 | t_tracking = get_c(phi_tracking); 120 | t_wang = get_c(phi_wang); 121 | t_base = get_c(phi_base); 122 | t_ours = get_c(phi); 123 | 124 | % min-normalized 125 | t_tracking = t_tracking - min(t_tracking); 126 | t_wang = t_wang - min(t_wang); 127 | t_base = t_base - min(t_base); 128 | t_ours = t_ours - min(t_ours); 129 | 130 | % x coordinates 131 | tt = 1:length(t_tracking); 132 | tt = tt - mean(tt); 133 | tt = tt * pixel_size/20 / 0.9; 134 | 135 | % show plots 136 | figure; plot(tt, [t_tracking t_wang t_base t_ours],'LineWidth',2); 137 | axis tight; 138 | legend('Berto et al. 2017','Wang et al. 2017','Berujon et al. 2015','Ours'); 139 | 140 | 141 | %% Comparison with Zygo data 142 | 143 | % read data 144 | [dat, xl, yl] = LoadMetroProData([dpath 'Zygo.dat']); 145 | dat(isnan(dat)) = mean(dat(~isnan(dat))); 146 | 147 | % tilt removal 148 | dat = tilt_removal(dat*1e6); 149 | 150 | % show raw data 151 | figure; imshow(dat,[],'i','f'); axis tight on 152 | colormap(map);colorbar; 153 | title('Zygo raw'); 154 | 155 | % correspond region 156 | xx = [188 222]; 157 | yy = [157 155]; 158 | 159 | % get the indices 160 | a = (yy(2)-yy(1))/(xx(2)-xx(1)); 161 | b = yy(1) - a*xx(1); 162 | x = xx(1):xx(2); 163 | y = round(a*x+b); 164 | ind = sub2ind(size(dat),y,x); 165 | t_gt = dat(ind); 166 | t_gt = t_gt - min(t_gt); 167 | 168 | % x coordinates 169 | tt = 1:length(t_tracking); 170 | tt = tt - mean(tt); 171 | tt = tt * pixel_size/20 / 0.9; % compensation for small misalignment 172 | 173 | % x coordinates 174 | tt = imresize(tt, [1 numel(t_gt)]); 175 | 176 | % plot 177 | figure; plot(tt, t_gt, 'o', 'LineWidth',2); 178 | title('Zygo plot') 179 | 180 | % calculate RMS 181 | t_gt = imresize(t_gt', [numel(t_tracking) 1]); 182 | calc_rms = @(x) sqrt(mean(abs(x - t_gt).^2)); 183 | 184 | % show RMS for each method 185 | disp('RMS is:'); 186 | disp(['Berto et al. 2017: ' num2str(calc_rms(t_tracking)) ' um']); 187 | disp(['Wang et al. 2017: ' num2str(calc_rms(t_wang)) ' um']); 188 | disp(['Berujon et al. 2015: ' num2str(calc_rms(t_base)) ' um']); 189 | disp(['Ours: ' num2str(calc_rms(t_ours)) ' um']); 190 | 191 | -------------------------------------------------------------------------------- /scripts/Figure3.m: -------------------------------------------------------------------------------- 1 | clc;clear;close all; 2 | addpath('../matlab/'); 3 | addpath('../matlab/utils/'); 4 | 5 | % input data path 6 | dpath = '../data/blood/'; 7 | r = imread([dpath 'ref_1.tif']); 8 | s = imread([dpath 'cap_1.tif']); 9 | n = 2; % we are doing OPD here 10 | 11 | % parameters 12 | pixel_size = 6.45; % [um] 13 | z = 1.43e3; % [um] 14 | scale_factor = pixel_size^2/z; 15 | map = jet(256); 16 | 17 | % read data 18 | r = double(r)/2^14 * 255; 19 | s = double(s)/2^14 * 255; 20 | 21 | 22 | %% Methods 23 | 24 | %%% Slope-tracking 25 | [w, ~] = imregdemons(s, r, 200); 26 | phi_tracking = poisson_solver(w(:,:,1), w(:,:,2)); 27 | phi_tracking = phi_tracking - mean2(phi_tracking); 28 | phi_tracking = tilt_removal(phi_tracking/(n-1)*scale_factor); 29 | 30 | 31 | %%% Wang et al. 32 | opt.isverbose = 0; 33 | opt.L = {[20 20]}; 34 | opt.mu = 100; 35 | opt.iter = 30; 36 | warping_iter = 2; 37 | beta = 1; 38 | phi_wang = main_wavefront_solver(cat(3, r, s), beta, opt, warping_iter); 39 | phi_wang = tilt_removal(phi_wang/(n-1)*scale_factor); 40 | 41 | 42 | %%% Baseline 43 | [A_base, phi_base, D_base] = speckle_pattern_baseline(r, s); 44 | phi_base = tilt_removal(phi_base/(n-1)*scale_factor); 45 | 46 | 47 | %%% Ours 48 | opt_cws.priors = [0.5 0.5 100 5]; 49 | [A_ours, phi, wavefront_lap, I_warp] = cws(r, s, opt_cws); 50 | A_ours = A_ours .* (1 + pixel_size/z*wavefront_lap); 51 | A_ours = sqrt(A_ours); 52 | I = A_ours; % amplitude 53 | phi = tilt_removal(phi/(n-1)*scale_factor); 54 | 55 | % denoise a little bit ... 56 | phi = medfilt2(phi, [3 3], 'symmetric'); 57 | 58 | 59 | %% Show results 60 | 61 | % normalize to start from 0 62 | phi_tracking = phi_tracking - min(phi_tracking(:)); 63 | phi_wang = phi_wang - min(phi_wang(:)); 64 | phi_base = phi_base - min(phi_base(:)); 65 | phi = phi - min(phi(:)); 66 | 67 | % show results 68 | figure;imshow(phi_tracking, []); 69 | axis tight ij;colormap(map);pause(0.2); 70 | title('Berto et al. 2017'); 71 | 72 | figure;imshow(phi_wang, []); 73 | axis tight ij;colormap(map);pause(0.2) 74 | title('Wang et al. 2017'); 75 | 76 | figure;imshow(phi_base, []); 77 | axis tight ij;colormap(map);pause(0.2) 78 | title('Berujon et al. 2015'); 79 | 80 | figure;imshow(phi, []); 81 | axis tight ij;colormap(map);pause(0.2) 82 | title('Ours'); 83 | 84 | 85 | %% Cross-sections 86 | 87 | % function handles 88 | get_c = @(phi) phi(191,390:550)'; 89 | min_n = @(x) x - min(x(:)); 90 | 91 | % get the maxs 92 | max(phi_tracking(:)) 93 | max(phi_wang(:)) 94 | max(phi_base(:)) 95 | max(phi(:)) 96 | 97 | % get cross-sections 98 | t_tracking = get_c(phi_tracking); 99 | t_wang = get_c(phi_wang); 100 | t_base = get_c(phi_base); 101 | t_ours = get_c(phi); 102 | 103 | % show plots 104 | figure; 105 | plot([min_n(t_tracking) min_n(t_wang) min_n(t_base) min_n(t_ours)],'LineWidth',2); 106 | axis tight; 107 | legend('Berto et al. 2017','Wang et al. 2017','Berujon et al. 2015','Ours'); 108 | 109 | -------------------------------------------------------------------------------- /scripts/Figure4.m: -------------------------------------------------------------------------------- 1 | clear;close all; 2 | addpath('../matlab/'); 3 | addpath('../matlab/utils/'); 4 | 5 | % data paths 6 | samples = {'blood', 'cheek', 'HeLa', 'MCF-7'}; 7 | dpaths = cellfun(@(x) ['../data/' x '/'], samples, 'un',0); 8 | 9 | % parameters 10 | pixel_size = 6.45; % [um] 11 | z = 1.43e3; % [um] 12 | scale_factor = pixel_size^2/z; 13 | 14 | % loop for each sub-figure 15 | A_final = cell(4,1); 16 | phi_final = cell(4,1); 17 | for i = 1:length(phi_final) 18 | disp(['Running sub-figure ' num2str(i) '/' num2str(length(phi_final))]) 19 | dpath = dpaths{i}; 20 | 21 | % read data 22 | if i == 1 23 | r = imread([dpath 'ref_2.tif']); 24 | s = imread([dpath 'cap_2.tif']); 25 | else 26 | r = imread([dpath 'ref.tif']); 27 | s = imread([dpath 'cap.tif']); 28 | end 29 | r = double(r)/2^14 * 255; 30 | s = double(s)/2^14 * 255; 31 | 32 | % our solver 33 | [A_ours, phi, wavefront_lap, I_warp] = cws(r, s); 34 | A_ours = sqrt(A_ours .* (1 + pixel_size/z*wavefront_lap)); 35 | A_final{i} = A_ours; % amplitude 36 | phi = tilt_removal(phi*scale_factor); % OPD 37 | phi_final{i} = phi - min(phi(:)); 38 | end 39 | 40 | --------------------------------------------------------------------------------