├── LICENSE
├── README.md
├── cuda
├── CMakeLists.txt
├── README.md
├── cws.cpp
├── deps
│ └── argtable3
│ │ ├── argtable3.c
│ │ └── argtable3.h
├── setup.py
└── src
│ ├── IO_helper_functions.h
│ ├── addKernel.cuh
│ ├── common.h
│ ├── computeDCTweightsKernel.cuh
│ ├── cws_A_phi.cu
│ ├── cws_A_phi.h
│ ├── medianfilteringKernel.cuh
│ ├── prepare_cufft_warmup.cu
│ ├── prepare_cufft_warmup.h
│ ├── prepare_precomputations.cu
│ ├── prepare_precomputations.h
│ ├── prox_gKernel.cuh
│ └── x_updateKernel.cuh
├── data
├── HeLa
│ ├── cap.tif
│ └── ref.tif
├── MCF-7
│ ├── cap.tif
│ └── ref.tif
├── MLA-150-7AR-M
│ ├── Zygo.dat
│ ├── cap.tif
│ └── ref.tif
├── blood
│ ├── cap_1.tif
│ ├── cap_2.tif
│ ├── ref_1.tif
│ └── ref_2.tif
└── cheek
│ ├── cap.tif
│ └── ref.tif
├── matlab
├── cpu_gpu_comparison.m
├── cws.m
├── cws_gpu_wrapper.m
├── main_wavefront_solver.m
├── speckle_pattern_baseline.m
└── utils
│ ├── LoadMetroProData.m
│ ├── poisson_solver.m
│ └── tilt_removal.m
└── scripts
├── Figure2.m
├── Figure3.m
└── Figure4.m
/LICENSE:
--------------------------------------------------------------------------------
1 | ## creative commons
2 |
3 | # Attribution-NonCommercial-ShareAlike 4.0 International
4 |
5 | Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.
6 |
7 | ### Using Creative Commons Public Licenses
8 |
9 | Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
10 |
11 | * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).
12 |
13 | * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).
14 |
15 | ## Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License
16 |
17 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
18 |
19 | ### Section 1 – Definitions.
20 |
21 | a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
22 |
23 | b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
24 |
25 | c. __BY-NC-SA Compatible License__ means a license listed at [creativecommons.org/compatiblelicenses](http://creativecommons.org/compatiblelicenses), approved by Creative Commons as essentially the equivalent of this Public License.
26 |
27 | d. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
28 |
29 | e. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
30 |
31 | f. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
32 |
33 | g. __License Elements__ means the license attributes listed in the name of a Creative Commons Public License. The License Elements of this Public License are Attribution, NonCommercial, and ShareAlike.
34 |
35 | h. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
36 |
37 | i. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
38 |
39 | h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.
40 |
41 | i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
42 |
43 | j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
44 |
45 | k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
46 |
47 | l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
48 |
49 | ### Section 2 – Scope.
50 |
51 | a. ___License grant.___
52 |
53 | 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
54 |
55 | A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
56 |
57 | B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
58 |
59 | 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
60 |
61 | 3. __Term.__ The term of this Public License is specified in Section 6(a).
62 |
63 | 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
64 |
65 | 5. __Downstream recipients.__
66 |
67 | A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
68 |
69 | B. __Additional offer from the Licensor – Adapted Material.__ Every recipient of Adapted Material from You automatically receives an offer from the Licensor to exercise the Licensed Rights in the Adapted Material under the conditions of the Adapter’s License You apply.
70 |
71 | C. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
72 |
73 | 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
74 |
75 | b. ___Other rights.___
76 |
77 | 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
78 |
79 | 2. Patent and trademark rights are not licensed under this Public License.
80 |
81 | 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
82 |
83 | ### Section 3 – License Conditions.
84 |
85 | Your exercise of the Licensed Rights is expressly made subject to the following conditions.
86 |
87 | a. ___Attribution.___
88 |
89 | 1. If You Share the Licensed Material (including in modified form), You must:
90 |
91 | A. retain the following if it is supplied by the Licensor with the Licensed Material:
92 |
93 | i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
94 |
95 | ii. a copyright notice;
96 |
97 | iii. a notice that refers to this Public License;
98 |
99 | iv. a notice that refers to the disclaimer of warranties;
100 |
101 | v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
102 |
103 | B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
104 |
105 | C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
106 |
107 | 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
108 |
109 | 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
110 |
111 | b. ___ShareAlike.___
112 |
113 | In addition to the conditions in Section 3(a), if You Share Adapted Material You produce, the following conditions also apply.
114 |
115 | 1. The Adapter’s License You apply must be a Creative Commons license with the same License Elements, this version or later, or a BY-NC-SA Compatible License.
116 |
117 | 2. You must include the text of, or the URI or hyperlink to, the Adapter's License You apply. You may satisfy this condition in any reasonable manner based on the medium, means, and context in which You Share Adapted Material.
118 |
119 | 3. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, Adapted Material that restrict exercise of the rights granted under the Adapter's License You apply.
120 |
121 | ### Section 4 – Sui Generis Database Rights.
122 |
123 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
124 |
125 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
126 |
127 | b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material, including for purposes of Section 3(b); and
128 |
129 | c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
130 |
131 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
132 |
133 | ### Section 5 – Disclaimer of Warranties and Limitation of Liability.
134 |
135 | a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__
136 |
137 | b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__
138 |
139 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
140 |
141 | ### Section 6 – Term and Termination.
142 |
143 | a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
144 |
145 | b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
146 |
147 | 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
148 |
149 | 2. upon express reinstatement by the Licensor.
150 |
151 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
152 |
153 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
154 |
155 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
156 |
157 | ### Section 7 – Other Terms and Conditions.
158 |
159 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
160 |
161 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
162 |
163 | ### Section 8 – Interpretation.
164 |
165 | a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
166 |
167 | b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
168 |
169 | c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
170 |
171 | d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
172 |
173 | > Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
174 | >
175 | > Creative Commons may be contacted at creativecommons.org
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Quantitative Phase and Intensity Microscope
2 | This is the open source repository for our paper in Scientific Reports:
3 |
4 | [**Quantitative Phase and Intensity Microscopy Using Snapshot White Light Wavefront Sensing**]()
5 |
6 | [Congli Wang](https://congliwang.github.io), [Qiang Fu](http://vccimaging.org/People/fuq/), [Xiong Dun](http://vccimaging.org/People/dunx/), and [Wolfgang Heidrich](http://vccimaging.org/People/heidriw/)
7 |
8 | King Abdullah University of Science and Technology (KAUST)
9 |
10 | ### Overview
11 |
12 | This repository contains:
13 |
14 | - An improved version of the wavefront solver in [1], implemented in MATLAB and CUDA:
15 | - For MATLAB code, simply plug & play.
16 | - For the CUDA solver, you need an NVIDIA graphics card with CUDA to compile and run. Also refer to [`./cuda/README.md`](./cuda/README.md) for how to compile the code.
17 | - Other solvers [2, 3] (our implementation).
18 | - Scripts and data for generating Figure 2, Figure 3, and Figure 4 in the paper.
19 |
20 | Our solver is not multi-scale because microscopy wavefronts are small; however based on our solver it is simple to implement such a pyramid scheme.
21 |
22 | ### Related
23 |
24 | - Sensor principle: [The Coded Wavefront Sensor](https://vccimaging.org/Publications/Wang2017CWS/) (Optics Express 2017).
25 | - An adaptive optics application using this sensor: [Megapixel Adaptive Optics]() (SIGGRAPH 2018).
26 | - Sensor simulation or old solvers, refer to repository .
27 |
28 |
29 | ### Citation
30 |
31 | ```bibtex
32 | @article{wang2019quantitative,
33 | title = {Quantitative Phase and Intensity Microscopy Using Snapshot White Light Wavefront Sensing},
34 | author = {Wang, Congli and Fu, Qiang and Dun, Xiong and Heidrich, Wolfgang},
35 | journal = {Scientific Reports},
36 | volume = {9},
37 | pages = {13795},
38 | year = {2019},
39 | publisher = {Nature Publishing Group}
40 | }
41 | ```
42 |
43 | ### Contact
44 |
45 | We welcome any questions or comments. Please either open up an issue, or email to congli.wang@kaust.edu.sa.
46 |
47 | ### References
48 |
49 | [1] Congli Wang, Qiang Fu, Xiong Dun, and Wolfgang Heidrich. "Ultra-high resolution coded wavefront sensor." *Optics Express* 25.12 (2017): 13736-13746.
50 |
51 | [2] Pascal Berto, Hervé Rigneault, and Marc Guillon. "Wavefront sensing with a thin diffuser." *Optics Letters* 42.24 (2017): 5117-5120.
52 |
53 | [3] Sebastien Berujon and Eric Ziegler. "Near-field speckle-scanning-based X-ray imaging." *Physical Review A* 92.1 (2015): 013837.
54 |
--------------------------------------------------------------------------------
/cuda/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.8)
2 | project(CWS LANGUAGES C CXX CUDA)
3 |
4 | # check requirements
5 | if (UNIX)
6 | find_package(OpenCV 3.0 REQUIRED)
7 | endif (UNIX)
8 | if (WIN32)
9 | find_package(OpenCV 3.0 REQUIRED PATHS C:/Program\ Files/OpenCV)
10 | endif (WIN32)
11 |
12 | # set include directories
13 | if (UNIX)
14 | set(CUDA_TARGET_INC ${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include)
15 | set(CUDA_SAMPLE_INC ${CUDA_TOOLKIT_ROOT_DIR}/samples/common/inc)
16 | endif (UNIX)
17 | if (WIN32)
18 | set(CUDA_SAMPLE_INC C:/ProgramData/NVIDIA\ Corporation/CUDA\ Samples/v10.0/common/inc)
19 | endif (WIN32)
20 | include_directories(deps src ${CUDA_TARGET_INC} ${CUDA_SAMPLE_INC} ${OpenCV_INCLUDE_DIRS})
21 |
22 | # set link directories
23 | if (UNIX)
24 | link_directories(${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/lib)
25 | endif (UNIX)
26 |
27 | # build our static library
28 | set(STATIC_LIB_CWS_A_PHI cws_A_phi)
29 | add_library(${STATIC_LIB_CWS_A_PHI} STATIC
30 | src/cws_A_phi.cu
31 | src/prepare_cufft_warmup.cu
32 | )
33 |
34 | # build cws
35 | set(APP cws)
36 | add_executable(${APP} cws.cpp ${CMAKE_SOURCE_DIR}/deps/argtable3/argtable3.c)
37 | target_link_libraries(${APP} PRIVATE ${STATIC_LIB_CWS_A_PHI} cufft ${OpenCV_LIBS})
38 | install(TARGETS ${APP} RUNTIME DESTINATION ${CMAKE_SOURCE_DIR}/bin)
39 |
40 | # install
41 | install(TARGETS ${STATIC_LIB_CWS_A_PHI}
42 | ARCHIVE DESTINATION ${CMAKE_SOURCE_DIR}/lib
43 | LIBRARY DESTINATION ${CMAKE_SOURCE_DIR}/lib)
44 | install(DIRECTORY ${CMAKE_SOURCE_DIR}/src
45 | DESTINATION ${CMAKE_SOURCE_DIR}/include
46 | FILES_MATCHING PATTERN "*.h")
47 | # NEED: Copy opencv_imgcodecs343.dll, opencv_imgproc343.dll and opencv_core343.dll to ${CMAKE_SOURCE_DIR}/bin, or make the folder contains these files be in PATH (as suggested by CMake warning).
48 |
--------------------------------------------------------------------------------
/cuda/README.md:
--------------------------------------------------------------------------------
1 | ### Installation
2 |
3 | The command line solver `cws` is written with C++ and CUDA, with code hierarchy managed by CMake.
4 |
5 | After successful compilation, you may call the MATLAB wrapper (`../matlab/cws_gpu_wrapper.m`) for this command line solver. See `../matlab/cpu_gpu_comparison.m` for a simple demo. Due to floating point errors, the CPU (`cws.m`) and GPU solvers (`cws`) do not produce exactly the same result, but the difference is small.
6 |
7 | #### Prerequisites
8 |
9 | - CUDA Toolkit 8.0 or above (We used CUDA 10.0);
10 | - CMake 2.8 or above (We used CMake 3.13);
11 | - OpenCV 3.0 or above (We used OpenCV 3.4.3).
12 |
13 | Make sure you complete above dependencies before proceeding.
14 |
15 | #### Linux
16 |
17 | Additional prerequisites:
18 |
19 | - GNU make;
20 | - GNU Compiler Collection (GCC).
21 |
22 | Simply run `setup.py` to get it done. You can also build it manually, for example:
23 |
24 | ```shell
25 | cd
26 | mkdir build
27 | cd build
28 | cmake ..
29 | make -j8
30 | sudo make install
31 | ```
32 |
33 | #### Windows
34 |
35 | Here we only provide one compilation approach. You can do any modifications for your convenience at will.
36 |
37 | It has been tested using Visual Studio 2017 and CUDA 10.0 on Windows 10.
38 |
39 | Steps:
40 |
41 | - Add OpenCV to the system `PATH` variable. The specific path depends on your installation option.
42 |
43 | - Run `cmake-gui` and configure your Visual Studio Solution. If CMake failed to auto-detect the paths:
44 |
45 | - Manually find the path (depends on specific installation option), and modify `./CMakeLists.txt`;
46 | - Then, try to configure and generate the project again.
47 |
48 | If success, you will see `CWS.sln` be generated in your build folder.
49 |
50 | - In Visual Studio 2017 (opened as Administrator), build the solution. To install, in Visual Studio:
51 |
52 | - Build -> Configuration Manager -> click on Build option for INSTALL;
53 | - Then build the solution.
54 |
55 | If success, you will see new folders `bin`, `include` and `lib` appear.
56 |
57 | ### Usage
58 |
59 | Run `./cws --help` in command line prompt to see all parameters:
60 |
61 | ```
62 | cws version 1.0 (Nov 2018)
63 | Simultaneous intensity and wavefront recovery GPU solver for the coded wavefront sensor. Solve for:
64 |
65 | min || i(x+\nabla phi) - A i_0(x) ||_2^2 +
66 | A,phi alpha || \nabla phi ||_1 +
67 | beta ( || \nabla phi ||_2^2 + || \nabla^2 phi ||_2^2 ) +
68 | gamma ( || \nabla A ||_1 + || \nabla^2 A ||_1 ) +
69 | tau ( || \nabla A ||_2^2 + || \nabla^2 A ||_2^2 ).
70 |
71 | Inputs : i_0 (reference), i (measure).
72 | Outputs: A (intensity), phi (phase).
73 |
74 | by Congli Wang, VCC Imaging @ KAUST.
75 |
76 | Usage: cws [-v] [--help] [--version] [-p ]... [-i ]... [-m ] [-m ] [-t ] [-s ] [-s ] [-l ] [-l ] [-o <.flo>] [-f ] [-f ]
77 | --help display this help and exit
78 | --version display version info and exit
79 | -p, --priors= prior weights {alpha,beta,gamma,beta} (default {0.1,0.1,100,5})
80 | -i, --iter= iteartions {total alternating iter, A-update iter, phi-update iter} (default {3,20,20})
81 | -m, --mu= ADMM parameters {mu_A,mu_phi} (default {0.1,100})
82 | -t, --tol_phi= phi-update tolerance stopping criteria (default 0.05)
83 | -v, --verbose verbose output (default 0)
84 | -s, --size= output size {width,height} (default input size)
85 | -l, --L= padding size {pad_width,pad_height} (default nearest power of 2 of out_size, each in range [2, 32])
86 | -o, --output=<.flo> save output file (intensity A & wavefront phi) as *.flo file (default "./out.flo")
87 | -f, --files= input file names (reference & measurement)
88 | ```
89 |
--------------------------------------------------------------------------------
/cuda/cws.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "argtable3/argtable3.h"
3 | #include "common.h"
4 | #include "cws_A_phi.h"
5 | #include "IO_helper_functions.h"
6 |
7 | // default: disable demo
8 | bool isdemo = false;
9 |
10 | bool read_rep_images(const char *image_name, cv::Mat &img, int *size = NULL, double max = 255.0, double min = 0.0)
11 | {
12 | cv::Mat temp = cv::imread(image_name, CV_LOAD_IMAGE_GRAYSCALE);
13 | if (temp.empty())
14 | {
15 | printf(image_name, "%s is not found; Program exits.\n");
16 | return true;
17 | }
18 |
19 | // set region of interest
20 | if (size != NULL)
21 | {
22 | int L_w = (temp.cols - size[0]) / 2;
23 | int L_h = (temp.rows - size[1]) / 2;
24 | temp = temp(cv::Rect(L_w, L_h, size[0], size[1]));
25 | }
26 | temp.convertTo(img, CV_32F);
27 |
28 | // convert to approximately [0 255]
29 | img.convertTo(img, CV_32F, 255.0f / (max-min));
30 |
31 | return false;
32 | }
33 |
34 |
35 | bool is_size_invalid(int s, int s_ref)
36 | {
37 | if (0 < s && s <= s_ref && (s % 2 == 0)) return false;
38 | else return true;
39 | }
40 |
41 |
42 | int cws(double *priors, int *iter, double *mu, double phi_tol, int v,
43 | int *out_size, int nsize, int *L, int Lsize, const char *outfile, const char **infiles, int ninfiles)
44 | {
45 | // remove CUDA timing latency
46 | cudaFree(0);
47 | cudaSetDevice(0);
48 |
49 | // images
50 | cv::Mat img_ref, img_cap;
51 | char image_name[80];
52 |
53 | // read images
54 | if (read_rep_images(infiles[0], img_ref, out_size)) return -1;
55 | if (read_rep_images(infiles[1], img_cap, out_size)) return -1;
56 |
57 | // output size check
58 | if (is_size_invalid(out_size[0], img_ref.cols) || is_size_invalid(out_size[1], img_ref.rows))
59 | {
60 | out_size[0] = img_ref.cols;
61 | out_size[1] = img_ref.rows;
62 | printf("out_size wrong or unspecified; Set as input image size = [%d, %d].\n\n", out_size[0], out_size[1]);
63 | }
64 |
65 | // define contatiner variables
66 | cv::Mat A, phi;
67 | opt_algo para_algo;
68 |
69 | // tradeoff parameters
70 | para_algo.alpha = priors[0];
71 | para_algo.beta = priors[1];
72 | para_algo.gamma = priors[2];
73 | para_algo.tau = priors[3];
74 |
75 | // if verbose for sub-update energy report
76 | para_algo.isverbose = v > 0 ? true : false;
77 |
78 | // alternating iterations
79 | para_algo.iter = iter[0];
80 | para_algo.A_iter = iter[1];
81 | para_algo.phi_iter = iter[2];
82 |
83 | // ADMM parameters
84 | para_algo.mu_A = mu[0];
85 | para_algo.mu_phi = mu[1];
86 |
87 | // tolerance for incremental phase
88 | para_algo.phi_tol = phi_tol;
89 |
90 | // set L
91 | para_algo.L.width = L[0];
92 | para_algo.L.height = L[1];
93 |
94 | // run the solver
95 | cws_A_phi(img_ref, img_cap, A, phi, para_algo);
96 |
97 | // save result
98 | WriteFloFile(outfile, img_ref.cols, img_ref.rows, A.ptr(0), phi.ptr(0));
99 |
100 | // return
101 | cudaDeviceReset();
102 | return 0;
103 | }
104 |
105 |
106 | void print_solver_info(char *progname)
107 | {
108 | printf("%s version 1.0 (Nov 2018) \n", progname);
109 | printf("Simultaneous intensity and wavefront recovery GPU solver for the coded wavefront sensor. Solve for:\n\n");
110 | printf(" min || i(x+\\nabla phi) - A i_0(x) ||_2^2 +\n");
111 | printf("A,phi alpha || \\nabla phi ||_1 +\n");
112 | printf(" beta ( || \\nabla phi ||_2^2 + || \\nabla^2 phi ||_2^2 ) +\n");
113 | printf(" gamma ( || \\nabla A ||_1 + || \\nabla^2 A ||_1 ) +\n");
114 | printf(" tau ( || \\nabla A ||_2^2 + || \\nabla^2 A ||_2^2 ).\n\n");
115 | printf("Inputs : i_0 (reference), i (measure).\n");
116 | printf("Outputs: A (intensity), phi (phase).\n");
117 | printf("\n");
118 | printf("by Congli Wang, VCC Imaging @ KAUST.\n");
119 | }
120 |
121 |
122 | int main(int argc, char **argv)
123 | {
124 | // help & version info
125 | struct arg_lit *help = arg_litn(NULL, "help", 0, 1, "display this help and exit");
126 | struct arg_lit *version = arg_litn(NULL, "version", 0, 1, "display version info and exit");
127 |
128 | // model: tradeoff parameters
129 | struct arg_dbl *priors = arg_dbln("p", "priors", "", 0, 4, "prior weights {alpha,beta,gamma,beta} (default {0.1,0.1,100,5})");
130 |
131 | // algorithm: ADMM parameters
132 | struct arg_int *iter = arg_intn("i", "iter", "", 0, 3, "iteartions {total alternating iter, A-update iter, phi-update iter} (default {3,20,20})");
133 | struct arg_dbl *mu = arg_dbln("m", "mu", "", 0, 2, "ADMM parameters {mu_A,mu_phi} (default {0.1,100})");
134 | struct arg_dbl *phi_tol = arg_dbl0("t", "tol_phi","", "phi-update tolerance stopping criteria (default 0.05)");
135 | struct arg_lit *verbose = arg_litn("v", "verbose", 0, 1, "verbose output (default 0)");
136 | struct arg_int *out_size = arg_intn("s", "size", "", 0, 2, "output size {width,height} (default input size)");
137 | struct arg_int *L = arg_intn("l", "L", "", 0, 2, "padding size {pad_width,pad_height} (default nearest power of 2 of out_size, each in range [2, 32])");
138 |
139 | // input & output files
140 | struct arg_file *out = arg_filen("o", "output", "<.flo>", 0, 1, "save output file (intensity A & wavefront phi) as *.flo file (default \"./out.flo\")");
141 | struct arg_file *file = arg_filen("f", "files", "", 0, 2, "input file names (reference & measurement)");
142 | struct arg_end *end = arg_end(20);
143 |
144 | /* the global arg_xxx structs are initialised within the argtable */
145 | void *argtable[] = {help, version, priors, iter, mu, phi_tol, verbose, out_size, L, out, file, end};
146 |
147 | int exitcode = 0;
148 | char progname[] = "cws";
149 |
150 | /* set any command line default values prior to parsing */
151 | priors ->dval[0] = 0.1; // alpha
152 | priors ->dval[1] = 0.1; // beta
153 | priors ->dval[2] = 100.0; // gamma
154 | priors ->dval[3] = 5.0; // tau
155 | iter ->ival[0] = 3; // total alternating iterations
156 | iter ->ival[1] = 20; // ADMM iterations for A-update
157 | iter ->ival[2] = 20; // ADMM iterations for phi-update
158 | mu ->dval[0] = 0.1; // ADMM parameter mu for A-update
159 | mu ->dval[1] = 100.0; // ADMM parameter mu for phi-update
160 | phi_tol->dval[0] = 0.05; // tolerance stopping criteria for phi-update
161 | verbose->count = 0;
162 | L ->ival[0] = 1; // L pad size for width
163 | L ->ival[1] = 1; // L pad size for height
164 | out->filename[0] = "./out.flo";
165 |
166 | // parse arguments
167 | int nerrors;
168 | nerrors = arg_parse(argc,argv,argtable);
169 |
170 | /* special case: '--help' takes precedence over error reporting */
171 | if (help->count > 0)
172 | {
173 | print_solver_info(progname);
174 | printf("\n");
175 | printf("Usage: %s", progname);
176 | arg_print_syntax(stdout, argtable, "\n");
177 | arg_print_glossary(stdout, argtable, " %-25s %s \n");
178 | exitcode = 0;
179 | goto exit;
180 | }
181 |
182 | /* special case: '--version' takes precedence error reporting */
183 | if (version->count > 0)
184 | {
185 | print_solver_info(progname);
186 | exitcode = 0;
187 | goto exit;
188 | }
189 |
190 | /* If the parser returned any errors then display them and exit */
191 | if (nerrors > 0)
192 | {
193 | /* Display the error details contained in the arg_end struct.*/
194 | arg_print_errors(stdout, end, progname);
195 | #if defined(_WIN32) || defined(_WIN64)
196 | printf("Try '%s --help' for more information.\n", progname);
197 | #elif defined(__unix)
198 | printf("Try './%s --help' for more information.\n", progname);
199 | #endif
200 | exitcode = 1;
201 | goto exit;
202 | }
203 |
204 | /* if no input files specified, use the test data */
205 | if (file->count < 2)
206 | {
207 | printf("Number of input files < 2; use test data instead; Set out_size as [992 992]\n\n");
208 | out_size->ival[0] = 992;
209 | out_size->ival[1] = 992;
210 | file->filename[0] = "../tests/solver_accuracy/data/img_reference.png";
211 | file->filename[1] = "../tests/solver_accuracy/data/img_capture.png";
212 | isdemo = true;
213 | exitcode = 1;
214 | }
215 |
216 | /* Command line parsing is complete, do the main processing */
217 | exitcode = cws(priors->dval, iter->ival, mu->dval, phi_tol->dval[0], verbose->count,
218 | out_size->ival, out_size->count, L->ival, L->count, out->filename[0], file->filename, file->count);
219 |
220 | exit:
221 | /* deallocate each non-null entry in argtable[] */
222 | arg_freetable(argtable, sizeof(argtable) / sizeof(argtable[0]));
223 |
224 | return exitcode;
225 | }
--------------------------------------------------------------------------------
/cuda/deps/argtable3/argtable3.h:
--------------------------------------------------------------------------------
1 | /*******************************************************************************
2 | * This file is part of the argtable3 library.
3 | *
4 | * Copyright (C) 1998-2001,2003-2011,2013 Stewart Heitmann
5 | *
6 | * All rights reserved.
7 | *
8 | * Redistribution and use in source and binary forms, with or without
9 | * modification, are permitted provided that the following conditions are met:
10 | * * Redistributions of source code must retain the above copyright
11 | * notice, this list of conditions and the following disclaimer.
12 | * * Redistributions in binary form must reproduce the above copyright
13 | * notice, this list of conditions and the following disclaimer in the
14 | * documentation and/or other materials provided with the distribution.
15 | * * Neither the name of STEWART HEITMANN nor the names of its contributors
16 | * may be used to endorse or promote products derived from this software
17 | * without specific prior written permission.
18 | *
19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22 | * ARE DISCLAIMED. IN NO EVENT SHALL STEWART HEITMANN BE LIABLE FOR ANY DIRECT,
23 | * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
26 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 | ******************************************************************************/
30 |
31 | #ifndef ARGTABLE3
32 | #define ARGTABLE3
33 |
34 | #include /* FILE */
35 | #include /* struct tm */
36 |
37 | #ifdef __cplusplus
38 | extern "C" {
39 | #endif
40 |
41 | #define ARG_REX_ICASE 1
42 |
43 | /* bit masks for arg_hdr.flag */
44 | enum
45 | {
46 | ARG_TERMINATOR=0x1,
47 | ARG_HASVALUE=0x2,
48 | ARG_HASOPTVALUE=0x4
49 | };
50 |
51 | typedef void (arg_resetfn)(void *parent);
52 | typedef int (arg_scanfn)(void *parent, const char *argval);
53 | typedef int (arg_checkfn)(void *parent);
54 | typedef void (arg_errorfn)(void *parent, FILE *fp, int error, const char *argval, const char *progname);
55 |
56 |
57 | /*
58 | * The arg_hdr struct defines properties that are common to all arg_xxx structs.
59 | * The argtable library requires each arg_xxx struct to have an arg_hdr
60 | * struct as its first data member.
61 | * The argtable library functions then use this data to identify the
62 | * properties of the command line option, such as its option tags,
63 | * datatype string, and glossary strings, and so on.
64 | * Moreover, the arg_hdr struct contains pointers to custom functions that
65 | * are provided by each arg_xxx struct which perform the tasks of parsing
66 | * that particular arg_xxx arguments, performing post-parse checks, and
67 | * reporting errors.
68 | * These functions are private to the individual arg_xxx source code
69 | * and are the pointer to them are initiliased by that arg_xxx struct's
70 | * constructor function. The user could alter them after construction
71 | * if desired, but the original intention is for them to be set by the
72 | * constructor and left unaltered.
73 | */
74 | struct arg_hdr
75 | {
76 | char flag; /* Modifier flags: ARG_TERMINATOR, ARG_HASVALUE. */
77 | const char *shortopts; /* String defining the short options */
78 | const char *longopts; /* String defiing the long options */
79 | const char *datatype; /* Description of the argument data type */
80 | const char *glossary; /* Description of the option as shown by arg_print_glossary function */
81 | int mincount; /* Minimum number of occurences of this option accepted */
82 | int maxcount; /* Maximum number of occurences if this option accepted */
83 | void *parent; /* Pointer to parent arg_xxx struct */
84 | arg_resetfn *resetfn; /* Pointer to parent arg_xxx reset function */
85 | arg_scanfn *scanfn; /* Pointer to parent arg_xxx scan function */
86 | arg_checkfn *checkfn; /* Pointer to parent arg_xxx check function */
87 | arg_errorfn *errorfn; /* Pointer to parent arg_xxx error function */
88 | void *priv; /* Pointer to private header data for use by arg_xxx functions */
89 | };
90 |
91 | struct arg_rem
92 | {
93 | struct arg_hdr hdr; /* The mandatory argtable header struct */
94 | };
95 |
96 | struct arg_lit
97 | {
98 | struct arg_hdr hdr; /* The mandatory argtable header struct */
99 | int count; /* Number of matching command line args */
100 | };
101 |
102 | struct arg_int
103 | {
104 | struct arg_hdr hdr; /* The mandatory argtable header struct */
105 | int count; /* Number of matching command line args */
106 | int *ival; /* Array of parsed argument values */
107 | };
108 |
109 | struct arg_dbl
110 | {
111 | struct arg_hdr hdr; /* The mandatory argtable header struct */
112 | int count; /* Number of matching command line args */
113 | double *dval; /* Array of parsed argument values */
114 | };
115 |
116 | struct arg_str
117 | {
118 | struct arg_hdr hdr; /* The mandatory argtable header struct */
119 | int count; /* Number of matching command line args */
120 | const char **sval; /* Array of parsed argument values */
121 | };
122 |
123 | struct arg_rex
124 | {
125 | struct arg_hdr hdr; /* The mandatory argtable header struct */
126 | int count; /* Number of matching command line args */
127 | const char **sval; /* Array of parsed argument values */
128 | };
129 |
130 | struct arg_file
131 | {
132 | struct arg_hdr hdr; /* The mandatory argtable header struct */
133 | int count; /* Number of matching command line args*/
134 | const char **filename; /* Array of parsed filenames (eg: /home/foo.bar) */
135 | const char **basename; /* Array of parsed basenames (eg: foo.bar) */
136 | const char **extension; /* Array of parsed extensions (eg: .bar) */
137 | };
138 |
139 | struct arg_date
140 | {
141 | struct arg_hdr hdr; /* The mandatory argtable header struct */
142 | const char *format; /* strptime format string used to parse the date */
143 | int count; /* Number of matching command line args */
144 | struct tm *tmval; /* Array of parsed time values */
145 | };
146 |
147 | enum {ARG_ELIMIT=1, ARG_EMALLOC, ARG_ENOMATCH, ARG_ELONGOPT, ARG_EMISSARG};
148 | struct arg_end
149 | {
150 | struct arg_hdr hdr; /* The mandatory argtable header struct */
151 | int count; /* Number of errors encountered */
152 | int *error; /* Array of error codes */
153 | void **parent; /* Array of pointers to offending arg_xxx struct */
154 | const char **argval; /* Array of pointers to offending argv[] string */
155 | };
156 |
157 |
158 | /**** arg_xxx constructor functions *********************************/
159 |
160 | struct arg_rem* arg_rem(const char* datatype, const char* glossary);
161 |
162 | struct arg_lit* arg_lit0(const char* shortopts,
163 | const char* longopts,
164 | const char* glossary);
165 | struct arg_lit* arg_lit1(const char* shortopts,
166 | const char* longopts,
167 | const char *glossary);
168 | struct arg_lit* arg_litn(const char* shortopts,
169 | const char* longopts,
170 | int mincount,
171 | int maxcount,
172 | const char *glossary);
173 |
174 | struct arg_key* arg_key0(const char* keyword,
175 | int flags,
176 | const char* glossary);
177 | struct arg_key* arg_key1(const char* keyword,
178 | int flags,
179 | const char* glossary);
180 | struct arg_key* arg_keyn(const char* keyword,
181 | int flags,
182 | int mincount,
183 | int maxcount,
184 | const char* glossary);
185 |
186 | struct arg_int* arg_int0(const char* shortopts,
187 | const char* longopts,
188 | const char* datatype,
189 | const char* glossary);
190 | struct arg_int* arg_int1(const char* shortopts,
191 | const char* longopts,
192 | const char* datatype,
193 | const char *glossary);
194 | struct arg_int* arg_intn(const char* shortopts,
195 | const char* longopts,
196 | const char *datatype,
197 | int mincount,
198 | int maxcount,
199 | const char *glossary);
200 |
201 | struct arg_dbl* arg_dbl0(const char* shortopts,
202 | const char* longopts,
203 | const char* datatype,
204 | const char* glossary);
205 | struct arg_dbl* arg_dbl1(const char* shortopts,
206 | const char* longopts,
207 | const char* datatype,
208 | const char *glossary);
209 | struct arg_dbl* arg_dbln(const char* shortopts,
210 | const char* longopts,
211 | const char *datatype,
212 | int mincount,
213 | int maxcount,
214 | const char *glossary);
215 |
216 | struct arg_str* arg_str0(const char* shortopts,
217 | const char* longopts,
218 | const char* datatype,
219 | const char* glossary);
220 | struct arg_str* arg_str1(const char* shortopts,
221 | const char* longopts,
222 | const char* datatype,
223 | const char *glossary);
224 | struct arg_str* arg_strn(const char* shortopts,
225 | const char* longopts,
226 | const char* datatype,
227 | int mincount,
228 | int maxcount,
229 | const char *glossary);
230 |
231 | struct arg_rex* arg_rex0(const char* shortopts,
232 | const char* longopts,
233 | const char* pattern,
234 | const char* datatype,
235 | int flags,
236 | const char* glossary);
237 | struct arg_rex* arg_rex1(const char* shortopts,
238 | const char* longopts,
239 | const char* pattern,
240 | const char* datatype,
241 | int flags,
242 | const char *glossary);
243 | struct arg_rex* arg_rexn(const char* shortopts,
244 | const char* longopts,
245 | const char* pattern,
246 | const char* datatype,
247 | int mincount,
248 | int maxcount,
249 | int flags,
250 | const char *glossary);
251 |
252 | struct arg_file* arg_file0(const char* shortopts,
253 | const char* longopts,
254 | const char* datatype,
255 | const char* glossary);
256 | struct arg_file* arg_file1(const char* shortopts,
257 | const char* longopts,
258 | const char* datatype,
259 | const char *glossary);
260 | struct arg_file* arg_filen(const char* shortopts,
261 | const char* longopts,
262 | const char* datatype,
263 | int mincount,
264 | int maxcount,
265 | const char *glossary);
266 |
267 | struct arg_date* arg_date0(const char* shortopts,
268 | const char* longopts,
269 | const char* format,
270 | const char* datatype,
271 | const char* glossary);
272 | struct arg_date* arg_date1(const char* shortopts,
273 | const char* longopts,
274 | const char* format,
275 | const char* datatype,
276 | const char *glossary);
277 | struct arg_date* arg_daten(const char* shortopts,
278 | const char* longopts,
279 | const char* format,
280 | const char* datatype,
281 | int mincount,
282 | int maxcount,
283 | const char *glossary);
284 |
285 | struct arg_end* arg_end(int maxerrors);
286 |
287 |
288 | /**** other functions *******************************************/
289 | int arg_nullcheck(void **argtable);
290 | int arg_parse(int argc, char **argv, void **argtable);
291 | void arg_print_option(FILE *fp, const char *shortopts, const char *longopts, const char *datatype, const char *suffix);
292 | void arg_print_syntax(FILE *fp, void **argtable, const char *suffix);
293 | void arg_print_syntaxv(FILE *fp, void **argtable, const char *suffix);
294 | void arg_print_glossary(FILE *fp, void **argtable, const char *format);
295 | void arg_print_glossary_gnu(FILE *fp, void **argtable);
296 | void arg_print_errors(FILE* fp, struct arg_end* end, const char* progname);
297 | void arg_freetable(void **argtable, size_t n);
298 |
299 | /**** deprecated functions, for back-compatibility only ********/
300 | void arg_free(void **argtable);
301 |
302 | #ifdef __cplusplus
303 | }
304 | #endif
305 | #endif
306 |
--------------------------------------------------------------------------------
/cuda/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import sys
4 | import os
5 |
6 | def print_y(s):
7 | print('\033[1;33m' + s + '\033[m')
8 | return
9 |
10 | def print_bw(s):
11 | print('\033[1;1m' + s + '\033[m')
12 | return
13 |
14 | def errorcheck_sys(error):
15 | if error:
16 | print_y("Error occurs. Please run setup.py first to compile the binary files.")
17 | sys.exit()
18 | return
19 |
20 | def errorcheck_prog(error):
21 | if error:
22 | print_y("Binary error occurs.")
23 | sys.exit()
24 | return
25 |
26 | print_bw("Compiling AO_CWS code ...")
27 |
28 | if not os.path.exists("build"):
29 | errorcheck_sys(os.mkdir("build"))
30 | errorcheck_sys(os.chdir("build"))
31 |
32 | if os.name == 'nt':
33 | errorcheck_sys(os.system('cmake -DCMAKE_GENERATOR_PLATFORM=x64 --target install ..'))
34 | # errorcheck_sys(os.system("msbuild CWS_and_AO.sln /p:Configuration=Release")) // need to manually build the solution
35 | elif os.name == 'posix':
36 | errorcheck_sys(os.system("cmake .."))
37 | errorcheck_sys(os.system("make -j8"))
38 | errorcheck_sys(os.system("make install"))
39 | else:
40 | print_y("Unknown platform. Compilation ends.")
41 |
42 | print_bw("You are done. AO_CWS CUDA code compilation is finished. Program exits.")
43 |
--------------------------------------------------------------------------------
/cuda/src/IO_helper_functions.h:
--------------------------------------------------------------------------------
1 | #ifndef IO_HELPER_FUNCTIONS_H
2 | #define IO_HELPER_FUNCTIONS_H
3 |
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 |
11 | // use fopen under Linux
12 | #ifdef __unix
13 | #define fopen_s(pFile,filename,mode) ((*(pFile))=fopen((filename), (mode)))==NULL
14 | #endif
15 |
16 | ///////////////////////////////////////////////////////////////////////////////
17 | /// \brief save optical flow in format described on vision.middlebury.edu/flow
18 | /// \param[in] name output file name
19 | /// \param[in] w optical flow field width
20 | /// \param[in] h optical flow field height
21 | /// \param[in] u horizontal displacement
22 | /// \param[in] v vertical displacement
23 | ///////////////////////////////////////////////////////////////////////////////
24 | void WriteFloFile(const char *name, int w, int h, const float *u, const float *v)
25 | {
26 | FILE *stream;
27 |
28 | if ((fopen_s(&stream, name, "wb")) != 0)
29 | {
30 | printf("Could not save flow to \"%s\"\n", name);
31 | return;
32 | }
33 |
34 | float data = 202021.25f;
35 | fwrite(&data, sizeof(float), 1, stream);
36 | fwrite(&w, sizeof(w), 1, stream);
37 | fwrite(&h, sizeof(h), 1, stream);
38 |
39 | for (int i = 0; i < h; ++i)
40 | {
41 | for (int j = 0; j < w; ++j)
42 | {
43 | const int pos = j + i * w;
44 | fwrite(u + pos, sizeof(float), 1, stream);
45 | fwrite(v + pos, sizeof(float), 1, stream);
46 | }
47 | }
48 |
49 | fclose(stream);
50 | }
51 |
52 |
53 | std::string readFile(const char *filePath) {
54 | std::string content;
55 | std::ifstream fileStream(filePath, std::ios::in);
56 |
57 | if(!fileStream.is_open()) {
58 | std::cerr << "Could not read file " << filePath << ". File does not exist." << std::endl;
59 | return "";
60 | }
61 |
62 | std::string line = "";
63 | while(!fileStream.eof()) {
64 | std::getline(fileStream, line);
65 | content.append(line + "\n");
66 | }
67 |
68 | fileStream.close();
69 | return content;
70 | }
71 |
72 |
73 | template
74 | void readtxt(const char *dataPath, T *data, int N)
75 | {
76 | FILE *file = fopen(dataPath,"r");
77 | for(int i = 0; i < N; i++)
78 | fscanf(file, "%f\n", &data[i]);
79 |
80 | // fflush(file);
81 | fclose(file);
82 |
83 | printf("%s data loaded\n", dataPath);
84 | }
85 |
86 | template
87 | void savetxt(const char *dataPath, T *data, int N)
88 | {
89 | std::ofstream file(dataPath);
90 | if (file.is_open()){
91 | for (int i = 0; i < N; i++)
92 | file << std::setprecision(std::numeric_limits::digits10 + 1) << data[i] << "\n";
93 | file.close();
94 | }
95 | printf("%s data saved\n", dataPath);
96 | }
97 |
98 |
99 | #endif
100 |
101 |
--------------------------------------------------------------------------------
/cuda/src/addKernel.cuh:
--------------------------------------------------------------------------------
1 | #include "common.h"
2 |
3 | ///////////////////////////////////////////////////////////////////////////////
4 | /// \brief add two vectors of size _count_
5 | ///
6 | /// CUDA kernel
7 | /// \param[in] op1 term one
8 | /// \param[in] op2 term two
9 | /// \param[in] count vector size
10 | /// \param[out] sum result
11 | ///////////////////////////////////////////////////////////////////////////////
12 | __global__
13 | void AddKernel(const float *op1, const float *op2, int count, float *sum)
14 | {
15 | const int pos = threadIdx.x + blockIdx.x * blockDim.x;
16 |
17 | if (pos >= count) return;
18 |
19 | sum[pos] = op1[pos] + op2[pos];
20 | }
21 |
22 | ///////////////////////////////////////////////////////////////////////////////
23 | /// \brief add two vectors of size _count_
24 | /// \param[in] op1 term one
25 | /// \param[in] op2 term two
26 | /// \param[in] count vector size
27 | /// \param[out] sum result
28 | ///////////////////////////////////////////////////////////////////////////////
29 | static
30 | void Add(const float *op1, const float *op2, int count, float *sum)
31 | {
32 | dim3 threads(256);
33 | dim3 blocks(iDivUp(count, threads.x));
34 |
35 | AddKernel<<>>(op1, op2, count, sum);
36 | }
37 |
--------------------------------------------------------------------------------
/cuda/src/common.h:
--------------------------------------------------------------------------------
1 | ///////////////////////////////////////////////////////////////////////////////
2 | // Header for common includes and utility functions
3 | ///////////////////////////////////////////////////////////////////////////////
4 |
5 | #ifndef COMMON_H
6 | #define COMMON_H
7 |
8 |
9 | ///////////////////////////////////////////////////////////////////////////////
10 | // Common includes
11 | ///////////////////////////////////////////////////////////////////////////////
12 |
13 | #include
14 | #include
15 | #include
16 | #include
17 | #include
18 |
19 | #include
20 | #include
21 | #include
22 |
23 | #include
24 | #include
25 |
26 | ///////////////////////////////////////////////////////////////////////////////
27 | // Error checking functions
28 | ///////////////////////////////////////////////////////////////////////////////
29 | //static const char *_cudaGetErrorEnum(cufftResult error)
30 | //{
31 | // switch (error)
32 | // {
33 | // case CUFFT_SUCCESS:
34 | // return "CUFFT_SUCCESS";
35 |
36 | // case CUFFT_INVALID_PLAN:
37 | // return "CUFFT_INVALID_PLAN";
38 |
39 | // case CUFFT_ALLOC_FAILED:
40 | // return "CUFFT_ALLOC_FAILED";
41 |
42 | // case CUFFT_INVALID_TYPE:
43 | // return "CUFFT_INVALID_TYPE";
44 |
45 | // case CUFFT_INVALID_VALUE:
46 | // return "CUFFT_INVALID_VALUE";
47 |
48 | // case CUFFT_INTERNAL_ERROR:
49 | // return "CUFFT_INTERNAL_ERROR";
50 |
51 | // case CUFFT_EXEC_FAILED:
52 | // return "CUFFT_EXEC_FAILED";
53 |
54 | // case CUFFT_SETUP_FAILED:
55 | // return "CUFFT_SETUP_FAILED";
56 |
57 | // case CUFFT_INVALID_SIZE:
58 | // return "CUFFT_INVALID_SIZE";
59 |
60 | // case CUFFT_UNALIGNED_DATA:
61 | // return "CUFFT_UNALIGNED_DATA";
62 | // }
63 |
64 | // return "";
65 | //}
66 |
67 | //#define cufftSafeCall(err) __cufftSafeCall(err, __FILE__, __LINE__)
68 | //inline void __cufftSafeCall(cufftResult err, const char *file, const int line)
69 | //{
70 | // if( CUFFT_SUCCESS != err) {
71 | // fprintf(stderr, "CUFFT error in file '%s', line %d\n %s\nerror %d: %s\nterminating!\n",__FILE__, __LINE__,err, \
72 | // _cudaGetErrorEnum(err)); \
73 | // cudaDeviceReset(); assert(0); \
74 | // }
75 | //}
76 |
77 |
78 | ///////////////////////////////////////////////////////////////////////////////
79 | // Common constants
80 | ///////////////////////////////////////////////////////////////////////////////
81 | const int StrideAlignment = 32;
82 |
83 | // #ifdef DOUBLE_PRECISION
84 | // typedef double real;
85 | // typedef cufftDoubleComplex complex;
86 | // #else
87 | // typedef float real;
88 | // typedef cufftComplex complex;
89 | // #endif
90 |
91 | typedef cufftComplex complex;
92 |
93 | #define PI 3.1415926535897932384626433832795028841971693993751
94 |
95 | // block size for shared memory
96 | #define BLOCK_X 32
97 | #define BLOCK_Y 32
98 |
99 |
100 | ///////////////////////////////////////////////////////////////////////////////
101 | // Common functions
102 | ///////////////////////////////////////////////////////////////////////////////
103 |
104 | // A GPU timer
105 | struct GpuTimer
106 | {
107 | cudaEvent_t start;
108 | cudaEvent_t stop;
109 |
110 | GpuTimer()
111 | {
112 | cudaEventCreate(&start);
113 | cudaEventCreate(&stop);
114 | }
115 |
116 | ~GpuTimer()
117 | {
118 | cudaEventDestroy(start);
119 | cudaEventDestroy(stop);
120 | }
121 |
122 | void Start()
123 | {
124 | cudaEventRecord(start, 0);
125 | }
126 |
127 | void Stop()
128 | {
129 | cudaEventRecord(stop, 0);
130 | }
131 |
132 | float Elapsed()
133 | {
134 | float elapsed;
135 | cudaEventSynchronize(stop);
136 | cudaEventElapsedTime(&elapsed, start, stop);
137 | return elapsed;
138 | }
139 | };
140 |
141 | // Align up n to the nearest multiple of m
142 | inline int iAlignUp(int n, int m = StrideAlignment)
143 | {
144 | int mod = n % m;
145 |
146 | if (mod)
147 | return n + m - mod;
148 | else
149 | return n;
150 | }
151 |
152 | // round up n/m
153 | inline int iDivUp(int n, int m)
154 | {
155 | return (n + m - 1) / m;
156 | }
157 |
158 | // swap two values
159 | template
160 | inline void Swap(T &a, T &b)
161 | {
162 | T t = a;
163 | a = b;
164 | b = t;
165 | }
166 |
167 | // Wrap to [-0.5 0.5], then add 0.5 to [0 1] for final phase show
168 | __host__
169 | __device__
170 | inline float wrap(float x)
171 | {
172 | return x - floor(x + 0.5f) + 0.5f;
173 | }
174 |
175 | // nearest integer of power of 2
176 | __host__
177 | inline int nearest_power_of_two(int x)
178 | {
179 | return 2 << (static_cast(std::log2(x) + 1.0) - 1);
180 | }
181 | #endif
182 |
--------------------------------------------------------------------------------
/cuda/src/computeDCTweightsKernel.cuh:
--------------------------------------------------------------------------------
1 | #include "common.h"
2 |
3 | ///////////////////////////////////////////////////////////////////////////////
4 | /// \brief compute mat_x_hat
5 | ///
6 | /// CUDA kernel
7 | /// \param[in] mu proximal parameter
8 | /// \param[in] alpha regularization parameter
9 | /// \param[in] M_width unknown width
10 | /// \param[in] M_height unknown height
11 | /// \param[out] mat_x_hat result
12 | ///////////////////////////////////////////////////////////////////////////////
13 | __global__
14 | void computeDCTweightsKernel(int N_height, complex *out)
15 | {
16 | const int pos = threadIdx.x + blockIdx.x * blockDim.x;
17 |
18 | if (pos >= N_height) return;
19 |
20 | // we use double to enhance accuracy
21 | out[pos].x = (float) (sqrt((double)(2*N_height)) * cos(pos*PI/(2*N_height)));
22 | out[pos].y = (float) (sqrt((double)(2*N_height)) * sin(pos*PI/(2*N_height)));
23 | }
24 |
25 | ///////////////////////////////////////////////////////////////////////////////
26 | /// \brief compute DCT weights
27 | /// \param[in] mu proximal parameter
28 | /// \param[in] alpha regularization parameter
29 | /// \param[in] M_width unknown width
30 | /// \param[in] M_height unknown height
31 | /// \param[out] mat_x_hat result
32 | ///////////////////////////////////////////////////////////////////////////////
33 | static
34 | void computeDCTweights(int N_width, int N_height, complex *ww_1, complex *ww_2)
35 | {
36 | dim3 threads(256);
37 | dim3 blocks1(iDivUp(N_height, threads.x));
38 | dim3 blocks2(iDivUp(N_width, threads.x));
39 |
40 | computeDCTweightsKernel<<>>(N_height, ww_1);
41 | computeDCTweightsKernel<<>>(N_width, ww_2);
42 | }
43 |
--------------------------------------------------------------------------------
/cuda/src/cws_A_phi.cu:
--------------------------------------------------------------------------------
1 | // Please keep this include order to make sure a successful compilation!
2 | #include
3 | #include
4 | #include // for profiling
5 | #include
6 |
7 | // OpenCV
8 | #include
9 |
10 | // include kernels
11 | #include "common.h"
12 | #include "prepare_cufft_warmup.h"
13 | #include "computeDCTweightsKernel.cuh"
14 | #include "x_updateKernel.cuh"
15 | #include "prox_gKernel.cuh"
16 | #include "addKernel.cuh"
17 | #include "medianfilteringKernel.cuh"
18 |
19 | // thrust headers
20 | #include
21 | #include
22 | #include
23 | #include
24 | #include
25 |
26 | #include "cws_A_phi.h"
27 |
28 |
29 | // =============================================================
30 | // thrust definitions
31 | // =============================================================
32 | // thrust operators
33 | template
34 | struct absolute{
35 | __host__ __device__
36 | T operator()(const T& x) const {
37 | return fabs(x);
38 | }
39 | };
40 | template
41 | struct square{
42 | __host__ __device__
43 | T operator()(const T& x) const {
44 | return x * x;
45 | }
46 | };
47 |
48 | // define reduction pointer
49 | thrust::device_ptr thrust_ptr;
50 |
51 | // thrust setup arguments
52 | absolute u_op_abs;
53 | square u_op_sq;
54 | thrust::plus b_op;
55 | // =============================================================
56 |
57 |
58 | // =============================================================
59 | // norm & mean functions
60 | // =============================================================
61 | template
62 | real norm(real *d_x, int n)
63 | {
64 | switch (N)
65 | {
66 | case 1:
67 | {
68 | thrust_ptr = thrust::device_pointer_cast(d_x);
69 | return thrust::transform_reduce(thrust_ptr, thrust_ptr + n, u_op_abs, 0.0, b_op);
70 | }
71 | case 2:
72 | {
73 | thrust_ptr = thrust::device_pointer_cast(d_x);
74 | return thrust::transform_reduce(thrust_ptr, thrust_ptr + n, u_op_sq, 0.0, b_op);
75 | }
76 | default:
77 | {
78 | printf("Undefined norm!\n");
79 | return 0;
80 | }
81 | }
82 | }
83 | template
84 | real mean(real *d_x, int n)
85 | {
86 | switch (N)
87 | {
88 | case 1:
89 | return norm<1>(d_x, n) / static_cast(n);
90 | case 2:
91 | return norm<2>(d_x, n) / static_cast(n);
92 | default:
93 | {
94 | printf("Undefined mean!\n");
95 | return 0;
96 | }
97 | }
98 | }
99 | // =============================================================
100 |
101 |
102 | // =============================================================
103 | // helper functions
104 | // =============================================================
105 | // ----- set to ones -----
106 | __global__
107 | void setonesKernel(real *I, cv::Size N)
108 | {
109 | const int ix = threadIdx.x + blockIdx.x * blockDim.x;
110 | const int iy = threadIdx.y + blockIdx.y * blockDim.y;
111 | if (ix >= N.width || iy >= N.height) return;
112 |
113 | I[ix + iy * N.width] = 1.0f;
114 | }
115 | void setasones(real *I, cv::Size N)
116 | {
117 | dim3 threads(16, 16);
118 | dim3 blocks(iDivUp(N.width, threads.x), iDivUp(N.height, threads.y));
119 | setonesKernel<<>>(I, N);
120 | }
121 | // =============================================================
122 |
123 |
124 | // =============================================================
125 | // constant memory cache
126 | // =============================================================
127 | __constant__ real shrinkage_value[2];
128 | __constant__ int L_SIZE[2];
129 | // =============================================================
130 |
131 |
132 | // =============================================================
133 | // pre-computation: DCT basis function
134 | // =============================================================
135 | __device__
136 | double DCT_kernel(const int ix, const int iy, int w, int h)
137 | {
138 | return 4.0 - 2*cos(PI*(double)ix/(double)w) - 2*cos(PI*(double)iy/(double)h);
139 | }
140 | __global__
141 | void mat_A_hatKernel(opt_A opt_A_t, cv::Size M, real *mat_A_hat)
142 | {
143 | const int ix = threadIdx.x + blockIdx.x * blockDim.x;
144 | const int iy = threadIdx.y + blockIdx.y * blockDim.y;
145 | const int pos = ix + iy * M.width;
146 |
147 | if (ix >= M.width || iy >= M.height) return;
148 |
149 | // \nabla^2
150 | double lap_mat = DCT_kernel(ix, iy, M.width, M.height);
151 |
152 | // \nabla^4 + \nabla^2
153 | double K_mat = lap_mat * (lap_mat + 1.0);
154 |
155 | // get mat_A_hat
156 | mat_A_hat[pos] = (double) ( 1.0 + (opt_A_t.tau_new + opt_A_t.mu) * K_mat );
157 | }
158 | __global__
159 | void mat_phi_hatKernel(opt_phi opt_phi_t, real beta, cv::Size N, real *mat_phi_hat)
160 | {
161 | const int ix = threadIdx.x + blockIdx.x * blockDim.x;
162 | const int iy = threadIdx.y + blockIdx.y * blockDim.y;
163 | const int pos = ix + iy * N.width;
164 |
165 | if (ix >= N.width || iy >= N.height) return;
166 |
167 | // \nabla^2
168 | double lap_mat = DCT_kernel(ix, iy, N.width, N.height);
169 |
170 | // get mat_phi_hat (pre-divide mu so there is no need for re-scaling phi in ADMM)
171 | if (pos == 0){
172 | mat_phi_hat[pos] = 1.0;
173 | }
174 | else{
175 | mat_phi_hat[pos] = (double) (
176 | (opt_phi_t.mu + beta + beta * lap_mat) * lap_mat );
177 | }
178 | }
179 | void compute_inverse_mat(opt_A opt_A_t, opt_phi opt_phi_t, real beta,
180 | cv::Size M, cv::Size N, real *mat_A_hat, real *mat_phi_hat)
181 | {
182 | dim3 threads(16, 16);
183 | dim3 blocks_M(iDivUp(M.width, threads.x), iDivUp(M.height, threads.y));
184 | dim3 blocks_N(iDivUp(N.width, threads.x), iDivUp(N.height, threads.y));
185 | mat_A_hatKernel<<>>(opt_A_t, M, mat_A_hat);
186 | mat_phi_hatKernel<<>>(opt_phi_t, beta, N, mat_phi_hat);
187 | }
188 | // =============================================================
189 |
190 |
191 | // =============================================================
192 | // operators
193 | // =============================================================
194 | // ----- nabla -----
195 | __global__
196 | void nablaKernel(const real *__restrict__ I, cv::Size N, real *Ix, real *Iy)
197 | {
198 | const int ix = threadIdx.x + blockIdx.x * blockDim.x;
199 | const int iy = threadIdx.y + blockIdx.y * blockDim.y;
200 | const int pos = ix + iy * N.width;
201 |
202 | if (ix >= N.width || iy >= N.height) return;
203 |
204 | // replicate/symmetric boundary condition
205 | if(ix == 0){
206 | Ix[pos] = 0.0f;
207 | }
208 | else{
209 | Ix[pos] = I[pos] - I[(ix-1) + iy * N.width];
210 | }
211 | if (iy == 0){
212 | Iy[pos] = 0.0f;
213 | }
214 | else{
215 | Iy[pos] = I[pos] - I[ix + (iy-1) * N.width];
216 | }
217 | }
218 | void nabla(const real *I, cv::Size N, real *nabla_x, real *nabla_y)
219 | {
220 | dim3 threads(32, 32);
221 | dim3 blocks(iDivUp(N.width, threads.x), iDivUp(N.height, threads.y));
222 | nablaKernel<<>>(I, N, nabla_x, nabla_y);
223 | }
224 | // -----------------
225 |
226 | // ----- nablaT -----
227 | __global__
228 | void nablaTKernel(const real *__restrict__ Ix, const real *__restrict__ Iy, cv::Size N, real *div)
229 | {
230 | const int ix = threadIdx.x + blockIdx.x * blockDim.x;
231 | const int iy = threadIdx.y + blockIdx.y * blockDim.y;
232 | const int pos = ix + iy * N.width;
233 |
234 | real val1, val2;
235 | if (ix >= N.width || iy >= N.height) return;
236 |
237 | // replicate/symmetric boundary condition
238 | if(ix == N.width-1){
239 | val1 = 0.0f;
240 | }
241 | else{
242 | val1 = Ix[pos] - Ix[ix+1 + iy * N.width];
243 | }
244 | if (iy == N.height-1){
245 | val2 = 0.0f;
246 | }
247 | else{
248 | val2 = Iy[pos] - Iy[ix + (iy+1) * N.width];
249 | }
250 | div[pos] = val1 + val2;
251 | }
252 | void nablaT(const real *__restrict__ Ix, const real *__restrict__ Iy, cv::Size N, real *div)
253 | {
254 | dim3 threads(32, 32);
255 | dim3 blocks(iDivUp(N.width, threads.x), iDivUp(N.height, threads.y));
256 | nablaTKernel<<>>(Ix, Iy, N, div);
257 | }
258 | // -----------------
259 |
260 | // ----- nabla2 -----
261 | __global__
262 | void nabla2Kernel(const real *__restrict__ I, cv::Size N, real *L)
263 | {
264 | const int ix = threadIdx.x + blockIdx.x * blockDim.x;
265 | const int iy = threadIdx.y + blockIdx.y * blockDim.y;
266 | const int py = iy * N.width;
267 | const int pos = ix + py;
268 |
269 | if (ix >= N.width || iy >= N.height) return;
270 |
271 | // replicate/symmetric boundary condition (3x3 stencil)
272 | const int x_min = max(ix-1, 0);
273 | const int x_max = min(ix+1, N.width-1);
274 | const int y_min = max(iy-1, 0);
275 | const int y_max = min(iy+1, N.height-1);
276 |
277 | L[pos] = -4 * I[pos] + I[x_min + py] + I[x_max + py] + I[ix + y_min*N.width] + I[ix + y_max*N.width];
278 | }
279 | void nabla2(const real *I, cv::Size N, real *L)
280 | {
281 | dim3 threads(32, 32);
282 | dim3 blocks(iDivUp(N.width, threads.x), iDivUp(N.height, threads.y));
283 | nabla2Kernel<<>>(I, N, L);
284 | }
285 | // -----------------
286 |
287 | // ----- K & KT -----
288 | void K(const real *I, cv::Size N, real *grad_x, real *grad_y, real *L)
289 | {
290 | nabla(I, N, grad_x, grad_y);
291 | nabla2(I, N, L);
292 | }
293 | __global__
294 | void KTKernel(const real *__restrict__ Ix, const real *__restrict__ Iy,
295 | const real *__restrict__ L, cv::Size N, real *I)
296 | {
297 | const int ix = threadIdx.x + blockIdx.x * blockDim.x;
298 | const int iy = threadIdx.y + blockIdx.y * blockDim.y;
299 | const int py = iy * N.width;
300 | const int pos = ix + py;
301 |
302 | if (ix >= N.width || iy >= N.height) return;
303 |
304 | // replicate/symmetric boundary condition (3x3 stencil)
305 | const int x_min = max(ix-1, 0);
306 | const int x_max = min(ix+1, N.width-1);
307 | const int y_min = max(iy-1, 0);
308 | const int y_max = min(iy+1, N.height-1);
309 |
310 | real Div = Ix[pos] - Ix[x_max + iy*N.width] + Iy[pos] - Iy[ix + y_max*N.width];
311 | real Lap = -4 * L[pos] + L[x_min + py] + L[x_max + py] + L[ix + y_min*N.width] + L[ix + y_max*N.width];
312 | I[pos] = Div + Lap;
313 | }
314 | void KT(const real *grad_x, const real *grad_y, const real *L, cv::Size N, real *I)
315 | {
316 | dim3 threads(32, 32);
317 | dim3 blocks(iDivUp(N.width, threads.x), iDivUp(N.height, threads.y));
318 | KTKernel<<>>(grad_x, grad_y, L, N, I);
319 | }
320 | __global__
321 | void MKernel(const real *in, cv::Size M, cv::Size N, real *out)
322 | {
323 | const int ix = threadIdx.x + blockIdx.x * blockDim.x;
324 | const int iy = threadIdx.y + blockIdx.y * blockDim.y;
325 |
326 | const int pos_N = ix + iy * N.width;
327 | const int pos_M = ix - L_SIZE[0] + (iy - L_SIZE[1]) * M.width;
328 |
329 | if (ix >= N.width || iy >= N.height) return;
330 | else if (ix >= L_SIZE[0] && ix < N.width - L_SIZE[0] &&
331 | iy >= L_SIZE[1] && iy < N.height - L_SIZE[1]) {
332 | out[pos_M] = in[pos_N];
333 | }
334 | else return;
335 | }
336 | void Mask_func(const real *in, cv::Size M, cv::Size N, real *out)
337 | {
338 | dim3 threads(32, 32);
339 | dim3 blocks(iDivUp(N.width, threads.x), iDivUp(N.height, threads.y));
340 | MKernel<<>>(in, M, N, out);
341 | }
342 | // -----------------
343 | // =============================================================
344 |
345 |
346 | // =============================================================
347 | // objective functions
348 | // =============================================================
349 | __global__
350 | void res1Kernel(const real *A, const real *b, const real *I0, cv::Size N, real *tmp)
351 | {
352 | const int ix = threadIdx.x + blockIdx.x * blockDim.x;
353 | const int iy = threadIdx.y + blockIdx.y * blockDim.y;
354 | const int pos = ix + iy * N.width;
355 |
356 | if (ix >= N.width || iy >= N.height) return;
357 | tmp[pos] = A[pos] - b[pos]/I0[pos];
358 | }
359 | __global__
360 | void res2Kernel(const real *A, const real *b, const real *I0, cv::Size M, real *tmp)
361 | {
362 | const int ix = threadIdx.x + blockIdx.x * blockDim.x;
363 | const int iy = threadIdx.y + blockIdx.y * blockDim.y;
364 | const int pos = ix + iy * M.width;
365 |
366 | if (ix >= M.width || iy >= M.height) return;
367 | tmp[pos] = I0[pos] * A[pos] - b[pos];
368 | }
369 | __global__
370 | void res3Kernel(real *gx, real *gy, real *gt, real *wx, real *wy,
371 | cv::Size M, cv::Size N, real *out, bool isM = true)
372 | {
373 | // This function will be used by:
374 | // - 1. Cost function of phi (function: obj_phi)
375 | // - 2. Update I_warp (function: update_I_warp)
376 | // To fulfill both goals, size of out is designed to be either M or N.
377 | const int ix = threadIdx.x + blockIdx.x * blockDim.x;
378 | const int iy = threadIdx.y + blockIdx.y * blockDim.y;
379 |
380 | const int pos_N = ix + iy * N.width;
381 | const int pos_M = ix - L_SIZE[0] + (iy - L_SIZE[1]) * M.width;
382 | const int pos_tmp = isM ? pos_M : pos_N;
383 |
384 | if (ix >= N.width || iy >= N.height) return;
385 | else if (ix >= L_SIZE[0] && ix < N.width - L_SIZE[0] &&
386 | iy >= L_SIZE[1] && iy < N.height - L_SIZE[1]) {
387 | out[pos_tmp] = gx[pos_M] * wx[pos_N] + gy[pos_M] * wy[pos_N] + gt[pos_M];
388 | }
389 | else if (!isM) {
390 | out[pos_tmp] = 0.0f; // outside area be zeros; in this case tmp is of size N
391 | }
392 | else return;
393 | }
394 | real obj_A(const real *A, const real *b, const real *I0, opt_A opt_A_t, cv::Size N,
395 | real *tmp1, real *tmp2, real *tmp3)
396 | {
397 | // compute A - b/I0, and data term
398 | dim3 threads(32, 32);
399 | dim3 blocks(iDivUp(N.width, threads.x), iDivUp(N.height, threads.y));
400 | res1Kernel<<>>(A, b, I0, N, tmp1);
401 | real data_term = norm<2>(tmp1, N.area());
402 |
403 | // compute KA, and priors
404 | K(A, N, tmp1, tmp2, tmp3);
405 | real L1_term = opt_A_t.gamma_new * ( norm<1>(tmp1, N.area()) + norm<1>(tmp2, N.area()) + norm<1>(tmp3, N.area()) );
406 | real L2_term = opt_A_t.tau_new * ( norm<2>(tmp1, N.area()) + norm<2>(tmp2, N.area()) + norm<2>(tmp3, N.area()) );
407 |
408 | return data_term + L1_term + L2_term;
409 | }
410 | real obj_phi(const real *phi, real *gx, real *gy, real *gt,
411 | real alpha, real beta, cv::Size M, cv::Size N,
412 | real *tmp1_N, real *tmp2_N, real *tmp3_N)
413 | {
414 | // compute data term
415 | dim3 threads(32, 32);
416 | dim3 blocks(iDivUp(N.width, threads.x), iDivUp(N.height, threads.y));
417 | nabla(phi, N, tmp1_N, tmp2_N);
418 | res3Kernel<<>>(gx, gy, gt, tmp1_N, tmp2_N, M, N, tmp1_N, false);
419 | real data_term = norm<2>(tmp1_N, N.area());
420 |
421 | // compute nabla phi, and phi priors
422 | K(phi, N, tmp1_N, tmp2_N, tmp3_N);
423 | real L1_term = alpha * ( norm<1>(tmp1_N, N.area()) + norm<1>(tmp2_N, N.area()) );
424 | real L2_term = beta * ( norm<2>(tmp1_N, N.area()) + norm<2>(tmp2_N, N.area()) + norm<2>(tmp3_N, N.area()) );
425 |
426 | return data_term + L1_term + L2_term;
427 | }
428 | real obj_total(const real *A, const real *phi, const real *b, const real *I0,
429 | real alpha, real beta, real gamma, real tau, cv::Size M, cv::Size N,
430 | real *tmp1_M, real *tmp2_M, real *tmp3_M,
431 | real *tmp1_N, real *tmp2_N, real *tmp3_N)
432 | {
433 | // compute I0*A - b, and data term
434 | dim3 threads(16, 16);
435 | dim3 blocks(iDivUp(M.width, threads.x), iDivUp(M.height, threads.y));
436 | res2Kernel<<>>(A, b, I0, M, tmp1_M);
437 | real data_term = norm<2>(tmp1_M, M.area());
438 |
439 | // compute nabla phi, and phi priors
440 | K(phi, N, tmp1_N, tmp2_N, tmp3_N);
441 | real phi_L1_term = alpha * ( norm<1>(tmp1_N, N.area()) + norm<1>(tmp2_N, N.area()) );
442 | real phi_L2_term = beta * ( norm<2>(tmp1_N, N.area()) + norm<2>(tmp2_N, N.area()) + norm<2>(tmp3_N, N.area()) );
443 |
444 | // compute KA, and A priors
445 | K(A, M, tmp1_M, tmp2_M, tmp3_M);
446 | real A_L1_term = gamma * ( norm<1>(tmp1_M, M.area()) + norm<1>(tmp2_M, M.area()) + norm<1>(tmp3_M, M.area()) );
447 | real A_L2_term = tau * ( norm<2>(tmp1_M, M.area()) + norm<2>(tmp2_M, M.area()) + norm<2>(tmp3_M, M.area()) );
448 |
449 | return data_term + phi_L1_term + phi_L2_term + A_L1_term + A_L2_term;
450 | }
451 | // =============================================================
452 |
453 |
454 | // =============================================================
455 | // sub functions for updates
456 | // =============================================================
457 | __global__
458 | void ADMM_AKernel(real *zeta_x, real *zeta_y, real *zeta_z,
459 | real *tmp_x, real *tmp_y, real *tmp_z, cv::Size N)
460 | {
461 | const int ix = threadIdx.x + blockIdx.x * blockDim.x;
462 | const int iy = threadIdx.y + blockIdx.y * blockDim.y;
463 | const int pos = ix + iy * N.width;
464 |
465 | if (ix >= N.width || iy >= N.height) return;
466 |
467 | // compute u = KA + zeta
468 | real u_x = tmp_x[pos] + zeta_x[pos];
469 | real u_y = tmp_y[pos] + zeta_y[pos];
470 | real u_z = tmp_z[pos] + zeta_z[pos];
471 |
472 | // B-update
473 | real B_x = prox_l1(u_x, shrinkage_value[0]);
474 | real B_y = prox_l1(u_y, shrinkage_value[0]);
475 | real B_z = prox_l1(u_z, shrinkage_value[0]);
476 |
477 | // zeta-update
478 | zeta_x[pos] = u_x - B_x;
479 | zeta_y[pos] = u_y - B_y;
480 | zeta_z[pos] = u_z - B_z;
481 |
482 | // store B-zeta
483 | tmp_x[pos] = 2*B_x - u_x;
484 | tmp_y[pos] = 2*B_y - u_y;
485 | tmp_z[pos] = 2*B_z - u_z;
486 | }
487 | __global__
488 | void ADMM_precomputeAKernel(real *I_warp, real *I0, real *KT_res, real mu, cv::Size N, real *A)
489 | {
490 | const int ix = threadIdx.x + blockIdx.x * blockDim.x;
491 | const int iy = threadIdx.y + blockIdx.y * blockDim.y;
492 | const int py = iy * N.width;
493 | const int pos = ix + py;
494 |
495 | if (ix >= N.width || iy >= N.height) return;
496 |
497 | // pre-compute for A-update (I_warp/I0 can be pre-computed)
498 | A[pos] = I_warp[pos] / I0[pos] + mu * KT_res[pos];
499 | }
500 | void A_update(real *A, real *I_warp, real *I0,
501 | real *zeta_x, real *zeta_y, real *zeta_z,
502 | real *grad_x_M, real *grad_y_M, real *lap_M,
503 | complex *tmp_dct_M, real *mat_A_hat,
504 | complex *ww_1_M, complex *ww_2_M, cufftHandle plan_dct_1_M, cufftHandle plan_dct_2_M,
505 | opt_A opt_A_t, cv::Size M)
506 | {
507 | dim3 threads(32, 32);
508 | dim3 blocks(iDivUp(M.width, threads.x), iDivUp(M.height, threads.y));
509 |
510 | // initial objective
511 | real objval;
512 | if (opt_A_t.isverbose)
513 | {
514 | printf("-- A update\n");
515 | objval = obj_A(A, I_warp, I0, opt_A_t, M, grad_x_M, grad_y_M, lap_M);
516 | printf("---- init, obj = %.4e \n", objval);
517 | }
518 |
519 | // initialization
520 | checkCudaErrors(cudaMemset(zeta_x, 0, M.area()*sizeof(real)));
521 | checkCudaErrors(cudaMemset(zeta_y, 0, M.area()*sizeof(real)));
522 | checkCudaErrors(cudaMemset(zeta_z, 0, M.area()*sizeof(real)));
523 | checkCudaErrors(cudaMemset(grad_x_M, 0, M.area()*sizeof(real)));
524 | checkCudaErrors(cudaMemset(grad_y_M, 0, M.area()*sizeof(real)));
525 | checkCudaErrors(cudaMemset(lap_M, 0, M.area()*sizeof(real)));
526 |
527 | // ADMM loop
528 | for (signed int i = 1; i <= opt_A_t.iter; ++i)
529 | {
530 | // compute KT(B - zeta)
531 | KT(grad_x_M, grad_y_M, lap_M, M, A);
532 |
533 | // A-update inversion
534 | ADMM_precomputeAKernel<<>>(I_warp, I0, A, opt_A_t.mu, M, A);
535 | x_update(A, tmp_dct_M, mat_A_hat, ww_1_M, ww_2_M, M.width, M.height, plan_dct_1_M, plan_dct_2_M);
536 |
537 | // B-update & zeta-update
538 | K(A, M, grad_x_M, grad_y_M, lap_M);
539 | ADMM_AKernel<<>>(zeta_x, zeta_y, zeta_z, grad_x_M, grad_y_M, lap_M, M);
540 | }
541 |
542 | // show final objective
543 | if (opt_A_t.isverbose)
544 | {
545 | objval = obj_A(A, I_warp, I0, opt_A_t, M, grad_x_M, grad_y_M, lap_M);
546 | printf("---- iter = %d, obj = %.4e \n", opt_A_t.iter, objval);
547 | }
548 | }
549 | __global__
550 | void I_warp_updateKernel(real *A, real *I0, cv::Size M, real *I_warp)
551 | {
552 | const int ix = threadIdx.x + blockIdx.x * blockDim.x;
553 | const int iy = threadIdx.y + blockIdx.y * blockDim.y;
554 | const int pos = ix + iy * M.width;
555 |
556 | if (ix >= M.width || iy >= M.height) return;
557 | I_warp[pos] = A[pos] * I0[pos];
558 | }
559 | void I_warp_update(real *A, real *I0, cv::Size M, real *I_warp)
560 | {
561 | dim3 threads(32, 32);
562 | dim3 blocks(iDivUp(M.width, threads.x), iDivUp(M.height, threads.y));
563 | I_warp_updateKernel<<>>(A, I0, M, I_warp);
564 | }
565 | void update_I_warp(real *I_warp, real *phi, real *gx, real *gy,
566 | real *tmp1, real *tmp2, cv::Size M, cv::Size N)
567 | {
568 | dim3 threads(32, 32);
569 | dim3 blocks(iDivUp(N.width, threads.x), iDivUp(N.height, threads.y));
570 | nabla(phi, N, tmp1, tmp2);
571 | res3Kernel<<>>(gx, gy, I_warp, tmp1, tmp2, M, N, I_warp, true);
572 | }
573 | texture texTarget;
574 | template // template global kernel to handle all cases
575 | __global__
576 | void ComputeDerivativesL1Kernel(const real *__restrict__ I0, cv::Size M, real mu,
577 | real *A11, real *A12, real *A22, real *a, real *b, real *c, const real *__restrict__ I = NULL)
578 | {
579 | const int ix = threadIdx.x + blockIdx.x * blockDim.x;
580 | const int iy = threadIdx.y + blockIdx.y * blockDim.y;
581 | const int pos = ix + iy * M.width;
582 |
583 | if (ix >= M.width || iy >= M.height) return;
584 |
585 | // compute the derivatives
586 | real gx, gy, gt;
587 | switch (mods)
588 | {
589 | case 0: // image size is multiple of 32; use texture memory
590 | {
591 | real dx = 1.0f / (real)M.width;
592 | real dy = 1.0f / (real)M.height;
593 |
594 | real x = ((real)ix + 0.5f) * dx;
595 | real y = ((real)iy + 0.5f) * dy;
596 |
597 | // x derivative
598 | gx = tex2D(texTarget, x - 2.0f * dx, y);
599 | gx -= tex2D(texTarget, x - 1.0f * dx, y) * 8.0f;
600 | gx += tex2D(texTarget, x + 1.0f * dx, y) * 8.0f;
601 | gx -= tex2D(texTarget, x + 2.0f * dx, y);
602 | gx /= 12.0f;
603 |
604 | // t derivative
605 | gt = tex2D(texTarget, x, y) - I0[pos];
606 |
607 | // y derivative
608 | gy = tex2D(texTarget, x, y - 2.0f * dy);
609 | gy -= tex2D(texTarget, x, y - 1.0f * dy) * 8.0f;
610 | gy += tex2D(texTarget, x, y + 1.0f * dy) * 8.0f;
611 | gy -= tex2D(texTarget, x, y + 2.0f * dy);
612 | gy /= 12.0f;
613 | }
614 | break;
615 |
616 | case 1: // image size is not multiple of 32; use shared memory
617 | {
618 | const int tx = threadIdx.x + 2;
619 | const int ty = threadIdx.y + 2;
620 |
621 | if (1)
622 | {
623 | // save to shared memory
624 | __shared__ real smem[BLOCK_X+4][BLOCK_Y+4];
625 | smem[tx][ty] = I[pos];
626 | set_bd_shared_memory<5>(smem, I, tx, ty, ix, iy, M.width, M.height);
627 | __syncthreads();
628 |
629 | // x derivative
630 | gx = smem[tx-2][ty];
631 | gx -= smem[tx-1][ty] * 8.0f;
632 | gx += smem[tx+1][ty] * 8.0f;
633 | gx -= smem[tx+2][ty];
634 | gx /= 12.0f;
635 |
636 | // t derivative
637 | gt = smem[tx][ty] - I0[pos];
638 |
639 | // y derivative
640 | gy = smem[tx][ty-2];
641 | gy -= smem[tx][ty-1] * 8.0f;
642 | gy += smem[tx][ty+1] * 8.0f;
643 | gy -= smem[tx][ty+2];
644 | gy /= 12.0f;
645 | }
646 | else
647 | {
648 | // x derivative
649 | gx = I[max(0,ix-2) + iy * M.width];
650 | gx -= I[max(0,ix-1) + iy * M.width] * 8.0f;
651 | gx += I[min(M.width-1,ix+1) + iy * M.width] * 8.0f;
652 | gx -= I[min(M.width-1,ix+2) + iy * M.width];
653 | gx /= 12.0f;
654 |
655 | // t derivative
656 | gt = I[pos] - I0[pos];
657 |
658 | // y derivative
659 | gy = I[ix + max(0,iy-2) * M.width];
660 | gy -= I[ix + max(0,iy-1) * M.width] * 8.0f;
661 | gy += I[ix + min(M.height-1,iy+1) * M.width] * 8.0f;
662 | gy -= I[ix + min(M.height-1,iy+2) * M.width];
663 | gy /= 12.0f;
664 | }
665 | }
666 | break;
667 | }
668 |
669 | // pre-caching
670 | real gxx = gx*gx;
671 | real gxy = gx*gy;
672 | real gyy = gy*gy;
673 | real denom = 1.0f / (mu * (gxx + gyy + mu));
674 |
675 | // L1 + L2
676 | A11[pos] = denom * (gyy + mu); // A11
677 | A12[pos] = -gxy * denom; // A12
678 | A22[pos] = denom * (gxx + mu); // A22
679 | a[pos] = gx;
680 | b[pos] = gy;
681 | c[pos] = gt;
682 | }
683 | void ComputeDerivativesL1(const real *__restrict__ I0, const real *__restrict__ I1,
684 | cv::Size M, int s, real mu,
685 | real *A11, real *A12, real *A22, real *a, real *b, real *c)
686 | {
687 | int mods = (((M.width % 32) == 0) && ((M.height % 32) == 0)) ? 0 : 1;
688 | switch (mods)
689 | {
690 | case 0: // image size is multiple of 32; use texture memory
691 | {
692 | dim3 threads(32, 32);
693 | dim3 blocks(iDivUp(M.width, threads.x), iDivUp(M.height, threads.y));
694 |
695 | // replicate if a coordinate value is out-of-range
696 | texTarget.addressMode[0] = cudaAddressModeClamp;
697 | texTarget.addressMode[1] = cudaAddressModeClamp;
698 | texTarget.filterMode = cudaFilterModeLinear;
699 | texTarget.normalized = true;
700 | cudaChannelFormatDesc desc = cudaCreateChannelDesc();
701 | cudaBindTexture2D(0, texTarget, I1, M.width, M.height, s*sizeof(real));
702 |
703 | ComputeDerivativesL1Kernel<0><<>>(I0, M, mu, A11, A12, A22, a, b, c);
704 | }
705 | break;
706 |
707 | case 1: // image size is not multiple of 32; use shared memory
708 | {
709 | dim3 threads(BLOCK_X, BLOCK_Y);
710 | dim3 blocks(iDivUp(M.width, threads.x), iDivUp(M.height, threads.y));
711 | ComputeDerivativesL1Kernel<1><<>>(I0, M, mu, A11, A12, A22, a, b, c, I1);
712 | }
713 | break;
714 | }
715 | }
716 | void phi_update(real *phi, real *I_warp, real *I0,
717 | real alpha, real beta,
718 | real *zeta_x, real *zeta_y, real *grad_x_N, real *grad_y_N,
719 | real *A11, real *A12, real *A22, real *gx, real *gy, real *gt,
720 | complex *tmp_dct_N, real *mat_phi_hat,
721 | complex *ww_1_N, complex *ww_2_N, cufftHandle plan_dct_1_N, cufftHandle plan_dct_2_N,
722 | opt_phi opt_phi_t, cv::Size M, cv::Size N)
723 | {
724 | dim3 threads(32, 32);
725 | dim3 blocks(iDivUp(N.width, threads.x), iDivUp(N.height, threads.y));
726 |
727 | // compute derivatives and constants
728 | ComputeDerivativesL1(I0, I_warp, M, M.width, opt_phi_t.mu,
729 | A11, A12, A22, gx, gy, gt);
730 |
731 | // initial objective
732 | real objval;
733 | if (opt_phi_t.isverbose)
734 | {
735 | printf("-- phi update\n");
736 | objval = obj_phi(phi, gx, gy, gt, alpha, beta, M, N, grad_x_N, grad_y_N, zeta_x);
737 | printf("---- init, obj = %.4e \n", objval);
738 | }
739 |
740 | // initialization
741 | checkCudaErrors(cudaMemset(zeta_x, 0, N.area()*sizeof(real)));
742 | checkCudaErrors(cudaMemset(zeta_y, 0, N.area()*sizeof(real)));
743 | checkCudaErrors(cudaMemset(grad_x_N, 0, N.area()*sizeof(real)));
744 | checkCudaErrors(cudaMemset(grad_y_N, 0, N.area()*sizeof(real)));
745 |
746 | // ADMM loop
747 | for (signed int i = 1; i <= opt_phi_t.iter; ++i)
748 | {
749 | // compute nablaT(w - zeta)
750 | nablaT(grad_x_N, grad_y_N, N, phi);
751 |
752 | // phi-update inversion
753 | x_update(phi, tmp_dct_N, mat_phi_hat, ww_1_N, ww_2_N, N.width, N.height, plan_dct_1_N, plan_dct_2_N);
754 |
755 | // w-update & zeta-update
756 | nabla(phi, N, grad_x_N, grad_y_N);
757 | prox_gL1Kernel<<>>(zeta_x, zeta_y, grad_x_N, grad_y_N, opt_phi_t.mu, alpha,
758 | A11, A12, A22, gx, gy, gt, N.width, N.height, M.width, M.height);
759 | }
760 |
761 | // show final objective
762 | if (opt_phi_t.isverbose)
763 | {
764 | objval = obj_phi(phi, gx, gy, gt, alpha, beta, M, N, grad_x_N, grad_y_N, zeta_x);
765 | printf("---- iter = %d, obj = %.4e \n", opt_phi_t.iter, objval);
766 | }
767 |
768 | // update I_warp as: I_warp += [gx; gy] * \nabla\phi
769 | update_I_warp(I_warp, phi, gx, gy, grad_x_N, grad_y_N, M, N);
770 | }
771 | // =============================================================
772 |
773 |
774 | CWS_A_Phi::CWS_A_Phi(cv::Mat I0, cv::Mat &h_A, cv::Mat &h_phi, opt_algo opt_algo_out)
775 | {
776 | opt_algo_t = opt_algo_out;
777 |
778 | // is verbose
779 | opt_A_t.isverbose = opt_algo_t.isverbose_sub;
780 | opt_phi_t.isverbose = opt_algo_t.isverbose_sub;
781 |
782 | // ADMM parameters
783 | opt_A_t.mu = opt_algo_t.mu_A;
784 | opt_phi_t.mu = opt_algo_t.mu_phi;
785 |
786 | // total number of alternation iterations
787 | opt_A_t.iter = opt_algo_t.A_iter;
788 | opt_phi_t.iter = opt_algo_t.phi_iter;
789 |
790 | // determine sizes
791 | M = I0.size();
792 |
793 | // validate L
794 | if (opt_algo_t.L.width < 2 || opt_algo_t.L.height < 2)
795 | {
796 | printf("L unspecified or wrong; will use nearest power of two of M; L \in [2, 32].\n");
797 | opt_algo_t.L.width = min(max(2, (nearest_power_of_two(M.width) - M.width) /2), 32);
798 | opt_algo_t.L.height = min(max(2, (nearest_power_of_two(M.height) - M.height)/2), 32);
799 | }
800 |
801 | // set L
802 | opt_phi_t.L.width = opt_algo_t.L.width;
803 | opt_phi_t.L.height = opt_algo_t.L.height;
804 |
805 | // set N
806 | N = M + opt_algo_t.L + opt_algo_t.L;
807 | printf("M = [%d, %d], N = [%d, %d] \n", M.width, M.height, N.width, N.height);
808 |
809 | // allocate host cv::Mat containers
810 | h_A = cv::Mat::zeros(M, CV_32F);
811 | h_phi = cv::Mat::zeros(N, CV_32F);
812 |
813 | // allocate device pointers
814 | checkCudaErrors(cudaMalloc(&d_I0, M.area()*sizeof(real)));
815 | checkCudaErrors(cudaMalloc(&d_I, M.area()*sizeof(real)));
816 | checkCudaErrors(cudaMalloc(&d_I0_tmp, M.area()*sizeof(real)));
817 |
818 | // copy from host to device
819 | checkCudaErrors(cudaMemcpy(d_I0, I0.ptr(0), M.area()*sizeof(real), cudaMemcpyHostToDevice));
820 | checkCudaErrors(cudaMemcpy(d_I0_tmp, d_I0, M.area()*sizeof(real), cudaMemcpyDeviceToDevice));
821 |
822 | // compute A-update parameters
823 | opt_A_t.gamma_new = opt_algo_t.gamma / mean<2>(d_I0, M.area());
824 | opt_A_t.tau_new = opt_algo_t.tau / mean<2>(d_I0, M.area());
825 |
826 | // set shrinkage value & cache to constant memory
827 | real shrinkage_value_array [2] = { opt_A_t.gamma_new/(2.0f*opt_algo_t.mu_A), opt_algo_t.alpha/(2.0f*opt_algo_t.mu_A) };
828 | cudaMemcpyToSymbol(shrinkage_value, shrinkage_value_array, 2*sizeof(real));
829 |
830 | // cache L to constant memory
831 | int L_size_array [2] = { opt_algo_t.L.width, opt_algo_t.L.height };
832 | cudaMemcpyToSymbol(L_SIZE, L_size_array, 2*sizeof(int));
833 |
834 | // allocate variable arrays
835 | // -- main variables
836 | checkCudaErrors(cudaMalloc(&A, M.area()*sizeof(real)));
837 | checkCudaErrors(cudaMalloc(&phi, N.area()*sizeof(real)));
838 | checkCudaErrors(cudaMalloc(&Delta_phi, N.area()*sizeof(real)));
839 | checkCudaErrors(cudaMalloc(&I_warp, M.area()*sizeof(real)));
840 |
841 | // -- update variables
842 | checkCudaErrors(cudaMalloc(&zeta_x_M, M.area()*sizeof(real)));
843 | checkCudaErrors(cudaMalloc(&zeta_y_M, M.area()*sizeof(real)));
844 | checkCudaErrors(cudaMalloc(&zeta_z_M, M.area()*sizeof(real)));
845 | checkCudaErrors(cudaMalloc(&zeta_x_N, N.area()*sizeof(real)));
846 | checkCudaErrors(cudaMalloc(&zeta_y_N, N.area()*sizeof(real)));
847 | checkCudaErrors(cudaMalloc(&zeta_z_N, N.area()*sizeof(real)));
848 |
849 | // -- temp arrays
850 | checkCudaErrors(cudaMalloc(&grad_x_M, M.area()*sizeof(real)));
851 | checkCudaErrors(cudaMalloc(&grad_y_M, M.area()*sizeof(real)));
852 | checkCudaErrors(cudaMalloc(&lap_M, M.area()*sizeof(real)));
853 | checkCudaErrors(cudaMalloc(&grad_x_N, N.area()*sizeof(real)));
854 | checkCudaErrors(cudaMalloc(&grad_y_N, N.area()*sizeof(real)));
855 | checkCudaErrors(cudaMalloc(&lap_N, N.area()*sizeof(real)));
856 | checkCudaErrors(cudaMalloc(&tmp_dct_M,M.area()*sizeof(complex)));
857 | checkCudaErrors(cudaMalloc(&tmp_dct_N,N.area()*sizeof(complex)));
858 |
859 | // -- temp variables for constant coefficients caching
860 | checkCudaErrors(cudaMalloc(&A11, M.area()*sizeof(real)));
861 | checkCudaErrors(cudaMalloc(&A12, M.area()*sizeof(real)));
862 | checkCudaErrors(cudaMalloc(&A22, M.area()*sizeof(real)));
863 | checkCudaErrors(cudaMalloc(&a, M.area()*sizeof(real)));
864 | checkCudaErrors(cudaMalloc(&b, M.area()*sizeof(real)));
865 | checkCudaErrors(cudaMalloc(&c, M.area()*sizeof(real)));
866 |
867 | // prepare FFT plans
868 | int pH_M[1] = {M.height}, pW_M[1] = {M.width};
869 | int pH_N[1] = {N.height}, pW_N[1] = {N.width};
870 | cufft_prepare(1, pH_M, pW_M, &plan_dct_1_M, &plan_dct_2_M, NULL, NULL, NULL, NULL, NULL, NULL);
871 | cufft_prepare(1, pH_N, pW_N, &plan_dct_1_N, &plan_dct_2_N, NULL, NULL, NULL, NULL, NULL, NULL);
872 |
873 | // prepare DCT coefficients
874 | checkCudaErrors(cudaMalloc(&ww_1_M, M.height*sizeof(complex)));
875 | checkCudaErrors(cudaMalloc(&ww_2_M, M.width *sizeof(complex)));
876 | checkCudaErrors(cudaMalloc(&ww_1_N, N.height*sizeof(complex)));
877 | checkCudaErrors(cudaMalloc(&ww_2_N, N.width *sizeof(complex)));
878 | computeDCTweights(M.width, M.height, ww_1_M, ww_2_M);
879 | computeDCTweights(N.width, N.height, ww_1_N, ww_2_N);
880 |
881 | // prepare inversion matrices
882 | checkCudaErrors(cudaMalloc(&mat_A_hat, M.area()*sizeof(real)));
883 | checkCudaErrors(cudaMalloc(&mat_phi_hat, N.area()*sizeof(real)));
884 | compute_inverse_mat(opt_A_t, opt_phi_t, opt_algo_t.beta, M, N, mat_A_hat, mat_phi_hat);
885 | }
886 |
887 | CWS_A_Phi::~CWS_A_Phi()
888 | {
889 | checkCudaErrors(cudaFree(d_I0));
890 | checkCudaErrors(cudaFree(d_I));
891 | checkCudaErrors(cudaFree(d_I0_tmp));
892 |
893 | // result variables -------------
894 | checkCudaErrors(cudaFree(A));
895 | checkCudaErrors(cudaFree(phi));
896 | checkCudaErrors(cudaFree(Delta_phi));
897 | checkCudaErrors(cudaFree(I_warp));
898 |
899 | // update variables -------------
900 | checkCudaErrors(cudaFree(zeta_x_M));
901 | checkCudaErrors(cudaFree(zeta_y_M));
902 | checkCudaErrors(cudaFree(zeta_z_M));
903 | checkCudaErrors(cudaFree(zeta_x_N));
904 | checkCudaErrors(cudaFree(zeta_y_N));
905 | checkCudaErrors(cudaFree(zeta_z_N));
906 |
907 | // temp variables -------------
908 | checkCudaErrors(cudaFree(grad_x_M));
909 | checkCudaErrors(cudaFree(grad_y_M));
910 | checkCudaErrors(cudaFree(lap_M));
911 | checkCudaErrors(cudaFree(grad_x_N));
912 | checkCudaErrors(cudaFree(grad_y_N));
913 | checkCudaErrors(cudaFree(lap_N));
914 | checkCudaErrors(cudaFree(tmp_dct_M));
915 | checkCudaErrors(cudaFree(tmp_dct_N));
916 | checkCudaErrors(cudaFree(A11));
917 | checkCudaErrors(cudaFree(A12));
918 | checkCudaErrors(cudaFree(A22));
919 | checkCudaErrors(cudaFree(a));
920 | checkCudaErrors(cudaFree(b));
921 | checkCudaErrors(cudaFree(c));
922 |
923 | // FFT plans -------------
924 | cufftDestroy(plan_dct_1_M);
925 | cufftDestroy(plan_dct_2_M);
926 | cufftDestroy(plan_dct_1_N);
927 | cufftDestroy(plan_dct_2_N);
928 |
929 | // DCT coefficients -------------
930 | checkCudaErrors(cudaFree(ww_1_M));
931 | checkCudaErrors(cudaFree(ww_2_M));
932 | checkCudaErrors(cudaFree(ww_1_N));
933 | checkCudaErrors(cudaFree(ww_2_N));
934 |
935 | // inversion matrices -------------
936 | checkCudaErrors(cudaFree(mat_A_hat));
937 | checkCudaErrors(cudaFree(mat_phi_hat));
938 | }
939 |
940 | void CWS_A_Phi::setParas(opt_algo opt_algo_out)
941 | {
942 | opt_algo_t = opt_algo_out;
943 | }
944 |
945 | // =============================================================
946 | // main algorithm
947 | // =============================================================
948 | void CWS_A_Phi::solver(cv::Mat I)
949 | {
950 | // CWS_A_PHI Simutanous intensity and wavefront recovery. Solve for:
951 | // min || i(x+\nabla phi) - A i_0(x) ||_2^2 + ...
952 | //A,phi alpha || \nabla phi ||_1 + ...
953 | // beta ( || \nabla phi ||_2^2 + || \nabla^2 phi ||_2^2 ) + ...
954 | // gamma ( || \nabla A ||_1 + || \nabla^2 A ||_1 ) + ...
955 | // tau ( || \nabla A ||_2^2 + || \nabla^2 A ||_2^2 )
956 |
957 | // copy from host to device
958 | checkCudaErrors(cudaMemcpy(d_I, I.ptr(0), M.area()*sizeof(real), cudaMemcpyHostToDevice));
959 |
960 | // record variables
961 | real objval, mean_Delta_phi;
962 |
963 | // initialization
964 | setasones(A, M);
965 | checkCudaErrors(cudaMemset(phi, 0, N.area()*sizeof(real)));
966 | checkCudaErrors(cudaMemset(Delta_phi, 0, N.area()*sizeof(real)));
967 | checkCudaErrors(cudaMemcpy(I_warp, d_I, M.area()*sizeof(real), cudaMemcpyDeviceToDevice));
968 |
969 | // initial objective
970 | if (opt_algo_t.isverbose)
971 | {
972 | objval = obj_total(A, phi, I_warp, d_I0, opt_algo_t.alpha, opt_algo_t.beta, opt_algo_t.gamma, opt_algo_t.tau, M, N,
973 | grad_x_M, grad_y_M, lap_M, grad_x_N, grad_y_N, lap_N);
974 | printf("iter = %d, obj = %.4e\n", 0, objval);
975 | }
976 |
977 | // the loop
978 | real obj_min = 0x7f800000; // = Inf in float
979 | for (signed int outer_loop = 1; outer_loop <= opt_algo_t.iter; ++outer_loop)
980 | {
981 | // === A-update ===
982 | A_update(A, I_warp, d_I0,
983 | zeta_x_M, zeta_y_M, zeta_z_M,
984 | grad_x_M, grad_y_M, lap_M,
985 | tmp_dct_M, mat_A_hat, ww_1_M, ww_2_M, plan_dct_1_M, plan_dct_2_M, opt_A_t, M);
986 |
987 | // median filtering A
988 | medfilt2(A, M.width, M.height, A);
989 |
990 | // update d_I0_tmp
991 | I_warp_update(A, d_I0, M, d_I0_tmp);
992 |
993 | // === phi-update ===
994 | phi_update(Delta_phi, I_warp, d_I0_tmp, opt_algo_t.alpha, opt_algo_t.beta, zeta_x_N, zeta_y_N, grad_x_N, grad_y_N,
995 | A11, A12, A22, a, b, c, tmp_dct_N, mat_phi_hat,
996 | ww_1_N, ww_2_N, plan_dct_1_N, plan_dct_2_N, opt_phi_t, M, N);
997 |
998 | // medfilt2(Delta_phi, N.width, N.height, Delta_phi);
999 |
1000 | // === accmulate phi ===
1001 | mean_Delta_phi = mean<1>(Delta_phi, N.area());
1002 | if (mean_Delta_phi < opt_algo_t.phi_tol)
1003 | {
1004 | printf("-- mean(|Delta phi|) = %.4e < %.4e = tol; Quit. \n", mean_Delta_phi, opt_algo_t.phi_tol);
1005 | break;
1006 | }
1007 | else
1008 | {
1009 | printf("-- mean(|Delta phi|) = %.4e \n", mean_Delta_phi);
1010 | Add(phi, Delta_phi, N.area(), phi);
1011 | }
1012 |
1013 | // === records ===
1014 | if (opt_algo_t.isverbose)
1015 | {
1016 | objval = obj_total(A, phi, I_warp, d_I0, opt_algo_t.alpha, opt_algo_t.beta, opt_algo_t.gamma, opt_algo_t.tau, M, N,
1017 | grad_x_M, grad_y_M, lap_M, grad_x_N, grad_y_N, lap_N);
1018 | printf("iter = %d, obj = %.4e\n", outer_loop, objval);
1019 | if (objval > obj_min)
1020 | {
1021 | printf("Obj increasing; Quit. \n");
1022 | break;
1023 | }
1024 | else
1025 | obj_min = objval;
1026 | }
1027 | }
1028 |
1029 | // median filtering phi
1030 | medfilt2(phi, N.width, N.height, phi);
1031 | }
1032 |
1033 |
1034 | // download results from GPU to host, and crop phi
1035 | void CWS_A_Phi::download(cv::Mat &h_A, cv::Mat &h_phi)
1036 | {
1037 | // dummy variable
1038 | cv::Mat phi_tmp = cv::Mat::zeros(N, CV_32F);
1039 |
1040 | // transfer result from device to host
1041 | checkCudaErrors(cudaMemcpy(h_A.ptr(0), A, M.area()*sizeof(real), cudaMemcpyDeviceToHost));
1042 | checkCudaErrors(cudaMemcpy(phi_tmp.ptr(0), phi, N.area()*sizeof(real), cudaMemcpyDeviceToHost));
1043 |
1044 | // crop phi to preserve only the center part
1045 | h_phi = phi_tmp(cv::Rect(opt_algo_t.L.width, opt_algo_t.L.height, M.width, M.height)).clone();
1046 | }
1047 |
1048 | // wrapper function for cws
1049 | void cws_A_phi(cv::Mat I0, cv::Mat I, cv::Mat &h_A, cv::Mat &h_phi, opt_algo para_algo)
1050 | {
1051 | // define cws object
1052 | CWS_A_Phi cws_obj(I0, h_A, h_phi, para_algo);
1053 |
1054 | // record variables
1055 | real fps;
1056 | GpuTimer timer;
1057 |
1058 | // timing started
1059 | timer.Start();
1060 |
1061 | // the loop
1062 | int rep_times = 1;
1063 | for (signed int i = 0; i < rep_times; ++i)
1064 | {
1065 | // run cws solver
1066 | cws_obj.solver(I);
1067 | cws_obj.download(h_A, h_phi);
1068 |
1069 | // timing ended
1070 | timer.Stop();
1071 | fps = 1000/timer.Elapsed()*rep_times;
1072 | printf("Mean elapsed time: %.4f ms. Mean frame rate: %.4f fps. \n", timer.Elapsed()/rep_times, fps);
1073 | }
1074 | }
1075 |
--------------------------------------------------------------------------------
/cuda/src/cws_A_phi.h:
--------------------------------------------------------------------------------
1 | #ifndef CWS_A_PHI_H
2 | #define CWS_A_PHI_H
3 |
4 | #include
5 |
6 | // accuracy
7 | #ifdef USE_DOUBLES
8 | typedef double real;
9 | typedef cufftDoubleComplex complex;
10 | #else
11 | typedef float real;
12 | typedef cufftComplex complex;
13 | #endif
14 |
15 | // default parameters
16 | struct opt_algo {
17 | bool isverbose_sub = false;
18 | bool isverbose = false;
19 | int iter = 3;
20 | int A_iter = 20;
21 | int phi_iter = 20;
22 | float alpha = 0.1f;
23 | float beta = 0.1f;
24 | float gamma = 100.0f;
25 | float tau = 5.0f;
26 | float phi_tol = 0.05f;
27 | float mu_A = 0.1f;
28 | float mu_phi = 100.0f;
29 | cv::Size L = cv::Size(1,1);
30 | // (1,1) for initializers; if unspecified, will be updated in CWS_A_Phi::CWS_A_Phi as L = cv::Size(2,2)
31 | };
32 |
33 |
34 | // =============================================================
35 | // algorithm parameter structures
36 | // =============================================================
37 | struct opt_A {
38 | bool isverbose = false;
39 | int iter = 10;
40 | real mu = 0.1;
41 | real gamma_new;
42 | real tau_new;
43 | };
44 |
45 | struct opt_phi {
46 | bool isverbose = false;
47 | int iter = 10;
48 | real mu = 100.0;
49 | cv::Size L = cv::Size(2,2);
50 | };
51 | // =============================================================
52 |
53 |
54 | class CWS_A_Phi
55 | {
56 | // algorithm parameters
57 | opt_A opt_A_t;
58 | opt_phi opt_phi_t;
59 |
60 | // device pointers
61 | real *d_I0, *d_I, *d_I0_tmp;
62 |
63 | // update variables -------------
64 | real *zeta_x_M, *zeta_y_M, *zeta_z_M;
65 | real *zeta_x_N, *zeta_y_N, *zeta_z_N;
66 |
67 | // temp variables -------------
68 | real *grad_x_M, *grad_y_M, *lap_M;
69 | real *grad_x_N, *grad_y_N, *lap_N;
70 | complex *tmp_dct_M, *tmp_dct_N;
71 | real *A11, *A12, *A22, *a, *b, *c;
72 |
73 | // FFT plans -------------
74 | cufftHandle plan_dct_1_M, plan_dct_2_M;
75 | cufftHandle plan_dct_1_N, plan_dct_2_N;
76 |
77 | // DCT coefficients -------------
78 | complex *ww_1_M, *ww_2_M, *ww_1_N, *ww_2_N;
79 |
80 | // inversion matrices -------------
81 | real *mat_A_hat, *mat_phi_hat;
82 |
83 | public:
84 | // dimensions
85 | cv::Size M, N;
86 |
87 | // algorithm parameters
88 | opt_algo opt_algo_t;
89 |
90 | // result variables -------------
91 | real *A, *phi, *Delta_phi, *I_warp;
92 |
93 | CWS_A_Phi(cv::Mat I0, cv::Mat &h_A, cv::Mat &h_phi, opt_algo opt_algo_out);
94 | ~CWS_A_Phi();
95 | void setParas(opt_algo opt_algo_out);
96 | void solver(cv::Mat I);
97 | void download(cv::Mat &h_A, cv::Mat &h_phi);
98 | };
99 |
100 | void cws_A_phi(cv::Mat I0, cv::Mat I, cv::Mat &A, cv::Mat &phi, opt_algo para_algo);
101 |
102 | #endif
103 |
--------------------------------------------------------------------------------
/cuda/src/medianfilteringKernel.cuh:
--------------------------------------------------------------------------------
1 | #include "common.h"
2 |
3 | // replicate boundary condition (for stencil size of 3x3 or 5x5)
4 | template
5 | __device__
6 | void set_bd_shared_memory(volatile float (*smem)[BLOCK_Y+stencil_size-1], const float *d_in,
7 | const int tx, const int ty, const int ix, const int iy, int w, int h)
8 | {
9 | const int py = iy*w;
10 | const int x_min = max(ix-1, 0);
11 | const int x_max = min(ix+1, w-1);
12 | const int y_min = max(iy-1, 0) *w;
13 | const int y_max = min(iy+1, h-1)*w;
14 | switch (stencil_size)
15 | {
16 | case 3:
17 | {
18 | if (tx == 1)
19 | {
20 | smem[0][ty] = d_in[x_min + py];
21 | if (ty == 1)
22 | {
23 | smem[0][0] = d_in[x_min + y_min];
24 | smem[0][BLOCK_Y+1] = d_in[x_min + y_max];
25 | }
26 | }
27 | if (tx == BLOCK_X)
28 | {
29 | smem[BLOCK_X+1][ty] = d_in[x_max + py];
30 | if (ty == 1)
31 | {
32 | smem[BLOCK_X+1][0] = d_in[x_max + y_min];
33 | smem[BLOCK_X+1][BLOCK_Y+1] = d_in[x_max + y_max];
34 | }
35 | }
36 | if (ty == 1)
37 | {
38 | smem[tx][0] = d_in[ix + y_min];
39 | }
40 | if (ty == BLOCK_Y)
41 | {
42 | smem[tx][BLOCK_Y+1] = d_in[ix + y_max];
43 | }
44 | }
45 | break;
46 |
47 | case 5:
48 | {
49 | // boundary guards
50 | const int x_min2 = max(ix-2, 0);
51 | const int x_max2 = min(ix+2, w-1);
52 | const int y_min2 = max(iy-2, 0) *w;
53 | const int y_max2 = min(iy+2, h-1)*w;
54 |
55 | if (tx == 2)
56 | {
57 | smem[0][ty] = d_in[x_min2 + py];
58 | smem[1][ty] = d_in[x_min + py];
59 | if (ty == 2)
60 | {
61 | smem[0][0] = d_in[x_min2 + y_min2];
62 | smem[1][0] = d_in[x_min + y_min2];
63 | smem[0][1] = d_in[x_min2 + y_min ];
64 | smem[1][1] = d_in[x_min + y_min ];
65 | }
66 | }
67 | if (tx == BLOCK_X+1)
68 | {
69 | smem[BLOCK_X+2][ty] = d_in[x_max + py];
70 | smem[BLOCK_X+3][ty] = d_in[x_max2 + py];
71 | if (ty == BLOCK_Y+1)
72 | {
73 | smem[BLOCK_X+2][BLOCK_Y+2] = d_in[x_max + y_max ];
74 | smem[BLOCK_X+3][BLOCK_Y+2] = d_in[x_max2 + y_max ];
75 | smem[BLOCK_X+2][BLOCK_Y+3] = d_in[x_max + y_max2];
76 | smem[BLOCK_X+3][BLOCK_Y+3] = d_in[x_max2 + y_max2];
77 | }
78 | }
79 | if (ty == 2)
80 | {
81 | smem[tx][0] = d_in[ix + y_min2];
82 | smem[tx][1] = d_in[ix + y_min];
83 | if (tx == BLOCK_X+1)
84 | {
85 | smem[BLOCK_X+2][0] = d_in[x_max + y_min2];
86 | smem[BLOCK_X+3][0] = d_in[x_max2 + y_min2];
87 | smem[BLOCK_X+2][1] = d_in[x_max + y_min ];
88 | smem[BLOCK_X+3][1] = d_in[x_max2 + y_min ];
89 | }
90 | }
91 | if (ty == BLOCK_Y+1)
92 | {
93 | smem[tx][BLOCK_Y+2] = d_in[ix + y_max];
94 | smem[tx][BLOCK_Y+3] = d_in[ix + y_max2];
95 | if (tx == 2)
96 | {
97 | smem[0][BLOCK_X+2] = d_in[x_min2 + y_max ];
98 | smem[0][BLOCK_X+3] = d_in[x_min2 + y_max2];
99 | smem[1][BLOCK_X+2] = d_in[x_min + y_max ];
100 | smem[1][BLOCK_X+3] = d_in[x_min + y_max2];
101 | }
102 | }
103 | }
104 | break;
105 | }
106 | }
107 |
108 | // Exchange trick: Morgan McGuire, ShaderX 2008
109 | #define s2(a,b) { float tmp = a; a = min(a,b); b = max(tmp,b); }
110 | #define mn3(a,b,c) s2(a,b); s2(a,c);
111 | #define mx3(a,b,c) s2(b,c); s2(a,c);
112 |
113 | #define mnmx3(a,b,c) mx3(a,b,c); s2(a,b); // 3 exchanges
114 | #define mnmx4(a,b,c,d) s2(a,b); s2(c,d); s2(a,c); s2(b,d); // 4 exchanges
115 | #define mnmx5(a,b,c,d,e) s2(a,b); s2(c,d); mn3(a,c,e); mx3(b,d,e); // 6 exchanges
116 | #define mnmx6(a,b,c,d,e,f) s2(a,d); s2(b,e); s2(c,f); mn3(a,b,c); mx3(d,e,f); // 7 exchanges
117 |
118 | __global__ void medfilt2_exch(int width, int height, float *d_out, float *d_in)
119 | {
120 | const int tx = threadIdx.x + 1;
121 | const int ty = threadIdx.y + 1;
122 | const int ix = threadIdx.x + blockIdx.x * blockDim.x;
123 | const int iy = threadIdx.y + blockIdx.y * blockDim.y;
124 | const int pos = ix + iy*width;
125 |
126 | if (ix >= width || iy >= height) return;
127 |
128 | // shared memory for 3x3 stencil
129 | __shared__ float smem[BLOCK_X+2][BLOCK_Y+2];
130 |
131 | // load in shared memory
132 | smem[tx][ty] = d_in[pos];
133 | set_bd_shared_memory<3>(smem, d_in, tx, ty, ix, iy, width, height);
134 | __syncthreads();
135 |
136 | // pull top six from shared memory
137 | float v[6] = { smem[tx-1][ty-1], smem[tx][ty-1], smem[tx+1][ty-1],
138 | smem[tx-1][ty ], smem[tx][ty ], smem[tx+1][ty ]};
139 |
140 | // with each pass, remove min and max values and add new value
141 | mnmx6(v[0], v[1], v[2], v[3], v[4], v[5]);
142 | v[5] = smem[tx-1][ty+1]; // add new contestant
143 | mnmx5(v[1], v[2], v[3], v[4], v[5]);
144 | v[5] = smem[tx][ty+1];
145 | mnmx4(v[2], v[3], v[4], v[5]);
146 | v[5] = smem[tx+1][ty+1];
147 | mnmx3(v[3], v[4], v[5]);
148 |
149 | // pick the middle one
150 | d_out[pos] = v[4];
151 | }
152 |
153 |
154 | ///////////////////////////////////////////////////////////////////////////////
155 | /// \brief 3-by-3 median filtering of an image
156 | /// \param[in] img_in input image
157 | /// \param[in] w width of input image
158 | /// \param[in] h height of input image
159 | /// \param[out] img_out output image
160 | ///////////////////////////////////////////////////////////////////////////////
161 | static
162 | void medfilt2(float *img_in, int w, int h, float *img_out)
163 | {
164 | dim3 threads(BLOCK_X, BLOCK_Y);
165 | dim3 blocks(iDivUp(w, threads.x), iDivUp(h, threads.y));
166 |
167 | medfilt2_exch<<>>(w, h, img_out, img_in);
168 |
169 | }
170 |
--------------------------------------------------------------------------------
/cuda/src/prepare_cufft_warmup.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include "assert.h"
7 |
8 | typedef cufftComplex complex;
9 |
10 | void cufft_warper(complex *h_in, int n, int m, cufftHandle plan, complex *h_out)
11 | {
12 | const int data_size = n*m*sizeof(complex);
13 |
14 | // device memory allocation
15 | complex *d_temp;
16 | checkCudaErrors(cudaMalloc(&d_temp, data_size));
17 |
18 | // transfer data from host to device
19 | checkCudaErrors(cudaMemcpy(d_temp, h_in, data_size, cudaMemcpyHostToDevice));
20 |
21 | // Compute the FFT
22 | cufftExecC2C(plan, d_temp, d_temp, CUFFT_FORWARD);
23 |
24 | // transfer result from device to host
25 | checkCudaErrors(cudaMemcpy(h_out, d_temp, data_size, cudaMemcpyDeviceToHost));
26 |
27 | // cleanup
28 | checkCudaErrors(cudaFree(d_temp));
29 | }
30 |
31 | void cufft_prepare(int nLevels, int *pH_N, int *pW_N,
32 | cufftHandle *plan_dct_1, cufftHandle *plan_dct_2,
33 | cufftHandle *plan_dct_3, cufftHandle *plan_dct_4,
34 | cufftHandle *plan_dct_5, cufftHandle *plan_dct_6,
35 | cufftHandle *plan_dct_7, cufftHandle *plan_dct_8)
36 | {
37 | // prepare cufft plans & warmup
38 | printf("Preparing CuFFT plans and warmups ... ");
39 |
40 | int Length1[1], Length2[1];
41 | if (nLevels >= 1)
42 | {
43 | Length1[0] = pH_N[0]; // for each FFT, the Length1 is N_height
44 | Length2[0] = pW_N[0]; // for each FFT, the Length2 is N_width
45 | cufftPlanMany(plan_dct_1, 1, Length1,
46 | Length1, pW_N[0], 1,
47 | Length1, pW_N[0], 1,
48 | CUFFT_C2C, pW_N[0]);
49 | cufftPlanMany(plan_dct_2, 1, Length2,
50 | Length2, pH_N[0], 1,
51 | Length2, pH_N[0], 1,
52 | CUFFT_C2C, pH_N[0]);
53 | }
54 | else
55 | {
56 | printf("No CuFFT plans prepared; out ... \n");
57 | }
58 |
59 | if (nLevels >= 2)
60 | {
61 | Length1[0] = pH_N[1]; // for each FFT, the Length1 is N_height
62 | Length2[0] = pW_N[1]; // for each FFT, the Length2 is N_width
63 | cufftPlanMany(plan_dct_3, 1, Length1,
64 | Length1, pW_N[1], 1,
65 | Length1, pW_N[1], 1,
66 | CUFFT_C2C, pW_N[1]);
67 | cufftPlanMany(plan_dct_4, 1, Length2,
68 | Length2, pH_N[1], 1,
69 | Length2, pH_N[1], 1,
70 | CUFFT_C2C, pH_N[1]);
71 | }
72 |
73 | if (nLevels >= 3)
74 | {
75 | Length1[0] = pH_N[2]; // for each FFT, the Length1 is N_height
76 | Length2[0] = pW_N[2]; // for each FFT, the Length2 is N_width
77 | cufftPlanMany(plan_dct_5, 1, Length1,
78 | Length1, pW_N[2], 1,
79 | Length1, pW_N[2], 1,
80 | CUFFT_C2C, pW_N[2]);
81 | cufftPlanMany(plan_dct_6, 1, Length2,
82 | Length2, pH_N[2], 1,
83 | Length2, pH_N[2], 1,
84 | CUFFT_C2C, pH_N[2]);
85 | }
86 |
87 | if (nLevels >= 4)
88 | {
89 | Length1[0] = pH_N[3]; // for each FFT, the Length1 is N_height
90 | Length2[0] = pW_N[3]; // for each FFT, the Length2 is N_width
91 | cufftPlanMany(plan_dct_7, 1, Length2,
92 | Length1, pW_N[3], 1,
93 | Length1, pW_N[3], 1,
94 | CUFFT_C2C, pW_N[3]);
95 | cufftPlanMany(plan_dct_8, 1, Length2,
96 | Length2, pH_N[3], 1,
97 | Length2, pH_N[3], 1,
98 | CUFFT_C2C, pH_N[3]);
99 | }
100 |
101 | // cufft warmup
102 | int N_width = pW_N[0];
103 | int N_height = pH_N[0];
104 | complex *h_warmup_in = new complex[N_width * N_height];
105 | complex *h_warmup_out = new complex[N_width * N_height];
106 | cufft_warper(h_warmup_in, N_width, N_height, *plan_dct_1, h_warmup_out);
107 | cufft_warper(h_warmup_in, N_width, N_height, *plan_dct_2, h_warmup_out);
108 | delete[] h_warmup_in;
109 | delete[] h_warmup_out;
110 | printf("Done.\n");
111 | }
--------------------------------------------------------------------------------
/cuda/src/prepare_cufft_warmup.h:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 |
6 | typedef cufftComplex complex;
7 |
8 | #ifndef PREPARE_CUFFT_WARMUP_H
9 | #define PREPARE_CUFFT_WARMUP_H
10 |
11 | void cufft_warper(complex *h_in, int n, int m, cufftHandle plan, complex *h_out);
12 |
13 | void cufft_prepare(int nLevels, int *pH_N, int *pW_N,
14 | cufftHandle *plan_dct_1, cufftHandle *plan_dct_2,
15 | cufftHandle *plan_dct_3, cufftHandle *plan_dct_4,
16 | cufftHandle *plan_dct_5, cufftHandle *plan_dct_6,
17 | cufftHandle *plan_dct_7, cufftHandle *plan_dct_8);
18 |
19 | #endif
20 |
--------------------------------------------------------------------------------
/cuda/src/prepare_precomputations.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include "common.h"
3 |
4 | #include
5 |
6 | // include kernels
7 | #include "computemat_x_hatKernel.cuh"
8 | #include "computeDCTweightsKernel.cuh"
9 |
10 | // for thrust functions
11 | #include
12 | #include
13 | #include
14 | #include
15 | #include
16 |
17 | ///////////////////////////////////////////////////////////////////////////////
18 | /// texture references
19 | ///////////////////////////////////////////////////////////////////////////////
20 |
21 | /// image to downscale
22 | texture texFine;
23 |
24 |
25 | static
26 | __global__
27 | void transposeKernel(float *in, int width, int height, int stride1, int stride2, float *out)
28 | {
29 | const int ix = threadIdx.x + blockIdx.x * blockDim.x;
30 | const int iy = threadIdx.y + blockIdx.y * blockDim.y;
31 |
32 | if (ix >= width || iy >= height) return;
33 |
34 | out[iy + ix * stride2] = in[ix + iy * stride1];
35 | }
36 |
37 |
38 | static
39 | __device__ void tridisolve(float *x, float *b, float *mu, int m)
40 | {
41 | // forward
42 | for (int i = 0; i < m-1; i++)
43 | x[i+1] -= mu[i] * x[i];
44 |
45 | // backward
46 | x[m-1] /= b[m-1];
47 | for (int i = m-2; i >= 0; i--)
48 | x[i] = (x[i] - x[i+1]) / b[i];
49 |
50 | // maginify by 6
51 | for (int i = 0; i < m; i++)
52 | x[i] = 6.0f * x[i];
53 | }
54 |
55 |
56 | static
57 | __global__
58 | void tridisolve_parallel(float *img, float *b_h, float *mu_h,
59 | int width, int height, int stride2)
60 | {
61 | const int iy = threadIdx.x + blockIdx.x * blockDim.x;
62 |
63 | if (iy >= width) return;
64 |
65 | tridisolve(img + iy * stride2, b_h, mu_h, height);
66 | }
67 |
68 |
69 | void cbanal(float *in, float *out, int width, int height, int stride1, int stride2)
70 | {
71 | float *b = new float [height];
72 | float *mu = new float [height-1];
73 |
74 | // set b
75 | for (int j = 0; j < height; j++)
76 | b[j] = 4.0f;
77 |
78 | // pre-computations
79 | for (int j = 0; j < height-1; j++)
80 | {
81 | mu[j] = 1.0f / b[j];
82 | b[j+1] -= mu[j];
83 | }
84 |
85 | float *d_b, *d_mu;
86 | checkCudaErrors(cudaMalloc(&d_b, height*sizeof(float)));
87 | checkCudaErrors(cudaMalloc(&d_mu, (height-1)*sizeof(float)));
88 |
89 | checkCudaErrors(cudaMemcpy(d_b, b, height*sizeof(float), cudaMemcpyHostToDevice));
90 | checkCudaErrors(cudaMemcpy(d_mu, mu, (height-1)*sizeof(float), cudaMemcpyHostToDevice));
91 |
92 | dim3 threads_1D(256);
93 | dim3 blocks_1D(iDivUp(width, threads_1D.x));
94 |
95 | dim3 threads_2D(32, 6);
96 | dim3 blocks_2D(iDivUp(width, threads_2D.x), iDivUp(height, threads_2D.y));
97 |
98 | transposeKernel<<>>(in, width, height, stride1, stride2, out);
99 | tridisolve_parallel<<>>(out, d_b, d_mu, width, height, stride2);
100 |
101 | // cleanup
102 | checkCudaErrors(cudaFree(d_b));
103 | checkCudaErrors(cudaFree(d_mu));
104 | delete [] b;
105 | delete [] mu;
106 | }
107 |
108 |
109 |
110 | void cbanal2D(float *img, int width, int height, int stride1, int stride2)
111 | {
112 | float *temp;
113 | checkCudaErrors(cudaMalloc(&temp, height*width*sizeof(float)));
114 |
115 | // compute the cubic B-spline coefficients
116 | cbanal(img, temp, width, height, stride1, stride2);
117 | cbanal(temp, img, height, width, stride2, stride1);
118 |
119 | // cleanup
120 | checkCudaErrors(cudaFree(temp));
121 | }
122 |
123 |
124 |
125 | static
126 | __global__ void anti_weight_x_Kernel(int width, int height, int stride, float *out)
127 | {
128 | const int ix = threadIdx.x + blockIdx.x * blockDim.x;
129 | const int iy = threadIdx.y + blockIdx.y * blockDim.y;
130 |
131 | if (ix >= width || iy >= height)
132 | return;
133 |
134 | float dx = 1.0f/(float)width;
135 | float dy = 1.0f/(float)height;
136 |
137 | float x = ((float)ix + 0.5f) * dx;
138 | float y = ((float)iy + 0.5f) * dy;
139 |
140 | x -= 0.25*dx;
141 | y -= 0.25*dy;
142 |
143 | dx /= 2.0f;
144 | dy /= 2.0f;
145 |
146 | // magic number, in MATLAB: conv2([0.125 0.375 0.375 0.125],[0.125 0.375 0.375 0.125]')
147 | out[ix + iy * stride] =
148 | 0.015625f*tex2D(texFine, x-dx, y-1*dy) + 0.046875f*tex2D(texFine, x, y-1*dy) + 0.046875f*tex2D(texFine, x+dx, y-1*dy) + 0.015625f*tex2D(texFine, x+2*dx, y-1*dy) +
149 | 0.046875f*tex2D(texFine, x-dx, y ) + 0.140625f*tex2D(texFine, x, y ) + 0.140625f*tex2D(texFine, x+dx, y ) + 0.046875f*tex2D(texFine, x+2*dx, y ) +
150 | 0.046875f*tex2D(texFine, x-dx, y+1*dy) + 0.140625f*tex2D(texFine, x, y+1*dy) + 0.140625f*tex2D(texFine, x+dx, y+1*dy) + 0.046875f*tex2D(texFine, x+2*dx, y+1*dy) +
151 | 0.015625f*tex2D(texFine, x-dx, y+2*dy) + 0.046875f*tex2D(texFine, x, y+2*dy) + 0.046875f*tex2D(texFine, x+dx, y+2*dy) + 0.015625f*tex2D(texFine, x+2*dx, y+2*dy);
152 | }
153 |
154 |
155 | extern
156 | void Downscale_Anti(const float *src, int width, int height, int stride,
157 | int newWidth, int newHeight, int newStride, float *out)
158 | {
159 | dim3 threads(32, 8);
160 | dim3 blocks(iDivUp(newWidth, threads.x), iDivUp(newHeight, threads.y));
161 |
162 | // mirror if a coordinate value is out-of-range
163 | texFine.addressMode[0] = cudaAddressModeMirror;
164 | texFine.addressMode[1] = cudaAddressModeMirror;
165 | texFine.filterMode = cudaFilterModeLinear;
166 | texFine.normalized = true;
167 |
168 | cudaChannelFormatDesc desc = cudaCreateChannelDesc();
169 |
170 | checkCudaErrors(cudaBindTexture2D(0, texFine, src, width, height, stride * sizeof(float)));
171 |
172 | anti_weight_x_Kernel<<>>(newWidth, newHeight, newStride, out);
173 | }
174 |
175 |
176 |
177 |
178 | ///////////////////////////////////////////////////////////////////////////////
179 | /// \brief method logic
180 | ///
181 | /// handles memory allocations, control flow
182 | /// \param[in] N_W unknown phase width
183 | /// \param[in] N_H unknown phase height
184 | /// \param[in] alpha degree of displacement field smoothness
185 | /// \param[in] mu proximal parameter
186 | /// \param[out] h_mat_x_hat pre-computed mat_x_hat
187 | /// \param[out] ww_1 ww_1 coefficient
188 | /// \param[out] ww_2 ww_2 coefficient
189 | ///////////////////////////////////////////////////////////////////////////////
190 | void prepare_precomputations(int nLevels, int L_width, int L_height,
191 | int N_width, int N_height,
192 | int M_width, int M_height, int M_stride,
193 | int N_W, int N_H, int currentLevel,
194 | int *pW_N, int *pH_N, int *pS_N,
195 | int *pW_M, int *pH_M, int *pS_M,
196 | int *pW_L, int *pH_L,
197 | float **pI0, float **pI1, float **d_I0_coeff,
198 | float alpha, float *mu, const float **mat_x_hat,
199 | const complex **ww_1, const complex **ww_2)
200 | {
201 | // determine L value for best performance (we use mod 2 here)
202 | for (int i = nLevels - 1; i >= 0; i--) {
203 | pW_L[i] = L_width >> (nLevels - 1 - i);
204 | pH_L[i] = L_height >> (nLevels - 1 - i);
205 | if (pW_L[i] < 2) pW_L[i] = 2;
206 | if (pH_L[i] < 2) pH_L[i] = 2;
207 | }
208 |
209 | pW_N[nLevels - 1] = M_width + 2 * pW_L[nLevels - 1];
210 | pH_N[nLevels - 1] = M_height + 2 * pH_L[nLevels - 1];
211 | pS_N[nLevels - 1] = iAlignUp(pW_N[nLevels - 1]);
212 | pW_M[nLevels - 1] = M_width;
213 | pH_M[nLevels - 1] = M_height;
214 | pS_M[nLevels - 1] = M_stride;
215 |
216 | printf("initial sizes: W_N = %d, H_N = %d, S_N = %d\n", N_width, N_height, iAlignUp(pW_N[nLevels - 1]));
217 | printf("initial sizes: W_M = %d, H_M = %d, S_M = %d\n", M_width, M_height, M_stride);
218 |
219 |
220 | printf("Pre-compute the variables on GPU...\n");
221 |
222 | if (currentLevel != 0){ // prepare pyramid
223 | for (; currentLevel > 0; currentLevel--)
224 | {
225 | int nw = pW_M[currentLevel] / 2;
226 | int nh = pH_M[currentLevel] / 2;
227 | int ns = iAlignUp(nw);
228 |
229 | pW_N[currentLevel - 1] = nw + 2*pW_L[currentLevel-1];
230 | pH_N[currentLevel - 1] = nh + 2*pH_L[currentLevel-1];
231 | pS_N[currentLevel - 1] = iAlignUp(pW_N[currentLevel - 1]);
232 |
233 | // pre-calculate mat_x_hat and store it
234 | checkCudaErrors(cudaMalloc(mat_x_hat + currentLevel,
235 | pW_N[currentLevel] * pH_N[currentLevel] * sizeof(float)));
236 | computemat_x_hat(mu[currentLevel], alpha, pW_N[currentLevel], pH_N[currentLevel],
237 | (float *)mat_x_hat[currentLevel]);
238 |
239 | // pre-compute DCT weights and store them in device
240 | checkCudaErrors(cudaMalloc(ww_1 + currentLevel, pH_N[currentLevel]*sizeof(complex)));
241 | checkCudaErrors(cudaMalloc(ww_2 + currentLevel, pW_N[currentLevel]*sizeof(complex)));
242 | computeDCTweights(pW_N[currentLevel], pH_N[currentLevel],
243 | (complex *)ww_1[currentLevel], (complex *)ww_2[currentLevel]);
244 |
245 | checkCudaErrors(cudaMalloc(pI0 + currentLevel - 1, ns * nh * sizeof(float)));
246 | checkCudaErrors(cudaMalloc(pI1 + currentLevel - 1, ns * nh * sizeof(float)));
247 | checkCudaErrors(cudaMalloc(d_I0_coeff + currentLevel - 1, ns * nh * sizeof(float)));
248 |
249 | // downscale
250 | // Downscale(pI0[currentLevel], pW_M[currentLevel], pH_M[currentLevel], pS_M[currentLevel],
251 | // nw, nh, ns, (float *)pI0[currentLevel - 1]);
252 | // Downscale(d_I0_coeff[currentLevel], pW_M[currentLevel], pH_M[currentLevel], pS_M[currentLevel],
253 | // nw, nh, ns, (float *)d_I0_coeff[currentLevel - 1]);
254 | Downscale_Anti(pI0[currentLevel], pW_M[currentLevel], pH_M[currentLevel], pS_M[currentLevel],
255 | nw, nh, ns, (float *)pI0[currentLevel - 1]);
256 | Downscale_Anti(d_I0_coeff[currentLevel], pW_M[currentLevel], pH_M[currentLevel], pS_M[currentLevel],
257 | nw, nh, ns, (float *)d_I0_coeff[currentLevel - 1]);
258 |
259 | // pre-compute cubic coefficients and store them in device
260 | cbanal2D(d_I0_coeff[currentLevel], pW_M[currentLevel], pH_M[currentLevel],
261 | pS_M[currentLevel], pH_M[currentLevel]);
262 |
263 | pW_M[currentLevel - 1] = nw;
264 | pH_M[currentLevel - 1] = nh;
265 | pS_M[currentLevel - 1] = ns;
266 |
267 | printf("pW_M[%d] = %d, pH_M[%d] = %d, pS_M[%d] = %d\n",
268 | currentLevel, pW_M[currentLevel],
269 | currentLevel, pH_M[currentLevel],
270 | currentLevel, pS_M[currentLevel]);
271 | printf("pW_N[%d] = %d, pH_N[%d] = %d, pS_N[%d] = %d\n",
272 | currentLevel, pW_N[currentLevel],
273 | currentLevel, pH_N[currentLevel],
274 | currentLevel, pS_N[currentLevel]);
275 | printf("pW_L[%d] = %d, pH_L[%d] = %d\n",
276 | currentLevel, pW_L[currentLevel],
277 | currentLevel, pH_L[currentLevel]);
278 | }
279 | }
280 |
281 | // pre-calculate mat_x_hat and store it
282 | checkCudaErrors(cudaMalloc(mat_x_hat + currentLevel,
283 | pW_N[currentLevel] * pH_N[currentLevel] * sizeof(float)));
284 | computemat_x_hat(mu[currentLevel], alpha, pW_N[currentLevel], pH_N[currentLevel], (float *)mat_x_hat[currentLevel]);
285 |
286 | // pre-compute DCT weights and store them in device
287 | checkCudaErrors(cudaMalloc(ww_1 + currentLevel, pH_N[currentLevel]*sizeof(complex)));
288 | checkCudaErrors(cudaMalloc(ww_2 + currentLevel, pW_N[currentLevel]*sizeof(complex)));
289 | computeDCTweights(pW_N[currentLevel], pH_N[currentLevel],
290 | (complex *)ww_1[currentLevel], (complex *)ww_2[currentLevel]);
291 |
292 | // pre-compute cubic coefficients and store them in device
293 | cbanal2D(d_I0_coeff[currentLevel], pW_M[currentLevel], pH_M[currentLevel],
294 | pS_M[currentLevel], pH_M[currentLevel]);
295 |
296 | printf("pW_M[%d] = %d, pH_M[%d] = %d, pS_M[%d] = %d\n",
297 | currentLevel, pW_M[currentLevel],
298 | currentLevel, pH_M[currentLevel],
299 | currentLevel, pS_M[currentLevel]);
300 | printf("pW_N[%d] = %d, pH_N[%d] = %d, pS_N[%d] = %d\n",
301 | currentLevel, pW_N[currentLevel],
302 | currentLevel, pH_N[currentLevel],
303 | currentLevel, pS_N[currentLevel]);
304 | printf("pW_L[%d] = %d, pH_L[%d] = %d\n",
305 | currentLevel, pW_L[currentLevel],
306 | currentLevel, pH_L[currentLevel]);
307 | }
308 |
309 |
310 |
--------------------------------------------------------------------------------
/cuda/src/prepare_precomputations.h:
--------------------------------------------------------------------------------
1 | #ifndef PREPARE_PRECOMPUTATIONS_H
2 | #define PREPARE_PRECOMPUTATIONS_H
3 |
4 | void prepare_precomputations(int nLevels,
5 | int L_width, int L_height,
6 | int N_width, int N_height,
7 | int M_width, int M_height, int M_stride,
8 | int N_W, // unknown phase width
9 | int N_H, // unknown phase height
10 | int currentLevel,
11 | int *pW_N, int *pH_N, int *pS_N,
12 | int *pW_M, int *pH_M, int *pS_M,
13 | int *pW_L, int *pH_L,
14 | float **pI0,
15 | float **pI1,
16 | float **d_I0_coeff,
17 | float alpha, // smoothness coefficient
18 | float *mu, // proximal parameter
19 | const float **mat_x_hat, // pre-computed mat_x_hat
20 | const complex **ww_1, // ww_1 coefficient
21 | const complex **ww_2); // ww_2 coefficient
22 | #endif
23 |
--------------------------------------------------------------------------------
/cuda/src/prox_gKernel.cuh:
--------------------------------------------------------------------------------
1 | #include "common.h"
2 |
3 | ///////////////////////////////////////////////////////////////////////////////
4 | ///////////////////////////////////////////////////////////////////////////////
5 | ///////////////////////////////////////////////////////////////////////////////
6 | ////////////////////////////// L1+L2 Flow ////////////////////////////////
7 | ///////////////////////////////////////////////////////////////////////////////
8 | ///////////////////////////////////////////////////////////////////////////////
9 | ///////////////////////////////////////////////////////////////////////////////
10 |
11 | template
12 | __host__
13 | __device__
14 | float sign(T val)
15 | {
16 | return val > T(0) ? 1.0f : -1.0f; // 0 is seldom
17 | }
18 | // proximal operator for L1 operator
19 | __host__
20 | __device__
21 | float prox_l1(float u, float tau)
22 | {
23 | return sign(u) * max(fabsf(u) - tau, 0.0f);
24 | }
25 |
26 | // LASSO in R2 (simple)
27 | __device__
28 | void lassoR2_simple(float A11, float A12, float A22,
29 | float a, float b, float c,
30 | float ux, float uy, float mu, float alpha,
31 | float& ours_x, float& ours_y)
32 | {
33 | // 1. get (x0, y0)
34 | // -- temp
35 | float tx = mu*ux - a*c;
36 | float ty = mu*uy - b*c;
37 |
38 | // -- l2 minimum
39 | float x0_x = A11*tx + A12*ty;
40 | float x0_y = A12*tx + A22*ty;
41 |
42 | // 2. get the sign of optimal
43 | float sign_x = sign(x0_x);
44 | float sign_y = sign(x0_y);
45 |
46 | // 3. get the optimal
47 | ours_x = x0_x - alpha/2.0f * (A11 * sign_x + A12 * sign_y);
48 | ours_y = x0_y - alpha/2.0f * (A12 * sign_x + A22 * sign_y);
49 |
50 | // 4. check sign and map to the range
51 | ours_x = sign(ours_x) == sign_x ? ours_x : 0.0f;
52 | ours_y = sign(ours_y) == sign_y ? ours_y : 0.0f;
53 | }
54 |
55 | // LASSO in R2 (complete)
56 | __device__
57 | void lassoR2_complete(float A11, float A12, float A22,
58 | float a, float b, float c,
59 | float ux, float uy, float mu, float alpha,
60 | float& ours_x, float& ours_y)
61 | {
62 | // temp
63 | float tx = mu*ux - a*c;
64 | float ty = mu*uy - b*c;
65 |
66 | // l2 minimum
67 | float x0_x = A11*tx + A12*ty;
68 | float x0_y = A12*tx + A22*ty;
69 |
70 | // x1 (R^n) [x1 0] & x2 (R^n) [0 x2]
71 | float x1 = (tx > 0) ? ((tx - alpha/2.0f) / (a*a + mu)) : ((tx + alpha/2.0f) / (a*a + mu));
72 | float x2 = (ty > 0) ? ((ty - alpha/2.0f) / (b*b + mu)) : ((ty + alpha/2.0f) / (b*b + mu));
73 |
74 | // x3 (R^2n)
75 | tx = alpha/2.0f * (A11 + A12);
76 | ty = alpha/2.0f * (A12 + A22);
77 | bool tb = (tx*x0_x + ty*x0_y) > 0;
78 | float x3_x = tb ? (x0_x - tx) : (x0_x + tx);
79 | float x3_y = tb ? (x0_y - ty) : (x0_y + ty);
80 |
81 | // x4 (R^2n)
82 | tx = alpha/2.0f * (A11 - A12);
83 | ty = alpha/2.0f * (A12 - A22);
84 | tb = (tx*x0_x + ty*x0_y) > 0;
85 | float x4_x = tb ? (x0_x - tx) : (x0_x + tx);
86 | float x4_y = tb ? (x0_y - ty) : (x0_y + ty);
87 |
88 | // cost functions
89 | float cost[4];
90 |
91 | // temp
92 | float uxx = ux*ux;
93 | float uyy = uy*uy;
94 |
95 | // cost function of x1
96 | float t1 = a*x1 + c;
97 | tx = x1 - ux;
98 | cost[0] = t1*t1 + mu*(tx*tx + uyy) + alpha*fabsf(x1);
99 |
100 | // cost function of x2
101 | t1 = b*x2 + c;
102 | ty = x2 - uy;
103 | cost[1] = t1*t1 + mu*(uxx + ty*ty) + alpha*fabsf(x2);
104 |
105 | // cost function of x3
106 | t1 = a*x3_x + b*x3_y + c;
107 | tx = x3_x - ux;
108 | ty = x3_y - uy;
109 | cost[2] = t1*t1 + mu*(tx*tx + ty*ty) + alpha*(fabsf(x3_x) + fabsf(x3_y));
110 |
111 | // cost function of x4
112 | t1 = a*x4_x + b*x4_y + c;
113 | tx = x4_x - ux;
114 | ty = x4_y - uy;
115 | cost[3] = t1*t1 + mu*(tx*tx + ty*ty) + alpha*(fabsf(x4_x) + fabsf(x4_y));
116 |
117 | // cost function of x5
118 | float cost_min = c*c + mu*(uxx + uyy); // [0 0] solution
119 |
120 | // find minimum
121 | signed int I = 5;
122 | for (signed int i = 0; i < 4; ++i)
123 | {
124 | if (cost[i] < cost_min)
125 | {
126 | cost_min = cost[i];
127 | I = i + 1;
128 | }
129 | }
130 |
131 | // set final solution
132 | switch (I)
133 | {
134 | case 1:
135 | ours_x = x1;
136 | ours_y = 0.0f;
137 | break;
138 | case 2:
139 | ours_x = 0.0f;
140 | ours_y = x2;
141 | break;
142 | case 3:
143 | ours_x = x3_x;
144 | ours_y = x3_y;
145 | break;
146 | case 4:
147 | ours_x = x4_x;
148 | ours_y = x4_y;
149 | break;
150 | case 5:
151 | ours_x = 0.0f;
152 | ours_y = 0.0f;
153 | break;
154 | }
155 | }
156 |
157 |
158 | __global__ void prox_gL1Kernel(float *zeta_x, float *zeta_y,
159 | float *temp_x, float *temp_y, float mu, float alpha,
160 | float *A11, float *A12, float *A22, float *a, float *b, float *c,
161 | int N_width, int N_height, int M_width, int M_height)
162 | {
163 | const int L_width = (N_width - M_width) / 2;
164 | const int L_height = (N_height - M_height) / 2;
165 |
166 | const int ix = threadIdx.x + blockIdx.x * blockDim.x;
167 | const int iy = threadIdx.y + blockIdx.y * blockDim.y;
168 |
169 | const int pos_N = ix + iy * N_width;
170 | const int pos_M = ix - L_width + (iy - L_height) * M_width;
171 |
172 | float temp_w_x = temp_x[pos_N] + zeta_x[pos_N];
173 | float temp_w_y = temp_y[pos_N] + zeta_y[pos_N];
174 | float val_x, val_y;
175 |
176 | // w-update
177 | if (ix >= N_width || iy >= N_height) return;
178 | else if (ix >= L_width && ix < N_width - L_width &&
179 | iy >= L_height && iy < N_height - L_height)
180 | { // update interior flow
181 | // lassoR2_simple(A11[pos_M], A12[pos_M], A22[pos_M],
182 | // a[pos_M], b[pos_M], c[pos_M],
183 | // temp_w_x, temp_w_y, mu, alpha, val_x, val_y);
184 | lassoR2_complete(A11[pos_M], A12[pos_M], A22[pos_M],
185 | a[pos_M], b[pos_M], c[pos_M],
186 | temp_w_x, temp_w_y, mu, alpha, val_x, val_y);
187 | }
188 | else { // keep exterior flow unchanged
189 | val_x = temp_w_x;
190 | val_y = temp_w_y;
191 | }
192 |
193 | // zeta-update
194 | zeta_x[pos_N] = temp_w_x - val_x;
195 | zeta_y[pos_N] = temp_w_y - val_y;
196 |
197 | // pre-store value of (w - zeta)
198 | temp_x[pos_N] = mu * (2 * val_x - temp_w_x);
199 | temp_y[pos_N] = mu * (2 * val_y - temp_w_y);
200 | }
201 |
202 | static
203 | void prox_gL1(float *zeta_x, float *zeta_y,
204 | float *temp_x, float *temp_y, float mu, float alpha,
205 | float *A11, float *A12, float *A22,
206 | float *a, float *b, float *c,
207 | int N_width, int N_height, int M_width, int M_height)
208 | {
209 | dim3 threads(32, 32);
210 | dim3 blocks(iDivUp(N_width, threads.x), iDivUp(N_height, threads.y));
211 |
212 | prox_gL1Kernel<<>>(zeta_x, zeta_y, temp_x, temp_y,
213 | mu, alpha, A11, A12, A22, a, b, c,
214 | N_width, N_height, M_width, M_height);
215 | }
216 |
217 |
218 | ///////////////////////////////////////////////////////////////////////////////
219 | ///////////////////////////////////////////////////////////////////////////////
220 | ///////////////////////////////////////////////////////////////////////////////
221 | ////////////////////////////// L2 Flow ////////////////////////////////
222 | ///////////////////////////////////////////////////////////////////////////////
223 | ///////////////////////////////////////////////////////////////////////////////
224 | ///////////////////////////////////////////////////////////////////////////////
225 |
226 |
227 | ///////////////////////////////////////////////////////////////////////////////
228 | /// \brief compute proximal operator of g
229 | ///
230 | /// \param[in/out] w_x flow along x
231 | /// \param[in/out] w_y flow along y
232 | /// \param[in/out] zeta_x dual variable zeta along x
233 | /// \param[in/out] zeta_y dual variable zeta along y
234 | /// \param[in/out] temp_x temp variable (either for nabla(x) or w-zeta) along x
235 | /// \param[in/out] temp_y temp variable (either for nabla(x) or w-zeta) along y
236 | /// \param[in] mu proximal parameter
237 | /// \param[in] w11, w12_or_w22, w13, w21, w23
238 | /// pre-computed weights
239 | /// \param[in] N_width unknown width
240 | /// \param[in] N_height unknown height
241 | /// \param[in] M_width image width
242 | /// \param[in] M_height image height
243 | ///////////////////////////////////////////////////////////////////////////////
244 | __global__ void prox_gKernel(float *w_x, float *w_y, float *zeta_x, float *zeta_y,
245 | float *temp_x, float *temp_y, float mu,
246 | const float *w11, const float *w12_or_w22,
247 | const float *w13, const float *w21, const float *w23,
248 | int N_width, int N_height,
249 | int M_width, int M_height)
250 | {
251 | const int L_width = (N_width - M_width)/2;
252 | const int L_height = (N_height - M_height)/2;
253 |
254 | const int ix = threadIdx.x + blockIdx.x * blockDim.x;
255 | const int iy = threadIdx.y + blockIdx.y * blockDim.y;
256 |
257 | const int pos_N = ix + iy * N_width;
258 | const int pos_M = (ix-L_width) + (iy-L_height) * M_width;
259 |
260 | float temp_w_x = temp_x[pos_N] + zeta_x[pos_N];
261 | float temp_w_y = temp_y[pos_N] + zeta_y[pos_N];
262 | float val_x, val_y;
263 |
264 | // w-update
265 | if (ix >= N_width || iy >= N_height) return;
266 | else if (ix >= L_width && ix < N_width -L_width &&
267 | iy >= L_height && iy < N_height-L_height){ // update interior flow
268 | val_x = w11[pos_M] * temp_w_x + w12_or_w22[pos_M] * temp_w_y + w13[pos_M];
269 | val_y = w21[pos_M] * temp_w_y + w12_or_w22[pos_M] * temp_w_x + w23[pos_M];
270 | }
271 | else { // keep exterior flow unchanged
272 | val_x = temp_w_x;
273 | val_y = temp_w_y;
274 | }
275 | w_x[pos_N] = val_x;
276 | w_y[pos_N] = val_y;
277 |
278 | // zeta-update
279 | zeta_x[pos_N] = temp_w_x - val_x;
280 | zeta_y[pos_N] = temp_w_y - val_y;
281 |
282 | // pre-store value of (w - zeta)
283 | temp_x[pos_N] = mu * (2*val_x - temp_w_x);
284 | temp_y[pos_N] = mu * (2*val_y - temp_w_y);
285 | }
286 |
287 |
288 |
289 | ///////////////////////////////////////////////////////////////////////////////
290 | /// \brief compute proximal operator of g
291 | ///
292 | /// \param[in/out] w_x flow along x
293 | /// \param[in/out] w_y flow along y
294 | /// \param[in/out] zeta_x dual variable zeta along x
295 | /// \param[in/out] zeta_y dual variable zeta along y
296 | /// \param[in/out] temp_x temp variable (either for nabla(x) or w-zeta) along x
297 | /// \param[in/out] temp_y temp variable (either for nabla(x) or w-zeta) along y
298 | /// \param[in] mu proximal parameter
299 | /// \param[in] w11, w12_or_w22, w13, w21, w23
300 | /// pre-computed weights
301 | /// \param[in] N_width unknown width
302 | /// \param[in] N_height unknown height
303 | /// \param[in] M_width image width
304 | /// \param[in] M_height image height
305 | ///////////////////////////////////////////////////////////////////////////////
306 | static
307 | void prox_g(float *w_x, float *w_y, float *zeta_x, float *zeta_y,
308 | float *temp_x, float *temp_y, float mu,
309 | const float *w11, const float *w12_or_w22,
310 | const float *w13, const float *w21, const float *w23,
311 | int N_width, int N_height, int M_width, int M_height)
312 | {
313 | dim3 threads(32, 6);
314 | dim3 blocks(iDivUp(N_width, threads.x), iDivUp(N_height, threads.y));
315 |
316 | prox_gKernel<<>>(w_x, w_y, zeta_x, zeta_y, temp_x, temp_y,
317 | mu, w11, w12_or_w22, w13, w21, w23,
318 | N_width, N_height, M_width, M_height);
319 | }
320 |
--------------------------------------------------------------------------------
/cuda/src/x_updateKernel.cuh:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 |
6 | #include "common.h"
7 |
8 | template
9 | __global__
10 | void dct_ReorderEvenKernel(T *phi, int N_width, int N_height, W *y)
11 | {
12 | const int ix = threadIdx.x + blockIdx.x * blockDim.x;
13 | const int iy = threadIdx.y + blockIdx.y * blockDim.y;
14 |
15 | if (ix >= N_width || iy >= N_height) return;
16 |
17 | const int pos = ix + iy * N_width;
18 |
19 | y[pos].y = 0.0f;
20 | if (iy < N_height/2){
21 | y[pos].x = phi[ix + 2*iy*N_width];
22 | }
23 | else{
24 | y[pos].x = phi[ix + (2*(N_height-iy)-1)*N_width];
25 | }
26 | }
27 |
28 | template
29 | __global__
30 | void dct_MultiplyFFTWeightsKernel(int N_width, int N_height, W *y, T *b, const W *ww)
31 | {
32 | const int ix = threadIdx.x + blockIdx.x * blockDim.x;
33 | const int iy = threadIdx.y + blockIdx.y * blockDim.y;
34 | if (ix >= N_width || iy >= N_height) return;
35 |
36 | const int pos = ix + iy * N_width;
37 | const int pos_tran = iy + ix * N_height; // transpose on the output b
38 |
39 | b[pos_tran] = (ww[iy].x * y[pos].x + ww[iy].y * y[pos].y) / (T)N_height;
40 | if (iy == 0)
41 | b[pos_tran] /= sqrtf(2);
42 | }
43 |
44 | template
45 | __global__
46 | void divide_mat_x_hatKernel(T *phi, const T *mat_x_hat, int N_width, int N_height)
47 | {
48 | const int ix = threadIdx.x + blockIdx.x * blockDim.x;
49 | const int iy = threadIdx.y + blockIdx.y * blockDim.y;
50 | if (ix >= N_width || iy >= N_height) return;
51 |
52 | const int pos = ix + iy * N_width;
53 |
54 | phi[pos] /= mat_x_hat[pos];
55 | }
56 |
57 | template
58 | __global__
59 | void idct_MultiplyFFTWeightsKernel(int N_width, int N_height, T *b, W *y, const W *ww)
60 | {
61 | const int ix = threadIdx.x + blockIdx.x * blockDim.x;
62 | const int iy = threadIdx.y + blockIdx.y * blockDim.y;
63 | if (ix >= N_width || iy >= N_height) return;
64 |
65 | const int pos = ix + iy * N_width;
66 |
67 | y[pos].x = ww[iy].x * b[pos];
68 | y[pos].y = ww[iy].y * b[pos];
69 | if (iy == 0){
70 | y[pos].x /= sqrtf(2);
71 | y[pos].y /= sqrtf(2);
72 | }
73 | }
74 |
75 | template
76 | __global__
77 | void idct_ReorderEvenKernel(W *y, int N_width, int N_height, T *phi)
78 | {
79 | const int ix = threadIdx.x + blockIdx.x * blockDim.x;
80 | const int iy = threadIdx.y + blockIdx.y * blockDim.y;
81 |
82 | if (ix >= N_width || iy >= N_height) return;
83 |
84 | // const int pos = ix + iy * N_width;
85 | const int pos = iy + ix * N_height;
86 |
87 | if ((iy & 1) == 0){ // iy is even
88 | phi[pos] = y[ix + iy/2*N_width].x / (T) N_height;
89 | }
90 | else{ // iy is odd
91 | phi[pos] = y[ix + (N_height-(iy+1)/2)*N_width].x / (T) N_height;
92 | }
93 | }
94 |
95 |
96 | template
97 | void x_update(T *phi, W *y, const T *mat_x_hat,
98 | const W *ww_1, const W *ww_2,
99 | int N_width, int N_height, cufftHandle plan_dct_1, cufftHandle plan_dct_2)
100 | {
101 | dim3 threads(32, 32);
102 | dim3 blocks_1(iDivUp(N_width, threads.x), iDivUp(N_height, threads.y));
103 | dim3 blocks_2(iDivUp(N_height, threads.x), iDivUp(N_width, threads.y));
104 |
105 | // first DCT
106 | dct_ReorderEvenKernel<<>>(phi, N_width, N_height, y);
107 | // cufftSafeCall(cufftExecC2C(plan_dct_1, y, y, CUFFT_FORWARD));
108 | cufftExecC2C(plan_dct_1, y, y, CUFFT_FORWARD);
109 | dct_MultiplyFFTWeightsKernel<<>>(N_width, N_height, y, phi, ww_1);
110 |
111 | // second DCT
112 | dct_ReorderEvenKernel<<>>(phi, N_height, N_width, y);
113 | cufftExecC2C(plan_dct_2, y, y, CUFFT_FORWARD);
114 | dct_MultiplyFFTWeightsKernel<<>>(N_height, N_width, y, phi, ww_2);
115 |
116 | // divided by mat_x_hat
117 | divide_mat_x_hatKernel<<>>(phi, mat_x_hat, N_width, N_height);
118 |
119 | // first IDCT
120 | idct_MultiplyFFTWeightsKernel<<>>(N_width, N_height, phi, y, ww_1);
121 | cufftExecC2C(plan_dct_1, y, y, CUFFT_INVERSE);
122 | idct_ReorderEvenKernel<<>>(y, N_width, N_height, phi);
123 |
124 | // second IDCT
125 | idct_MultiplyFFTWeightsKernel<<>>(N_height, N_width, phi, y, ww_2);
126 | cufftExecC2C(plan_dct_2, y, y, CUFFT_INVERSE);
127 | idct_ReorderEvenKernel<<>>(y, N_height, N_width, phi);
128 | }
129 |
--------------------------------------------------------------------------------
/data/HeLa/cap.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/PhaseIntensityMicroscope/007be7a30a7cbe782027a01dae91b11b95b16909/data/HeLa/cap.tif
--------------------------------------------------------------------------------
/data/HeLa/ref.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/PhaseIntensityMicroscope/007be7a30a7cbe782027a01dae91b11b95b16909/data/HeLa/ref.tif
--------------------------------------------------------------------------------
/data/MCF-7/cap.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/PhaseIntensityMicroscope/007be7a30a7cbe782027a01dae91b11b95b16909/data/MCF-7/cap.tif
--------------------------------------------------------------------------------
/data/MCF-7/ref.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/PhaseIntensityMicroscope/007be7a30a7cbe782027a01dae91b11b95b16909/data/MCF-7/ref.tif
--------------------------------------------------------------------------------
/data/MLA-150-7AR-M/Zygo.dat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/PhaseIntensityMicroscope/007be7a30a7cbe782027a01dae91b11b95b16909/data/MLA-150-7AR-M/Zygo.dat
--------------------------------------------------------------------------------
/data/MLA-150-7AR-M/cap.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/PhaseIntensityMicroscope/007be7a30a7cbe782027a01dae91b11b95b16909/data/MLA-150-7AR-M/cap.tif
--------------------------------------------------------------------------------
/data/MLA-150-7AR-M/ref.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/PhaseIntensityMicroscope/007be7a30a7cbe782027a01dae91b11b95b16909/data/MLA-150-7AR-M/ref.tif
--------------------------------------------------------------------------------
/data/blood/cap_1.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/PhaseIntensityMicroscope/007be7a30a7cbe782027a01dae91b11b95b16909/data/blood/cap_1.tif
--------------------------------------------------------------------------------
/data/blood/cap_2.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/PhaseIntensityMicroscope/007be7a30a7cbe782027a01dae91b11b95b16909/data/blood/cap_2.tif
--------------------------------------------------------------------------------
/data/blood/ref_1.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/PhaseIntensityMicroscope/007be7a30a7cbe782027a01dae91b11b95b16909/data/blood/ref_1.tif
--------------------------------------------------------------------------------
/data/blood/ref_2.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/PhaseIntensityMicroscope/007be7a30a7cbe782027a01dae91b11b95b16909/data/blood/ref_2.tif
--------------------------------------------------------------------------------
/data/cheek/cap.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/PhaseIntensityMicroscope/007be7a30a7cbe782027a01dae91b11b95b16909/data/cheek/cap.tif
--------------------------------------------------------------------------------
/data/cheek/ref.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/PhaseIntensityMicroscope/007be7a30a7cbe782027a01dae91b11b95b16909/data/cheek/ref.tif
--------------------------------------------------------------------------------
/matlab/cpu_gpu_comparison.m:
--------------------------------------------------------------------------------
1 | clc;clear;close all;
2 | addpath('./utils/');
3 |
4 | % datapath
5 | datapath = '../data/MLA-150-7AR-M/';
6 |
7 | % read image
8 | ref = double(imread([datapath 'ref.tif']));
9 | cap = double(imread([datapath 'cap.tif']));
10 |
11 | % set active area
12 | M = [992 992];
13 |
14 | % crop size
15 | S = (size(ref)-M)/2;
16 |
17 | % crop data
18 | ref = ref(1+S(1):end-S(1),1+S(2):end-S(2));
19 | cap = cap(1+S(1):end-S(1),1+S(2):end-S(2));
20 |
21 | % normalize images
22 | norm_img = @(x) double(uint8( 255 * x ./ max( max(ref(:)), max(cap(:)) ) ));
23 | ref = norm_img(ref);
24 | cap = norm_img(cap);
25 |
26 | % set gpu parameters
27 | opt.priors = [0.1 0.1 100 5];
28 | opt.iter = [3 10 20];
29 | opt.mu = [0.1 100];
30 | opt.tol = 0.05;
31 | opt.L = min((2.^ceil(log2(M)) - M) / 2, 256);
32 | opt.size = M;
33 | opt.isverbose = 1;
34 |
35 | % run cpu algorithm
36 | tic;
37 | [A_cpu, phi_cpu, ~, ~, ~] = cws(ref, cap, opt);
38 | toc
39 |
40 | % run gpu algorithm
41 | [A_gpu, phi_gpu] = cws_gpu_wrapper(ref, cap, opt);
42 |
43 | % mean normalized phase
44 | phi_gpu = phi_gpu - mean2(phi_gpu);
45 |
46 | % show results
47 | figure; imshow([A_cpu A_gpu A_cpu-A_gpu],[]);
48 | title('A: CPU / GPU / Difference');
49 | disp(['A: max diff = ' num2str(norm(A_gpu(:) - A_cpu(:),'inf'))]);
50 | figure; imshow([phi_cpu phi_gpu phi_cpu-phi_gpu],[]);
51 | title('phi: CPU / GPU / Difference');
52 | disp(['phi: max diff = ' num2str(norm(phi_gpu(:) - phi_cpu(:),'inf'))]);
53 | figure;mesh(phi_cpu - phi_gpu)
54 | title('phi: CPU / GPU / Difference');
55 |
--------------------------------------------------------------------------------
/matlab/cws.m:
--------------------------------------------------------------------------------
1 | function [A, phi, wavefront_lap, I_warp, objvaltotal] = cws(I0, I, opt)
2 | % Simutanous intensity and wavefront recovery. Solve for:
3 | %
4 | % min_{A,phi} || i(x+\nabla phi) - A i_0(x) ||_2^2 + ...
5 | % alpha || \nabla phi ||_1 + ...
6 | % beta ( || \nabla phi ||_2^2 + || \nabla^2 phi ||_2^2 ) + ...
7 | % gamma ( || \nabla A ||_1 + || \nabla^2 A ||_1 ) + ...
8 | % tau ( || \nabla A ||_2^2 + || \nabla^2 A ||_2^2 )
9 | %
10 | % Inputs:
11 | % - I0: reference image
12 | % - I: measurement image
13 | % - opt: solver options
14 | % - priors: prior weights [alpha beta gamma beta]
15 | % (default = [0.1,0.1,100,5])
16 | % - iter: [total_alternating_iter A-update_iter phi-update_iter]
17 | % (default = [3 10 20])
18 | % - mu: ADMM parameters (default = [0.1 100])
19 | % - tol: phi-update tolerance stopping criteria (default = 0.05)
20 | % - L: padding size [pad_width pad_height]
21 | % (default nearest power of 2 of out_size, each in range [2, 256])
22 | % - isverbose: output verbose [A_verbose phi_verbose]
23 | % (default = [0 0])
24 | % Outputs:
25 | % - A: intensity
26 | % - phi: wavefront
27 | % - wavefront_lap: wavefront Laplacian
28 | % - I_warp: warped image of I
29 | % - objvaltotal: objective function
30 | %
31 | % See also `cws_gpu_wrapper.m` for its GPU version.
32 |
33 | % check images
34 | if ~isa(I0,'double')
35 | I0 = double(I0);
36 | end
37 | if ~isa(I,'double')
38 | I = double(I);
39 | end
40 |
41 | if nargin < 3
42 | disp('Using default parameter settings ...')
43 |
44 | % tradeoff parameters
45 | alpha = 0.1;
46 | beta = 0.1;
47 | gamma = 100;
48 | tau = 5;
49 |
50 | % total number of alternation iterations
51 | iter = 3;
52 |
53 | % A-update parameters
54 | opt_A.isverbose = 0;
55 | opt_A.iter = 10;
56 | opt_A.mu_A = 1e-1;
57 |
58 | % phi-update parameters
59 | opt_phase.isverbose = 0;
60 | opt_phase.L = min((2.^ceil(log2(size(I0))) - size(I0)) / 2, 256); % suit to [1024 1024]
61 | opt_phase.mu = 100;
62 | opt_phase.iter = 20;
63 |
64 | % tolerance
65 | tol = 0.05;
66 | else
67 | % check option, if not existed, use default parameters
68 | if ~isfield(opt,'priors')
69 | opt.priors = [0.1 0.1 100 5];
70 | end
71 | if ~isfield(opt,'iter')
72 | opt.iter = [3 10 20];
73 | end
74 | if ~isfield(opt,'mu')
75 | opt.mu = [0.1 100];
76 | end
77 | if ~isfield(opt,'tol')
78 | opt.tol = 0.05;
79 | end
80 | if ~isfield(opt,'L')
81 | opt.L = min((2.^ceil(log2(size(I0))) - size(I0)) / 2, 256); % suit to [1024 1024]
82 | end
83 | if ~isfield(opt,'isverbose')
84 | opt.isverbose = [0 0];
85 | end
86 |
87 | % parameters checkers
88 | opt.priors = max(0, opt.priors);
89 | if length(opt.priors) ~= 4
90 | error('length of opt.priors must equal to 4!');
91 | end
92 | opt.iter = round(opt.iter);
93 | if length(opt.iter) ~= 3
94 | error('length of opt.iter must equal to 3!');
95 | end
96 | opt.mu = max(0, opt.mu);
97 | if length(opt.mu) ~= 2
98 | error('length of opt.mu must equal to 2!');
99 | end
100 | opt.tol = max(0, opt.tol);
101 | if length(opt.tol) ~= 1
102 | error('length of opt.tol must equal to 1!');
103 | end
104 | opt.L = max(0, opt.L);
105 | if length(opt.L) ~= 2
106 | error('length of opt.L must equal to 2!');
107 | end
108 |
109 | % tradeoff parameters
110 | alpha = opt.priors(1);
111 | beta = opt.priors(2);
112 | gamma = opt.priors(3);
113 | tau = opt.priors(4);
114 |
115 | % total number of alternation iterations
116 | iter = opt.iter(1);
117 |
118 | % A-update parameters
119 | opt_A.isverbose = 0;
120 | opt_A.iter = opt.iter(2);
121 | opt_A.mu_A = opt.mu(1);
122 |
123 | % phi-update parameters
124 | opt_phase.isverbose = 0;
125 | opt_phase.L = opt.L;
126 | opt_phase.mu = opt.mu(2);
127 | opt_phase.iter = opt.iter(3);
128 |
129 | % tolerance
130 | tol = opt.tol;
131 | end
132 |
133 | % compute A-update parameters
134 | gamma_new = gamma / mean2(abs(I0).^2);
135 | tau_new = tau / mean2(abs(I0).^2);
136 |
137 | % define operators in spatial domain
138 | x_k = [0 0 0; -1 1 0; 0 0 0];
139 | y_k = [0 -1 0; 0 1 0; 0 0 0];
140 | l_k = [0 1 0; 1 -4 1; 0 1 0];
141 |
142 | % define sizes
143 | M = size(I0);
144 | L = opt_phase.L;
145 | N = M + 2*L;
146 |
147 | % boundary mask
148 | M1 = @(u) cat(3, u(L(1)+1:end-L(1),L(2)+1:end-L(2),1), ...
149 | u(L(1)+1:end-L(1),L(2)+1:end-L(2),2));
150 |
151 | % specify boundary conditions
152 | bc = 'symmetric'; % use DCT
153 |
154 | % define operators
155 | nabla = @(phi) cat(3, imfilter(phi, x_k, bc), imfilter(phi, y_k, bc));
156 | nabla2 = @(phi) imfilter(phi, l_k, bc);
157 | nablaT = @(adj_phi) imfilter(adj_phi(:,:,1), rot90(x_k,2), bc) + ...
158 | imfilter(adj_phi(:,:,2), rot90(y_k,2), bc);
159 | K = @(x) cat(3, nabla(x), nabla2(x));
160 | KT = @(u) nablaT(u(:,:,1:2)) + nabla2(u(:,:,3));
161 |
162 | % pre-calculated variables
163 | [~, K_mat] = prepare_DCT_basis(M);
164 |
165 | % inversion basis in DCT domain for A-update and phi-update
166 | mat_A_hat = 1 + (tau_new + opt_A.mu_A) * K_mat;
167 |
168 | % define proximal operator of A
169 | prox_A = @(u) sign(u) .* max(abs(u) - gamma_new/(2*opt_A.mu_A), 0);
170 |
171 | % define norms
172 | norm2 = @(x) sum(abs(x(:)).^2);
173 | norm1 = @(x) sum(abs(x(:)));
174 |
175 | % define objective functions
176 | obj_A = @(A, b) norm2(A - b./I0) + gamma_new * norm1(K(A)) + tau_new * norm2(K(A));
177 | obj_total = @(A, b, phi) norm2(I0.*A - b) + ...
178 | alpha * norm1(nabla(phi)) + beta * norm2(K(phi)) + ...
179 | gamma * norm1(K(A)) + tau * norm2(K(A));
180 |
181 | % define obj handle
182 | disp_obj = @(s,i,objval) disp([s num2str(i) ', obj = ' num2str(objval,'%.4e')]);
183 |
184 | % initialization
185 | A = ones(M);
186 | phi = zeros(N);
187 | I_warp = I;
188 |
189 | % main loop
190 | objvaltotal = zeros(iter,1);
191 | disp_obj('iter = ', 0, obj_total(A, I_warp, phi));
192 | for outer_loop = 1:iter
193 | % === A-update ===
194 | B = zeros([size(A) 3]);
195 | zeta = zeros([size(A) 3]);
196 | disp('-- A-update')
197 | for i_A = 1:opt_A.iter
198 | % A-update
199 | A = idct(idct(dct(dct( I_warp./I0 + opt_A.mu_A*KT(B-zeta) ).').' ./ mat_A_hat).').';
200 |
201 | % pre-cache
202 | u = K(A) + zeta;
203 |
204 | % B-update
205 | B = prox_A(u);
206 |
207 | % zeta-update
208 | zeta = u - B;
209 |
210 | % show objective function
211 | if opt_A.isverbose
212 | disp_obj('---- iter = ', i_A, obj_A(A,I_warp));
213 | end
214 | end
215 | % median filter A
216 | A = medfilt2(A, [3 3], 'symmetric');
217 |
218 | % === phi-update ===
219 | disp('-- phi-update')
220 | img = cat(3, A.*I0, I_warp);
221 | [~, Delta_phi, I_warp] = phase_update_ADMM(img, alpha, beta, opt_phase);
222 |
223 | disp(['-- mean(|\Delta\phi|) = ' num2str(mean(abs(Delta_phi(:))),'%.3e')])
224 | if mean(abs(Delta_phi(:))) < tol
225 | disp('-- mean(|\Delta\phi|) too small; quit');
226 | break;
227 | end
228 | phi = phi + Delta_phi;
229 |
230 | % === records ===
231 | objvaltotal(outer_loop) = obj_total(A, I_warp, phi);
232 | disp_obj('iter = ', outer_loop, objvaltotal(outer_loop));
233 | end
234 |
235 | % compute wavefront Laplacian
236 | wavefront_lap = nabla2(phi);
237 | wavefront_lap = M1(cat(3,wavefront_lap,wavefront_lap));
238 | wavefront_lap = wavefront_lap(:,:,1);
239 |
240 | % median filtering phi
241 | phi = medfilt2(phi, [3 3], 'symmetric');
242 |
243 | % return phi
244 | phi = M1(cat(3,phi,phi));
245 | phi = phi(:,:,1);
246 | phi = phi - mean2(phi);
247 |
248 | return;
249 |
250 | function [x, x_full, I_warp] = phase_update_ADMM(img, alpha, beta, opt)
251 |
252 | M = [size(img,1) size(img,2)];
253 | L = opt.L;
254 | N = M + 2*L;
255 |
256 | % boundary mask
257 | M1 = @(u) cat(3, u(L(1)+1:end-L(1),L(2)+1:end-L(2),1), ...
258 | u(L(1)+1:end-L(1),L(2)+1:end-L(2),2));
259 | [lap_mat, ~] = prepare_DCT_basis(N);
260 | mat_x_hat = (opt.mu+beta)*lap_mat + beta*lap_mat.^2;
261 | mat_x_hat(1) = 1;
262 |
263 | % initialization
264 | x = zeros(N);
265 | u = zeros([N 2]);
266 | zeta = zeros([N 2]);
267 |
268 | % get the matrices
269 | [gt, gx, gy] = partial_deriv(img);
270 |
271 | % pre-compute
272 | mu = opt.mu;
273 | gxy = gx.*gy;
274 | gxx = gx.^2;
275 | gyy = gy.^2;
276 |
277 | % store in memory at run-time
278 | denom = mu*(gxx + gyy + mu);
279 | A11 = (gyy + mu) ./ denom;
280 | A12 = -gxy ./ denom;
281 | A22 = (gxx + mu) ./ denom;
282 |
283 | % proximal algorithm
284 | objval = zeros(opt.iter,1);
285 | time = zeros(opt.iter,1);
286 |
287 | % the loop
288 | tic;
289 | for k = 1:opt.iter
290 |
291 | % x-update
292 | if k > 1
293 | x = idct(idct(dct(dct(mu*nablaT(u-zeta)).').' ./ mat_x_hat).').';
294 | end
295 |
296 | % pre-compute nabla_x
297 | nabla_x = nabla(x);
298 |
299 | % u-update
300 | u = nabla_x + zeta;
301 | u_temp = M1(u);
302 |
303 | % R2 LASSO (in practice, difference between the two solutions
304 | % are subtle; for speed's sake, the simple version is
305 | % recommended)
306 | [w_opt_x,w_opt_y] = R2LASSO(gx,gy,gt,u_temp(:,:,1),u_temp(:,:,2),alpha,mu,A11,A12,A22,'complete');
307 |
308 | % update u
309 | u(L(1)+1:end-L(1), L(2)+1:end-L(2), 1) = w_opt_x;
310 | u(L(1)+1:end-L(1), L(2)+1:end-L(2), 2) = w_opt_y;
311 |
312 | % zeta-update
313 | zeta = zeta + nabla_x - u;
314 |
315 | % record
316 | if opt.isverbose
317 | if k == 1
318 | % define operators
319 | G = @(u) gx.*u(:,:,1) + gy.*u(:,:,2);
320 |
321 | % define objective function
322 | obj = @(phi) norm2(G(M1(nabla(phi)))+gt) + ...
323 | alpha * norm1(nabla(phi)) + ...
324 | beta * (norm2(nabla(phi)) + norm2(nabla2(phi)));
325 | end
326 | objval(k) = obj(x);
327 | time(k) = toc;
328 | disp(['---- ADMM iter: ' num2str(k) ', obj = ' num2str(objval(k),'%.4e')])
329 | end
330 | end
331 |
332 | % compute the warped image by Taylor expansion
333 | w = M1(nabla(x));
334 | I_warp = gx.*w(:,:,1) + gy.*w(:,:,2) + img(:,:,2);
335 |
336 | % return masked x
337 | x = x - mean2(x);
338 | x_full = x;
339 | x = M1(cat(3,x,x));
340 | x = x(:,:,1);
341 | end
342 | end
343 |
344 |
345 | function [It,Ix,Iy] = partial_deriv(images)
346 |
347 | % derivative kernel
348 | h = [1 -8 0 8 -1]/12;
349 |
350 | % temporal gradient
351 | It = images(:,:,2) - images(:,:,1);
352 |
353 | % First compute derivative then warp
354 | Ix = imfilter(images(:,:,2), h, 'replicate');
355 | Iy = imfilter(images(:,:,2), h', 'replicate');
356 |
357 | end
358 |
359 |
360 | function [lap_mat, K_mat] = prepare_DCT_basis(M)
361 | % Prepare DCT basis of size M:
362 | % lap_mat: \nabla^2
363 | % K_mat: \nabla^2 + \nabla^4
364 |
365 | H = M(1);
366 | W = M(2);
367 | [x_coord,y_coord] = meshgrid(0:W-1,0:H-1);
368 | lap_mat = 4 - 2*cos(pi*x_coord/W) - 2*cos(pi*y_coord/H);
369 | K_mat = lap_mat .* (lap_mat + 1);
370 |
371 | end
372 |
373 |
374 | function [x, y] = R2LASSO(a,b,c,ux,uy,alpha,mu,A11,A12,A22,option)
375 | switch option
376 | case 'simple'
377 | [x, y] = R2LASSO_simple(a,b,c,ux,uy,alpha,mu,A11,A12,A22);
378 | case 'complete'
379 | [x, y] = R2LASSO_complete(a,b,c,ux,uy,alpha,mu,A11,A12,A22);
380 | otherwise
381 | error('invalid option for R2LASSO solver!')
382 | end
383 | end
384 |
385 |
386 | function [x, y] = R2LASSO_simple(a,b,c,ux,uy,alpha,mu,A11,A12,A22)
387 | % This function attempts to solve the R2 LASSO problem, in a fast but not accurate sense:
388 | % min (a x + b y + c)^2 + \mu [(x - ux)^2 + (y - uy)^2] + \alpha (|x| + |y|)
389 | % x,y
390 | %
391 | % See also R2LASSO_complete for a more accurate but slower one.
392 | %
393 | % A11, A12 and A22 can be computed as:
394 | % A_tmp = 1/(mu*(a*a + b*b + mu)) .* [b*b + mu -a*b; -a*b a*a + mu];
395 | % A11 = A_tmp(1);
396 | % A12 = A_tmp(2);
397 | % A22 = A_tmp(4);
398 |
399 | if ~exist('A11','var') || ~exist('A12','var') || ~exist('A22','var')
400 | denom = 1 ./ ( mu * (a.*a + b.*b + mu) );
401 | A11 = denom * (b.*b + mu);
402 | A12 = -a.*b .* denom;
403 | A22 = denom .* (a.*a + mu);
404 | end
405 |
406 | % 1. get (x0, y0)
407 |
408 | % temp
409 | tx = mu*ux - a.*c;
410 | ty = mu*uy - b.*c;
411 |
412 | % l2 minimum
413 | x0_x = A11.*tx + A12.*ty;
414 | x0_y = A12.*tx + A22.*ty;
415 |
416 | % 2. get the sign of optimal
417 | sign_x = sign(x0_x);
418 | sign_y = sign(x0_y);
419 |
420 | % 3. get the optimal
421 | x = x0_x - alpha/2 * (A11 .* sign_x + A12 .* sign_y);
422 | y = x0_y - alpha/2 * (A12 .* sign_x + A22 .* sign_y);
423 |
424 | % 4. check sign and map to the range
425 | x( sign(x) ~= sign_x ) = 0;
426 | y( sign(y) ~= sign_y ) = 0;
427 |
428 | end
429 |
430 |
431 | function [x, y] = R2LASSO_complete(a,b,c,ux,uy,alpha,mu,A11,A12,A22)
432 | % This function attempts to solve the R2 LASSO problem in a most accurate sense:
433 | % min (a x + b y + c)^2 + \mu [(x - ux)^2 + (y - uy)^2] + \alpha (|x| + |y|)
434 | % x,y
435 | %
436 | % See also R2LASSO_simple for a faster but less accurate one.
437 | %
438 | % A11, A12 and A22 can be computed as:
439 | % A_tmp = 1/(mu*(a*a + b*b + mu)) .* [b*b + mu -a*b; -a*b a*a + mu];
440 | % A11 = A_tmp(1);
441 | % A12 = A_tmp(2);
442 | % A22 = A_tmp(4);
443 |
444 | if ~exist('A11','var') || ~exist('A12','var') || ~exist('A22','var')
445 | denom = 1 ./ ( mu * (a.*a + b.*b + mu) );
446 | A11 = denom .* (b.*b + mu);
447 | A12 = -a.*b .* denom;
448 | A22 = denom .* (a.*a + mu);
449 | end
450 |
451 | % temp
452 | tx = mu*ux - a.*c;
453 | ty = mu*uy - b.*c;
454 |
455 | % l2 minimum
456 | x0_x = A11.*tx + A12.*ty;
457 | x0_y = A12.*tx + A22.*ty;
458 |
459 | % x1 (R^n) [x1 0] & x2 (R^n) [0 x2]
460 | % if (tx > 0)
461 | % x1 = (tx - alpha/2) ./ (a.*a + mu);
462 | % else
463 | % x1 = (tx + alpha/2) ./ (a.*a + mu);
464 | % end
465 | % if (ty > 0)
466 | % x2 = (ty - alpha/2) ./ (b.*b + mu);
467 | % else
468 | % x2 = (ty + alpha/2) ./ (b.*b + mu);
469 | % end
470 | x1 = (tx + alpha/2) ./ (a.*a + mu);
471 | tmp = (tx - alpha/2) ./ (a.*a + mu);
472 | ind = tx > 0;
473 | x1(ind) = tmp(ind);
474 |
475 | x2 = (ty + alpha/2) ./ (b.*b + mu);
476 | tmp = (ty - alpha/2) ./ (b.*b + mu);
477 | ind = ty > 0;
478 | x2(ind) = tmp(ind);
479 |
480 | % x3 (R^2n)
481 | % tx = alpha/2 * (A11 + A12);
482 | % ty = alpha/2 * (A12 + A22);
483 | % if (tx*x0_x + ty*x0_y) > 0
484 | % x3_x = x0_x - tx;
485 | % x3_y = x0_y - ty;
486 | % else
487 | % x3_x = x0_x + tx;
488 | % x3_y = x0_y + ty;
489 | % end
490 | tx = alpha/2 * (A11 + A12);
491 | ty = alpha/2 * (A12 + A22);
492 | x3_x = x0_x + tx;
493 | x3_y = x0_y + ty;
494 | tmp_x = x0_x - tx;
495 | tmp_y = x0_y - ty;
496 | ind = (tx.*x0_x + ty.*x0_y) > 0;
497 | x3_x(ind) = tmp_x(ind);
498 | x3_y(ind) = tmp_y(ind);
499 |
500 | % x4 (R^2n)
501 | % tx = alpha/2 * (A11 - A12);
502 | % ty = alpha/2 * (A12 - A22);
503 | % if (tx*x0_x + ty*x0_y) > 0
504 | % x4_x = x0_x - tx;
505 | % x4_y = x0_y - ty;
506 | % else
507 | % x4_x = x0_x + tx;
508 | % x4_y = x0_y + ty;
509 | % end
510 | tx = alpha/2 * (A11 - A12);
511 | ty = alpha/2 * (A12 - A22);
512 | x4_x = x0_x + tx;
513 | x4_y = x0_y + ty;
514 | tmp_x = x0_x - tx;
515 | tmp_y = x0_y - ty;
516 | ind = (tx.*x0_x + ty.*x0_y) > 0;
517 | x4_x(ind) = tmp_x(ind);
518 | x4_y(ind) = tmp_y(ind);
519 |
520 | % cost functions
521 | cost = inf([size(a) 5]);
522 |
523 | % x1
524 | t1 = a.*x1 + c;
525 | tx = x1 - ux;
526 | cost(:,:,1) = t1.*t1 + mu*(tx.*tx + uy.*uy) + alpha*abs(x1);
527 |
528 | % x2
529 | t1 = b.*x2 + c;
530 | ty = x2 - uy;
531 | cost(:,:,2) = t1.*t1 + mu*(ux.*ux + ty.*ty) + alpha*abs(x2);
532 |
533 | % x3
534 | t1 = a.*x3_x + b.*x3_y + c;
535 | tx = x3_x - ux;
536 | ty = x3_y - uy;
537 | cost(:,:,3) = t1.*t1 + mu*(tx.*tx + ty.*ty) + alpha*(abs(x3_x) + abs(x3_y));
538 |
539 | % x4
540 | t1 = a.*x4_x + b.*x4_y + c;
541 | tx = x4_x - ux;
542 | ty = x4_y - uy;
543 | cost(:,:,4) = t1.*t1 + mu*(tx.*tx + ty.*ty) + alpha*(abs(x4_x) + abs(x4_y));
544 |
545 | % x5
546 | cost(:,:,5) = c.*c + mu*(ux.*ux + uy.*uy); % [0 0] solution
547 |
548 | % find minimum
549 | cost_min = min(cost, [], 3);
550 | x_can = cat(3, x1, zeros(size(a)), x3_x, x4_x, zeros(size(a)));
551 | y_can = cat(3, zeros(size(a)), x2, x3_y, x4_y, zeros(size(a)));
552 | ind = cost == cost_min;
553 |
554 | % return
555 | x = sum(ind .* x_can, 3);
556 | y = sum(ind .* y_can, 3);
557 |
558 | end
559 |
--------------------------------------------------------------------------------
/matlab/cws_gpu_wrapper.m:
--------------------------------------------------------------------------------
1 | function [A, phi] = cws_gpu_wrapper(I0, I, opt)
2 | % This is a MATLAB wrapper for GPU solver `cws`.
3 | %
4 | % Inputs:
5 | % - I0: reference image
6 | % - I: measurement image
7 | % - opt: solver options
8 | % - priors: prior weights [alpha beta gamma beta]
9 | % (default = [0.1,0.1,100,5])
10 | % - iter: [total_alternating_iter A-update_iter phi-update_iter]
11 | % (default = [3 10 20])
12 | % - mu: ADMM parameters (default = [0.1 100])
13 | % - tol: phi-update tolerance stopping criteria (default = 0.05)
14 | % - L: padding size [pad_width pad_height]
15 | % (default nearest power of 2 of out_size, each in range [2, 256])
16 | % - isverbose: output verbose [A_verbose phi_verbose]
17 | % (default = [0 0])
18 | % - size: output size (this is unavailable in the CPU version)
19 | %
20 | % Outputs:
21 | % - A: intensity
22 | % - phi: wavefront
23 | %
24 | % See also `cws.m` for its CPU version, and `cpu_gpu_comparison` as demo.
25 |
26 | % get cws solver
27 | tpath = mfilename('fullpath');
28 | if ismac || isunix
29 | cws_path = [tpath(1:end-23) '/cuda/bin/cws'];
30 | elseif ispc
31 | cws_path = ['"' tpath(1:end-23) '\cuda\bin\cws' '"'];
32 | else
33 | disp('Platform not supported.');
34 | end
35 |
36 | % check options
37 | if nargin == 3
38 | % check option, if not existed, use default parameters
39 | if ~isfield(opt,'priors')
40 | opt.priors = [0.1 0.1 100 5];
41 | end
42 | if ~isfield(opt,'iter')
43 | opt.iter = [3 10 20];
44 | end
45 | if ~isfield(opt,'mu')
46 | opt.mu = [0.1 100];
47 | end
48 | if ~isfield(opt,'tol')
49 | opt.tol = 0.05;
50 | end
51 | if ~isfield(opt,'L')
52 | opt.L = min((2.^ceil(log2(size(I0))) - size(I0)) / 2, 256); % suit to [1024 1024]
53 | end
54 | if ~isfield(opt,'isverbose')
55 | opt.isverbose = 0;
56 | end
57 |
58 | % parameters checkers
59 | opt.priors = max(0, opt.priors);
60 | if length(opt.priors) ~= 4
61 | error('length of opt.priors must equal to 4!');
62 | end
63 | opt.iter = round(opt.iter);
64 | if length(opt.iter) ~= 3
65 | error('length of opt.iter must equal to 3!');
66 | end
67 | opt.mu = max(0, opt.mu);
68 | if length(opt.mu) ~= 2
69 | error('length of opt.mu must equal to 2!');
70 | end
71 | opt.tol = max(0, opt.tol);
72 | if length(opt.tol) ~= 1
73 | error('length of opt.tol must equal to 1!');
74 | end
75 | opt.L = max(0, opt.L);
76 | if length(opt.L) ~= 2
77 | error('length of opt.L must equal to 2!');
78 | end
79 | if length(opt.isverbose) ~= 1
80 | error('length of opt.isverbose must equal to 1!');
81 | end
82 | else % use default parameters
83 | opt.priors = [0.1 0.1 100 5];
84 | opt.iter = [3 10 20];
85 | opt.mu = [0.1 100];
86 | opt.tol = 0.05;
87 | opt.L = min((2.^ceil(log2(size(I0))) - size(I0)) / 2, 256); % suit to [1024 1024]
88 | opt.isverbose = 0;
89 | end
90 |
91 | if isfield(opt,'size')
92 | opt.size = max(0, 2*round(opt.size/2));
93 | issize = [' -s ' num2str(opt.size(1)) ' -s ' num2str(opt.size(2))];
94 | else % default the same size as inputs
95 | issize = [' -s ' num2str(size(I0,1)) ' -s ' num2str(size(I0,2))];
96 | end
97 |
98 | if abs(sign(opt.isverbose))
99 | isverbose_str = ' --verbose';
100 | else
101 | isverbose_str = ' ';
102 | end
103 |
104 | % write input data to disk
105 | J = cat(3, I0, I);
106 | if ~isa(J,'uint8')
107 | J = uint8(255 * J / max(J(:)));
108 | end
109 | I_path = {['I1_' char(datetime('today')) '.png']; ...
110 | ['I2_' char(datetime('today')) '.png']};
111 | arrayfun(@(i) imwrite(J(:,:,i), I_path{i}), 1:2);
112 |
113 | % output name
114 | out_name = ['out_tmp_' char(datetime('today')) '.flo'];
115 |
116 | % cat solver info
117 | cws_run = [cws_path ...
118 | ' -p ' num2str(opt.priors(1)) ' -p ' num2str(opt.priors(2)) ...
119 | ' -p ' num2str(opt.priors(3)) ' -p ' num2str(opt.priors(4)) ...
120 | ' -i ' num2str(opt.iter(1)) ' -i ' num2str(opt.iter(2)) ...
121 | ' -i ' num2str(opt.iter(3)) ...
122 | ' -m ' num2str(opt.mu(1)) ' -m ' num2str(opt.mu(2)) ...
123 | ' -t ' num2str(opt.tol) isverbose_str issize ...
124 | ' -l ' num2str(opt.L(1)) ' -l ' num2str(opt.L(2))...
125 | ' -o ' out_name ' -f "' I_path{1} '" -f "' I_path{2} '"'];
126 |
127 | % run GPU solver
128 | system(cws_run);
129 |
130 | % read and load the results
131 | test = readFlowFile(out_name);
132 | A = test(:,:,1);
133 | phi = test(:,:,2);
134 |
135 | % remove output file
136 | if ismac || isunix
137 | system(['rm ' out_name]);
138 | arrayfun(@(i) system(['rm ' I_path{i}]), 1:numel(I_path));
139 | elseif ispc
140 | system(['del ' out_name]);
141 | arrayfun(@(i) system(['del ' I_path{i}]), 1:numel(I_path));
142 | else
143 | disp('Platform not supported.');
144 | end
145 |
146 | end
147 |
148 |
149 | function img = readFlowFile(filename)
150 |
151 | % readFlowFile read a flow file FILENAME into 2-band image IMG
152 |
153 | % According to the c++ source code of Daniel Scharstein
154 | % Contact: schar@middlebury.edu
155 |
156 | % Author: Deqing Sun, Department of Computer Science, Brown University
157 | % Contact: dqsun@cs.brown.edu
158 | % $Date: 2007-10-31 16:45:40 (Wed, 31 Oct 2006) $
159 |
160 | % Copyright 2007, Deqing Sun.
161 | %
162 | % All Rights Reserved
163 | %
164 | % Permission to use, copy, modify, and distribute this software and its
165 | % documentation for any purpose other than its incorporation into a
166 | % commercial product is hereby granted without fee, provided that the
167 | % above copyright notice appear in all copies and that both that
168 | % copyright notice and this permission notice appear in supporting
169 | % documentation, and that the name of the author and Brown University not be used in
170 | % advertising or publicity pertaining to distribution of the software
171 | % without specific, written prior permission.
172 | %
173 | % THE AUTHOR AND BROWN UNIVERSITY DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE,
174 | % INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY
175 | % PARTICULAR PURPOSE. IN NO EVENT SHALL THE AUTHOR OR BROWN UNIVERSITY BE LIABLE FOR
176 | % ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
177 | % WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
178 | % ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
179 | % OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
180 |
181 | TAG_FLOAT = 202021.25; % check for this when READING the file
182 |
183 | % sanity check
184 | if isempty(filename) == 1
185 | error('readFlowFile: empty filename');
186 | end
187 |
188 | idx = strfind(filename, '.');
189 | idx = idx(end);
190 |
191 | if length(filename(idx:end)) == 1
192 | error('readFlowFile: extension required in filename %s', filename);
193 | end
194 |
195 | if strcmp(filename(idx:end), '.flo') ~= 1
196 | error('readFlowFile: filename %s should have extension ''.flo''', filename);
197 | end
198 |
199 | fid = fopen(filename, 'r');
200 | if (fid < 0)
201 | error('readFlowFile: could not open %s', filename);
202 | end
203 |
204 | tag = fread(fid, 1, 'float32');
205 | width = fread(fid, 1, 'int32');
206 | height = fread(fid, 1, 'int32');
207 |
208 | % sanity check
209 |
210 | if (tag ~= TAG_FLOAT)
211 | error('readFlowFile(%s): wrong tag (possibly due to big-endian machine?)', filename);
212 | end
213 |
214 | if (width < 1 || width > 99999)
215 | error('readFlowFile(%s): illegal width %d', filename, width);
216 | end
217 |
218 | if (height < 1 || height > 99999)
219 | error('readFlowFile(%s): illegal height %d', filename, height);
220 | end
221 |
222 | nBands = 2;
223 |
224 | % arrange into matrix form
225 | tmp = fread(fid, inf, 'float32');
226 | tmp = reshape(tmp, [width*nBands, height]);
227 | tmp = tmp';
228 | img(:,:,1) = tmp(:, (1:width)*nBands-1);
229 | img(:,:,2) = tmp(:, (1:width)*nBands);
230 |
231 | fclose(fid);
232 |
233 | end
234 |
235 |
--------------------------------------------------------------------------------
/matlab/main_wavefront_solver.m:
--------------------------------------------------------------------------------
1 | function phi = main_wavefront_solver(img, beta, opt, warping_iter)
2 | % [in] img the stacked reference and captured image pair
3 | % [in] beta smoothness paramter \beta
4 | % [in] opt solver options:
5 | % opt.mu [vector] proximal paramter at each pyramid level
6 | % opt.L [cell] additional boundary unknowns at each pyramid level
7 | % (better larger than [2 2] for each cell element)
8 | % opt.isverbose [bool] is the solver verbose or not
9 | % opt.ls [char] which linear solver to use; 'ADMM' or 'CG'
10 | % opt.iter [double] linear solver iterations
11 | % warping_iter [vector] warping iterations at each pyramid level
12 | % [out] phi the recovered wavefront
13 | %
14 | % From https://github.com/vccimaging/MegapixelAO.
15 |
16 | % set default values
17 | if ~isfield(opt,'mu')
18 | opt.mu = 100;
19 | end
20 | if ~isfield(opt,'L')
21 | opt.L = repmat({[2 2]}, [1 numel(opt.mu)]);
22 | end
23 | if ~isfield(opt,'isverbose')
24 | opt.ls = 0;
25 | end
26 | if ~isfield(opt,'ls') || ((opt.ls ~= 'ADMM') && (opt.ls ~= 'CG'))
27 | opt.ls = 'ADMM';
28 | opt.iter = 10;
29 | end
30 | if ~isfield(opt,'iter')
31 | switch opt.ls
32 | case 'ADMM'
33 | opt.iter = 10;
34 | case 'CG'
35 | opt.iter = 1000;
36 | end
37 | end
38 | if ~exist('warping_iter','var')
39 | warping_iter = repmat(2, [1 numel(opt.mu)]);
40 | end
41 |
42 | % check parameter length
43 | if ~((numel(warping_iter) == numel(opt.L)) && (numel(opt.L) == numel(opt.mu)))
44 | error('numel(warping_iter) == numel(opt.L) == numel(opt.mu) not satisfied!');
45 | end
46 |
47 | % get number of pyramid levels
48 | pyramid_level = numel(warping_iter)-1;
49 |
50 | % define scale for x and y dimensions
51 | scale = [2 2];
52 | dim_org = [size(img,1) size(img,2)];
53 |
54 | % check size of L
55 | if length(opt.L) ~= pyramid_level + 1
56 | error('length(opt.L) ~= pyramid_level + 1!');
57 | end
58 |
59 | % the warping scheme
60 | for i = pyramid_level:-1:0 % (pyramid warping)
61 | img_pyramid = imresize(img, dim_org./scale.^i, 'bilinear');
62 | dim = size(img_pyramid(:,:,2));
63 | [x, y] = meshgrid(1:dim(2),1:dim(1));
64 |
65 | % pre-cache cubic spline coefficients
66 | c = cbanal(cbanal(img_pyramid(:,:,1)).').';
67 |
68 | % options
69 | opt_small = opt;
70 | opt_small.L = opt.L{i+1};
71 | opt_small.mu = opt.mu(i+1);
72 |
73 | % Matrix M
74 | M = @(x) x(1+opt_small.L(1):end-opt_small.L(1), ...
75 | 1+opt_small.L(2):end-opt_small.L(2));
76 |
77 | for j = 1:warping_iter(i+1) % (in-level warping)
78 | if exist('phi','var')
79 | wx = imfilter(phi, [-1 1 0], 'replicate');
80 | wy = imfilter(phi, [-1 1 0]', 'replicate');
81 | x_new = -M(wx) + x;
82 | y_new = -M(wy) + y;
83 | temp_img = cat(3, cbinterp(c,y_new,x_new), img_pyramid(:,:,2));
84 | else
85 | temp_img = img_pyramid;
86 | end
87 | switch opt_small.ls % linear solver
88 | case 'ADMM'
89 | [~, phi_delta] = main_ADMM_fast(temp_img, beta, [], opt_small);
90 | case 'CG'
91 | [~, phi_delta] = main_cg_fast(temp_img, beta, opt_small);
92 | end
93 |
94 | % check if the mean of phi_delta is too small; for early termination
95 | mean_phi = mean(abs(phi_delta(:)));
96 | if mean_phi < 0 % 0.1/prod(scale)^i % ( [0.314 0.628 1.257 2.094] = 2*pi./[20 10 5 3] )
97 | disp(['Pyr ' num2str(i) ', Warp ' num2str(j) ...
98 | ', Mean of delta phi = ' num2str(mean_phi) ' < eps: ' ...
99 | 'Early termination'])
100 | if i == pyramid_level && j == 1
101 | phi_delta = zeros(dim+2*opt_small.L);
102 | phi = phi_delta;
103 | end
104 | break;
105 | else
106 | disp(['Pyr ' num2str(i) ', Warp ' num2str(j) ...
107 | ', Mean of delta phi = ' num2str(mean_phi)])
108 | end
109 |
110 | % update phi
111 | if i == pyramid_level && j == 1
112 | phi = phi_delta;
113 | else
114 | phi = phi_delta + phi;
115 | end
116 | end
117 |
118 | if i > 0
119 | dim_next = scale.*dim + 2*opt.L{i};
120 | phi = prod( dim_next./(size(phi)) ) * imresize(phi, dim_next, 'bilinear');
121 | end
122 | end
123 |
124 | % singular case
125 | if nnz(phi) == 0
126 | disp('All estimations were rejected; set output to zero')
127 | phi = zeros(size(phi_delta));
128 | end
129 |
130 | % get center part
131 | phi = phi(1+opt.L{1}(1):end-opt.L{1}(1), 1+opt.L{1}(2):end-opt.L{1}(2));
132 |
133 | end
134 |
135 |
136 | % ADMM linear solver
137 | function [x, x_full] = main_ADMM_fast(img, beta, x, opt)
138 |
139 | %% define constants and operators
140 |
141 | % define operators in spatial domain
142 | nabla_x_kern = [0 0 0; -1 1 0; 0 0 0];
143 | nabla_y_kern = [0 -1 0; 0 1 0; 0 0 0];
144 | nabla2_kern = [0 1 0; 1 -4 1; 0 1 0];
145 |
146 | % define sizes
147 | M = size(img(:,:,1));
148 | L1 = (size(nabla_x_kern) - 1) / 2; L1 = L1.*opt.L;
149 | N = M + 2*L1;
150 |
151 | % boundary mask
152 | M1 = @(u) cat(3, u(L1(1)+1:end-L1(1),L1(2)+1:end-L1(2),1), ...
153 | u(L1(1)+1:end-L1(1),L1(2)+1:end-L1(2),2));
154 |
155 |
156 | %% formulate the problem
157 |
158 | % specify boundary conditions
159 | bc = 'symmetric'; % use DCT
160 |
161 | % define forward operators
162 | nabla = @(phi) cat(3, imfilter(phi, nabla_x_kern, bc), ...
163 | imfilter(phi, nabla_y_kern, bc));
164 |
165 | % define adjoint operators
166 | nablaT = @(grad_phi) imfilter(grad_phi(:,:,1), rot90(nabla_x_kern,2), bc) + ...
167 | imfilter(grad_phi(:,:,2), rot90(nabla_y_kern,2), bc);
168 |
169 |
170 | %% run proximal algorithm
171 |
172 | % initialization
173 | if isempty(x)
174 | x = zeros(N);
175 | end
176 | zeta = zeros([N 2]);
177 |
178 | % get the matrices
179 | [gt, gx, gy] = partial_deriv(img);
180 |
181 | % pre-compute
182 | mu = opt.mu;
183 | gxy = gx.*gy;
184 | gxx = gx.^2;
185 | gyy = gy.^2;
186 | denom = gxx + gyy + mu/2;
187 |
188 | % store in memory at run-time
189 | w11 = (mu/2 + gyy) ./ denom;
190 | w12 = - gxy ./ denom;
191 | w13 = - gx.*gt ./ denom;
192 | w21 = (mu/2 + gxx) ./ denom;
193 | w22 = - gxy ./ denom;
194 | w23 = - gy.*gt ./ denom;
195 |
196 | % proximal algorithm
197 | if opt.isverbose
198 | disp('start ADMM iteration ...')
199 | end
200 | objval = zeros(opt.iter,1);
201 | res = zeros(opt.iter,1);
202 | time = zeros(opt.iter,1);
203 |
204 | tic;
205 | for k = 1:opt.iter
206 | % x-update
207 | if k == 1
208 | H = N(1);
209 | W = N(2);
210 | [x_coord,y_coord] = meshgrid(0:W-1,0:H-1);
211 | mat_x_hat = - (mu+2*beta) * ...
212 | (2*cos(pi*x_coord/W) + 2*cos(pi*y_coord/H) - 4);
213 | mat_x_hat(1) = 1;
214 | else
215 | x = idct(idct(dct(dct(mu*nablaT(u-zeta)).').' ./ mat_x_hat).').';
216 | end
217 |
218 | % pre-compute nabla_x
219 | nabla_x = nabla(x);
220 |
221 | % u-update
222 | u = nabla_x + zeta;
223 | u_temp = M1(u);
224 | Mu_x = u_temp(:,:,1);
225 | Mu_y = u_temp(:,:,2);
226 |
227 | % update u
228 | u(L1(1)+1:end-L1(1), L1(2)+1:end-L1(2), 1) = ...
229 | w11.*Mu_x + w12.*Mu_y + w13;
230 | u(L1(1)+1:end-L1(1), L1(2)+1:end-L1(2), 2) = ...
231 | w21.*Mu_y + w22.*Mu_x + w23;
232 |
233 | % zeta-update
234 | zeta = zeta + nabla_x - u;
235 |
236 | % record
237 | if opt.isverbose
238 | if k == 1
239 | % define operators
240 | G = @(u) gx.*u(:,:,1) + gy.*u(:,:,2);
241 | nabla2 = @(phi) imfilter(phi, nabla2_kern, bc);
242 | GT = @(Gu) cat(3, gx.*Gu, gy.*Gu);
243 | M1T = @(Mu) padarray(Mu, L1);
244 |
245 | % define objective function
246 | obj = @(phi) sum(sum(abs(G(M1(nabla(phi)))+gt).^2)) + ...
247 | beta * sum(sum(sum(abs(nabla(phi)).^2)));
248 |
249 | % define its gradient
250 | grad = @(phi) nablaT(M1T(GT(G(M1(nabla(phi)))+gt))) + ...
251 | 2*beta*nabla2(phi);
252 | end
253 | objval(k) = obj(x);
254 | res(k) = sum(sum(abs(grad(x)).^2));
255 | time(k) = toc;
256 | disp(['ADMM iter: ' num2str(k) ...
257 | ', obj = ' num2str(objval(k),'%e') ...
258 | ', res = ' num2str(res(k),'%e')])
259 | end
260 | end
261 |
262 | % do median filtering at the last step
263 | temp = cat(3, medfilt2( u(:,:,1) - zeta(:,:,1), [3 3] ), ...
264 | medfilt2( u(:,:,2) - zeta(:,:,2), [3 3] ));
265 | x = idct(idct( dct(dct( mu*nablaT(temp) ).').' ./ mat_x_hat ).').';
266 | if opt.isverbose
267 | disp(['final objective: ' num2str(obj(x),'%e')])
268 | end
269 | toc
270 |
271 | % return masked x
272 | x_full = x;
273 | x = M1(cat(3,x,x));
274 | x = x(:,:,1);
275 |
276 | end
277 |
278 |
279 | % CG linear solver
280 | function [x, x_full] = main_cg_fast(img, beta, opt)
281 |
282 | %% define consts and operators
283 |
284 | % define operators in spatial domain
285 | nabla_x_kern = [0 0 0; 1 -1 0; 0 0 0];
286 | nabla_y_kern = [0 1 0; 0 -1 0; 0 0 0];
287 | nabla2_kern = -[0 1 0; 1 -4 1; 0 1 0];
288 |
289 | % define sizes
290 | M = size(img(:,:,1));
291 | L1 = (size(nabla_x_kern) - 1) / 2; L1 = L1.*opt.L;
292 | N = M + 2*L1;
293 |
294 |
295 | %% formulate the problem
296 |
297 | % specify boundary conditions
298 | boundary_cond = 'symmetric'; % use DCT
299 |
300 | % get the matrices
301 | [gt, gx, gy] = partial_deriv(img);
302 |
303 | % define forward operators
304 | nabla = @(phi) cat(3, imfilter(phi, nabla_x_kern, boundary_cond), ...
305 | imfilter(phi, nabla_y_kern, boundary_cond));
306 |
307 | % define adjoint operators
308 | nablaT = @(grad_phi) imfilter(grad_phi(:,:,1), rot90(nabla_x_kern,2), boundary_cond) + ...
309 | imfilter(grad_phi(:,:,2), rot90(nabla_y_kern,2), boundary_cond);
310 |
311 | % define Laplacian operator
312 | nabla2 = @(phi) imfilter(phi, nabla2_kern, boundary_cond);
313 |
314 | % boundary masks
315 | M1 = @(u) cat(3, u(L1(1)+1:end-L1(1),L1(2)+1:end-L1(2),1), ...
316 | u(L1(1)+1:end-L1(1),L1(2)+1:end-L1(2),2));
317 | MT = @(Mu) padarray(Mu, L1);
318 |
319 |
320 | %% run CG
321 |
322 | % define G and GT
323 | G = @(x) gx.*x(:,:,1) + gy.*x(:,:,2);
324 | GT = @(x) cat(3, gx.*x, gy.*x);
325 |
326 | % define A and b
327 | A = @(x) nablaT(MT(GT(G(M1(nabla(x)))))) + beta*nabla2(x);
328 | b = -nablaT(MT(GT(gt)));
329 |
330 | % initialize
331 | x = zeros(N);
332 |
333 | % get initial r and p
334 | r = b - A(x);
335 | p = r;
336 |
337 | % GGT
338 | disp('start CG iteration ...')
339 | objval = zeros(opt.iter,1);
340 | res = zeros(opt.iter,1);
341 | time = zeros(opt.iter,1);
342 |
343 | tic;
344 | for k = 1:opt.iter
345 | alpha = sum(abs(r(:)).^2) / sum(sum(conj(p).*A(p)));
346 | x = x + alpha*p;
347 | r_new = r - alpha*A(p);
348 | beta = sum(abs(r_new(:)).^2) / sum(abs(r(:)).^2);
349 | r = r_new;
350 | p = r + beta*p;
351 |
352 | if opt.isverbose
353 | if k == 1
354 | % define objective function
355 | obj = @(phi) sum(sum(abs(G(M1(nabla(phi)))+gt).^2)) + ...
356 | beta * sum(sum(sum(abs(nabla(phi)).^2)));
357 | end
358 | objval(k) = obj(x);
359 | res(k) = sum(abs(r(:)));
360 | time(k) = toc;
361 | disp(['CG iter: ' num2str(k) ...
362 | ', obj = ' num2str(objval(k),'%e') ...
363 | ', res = ' num2str(res(k),'%e')])
364 | else
365 | if ~mod(k,100)
366 | disp(['CG iter: ' num2str(k)])
367 | end
368 | end
369 | end
370 |
371 | % record final objective
372 | obj = @(phi) sum(sum(abs(G(M1(nabla(phi)))+gt).^2)) + ...
373 | beta * sum(sum(sum(abs(nabla(phi)).^2)));
374 | disp(['final objective: ' num2str(obj(x),'%e')])
375 | toc
376 |
377 | % return masked x
378 | x_full = x;
379 | x = M1(cat(3,x,x));
380 | x = x(:,:,1);
381 |
382 | end
383 |
384 |
385 | % function to get gx, gy and gt
386 | function [It,Ix,Iy] = partial_deriv(images)
387 |
388 | % derivative kernel
389 | h = [1 -8 0 8 -1]/12;
390 |
391 | % blending ratio
392 | b = 0.5;
393 |
394 | % get images
395 | img1 = images(:,:,1);
396 | img2 = images(:,:,2);
397 |
398 | % get gt
399 | It = images(:,:,2) - images(:,:,1);
400 |
401 | % first compute derivative then warp
402 | I2x = imfilter(img2, h, 'replicate');
403 | I2y = imfilter(img2, h', 'replicate');
404 |
405 | % temporal average
406 | I1x = imfilter(img1, h, 'replicate');
407 | I1y = imfilter(img1, h', 'replicate');
408 |
409 | % get gx and gy
410 | Ix = b*I2x + (1-b)*I1x;
411 | Iy = b*I2y + (1-b)*I1y;
412 |
413 | end
414 |
415 |
416 | % cubic spline interpolation function (backward: coeffcient -> image)
417 | function out = cbinterp(c,x,y)
418 |
419 | % calculate movements
420 | px = floor(x);
421 | fx = x - px;
422 | py = floor(y);
423 | fy = y - py;
424 |
425 | % define device functions
426 | w0 = @(a) (1/6)*(a.*(a.*(-a + 3) - 3) + 1);
427 | w1 = @(a) (1/6)*(a.*a.*(3*a - 6) + 4);
428 | w2 = @(a) (1/6)*(a.*(a.*(-3*a + 3) + 3) + 1);
429 | w3 = @(a) (1/6)*(a.*a.*a);
430 | r = @(x,c0,c1,c2,c3) c0.*w0(x) + c1.*w1(x) + c2.*w2(x) + c3.*w3(x);
431 |
432 | % define texture function
433 | tex = @(x,y) interp2(c,y,x,'nearest',0);
434 |
435 | % elementwise lookup
436 | out = r(fy, ...
437 | r(fx, tex(px-1,py-1), tex(px,py-1), tex(px+1,py-1), tex(px+2,py-1)), ...
438 | r(fx, tex(px-1,py ), tex(px,py ), tex(px+1,py ), tex(px+2,py )), ...
439 | r(fx, tex(px-1,py+1), tex(px,py+1), tex(px+1,py+1), tex(px+2,py+1)), ...
440 | r(fx, tex(px-1,py+2), tex(px,py+2), tex(px+1,py+2), tex(px+2,py+2)));
441 | end
442 |
443 |
444 | % cubic spline interpolation function (forward: image -> coeffcient)
445 | function c = cbanal(img)
446 |
447 | [m,n] = size(img);
448 |
449 | % A = toeplitz([4 1 zeros(1,m-2)]) / 6;
450 |
451 | c = zeros(m,n);
452 | for i = 1:n
453 | % c(:,i) = A \ y(:,i);
454 | c(:,i) = 6 * tridisolve(img(:,i));
455 | end
456 |
457 | end
458 |
459 |
460 | % triangular linear solver
461 | function x = tridisolve(d)
462 |
463 | % initialize
464 | x = d;
465 |
466 | % get length
467 | m = length(x);
468 |
469 | % define a, b and c
470 | b = 4*ones(m,1);
471 |
472 | % forward
473 | for j = 1:m-1
474 | mu = 1/b(j);
475 | b(j+1) = b(j+1) - mu;
476 | x(j+1) = x(j+1) - mu*x(j);
477 | end
478 |
479 | % backward
480 | x(m) = x(m)/b(m);
481 | for j = m-1:-1:1
482 | x(j) = (x(j)-x(j+1))/b(j);
483 | end
484 |
485 | end
486 |
--------------------------------------------------------------------------------
/matlab/speckle_pattern_baseline.m:
--------------------------------------------------------------------------------
1 | function [A, phi, D] = speckle_pattern_baseline(I0, I)
2 | % A: amplitude
3 | % phi: phase image
4 | % D: dark-field scattering
5 |
6 | % window size
7 | n = [3 3];
8 | h = ones(n);
9 |
10 | % calculate A2
11 | A = meanfilt(I,h) ./ meanfilt(I0,h);
12 | % A = medfilt2(A, n, 'symmetric');
13 |
14 | % calculate D
15 | D = stdfilt(I,h) ./ stdfilt(I0,h);
16 | D = D ./ A;
17 |
18 | % calculate local pixel shifts (wavefront slopes)
19 | [w, ~] = imregdemons(I, A.*I0);
20 |
21 | % integrate phi from w
22 | phi = poisson_solver(w(:,:,1), w(:,:,2));
23 | phi = phi - mean2(phi);
24 |
25 | % square to get amplitude
26 | A = sqrt(A);
27 |
28 | end
29 |
30 |
31 | function J = meanfilt(I, h)
32 | %MEANFILT Local mean of image.
33 |
34 | J = imfilter(I, h/sum(h(:)) , 'replicate');
35 |
36 | end
37 |
--------------------------------------------------------------------------------
/matlab/utils/LoadMetroProData.m:
--------------------------------------------------------------------------------
1 | %
2 | % function [ dat, xl, yl ] = LoadMetroProData( filename )
3 | %
4 | % Read MetroPro binary data file and returns in dat. If failed, dat is empty.
5 | %
6 | % x and y coordinates of data points are calculaed as (see NB below)
7 | % xl = dxy * ( -(size(dat,2)-1)/2 : (size(dat,2)-1)/2 );
8 | % yl = dxy * ( -(size(dat,1)-1)/2 : (size(dat,1)-1)/2 );
9 | %
10 | % dat(iy,jx) is the value at the location (x,y) = ( xl(jx), yl(iy) )
11 | % which is the convention of matlab
12 | %
13 | % The following commands will show data
14 | % plot( xl, dat(floor(end/2), :), yl, dat(:, floor(end/2) )
15 | % mesh( xl, yl, dat*1e9 )
16 | %
17 | % NB) aboslute values of coordinate is not useful,
18 | % because the mesurement origin (0,0) is located at the top-left corner
19 | % and not at the center of the mirror.
20 | %
21 | function [ dat, xl, yl ] = LoadMetroProData( filename )
22 |
23 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
24 | %
25 | % March 20, 2014
26 | % Zernike term/tilt handling has been removed, and this function works without other functions
27 | %
28 | % September 17, 2013
29 | % Third output argument, xl, added
30 | %
31 | % September 9, 2013 HY
32 | % zernike term removal added
33 | %
34 | % August 30, 2013 HY
35 | % 1) piston/tilt removal is added
36 | % 2) matrxi transversed to follow matlab convention
37 | %
38 | % MetroPro format stores data row major, starting from top-left
39 | % I.e., in units of spacing dxy, data at the following locations are stored in sequence
40 | % (x,y) = (1,1), (2,1), (3,1)...
41 | % When this series of data is read in into matlab array dat, the data at (i,j) is loaded at dat(i,j)
42 | % To use the matlab convention, i.e., x axis horizontal or second column index, matrix is transvered.
43 | % Now dat(iy, jx) is the value at phyical location (jx, iy)*dxy.
44 | %
45 | % When using a nominal x-y convention, positive directions of x and y axises are right and up.
46 | % The data need to be revered in the y-direction. Because ...
47 | % The first data at top left is at (x,y) = (1,1) and the last data at bot-right is at (N,N).
48 | % For x axis, this is the correct ordering (1 to N from left to right),
49 | % but, for y, it is opposite (1 to N from top to bottom).
50 | % To change the y ordering, data(iy, jx) needs to be data(N-iy+1, jx) or flipud.
51 | % This makes the phyical location of y for dat(1,j) is below of dat(2,j).
52 | %
53 | % August 19, 2013 Hiro Yamamoto
54 | % Based on MetroPro Reference Guide,
55 | % Section 12 "Data Format and Conversion" which covers
56 | % header format 2 and 3 and phase_res format 0, 1 and 2.
57 | %
58 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
59 |
60 | % try to open the file
61 | fid = fopen( filename, 'r', 'b' );
62 | if fid == -1
63 | error('File ''%s'' does not exist', filename);
64 | end;
65 |
66 | % read the header information
67 | hData = readerHeader( fid );
68 | if hData.hFormat < 0
69 | fclose(fid);
70 | error('Format unknown');
71 | end;
72 |
73 | % read the phasemap data
74 | % skip the header and intensity data
75 | fseek( fid, hData.hSize + hData.IntNBytes, -1 );
76 | [dat, count] = fread( fid, hData.XN*hData.YN, 'int32' );
77 | fclose( fid );
78 |
79 | if count ~= hData.XN*hData.YN
80 | error('data could not fully read');
81 | end
82 |
83 | % mark unmeasured data as NaN
84 | dat( dat >= hData.invalid ) = NaN;
85 | % scale data to unit of meter
86 | dat = dat * hData.convFactor;
87 | % reshape data to XN x YN matrux
88 | dat = reshape( dat, hData.XN, hData.YN );
89 | % transpose to make the matrix (NY, NX)
90 | dat = dat';
91 | % change the y-axis diretion
92 | dat = flipud( dat );
93 |
94 | % auxiliary data to return
95 | dxy = hData.CameraRes;
96 |
97 | if nargout >= 2
98 | xl = dxy * ( -(size(dat,2)-1)/2 : (size(dat,2)-1)/2 );
99 | end
100 | if nargout >= 3
101 | yl = dxy * ( -(size(dat,1)-1)/2 : (size(dat,1)-1)/2 );
102 | end
103 |
104 | return
105 |
106 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
107 |
108 | % read header information, and hData.hFormat = -1 when failed.
109 | % information of the header segment is from MetroPro Reference Guide
110 | % The manual covers format 1, 2 and 3, and this function fails if the data has unknown format.
111 |
112 | function hData = readerHeader( fid )
113 |
114 | % first, check the format information to make sure this is MetroPro binary data
115 | hData.hFormat = -1;
116 | [magicNum, count] = fread( fid, 1, 'uint32' );
117 | if count == 0
118 | return;
119 | end;
120 |
121 | [hData.hFormat, count] = readI16( fid );
122 | if count == 0
123 | return;
124 | end;
125 |
126 | [hData.hSize, count] = readI32( fid );
127 | if count == 0
128 | return;
129 | end;
130 |
131 | % check if the magic string and format are known ones
132 | if (hData.hFormat >= 1) && (hData.hFormat<=3) && (magicNum-hData.hFormat == hex2dec('881B036E'))
133 | % sprintf('MetroPro format %d with header size %d', hData.hFormat, hData.hSize )
134 | else
135 | % sprintf('====> warning format unknown : %d\n', hData.hFormat);
136 | hData.hFormat = -1;
137 | return;
138 | end
139 |
140 | % read necessary data
141 | hData.invalid = hex2dec('7FFFFFF8');
142 |
143 | % intensitity data, which we will skip over
144 | hData.IntNBytes = readI32( fid, 61-1 );
145 |
146 | % top-left coordinate, which are useless
147 | hData.X0 = readI16( fid, 65-1 );
148 | hData.Y0 = readI16( fid, 67-1 );
149 |
150 | % number of data points along x and y
151 | hData.XN = readI16( fid, 69-1 );
152 | hData.YN = readI16( fid, 71-1 );
153 |
154 | % total data, 4 * XN * YN
155 | hData.PhaNBytes = readI32( fid, 73-1 );
156 |
157 | % scale factor is determined by phase resolution tag
158 | phaseResTag = readI16( fid, 219-1 );
159 | switch phaseResTag
160 | case 0,
161 | phaseResVal = 4096;
162 |
163 | case 1,
164 | phaseResVal = 32768;
165 |
166 | case 2,
167 | phaseResVal = 131072;
168 |
169 | otherwise
170 | phaseResVal = 0;
171 | end
172 |
173 | hData.waveLength = readReal( fid, 169-1 );
174 | % Eq. in p12-6 in MetroPro Reference Guide
175 | hData.convFactor = readReal( fid, 165-1 ) * readReal( fid, 177-1 ) * hData.waveLength / phaseResVal;
176 |
177 | % bin size of each measurement
178 | hData.CameraRes = readReal( fid, 185-1 );
179 |
180 | return
181 |
182 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
183 |
184 | % utility to read data, which are stored in big-endian format
185 | function [val, count] = readI16( fid, offset )
186 | if nargin == 2
187 | fseek( fid, offset, -1 );
188 | end
189 | [val, count] = fread( fid, 1, 'int16' );
190 | return
191 |
192 | function [val, count] = readI32( fid, offset )
193 | if nargin == 2
194 | fseek( fid, offset, -1 );
195 | end
196 | [val, count] = fread( fid, 1, 'int32' );
197 | return
198 |
199 | function [val, count] = readReal( fid, offset )
200 | if nargin == 2
201 | fseek( fid, offset, -1 );
202 | end
203 | [val, count] = fread( fid, 1, 'float' );
204 | return
--------------------------------------------------------------------------------
/matlab/utils/poisson_solver.m:
--------------------------------------------------------------------------------
1 | function rec = poisson_solver(gx,gy)
2 | % A DCT-based Poisson solver to integrate surface from the gradients.
3 | %
4 | % Inputs:
5 | % - gx: gradient on x
6 | % - gy: gradient on y
7 | %
8 | % Output:
9 | % - rec: reconstructed surface
10 |
11 | % pad size
12 | wid = 1;
13 | gx = padarray(gx,[wid wid]);
14 | gy = padarray(gy,[wid wid]);
15 |
16 | % define operators in spatial domain
17 | nabla_x_kern = [0 0 0; -1 1 0; 0 0 0];
18 | nabla_y_kern = [0 -1 0; 0 1 0; 0 0 0];
19 |
20 | % specify boundary conditions
21 | bc = 'symmetric'; % use DCT
22 |
23 | % define adjoint operator
24 | nablaT = @(gx,gy) imfilter(gx, rot90(nabla_x_kern,2), bc) + ...
25 | imfilter(gy, rot90(nabla_y_kern,2), bc);
26 |
27 | % genereate inverse kernel
28 | [H,W] = size(gx);
29 | [x_coord,y_coord] = meshgrid(0:W-1,0:H-1);
30 | mat_x_hat = 2*cos(pi*x_coord/W) + 2*cos(pi*y_coord/H) - 4;
31 | mat_x_hat(1) = 1;
32 |
33 | % do inverse filtering
34 | rec = idct2( dct2(nablaT(gx,gy)) ./ -mat_x_hat );
35 | rec = rec(1+wid:end-wid,1+wid:end-wid);
36 |
37 | % redeem on boundary
38 | rec(1,:) = [];
39 | rec(end,:) = [];
40 | rec(:,1) = [];
41 | rec(:,end) = [];
42 | rec = padarray(rec, [1 1], 'replicate');
43 |
44 | % zero-normalize
45 | rec = rec - mean2(rec);
46 |
--------------------------------------------------------------------------------
/matlab/utils/tilt_removal.m:
--------------------------------------------------------------------------------
1 | function phi = tilt_removal(phi)
2 |
3 | phi = phi - min(phi(:));
4 | [x,y] = meshgrid(1:size(phi,2),1:size(phi,1));
5 | [~,~,~,S] = affine_fit(x,y,phi);
6 | phi = phi - S;
7 | phi = phi - min(phi(:));
8 |
9 | end
10 |
11 |
12 | function [n,V,p,S] = affine_fit(x,y,z)
13 | %Computes the plane that fits best (lest square of the normal distance
14 | %to the plane) a set of sample points.
15 | %INPUTS:
16 | %x: 2D x coordinates
17 | %y: 2D y coordinates
18 | %z: input 2D array to be fit
19 | % (org) X: a N by 3 matrix where each line is a sample point
20 | %
21 | %OUTPUTS:
22 | %
23 | %n : a unit (column) vector normal to the plane
24 | %V : a 3 by 2 matrix. The columns of V form an orthonormal basis of the
25 | %plane
26 | %p : a point belonging to the plane
27 | %S : the fitted affine plane
28 | %
29 | %NB: this code actually works in any dimension (2,3,4,...)
30 | %Author: Adrien Leygue
31 | %Date: August 30 2013
32 |
33 | X = [x(:) y(:) z(:)];
34 |
35 | %the mean of the samples belongs to the plane
36 | p = mean(X,1);
37 |
38 | %The samples are reduced:
39 | R = bsxfun(@minus,X,p);
40 | %Computation of the principal directions if the samples cloud
41 | [V,~] = eig(R'*R);
42 | %Extract the output from the eigenvectors
43 | n = V(:,1);
44 | V = V(:,2:end);
45 |
46 | % the fitted plane
47 | S = - (n(1)/n(3)*x + n(2)/n(3)*y - dot(n,p)/n(3));
48 | end
49 |
--------------------------------------------------------------------------------
/scripts/Figure2.m:
--------------------------------------------------------------------------------
1 | clc;clear;close all;
2 | addpath('../matlab/');
3 | addpath('../matlab/utils/');
4 |
5 | % input data path
6 | dpath = '../data/MLA-150-7AR-M/';
7 | r = imread([dpath 'ref.tif']);
8 | s = imread([dpath 'cap.tif']);
9 | n = 1.46; % refractive index
10 |
11 | % parameters
12 | pixel_size = 6.45; % [um]
13 | z = 1.43e3; % [um]
14 | scale_factor = pixel_size^2/z;
15 | map = jet(256);
16 |
17 | % read data
18 | r = double(r)/2^14 * 255;
19 | s = double(s)/2^14 * 255;
20 |
21 |
22 | %% Methods
23 |
24 | %%% Slope-tracking
25 | [w, ~] = imregdemons(s, r, 200);
26 | phi_tracking = poisson_solver(w(:,:,1), w(:,:,2));
27 | phi_tracking = phi_tracking - mean2(phi_tracking);
28 | phi_tracking = tilt_removal(phi_tracking/(n-1)*scale_factor);
29 |
30 |
31 | %%% Wang et al.
32 | opt.isverbose = 0;
33 | opt.L = {[20 20]};
34 | opt.mu = 100;
35 | opt.iter = 30;
36 | warping_iter = 2;
37 | beta = 1;
38 | phi_wang = main_wavefront_solver(cat(3, r, s), beta, opt, warping_iter);
39 | phi_wang = tilt_removal(phi_wang/(n-1)*scale_factor);
40 |
41 |
42 | %%% Baseline
43 | [A_base, phi_base, D_base] = speckle_pattern_baseline(r, s);
44 | phi_base = tilt_removal(phi_base/(n-1)*scale_factor);
45 |
46 |
47 | %%% Ours
48 | opt_cws.priors = [0.5 0.5 100 5];
49 | [A_ours, phi, wavefront_lap, I_warp] = cws(r, s, opt_cws);
50 | A_ours = A_ours .* (1 + pixel_size/z*wavefront_lap);
51 | A_ours = sqrt(A_ours);
52 | I = A_ours; % amplitude
53 | phi = tilt_removal(phi/(n-1)*scale_factor);
54 |
55 | % denoise a little bit ...
56 | phi = medfilt2(phi, [3 3], 'symmetric');
57 |
58 |
59 | %% Show results
60 |
61 | % normalize to start from 0
62 | phi_tracking = phi_tracking - min(phi_tracking(:));
63 | phi_wang = phi_wang - min(phi_wang(:));
64 | phi_base = phi_base - min(phi_base(:));
65 | phi = phi - min(phi(:));
66 |
67 | % show results
68 | figure;imshow(phi_tracking, []);
69 | axis tight ij;colormap(map);pause(0.2);
70 | title('Berto et al. 2017');
71 |
72 | figure;imshow(phi_wang, []);
73 | axis tight ij;colormap(map);pause(0.2)
74 | title('Wang et al. 2017');
75 |
76 | figure;imshow(phi_base, []);
77 | axis tight ij;colormap(map);pause(0.2)
78 | title('Berujon et al. 2015');
79 |
80 | figure;imshow(phi, []);
81 | axis tight ij;colormap(map);pause(0.2)
82 | title('Ours');
83 |
84 |
85 | %% Cross-sections
86 |
87 | % for MLA
88 | % set parameters
89 | x1=253; y1=360;
90 | x2=670; y2=585;
91 | x3=1082; y3=800;
92 |
93 | % set two points for cross-sectioning
94 | xx = [x2 x3];
95 | yy = [y2 y3];
96 |
97 | % get the indexes
98 | a = (yy(2)-yy(1))/(xx(2)-xx(1));
99 | b = yy(1) - a*xx(1);
100 | x = xx(1):xx(2);
101 | y = round(a*x+b);
102 | ind = sub2ind(size(phi),y,x);
103 |
104 | % show the points
105 | figure; imshow(phi, []); axis tight;
106 | hold on; line([xx(1) xx(2)],[yy(1) yy(2)],'Color',[1 0 0]);
107 |
108 | % get cross-section
109 | get_c = @(phi) phi(ind)';
110 |
111 | % get the maxs
112 | max(phi_tracking(:))
113 | max(phi_wang(:))
114 | max(phi_base(:))
115 | max(phi(:))
116 | max_phi = round(max(phi(:)),2);
117 |
118 | % get cross-sections
119 | t_tracking = get_c(phi_tracking);
120 | t_wang = get_c(phi_wang);
121 | t_base = get_c(phi_base);
122 | t_ours = get_c(phi);
123 |
124 | % min-normalized
125 | t_tracking = t_tracking - min(t_tracking);
126 | t_wang = t_wang - min(t_wang);
127 | t_base = t_base - min(t_base);
128 | t_ours = t_ours - min(t_ours);
129 |
130 | % x coordinates
131 | tt = 1:length(t_tracking);
132 | tt = tt - mean(tt);
133 | tt = tt * pixel_size/20 / 0.9;
134 |
135 | % show plots
136 | figure; plot(tt, [t_tracking t_wang t_base t_ours],'LineWidth',2);
137 | axis tight;
138 | legend('Berto et al. 2017','Wang et al. 2017','Berujon et al. 2015','Ours');
139 |
140 |
141 | %% Comparison with Zygo data
142 |
143 | % read data
144 | [dat, xl, yl] = LoadMetroProData([dpath 'Zygo.dat']);
145 | dat(isnan(dat)) = mean(dat(~isnan(dat)));
146 |
147 | % tilt removal
148 | dat = tilt_removal(dat*1e6);
149 |
150 | % show raw data
151 | figure; imshow(dat,[],'i','f'); axis tight on
152 | colormap(map);colorbar;
153 | title('Zygo raw');
154 |
155 | % correspond region
156 | xx = [188 222];
157 | yy = [157 155];
158 |
159 | % get the indices
160 | a = (yy(2)-yy(1))/(xx(2)-xx(1));
161 | b = yy(1) - a*xx(1);
162 | x = xx(1):xx(2);
163 | y = round(a*x+b);
164 | ind = sub2ind(size(dat),y,x);
165 | t_gt = dat(ind);
166 | t_gt = t_gt - min(t_gt);
167 |
168 | % x coordinates
169 | tt = 1:length(t_tracking);
170 | tt = tt - mean(tt);
171 | tt = tt * pixel_size/20 / 0.9; % compensation for small misalignment
172 |
173 | % x coordinates
174 | tt = imresize(tt, [1 numel(t_gt)]);
175 |
176 | % plot
177 | figure; plot(tt, t_gt, 'o', 'LineWidth',2);
178 | title('Zygo plot')
179 |
180 | % calculate RMS
181 | t_gt = imresize(t_gt', [numel(t_tracking) 1]);
182 | calc_rms = @(x) sqrt(mean(abs(x - t_gt).^2));
183 |
184 | % show RMS for each method
185 | disp('RMS is:');
186 | disp(['Berto et al. 2017: ' num2str(calc_rms(t_tracking)) ' um']);
187 | disp(['Wang et al. 2017: ' num2str(calc_rms(t_wang)) ' um']);
188 | disp(['Berujon et al. 2015: ' num2str(calc_rms(t_base)) ' um']);
189 | disp(['Ours: ' num2str(calc_rms(t_ours)) ' um']);
190 |
191 |
--------------------------------------------------------------------------------
/scripts/Figure3.m:
--------------------------------------------------------------------------------
1 | clc;clear;close all;
2 | addpath('../matlab/');
3 | addpath('../matlab/utils/');
4 |
5 | % input data path
6 | dpath = '../data/blood/';
7 | r = imread([dpath 'ref_1.tif']);
8 | s = imread([dpath 'cap_1.tif']);
9 | n = 2; % we are doing OPD here
10 |
11 | % parameters
12 | pixel_size = 6.45; % [um]
13 | z = 1.43e3; % [um]
14 | scale_factor = pixel_size^2/z;
15 | map = jet(256);
16 |
17 | % read data
18 | r = double(r)/2^14 * 255;
19 | s = double(s)/2^14 * 255;
20 |
21 |
22 | %% Methods
23 |
24 | %%% Slope-tracking
25 | [w, ~] = imregdemons(s, r, 200);
26 | phi_tracking = poisson_solver(w(:,:,1), w(:,:,2));
27 | phi_tracking = phi_tracking - mean2(phi_tracking);
28 | phi_tracking = tilt_removal(phi_tracking/(n-1)*scale_factor);
29 |
30 |
31 | %%% Wang et al.
32 | opt.isverbose = 0;
33 | opt.L = {[20 20]};
34 | opt.mu = 100;
35 | opt.iter = 30;
36 | warping_iter = 2;
37 | beta = 1;
38 | phi_wang = main_wavefront_solver(cat(3, r, s), beta, opt, warping_iter);
39 | phi_wang = tilt_removal(phi_wang/(n-1)*scale_factor);
40 |
41 |
42 | %%% Baseline
43 | [A_base, phi_base, D_base] = speckle_pattern_baseline(r, s);
44 | phi_base = tilt_removal(phi_base/(n-1)*scale_factor);
45 |
46 |
47 | %%% Ours
48 | opt_cws.priors = [0.5 0.5 100 5];
49 | [A_ours, phi, wavefront_lap, I_warp] = cws(r, s, opt_cws);
50 | A_ours = A_ours .* (1 + pixel_size/z*wavefront_lap);
51 | A_ours = sqrt(A_ours);
52 | I = A_ours; % amplitude
53 | phi = tilt_removal(phi/(n-1)*scale_factor);
54 |
55 | % denoise a little bit ...
56 | phi = medfilt2(phi, [3 3], 'symmetric');
57 |
58 |
59 | %% Show results
60 |
61 | % normalize to start from 0
62 | phi_tracking = phi_tracking - min(phi_tracking(:));
63 | phi_wang = phi_wang - min(phi_wang(:));
64 | phi_base = phi_base - min(phi_base(:));
65 | phi = phi - min(phi(:));
66 |
67 | % show results
68 | figure;imshow(phi_tracking, []);
69 | axis tight ij;colormap(map);pause(0.2);
70 | title('Berto et al. 2017');
71 |
72 | figure;imshow(phi_wang, []);
73 | axis tight ij;colormap(map);pause(0.2)
74 | title('Wang et al. 2017');
75 |
76 | figure;imshow(phi_base, []);
77 | axis tight ij;colormap(map);pause(0.2)
78 | title('Berujon et al. 2015');
79 |
80 | figure;imshow(phi, []);
81 | axis tight ij;colormap(map);pause(0.2)
82 | title('Ours');
83 |
84 |
85 | %% Cross-sections
86 |
87 | % function handles
88 | get_c = @(phi) phi(191,390:550)';
89 | min_n = @(x) x - min(x(:));
90 |
91 | % get the maxs
92 | max(phi_tracking(:))
93 | max(phi_wang(:))
94 | max(phi_base(:))
95 | max(phi(:))
96 |
97 | % get cross-sections
98 | t_tracking = get_c(phi_tracking);
99 | t_wang = get_c(phi_wang);
100 | t_base = get_c(phi_base);
101 | t_ours = get_c(phi);
102 |
103 | % show plots
104 | figure;
105 | plot([min_n(t_tracking) min_n(t_wang) min_n(t_base) min_n(t_ours)],'LineWidth',2);
106 | axis tight;
107 | legend('Berto et al. 2017','Wang et al. 2017','Berujon et al. 2015','Ours');
108 |
109 |
--------------------------------------------------------------------------------
/scripts/Figure4.m:
--------------------------------------------------------------------------------
1 | clear;close all;
2 | addpath('../matlab/');
3 | addpath('../matlab/utils/');
4 |
5 | % data paths
6 | samples = {'blood', 'cheek', 'HeLa', 'MCF-7'};
7 | dpaths = cellfun(@(x) ['../data/' x '/'], samples, 'un',0);
8 |
9 | % parameters
10 | pixel_size = 6.45; % [um]
11 | z = 1.43e3; % [um]
12 | scale_factor = pixel_size^2/z;
13 |
14 | % loop for each sub-figure
15 | A_final = cell(4,1);
16 | phi_final = cell(4,1);
17 | for i = 1:length(phi_final)
18 | disp(['Running sub-figure ' num2str(i) '/' num2str(length(phi_final))])
19 | dpath = dpaths{i};
20 |
21 | % read data
22 | if i == 1
23 | r = imread([dpath 'ref_2.tif']);
24 | s = imread([dpath 'cap_2.tif']);
25 | else
26 | r = imread([dpath 'ref.tif']);
27 | s = imread([dpath 'cap.tif']);
28 | end
29 | r = double(r)/2^14 * 255;
30 | s = double(s)/2^14 * 255;
31 |
32 | % our solver
33 | [A_ours, phi, wavefront_lap, I_warp] = cws(r, s);
34 | A_ours = sqrt(A_ours .* (1 + pixel_size/z*wavefront_lap));
35 | A_final{i} = A_ours; % amplitude
36 | phi = tilt_removal(phi*scale_factor); % OPD
37 | phi_final{i} = phi - min(phi(:));
38 | end
39 |
40 |
--------------------------------------------------------------------------------