├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── environment.yaml
├── sample_images
├── clean_1000iters.png
└── noisy.png
└── src
├── cscnet.py
├── example.yaml
├── fmd_dataloader.py
├── inference.py
├── model.py
├── neighbor2neighbor.py
├── run_folder.py
└── training.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 | data/
162 | results/
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM continuumio/anaconda3
2 |
3 | # Install system dependencies
4 | RUN apt-get update \
5 | && DEBIAN_FRONTEND=noninteractive apt-get install -y \
6 | build-essential \
7 | curl \
8 | git \
9 | && apt-get clean
10 |
11 | # Install python miniconda3 + requirements
12 | ENV MINICONDA_HOME="/opt/miniconda"
13 | ENV PATH="${MINICONDA_HOME}/bin:${PATH}"
14 | COPY environment.yaml environment.yaml
15 | RUN conda env create --name p2s --file=environment.yaml
16 |
17 | #
18 | WORKDIR /Poisson2Sparse
19 | COPY src /Poisson2Sparse
20 |
21 | ENTRYPOINT [ "bash"]
22 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Poisson2Sparse
2 | Official Implementation of Poisson2Sparse - MICCAI 2022
3 |
4 | WARNING REPO IS STILL UNDER DEVELOPMENT AND NEEDS TO BE CLEANEDUP
5 |
6 |
7 | # Docker Instructions
8 | After building and running the container activate the conda environment
9 | `conda activate p2s` before running the inference.py script.
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: deep-image-prior
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - cudatoolkit=11.3.1=h9edb442_10
8 | - numpy=1.22.4=py39hc58783e_0
9 | - pillow=9.1.1=py39hae2aec6_1
10 | - pip=22.1.2=pyhd8ed1ab_0
11 | - python=3.9.13=h9a8a25e_0_cpython
12 | - pytorch=1.11.0=py3.9_cuda11.3_cudnn8.2.0_0
13 | - pytorch-mutex=1.0=cuda
14 | - torchvision=0.12.0=py39_cu113
15 | - pip:
16 | - icecream==2.1.2
17 | - imageio==2.19.3
18 | - kornia==0.6.5
19 | - matplotlib==3.5.2
20 | - pywavelets==1.3.0
21 | - pyyaml==6.0
22 | - scikit-image==0.19.2
23 | - scipy==1.8.1
24 | - six==1.16.0
25 | - tifffile==2022.5.4
26 | - tqdm==4.64.0
27 | prefix: /home/cta/anaconda3/envs/deep-image-prior
28 |
--------------------------------------------------------------------------------
/sample_images/clean_1000iters.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tacalvin/Poisson2Sparse/f2b18acf755a204f44964cdbd8f1467cbcd8f991/sample_images/clean_1000iters.png
--------------------------------------------------------------------------------
/sample_images/noisy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tacalvin/Poisson2Sparse/f2b18acf755a204f44964cdbd8f1467cbcd8f991/sample_images/noisy.png
--------------------------------------------------------------------------------
/src/cscnet.py:
--------------------------------------------------------------------------------
1 | # Credit goes to https://github.com/drorsimon/CSCNet
2 | # Simon, Dror, and Michael Elad. "Rethinking the CSC model for natural images." Advances in Neural Information Processing Systems 32 (2019).
3 |
4 | from collections import namedtuple
5 |
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | ListaParams = namedtuple('ListaParams', ['kernel_size', 'num_filters', 'stride', 'unfoldings', 'channels'])
13 | def calc_pad_sizes(I: torch.Tensor, kernel_size: int, stride: int):
14 | left_pad = stride
15 | right_pad = 0 if (I.shape[3] + left_pad - kernel_size) % stride == 0 else stride - ((I.shape[3] + left_pad - kernel_size) % stride)
16 | top_pad = stride
17 | bot_pad = 0 if (I.shape[2] + top_pad - kernel_size) % stride == 0 else stride - ((I.shape[2] + top_pad - kernel_size) % stride)
18 | right_pad += stride
19 | bot_pad += stride
20 | return left_pad, right_pad, top_pad, bot_pad
21 |
22 | class SoftThreshold(nn.Module):
23 | def __init__(self, size, init_threshold=1e-3):
24 | super(SoftThreshold, self).__init__()
25 | self.threshold = nn.Parameter(init_threshold * torch.ones(1,size,1,1))
26 |
27 | def forward(self, x):
28 | mask1 = (x > self.threshold).float()
29 | mask2 = (x < -self.threshold).float()
30 | out = mask1.float() * (x - self.threshold)
31 | out += mask2.float() * (x + self.threshold)
32 | return out
33 |
34 |
35 |
36 |
37 |
38 | class ConvLista_T(nn.Module):
39 | def __init__(self, params: ListaParams, A=None, B=None, C=None, threshold=1e-2, norm=False):
40 | super(ConvLista_T, self).__init__()
41 | if A is None:
42 | A = torch.randn(params.num_filters, params.channels, params.kernel_size, params.kernel_size)
43 | l = conv_power_method(A, [128, 128], num_iters=20, stride=params.stride)
44 | # l = conv_power_method(A, [28,28], num_iters=200, stride=params.stride)
45 | A /= torch.sqrt(l)
46 | if B is None:
47 | B = torch.clone(A)
48 | if C is None:
49 | C = torch.clone(A)
50 | self.apply_A = torch.nn.ConvTranspose2d(params.num_filters, params.channels, kernel_size=params.kernel_size,
51 | stride=params.stride, bias=False)
52 | self.apply_B = torch.nn.Conv2d(params.channels, params.num_filters, kernel_size=params.kernel_size, stride=params.stride, bias=False)
53 | self.apply_C = torch.nn.ConvTranspose2d(params.num_filters, params.channels, kernel_size=params.kernel_size,
54 | stride=params.stride, bias=False)
55 | self.apply_A.weight.data = A
56 | self.apply_B.weight.data = B
57 | self.apply_C.weight.data = C
58 | self.soft_threshold = SoftThreshold(params.num_filters, threshold)
59 | self.params = params
60 | self.num_iter = params.unfoldings
61 | # self.norm = norm
62 | # if self.norm:
63 | # self.norm_layer = torch.nn.InstanceNorm2d(params.num_filters)
64 | # self.norm_layer = torch.nn.
65 |
66 | def _split_image(self, I):
67 | if self.params.stride == 1:
68 | return I, torch.ones_like(I)
69 | left_pad, right_pad, top_pad, bot_pad = calc_pad_sizes(I, self.params.kernel_size, self.params.stride)
70 | I_batched_padded = torch.zeros(I.shape[0], self.params.stride ** 2, I.shape[1], top_pad + I.shape[2] + bot_pad,
71 | left_pad + I.shape[3] + right_pad).type_as(I)
72 | valids_batched = torch.zeros_like(I_batched_padded)
73 | for num, (row_shift, col_shift) in enumerate([(i, j) for i in range(self.params.stride) for j in range(self.params.stride)]):
74 | I_padded = F.pad(I, pad=(
75 | left_pad - col_shift, right_pad + col_shift, top_pad - row_shift, bot_pad + row_shift), mode='reflect')
76 | valids = F.pad(torch.ones_like(I), pad=(
77 | left_pad - col_shift, right_pad + col_shift, top_pad - row_shift, bot_pad + row_shift), mode='constant')
78 | I_batched_padded[:, num, :, :, :] = I_padded
79 | valids_batched[:, num, :, :, :] = valids
80 | I_batched_padded = I_batched_padded.reshape(-1, *I_batched_padded.shape[2:])
81 | valids_batched = valids_batched.reshape(-1, *valids_batched.shape[2:])
82 | return I_batched_padded, valids_batched
83 |
84 | def disable_warmup(self):
85 | self.num_iter = self.params.unfoldings
86 |
87 | def enable_warmup(self):
88 | self.num_iter = 1
89 |
90 | def forward(self, I):
91 | I_batched_padded, valids_batched = self._split_image(I)
92 | conv_input = self.apply_B(I_batched_padded) #encode
93 | gamma_k = self.soft_threshold(conv_input)
94 | # ic(gamma_k.shape)
95 | for k in range(self.num_iter - 1):
96 | x_k = self.apply_A(gamma_k) # decode
97 | # r_k = self.apply_B(x_k-I_batched_padded) #encode
98 | r_k = self.apply_B(x_k-I_batched_padded) #encode
99 | # if self.norm:
100 | # r_k = self.norm_layer(r_k)
101 | #bug? try adding
102 | gamma_k = self.soft_threshold(gamma_k - r_k)
103 | output_all = self.apply_C(gamma_k)
104 | output_cropped = torch.masked_select(output_all, valids_batched.bool()).reshape(I.shape[0], self.params.stride ** 2, *I.shape[1:])
105 | # if self.return_all:
106 | # return output_cropped
107 | output = output_cropped.mean(dim=1, keepdim=False)
108 | # output = F.relu(output)
109 | return torch.clamp(output,0.0,1.0)
110 |
111 |
112 |
113 | def conv_power_method(D, image_size, num_iters=100, stride=1):
114 | """
115 | Finds the maximal eigenvalue of D.T.dot(D) using the iterative power method
116 | :param D:
117 | :param num_needles:
118 | :param image_size:
119 | :param patch_size:
120 | :param num_iters:
121 | :return:
122 | """
123 | needles_shape = [int(((image_size[0] - D.shape[-2])/stride)+1), int(((image_size[1] - D.shape[-1])/stride)+1)]
124 | x = torch.randn(1, D.shape[0], *needles_shape).type_as(D)
125 | for _ in range(num_iters):
126 | c = torch.norm(x.reshape(-1))
127 | x = x / c
128 | y = F.conv_transpose2d(x, D, stride=stride)
129 | x = F.conv2d(y, D, stride=stride)
130 | return torch.norm(x.reshape(-1))
131 |
132 |
133 |
134 |
135 |
--------------------------------------------------------------------------------
/src/example.yaml:
--------------------------------------------------------------------------------
1 | dev: false
2 | experiment_cfg:
3 | LAM: 2
4 | cuda: true
5 | dataset:
6 | dataset_path: ./data/PINCAT10
7 | extension: png
8 | gtandraw: true
9 | resize: false
10 | input_type: noise
11 | l1: 1.0e-05 #l1 regularization
12 | lr: 0.0001
13 | num_iter: 5500
14 | optimizer: Adam
15 | poisson_loss: true
16 | experiment_pipeline: ours
17 | model_cfg:
18 | channels: 1
19 | kernel_size: 3
20 | norm: false
21 | num_filters: 512
22 | num_iter: 10
23 | stride: 1
24 | threshold: 0.01
25 | output_dir: ./results/PINCAT10/
26 |
--------------------------------------------------------------------------------
/src/fmd_dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from PIL import Image
4 |
5 | import torch
6 | import torchvision
7 | from torchvision import transforms
8 | import numpy as np
9 | from torchvision.transforms.transforms import CenterCrop
10 |
11 | from glob import glob
12 |
13 |
14 | class FMDDataset():
15 | def __init__(self, path):
16 | self.path = path
17 | self.gt_root = os.path.join(path, 'gt/')
18 | self.raw_root = os.path.join(path, 'raw/')
19 | self.transforms = transforms.Compose([
20 | transforms.ToTensor()
21 | ])
22 |
23 | self.data = []
24 |
25 | # 20 elements
26 | gt_elements = glob(self.gt_root+'*')
27 | # print(self.gt_root + "*")
28 | # print(gt_elements)
29 | gt_elements.sort()
30 | # print(gt_elements)
31 | raw_elements = glob(self.raw_root + '*')
32 | raw_elements.sort()
33 | for i in range(len(gt_elements)):
34 | self.data.append( {
35 | 'gt': gt_elements[i],
36 | 'raw': raw_elements[i]
37 | })
38 |
39 | def __len__(self):
40 | return len(self.data)
41 |
42 | def __getitem__(self, index):
43 | # img_files = glob(self.data[index]['raw']+'*png')
44 | img_files = self.data[index]['raw']
45 | # print(self.data[index]['raw']+'*png')
46 | # img_files.sort()
47 | img = self.transforms(Image.open(img_files))
48 | gt = self.transforms(Image.open(self.data[index]['gt'] ))
49 |
50 | # print(img, torch.max(img), torch.min(img))
51 | # quit()
52 | return img, gt
53 |
--------------------------------------------------------------------------------
/src/inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import torch.nn as nn
4 | import torchvision.io as io
5 | import torchvision.transforms.functional as TF
6 | import torch.nn.functional as F
7 | from tqdm import trange
8 |
9 | import os
10 | from neighbor2neighbor import generate_mask_pair, generate_subimages
11 | from model import build_model
12 | from kornia.metrics import psnr
13 | import yaml
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument(
17 | "--cfg_path", default="src/example.yaml", help="Model and Hyperparamter Config"
18 | )
19 | parser.add_argument("--input_path", required=True, help="Path to image to denoise")
20 | parser.add_argument("--output_path", required=True, help="Path to save denoised image")
21 |
22 |
23 |
24 | class Loader(yaml.SafeLoader):
25 |
26 | def __init__(self, stream):
27 |
28 | self._root = os.path.split(stream.name)[0]
29 |
30 | super(Loader, self).__init__(stream)
31 |
32 | def include(self, node):
33 |
34 | filename = self.construct_scalar(node)
35 |
36 | with open(filename, 'r') as f:
37 | return yaml.load(f, Loader)
38 |
39 |
40 | Loader.add_constructor('!include', Loader.include)
41 |
42 |
43 | def main(noisy, config, experiment_cfg):
44 | model = build_model(config)
45 | device = "cuda" if torch.cuda.is_available() else "cpu"
46 | model.to(device)
47 | print(
48 | "Number of params: ",
49 | sum(p.numel() for p in model.parameters() if p.requires_grad),
50 | )
51 | # optimizer
52 | if experiment_cfg["optimizer"] == "Adam":
53 | LR = experiment_cfg["lr"]
54 | optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
55 |
56 | # psnr_list = []
57 | # loss_list = []
58 | # ssims_list = []
59 | exp_weight = 0.99
60 |
61 | out_avg = None
62 |
63 | noisy_in = noisy
64 | noisy_in = noisy_in.to(device)
65 |
66 | H = None
67 | W = None
68 | # if noisy.shape[1] != noisy.shape[2]:
69 | # H = noisy.shape[2]
70 | # W = noisy.shape[3]
71 | # val_size = (max(H, W) + 31) // 32 * 32
72 | # noisy_in = TF.pad(
73 | # noisy,
74 | # (0, 0, val_size - noisy.shape[3], val_size - noisy.shape[2]),
75 | # padding_mode="reflect",
76 | # )
77 |
78 | t = trange(experiment_cfg["num_iter"])
79 | pll = nn.PoissonNLLLoss(log_input=False, full=True)
80 | last_net = None
81 | psrn_noisy_last = 0.0
82 | for i in t:
83 |
84 | mask1, mask2 = generate_mask_pair(noisy_in)
85 | mask1 = mask1.to(device)
86 | mask2 = mask2.to(device)
87 | with torch.no_grad():
88 | noisy_denoised = model(noisy_in)
89 | noisy_denoised = torch.clamp(noisy_denoised, 0.0, 1.0)
90 |
91 | noisy_in_aug = noisy_in.clone()
92 | # ic(noisy_in_aug.shape, mask1.shape, noisy_in.shape)
93 | noisy_sub1 = generate_subimages(noisy_in_aug, mask1)
94 | noisy_sub2 = generate_subimages(noisy_in_aug, mask2)
95 |
96 | noisy_sub1_denoised = generate_subimages(noisy_denoised, mask1)
97 | noisy_sub2_denoised = generate_subimages(noisy_denoised, mask2)
98 |
99 | noisy_output = model(noisy_sub1)
100 | noisy_output = torch.clamp(noisy_output, 0.0, 1.0)
101 | noisy_target = noisy_sub2
102 |
103 | Lambda = experiment_cfg["LAM"]
104 | diff = noisy_output - noisy_target
105 | exp_diff = noisy_sub1_denoised - noisy_sub2_denoised
106 |
107 | if "l1" in experiment_cfg.keys():
108 | l1_regularization = 0.0
109 | for param in model.parameters():
110 | l1_regularization += param.abs().sum()
111 | total_loss = experiment_cfg["l1"] * l1_regularization
112 | # else:
113 | if "poisson_loss" in experiment_cfg.keys():
114 | loss1 = pll(noisy_output, noisy_target)
115 | loss2 = F.l1_loss(noisy_output, noisy_target)
116 | loss1 += loss2
117 | elif "poisson_loss_only" in experiment_cfg.keys():
118 | loss1 = pll(noisy_output, noisy_target)
119 | elif "l1_loss" in experiment_cfg.keys():
120 | loss1 = F.l1_loss(noisy_output, noisy_target)
121 |
122 | elif "mse" in experiment_cfg.keys():
123 | loss1 = torch.mean(diff ** 2)
124 | else:
125 | loss1 = F.l1_loss(noisy_output, noisy_target)
126 | # orch.mean(diff**2)
127 | loss2 = Lambda * torch.mean((diff - exp_diff) ** 2)
128 |
129 | loss = loss1 + loss2
130 | if "l1" in experiment_cfg.keys():
131 | loss += total_loss
132 | loss.backward()
133 |
134 | with torch.no_grad():
135 | out_full = model(noisy_in).detach().cpu()
136 | if H is not None:
137 | out_full = out_full[:, :, :H, :W]
138 | if out_avg is None:
139 | out_avg = out_full.detach().cpu()
140 | else:
141 | out_avg = out_avg * exp_weight + out_full * (1 - exp_weight)
142 | out_avg = out_avg.detach().cpu()
143 | noisy_psnr = psnr(out_full, noisy_in.detach().cpu(), max_val=1.0).item()
144 |
145 | if (i + 1) % 50:
146 | if noisy_psnr - psrn_noisy_last < -4 and last_net is not None:
147 | print("Falling back to previous checkpoint.")
148 |
149 | for new_param, net_param in zip(last_net, model.parameters()):
150 | net_param.data.copy_(new_param.cuda())
151 |
152 | total_loss = total_loss * 0
153 | optimizer.zero_grad()
154 | torch.cuda.empty_cache()
155 | continue
156 | else:
157 | last_net = [x.detach().cpu() for x in model.parameters()]
158 | psrn_noisy_last = noisy_psnr
159 |
160 | optimizer.step()
161 | optimizer.zero_grad()
162 |
163 | with torch.no_grad():
164 | out_full = model(noisy_in).detach().cpu()
165 | if H is not None:
166 | out_full = out_full[:, :, :H, :W]
167 | if out_avg is None:
168 | out_avg = out_full.detach().cpu()
169 | else:
170 | out_avg = out_avg * exp_weight + out_full * (1 - exp_weight)
171 | out_avg = out_avg.detach().cpu()
172 |
173 | return out_avg
174 |
175 |
176 | if __name__ == "__main__":
177 | args = parser.parse_args()
178 |
179 | with open(args.cfg_path, "r") as f:
180 | cfg = yaml.load(f, Loader=Loader)
181 |
182 | noisy = io.read_image(args.input_path).unsqueeze(0)/255
183 |
184 | out_image = main(noisy, cfg, cfg['experiment_cfg']) * 255
185 | out_image = out_image.type(torch.uint8).squeeze(0)
186 | io.write_png(out_image, args.output_path)
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 | from cscnet import ConvLista_T, ListaParams
2 |
3 | def build_model(cfg):
4 | params = ListaParams(cfg['model_cfg']['kernel_size'], cfg['model_cfg']['num_filters'], cfg['model_cfg']['stride'],
5 | cfg['model_cfg']['num_iter'], cfg['model_cfg']['channels'])
6 | net = ConvLista_T(params, threshold=cfg['model_cfg']['threshold'], norm=cfg['model_cfg']['norm'])
7 | return net
8 |
--------------------------------------------------------------------------------
/src/neighbor2neighbor.py:
--------------------------------------------------------------------------------
1 | # Credit goes to https://github.com/TaoHuang2018/Neighbor2Neighbor
2 | # Huang, Tao, et al. "Neighbor2neighbor: Self-supervised denoising from single noisy images." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2021.
3 | from operator import sub
4 | import torch
5 | import numpy as np
6 |
7 | operation_seed_counter = 0
8 | def get_generator():
9 | global operation_seed_counter
10 | operation_seed_counter += 1
11 | g_cuda_generator = torch.Generator(device="cuda")
12 | g_cuda_generator.manual_seed(operation_seed_counter)
13 | return g_cuda_generator
14 |
15 | class AugmentNoise(object):
16 | def __init__(self, style):
17 | if style.startswith('gauss'):
18 | self.params = [float(p) / 255.0 for p in style.replace('gauss', '', 1).split('_')]
19 | if len(self.params) == 1:
20 | self.style = "gauss_fix"
21 | elif len(self.params) == 2:
22 | self.style = "gauss_range"
23 | elif style.startswith('poisson'):
24 | self.params = [float(p) for p in style.replace('poisson', '', 1).split('_')]
25 | if len(self.params) == 1:
26 | self.style = "poisson_fix"
27 | elif len(self.params) == 2:
28 | self.style = "poisson_range"
29 | print(self.params)
30 | elif style.startswith('gain'):
31 | self.params = [float(p) for p in style.replace('gain', '', 1).split('_')]
32 | if len(self.params) == 1:
33 | self.style = "gain_fix"
34 | elif len(self.params) == 2:
35 | self.style = "gain_range"
36 | def add_train_noise(self, x):
37 | shape = x.shape
38 | if self.style == "gauss_fix":
39 | std = self.params[0]
40 | std = std * torch.ones((shape[0], 1, 1, 1), device=x.device)
41 | noise = torch.cuda.FloatTensor(shape, device=x.device)
42 | torch.normal(mean=0.0, std=std, generator=get_generator(), out=noise)
43 | return x + noise
44 | elif self.style == "gauss_range":
45 | min_std, max_std = self.params
46 | std = torch.rand(size=(shape[0], 1, 1, 1), device=x.device) * (max_std - min_std) + min_std
47 | noise = torch.cuda.FloatTensor(shape, device=x.device)
48 | torch.normal(mean=0, std=std, generator=get_generator(), out=noise)
49 | return x + noise
50 | elif self.style == "poisson_fix":
51 | lam = self.params[0]
52 | lam = lam * torch.ones((shape[0], 1, 1, 1), device=x.device)
53 | noised = torch.poisson(lam * x, generator=get_generator()) / lam
54 | noised = torch.clamp(noised, 0.0,1.0)
55 | return noised
56 | elif self.style == "poisson_range":
57 | min_lam, max_lam = self.params
58 | lam = torch.rand(size=(shape[0], 1, 1, 1), device=x.device) * (max_lam - min_lam) + min_lam
59 | noised = torch.poisson(lam * x, generator=get_generator()) / lam
60 | return noised
61 | elif self.style == "gain_fix":
62 | lam = self.params[0]
63 | lam = lam * torch.ones((shape[0], 1, 1, 1), device=x.device)
64 | noised = torch.poisson(x / lam, generator=get_generator()) * lam
65 | return noised
66 | elif self.style == "gain_range":
67 | min_lam, max_lam = self.params
68 | lam = torch.rand(size=(shape[0], 1, 1, 1), device=x.device) * (max_lam - min_lam) + min_lam
69 | noised = torch.poisson(x/lam, generator=get_generator()) * lam
70 | return noised
71 | def add_valid_noise(self, x):
72 | shape = x.shape
73 | if self.style == "gauss_fix":
74 | std = self.params[0]
75 | return np.array(x + np.random.normal(size=shape) * std, dtype=np.float32)
76 | elif self.style == "gauss_range":
77 | min_std, max_std = self.params
78 | std = np.random.uniform(low=min_std, high=max_std, size=(1, 1, 1))
79 | return np.array(x + np.random.normal(size=shape) * std, dtype=np.float32)
80 | elif self.style == "poisson_fix":
81 | lam = self.params[0]
82 | return np.array(np.random.poisson(lam * x) / lam, dtype=np.float32)
83 | elif self.style == "poisson_range":
84 | min_lam, max_lam = self.params
85 | lam = np.random.uniform(low=min_lam, high=max_lam, size=(1, 1, 1))
86 | return np.array(np.random.poisson(lam * x) / lam, dtype=np.float32)
87 |
88 |
89 | def space_to_depth(x, block_size):
90 | n, c, h, w = x.size()
91 | unfolded_x = torch.nn.functional.unfold(x, block_size, stride=block_size)
92 | return unfolded_x.view(n, c * block_size**2, h // block_size,
93 | w // block_size)
94 |
95 | def generate_mask_pair(img):
96 | # prepare masks (N x C x H/2 x W/2)
97 | n, c, h, w = img.shape
98 | mask1 = torch.zeros(size=(n * h // 2 * w // 2 * 4, ),
99 | dtype=torch.bool,
100 | device=img.device)
101 | mask2 = torch.zeros(size=(n * h // 2 * w // 2 * 4, ),
102 | dtype=torch.bool,
103 | device=img.device)
104 | # prepare random mask pairs
105 | idx_pair = torch.tensor(
106 | [[0, 1], [0, 2], [1, 3], [2, 3], [1, 0], [2, 0], [3, 1], [3, 2]],
107 | dtype=torch.int64,
108 | device=img.device)
109 | rd_idx = torch.zeros(size=(n * h // 2 * w // 2, ),
110 | dtype=torch.int64,
111 | device=img.device)
112 | torch.randint(low=0,
113 | high=8,
114 | size=(n * h // 2 * w // 2, ),
115 | generator=get_generator(),
116 | out=rd_idx)
117 | rd_pair_idx = idx_pair[rd_idx]
118 | rd_pair_idx += torch.arange(start=0,
119 | end=n * h // 2 * w // 2 * 4,
120 | step=4,
121 | dtype=torch.int64,
122 | device=img.device).reshape(-1, 1)
123 | # get masks
124 | mask1[rd_pair_idx[:, 0]] = 1
125 | mask2[rd_pair_idx[:, 1]] = 1
126 | return mask1, mask2
127 |
128 | def generate_subimages(img, mask):
129 | n, c, h, w = img.shape
130 | subimage = torch.zeros(n,
131 | c,
132 | h // 2,
133 | w // 2,
134 | dtype=img.dtype,
135 | layout=img.layout,
136 | device=img.device)
137 | # per channel
138 | for i in range(c):
139 | img_per_channel = space_to_depth(img[:, i:i + 1, :, :], block_size=2)
140 | # ic(img_per_channel.shape, subimage.shape)
141 | img_per_channel = img_per_channel.permute(0, 2, 3, 1).reshape(-1)
142 | channel_mask = img_per_channel[mask].reshape(
143 | n, h // 2, w // 2, 1).permute(0, 3, 1, 2)
144 | # ic(channel_mask.shape, subimage.shape)
145 | subimage[:, i:i+1, :, :] = channel_mask
146 | return subimage
--------------------------------------------------------------------------------
/src/run_folder.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import pickle
4 | import glob
5 | import random
6 |
7 | import kornia
8 | from kornia.losses import psnr
9 |
10 | import numpy as np
11 | from numpy.core.fromnumeric import resize, transpose
12 | import torch.nn.functional as F
13 | from torch.utils.data import DataLoader
14 | import torchvision.datasets as datasets
15 | from torchvision.transforms import transforms
16 | import torchvision.transforms.functional as T
17 | from training import train
18 | from fmd_dataloader import FMDDataset
19 |
20 | # from tqdm import tqdm
21 | from tqdm import trange
22 |
23 | import torch
24 | import torch.optim
25 |
26 | from skimage.metrics import peak_signal_noise_ratio
27 | from skimage.util import random_noise
28 | import yaml
29 | import datetime
30 |
31 | from PIL import Image
32 |
33 | # from utils.denoising_utils import *
34 |
35 | torch.manual_seed(123)
36 | np.random.seed(123)
37 |
38 |
39 |
40 |
41 | def load_dataset(cfg_experiment):
42 | # print(cfg_experiment)
43 | path = cfg_experiment['dataset']['dataset_path']
44 | # print(path)
45 |
46 | # if we have gt and raw we need to only load gt
47 | if 'gtandraw' in cfg_experiment['dataset'].keys() and cfg_experiment['dataset']['gtandraw']:
48 | print("path", path)
49 | return DataLoader(FMDDataset(path), batch_size=1, shuffle=False)
50 |
51 | transform_list = []
52 | if cfg_experiment['dataset']['resize']:
53 |
54 | transform_list.append(transforms.Resize((256,256)))
55 | if cfg_experiment['dataset']['greyscale']:
56 | transform_list.append(transforms.Grayscale())
57 |
58 | transform_list.append(transforms.ToTensor())
59 |
60 | transform = transforms.Compose(transform_list)
61 | if 'mnist' in cfg_experiment['dataset'].keys():
62 | dataset = datasets.MNIST("~", train=False, transform=transform)
63 | dataset = torch.utils.data.Subset(dataset, np.random.choice(np.arange(len(dataset)), size=32))
64 | imgs = DataLoader(dataset, batch_size=1, shuffle=True)
65 | else:
66 | dataset = datasets.ImageFolder(os.path.dirname(path),transform)
67 | imgs = DataLoader(dataset, batch_size=1, shuffle=False)
68 |
69 | return imgs
70 |
71 |
72 | #
73 |
74 | class Loader(yaml.SafeLoader):
75 |
76 | def __init__(self, stream):
77 |
78 | self._root = os.path.split(stream.name)[0]
79 |
80 | super(Loader, self).__init__(stream)
81 |
82 | def include(self, node):
83 |
84 | filename = self.construct_scalar(node)
85 |
86 | with open(filename, 'r') as f:
87 | return yaml.load(f, Loader)
88 |
89 |
90 | Loader.add_constructor('!include', Loader.include)
91 |
92 | # def p_noise(img, PEAK):
93 |
94 | # return random_noise(img, 'poisson')
95 |
96 | # def p_noise(img, PEAK):
97 | # # print(img, PEAK)
98 | # Q = np.max(np.max(img)) / PEAK
99 | # rate = img / Q
100 | # noisy = np.random.poisson(rate) * Q
101 |
102 | # # print(noisy)
103 | # # quit()
104 | # return noisy
105 |
106 |
107 | def p_noise(img, PEAK):
108 | img = np.multiply(img, PEAK)
109 | img = np.random.poisson(img)
110 | img = np.divide(img, PEAK)
111 | return img
112 |
113 |
114 | def g_noise(img, sigma):
115 | return random_noise(img, 'gaussian', var=float((sigma/255.0)**2))
116 |
117 |
118 | def gp_noise(img, PEAK, sigma):
119 | return random_noise(p_noise(img, PEAK), 'gaussian', var=sigma/255)
120 |
121 |
122 | def corrupt_dataset(imgs, cfg):
123 | if cfg['noise'] == 'p':
124 | # print()
125 | return [np.clip(p_noise(img, cfg['peak']), 0,1.0) for img in imgs]
126 | elif cfg['noise'] == 'g':
127 | return [np.clip(g_noise(img, cfg['sigma']), 0, 1.0) for img in imgs]
128 | elif cfg['noise'] == 'gp':
129 | return [np.clip(gp_noise(img, cfg['peak'], cfg['sigma']),0,1.0) for img in imgs]
130 | else:
131 | return imgs
132 |
133 | # return imgs
134 |
135 |
136 | def denoise_experiment(cfg):
137 | # load images
138 | output_path = cfg['output_dir']
139 | experiment_cfg = cfg['experiment_cfg']
140 | imgs = load_dataset(experiment_cfg)
141 |
142 | # if experiment_cfg['pipeline'] == 'nb':
143 | train(imgs, cfg)
144 |
145 | # print(noisy_imgs[0])
146 |
147 |
148 | def create_result_dir(cfg, dev=True):
149 | print(cfg)
150 | path = cfg['output_dir']
151 | try:
152 | num_exp = os.listdir(path)
153 | except:
154 | num_exp = []
155 | # print(num_exp)
156 | curr_dir_id = len(num_exp)
157 | output_path = os.path.join(path, "{:04d}".format(curr_dir_id+1))
158 | print(output_path)
159 | try:
160 | # os.mkdir(output_path)
161 | os.makedirs(output_path)
162 | except:
163 | pass
164 |
165 | # create copy of cfg into dir
166 |
167 | with open(os.path.join(output_path, 'cfg.yaml'), 'w') as yaml_file:
168 | yaml.dump(cfg, yaml_file, default_flow_style=False)
169 |
170 | return output_path
171 |
172 |
173 | def experiment(cfg):
174 | start_time = datetime.datetime.now()
175 | print("Begining Experiment {}".format(start_time))
176 | output_dir = create_result_dir(cfg, dev=cfg['dev'])
177 | if output_dir is not None:
178 | cfg['output_dir'] = output_dir
179 |
180 | # run experiment here
181 | denoise_experiment(cfg)
182 |
183 | end_time = datetime.datetime.now()
184 | print("Ending Experiment {}".format(end_time))
185 |
186 |
187 | def main(cfg_path):
188 | # load path
189 | # import yaml
190 | with open(cfg_path, 'r') as f:
191 | cfg = yaml.load(f, Loader=Loader)
192 | experiment(cfg)
193 |
194 |
195 | if __name__ == "__main__":
196 | parser = argparse.ArgumentParser(description='Run Denoising Experiment')
197 | parser.add_argument('--cfg_path', help='Path to experiment config')
198 |
199 | args = parser.parse_args()
200 |
201 | main(args.cfg_path)
202 |
--------------------------------------------------------------------------------
/src/training.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import glob
4 | import random
5 |
6 | import kornia
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | # from kornia.losses import psnr_loss as psnr
11 | from kornia.metrics import psnr
12 | import torchvision
13 |
14 | from torchvision.transforms import functional as TF
15 | from model import build_model
16 | from icecream import ic
17 |
18 |
19 | from tqdm import trange
20 |
21 |
22 | from neighbor2neighbor import AugmentNoise, generate_mask_pair, generate_subimages
23 | # from utils.denoising_utils import *
24 |
25 | def build_loss(cfg):
26 | if cfg['experiment_cfg']['mse']:
27 | loss = torch.nn.MSELoss()
28 | elif cfg['experiment_cfg']['pll']:
29 | loss = torch.nn.PoissonNLLLoss()
30 | return loss
31 |
32 | def train(dloader, config):
33 | # torch.cuda.set_per_process_memory_fraction(.6)
34 | experiment_cfg = config['experiment_cfg']
35 | output_path = config['output_dir']
36 |
37 | if config['experiment_cfg']['cuda']:
38 | dtype = torch.cuda.FloatTensor
39 | else:
40 | dtype = torch.FloatTensor
41 |
42 | # model = build_model(config)
43 | # .type(dtype)
44 |
45 | # init_state_dict = model.state_dict()
46 | # model.type(dtype)
47 | # create noise generator
48 | # noise_adder = AugmentNoise(style=experiment_cfg['noise'])
49 | running_psnr_avg = 0.0
50 | running_ssim_avg = 0.0
51 | # print(config)
52 | print(len(dloader))
53 | for idx, img in enumerate(dloader):
54 | if 'gtandraw' in experiment_cfg['dataset'].keys() and experiment_cfg['dataset']['gtandraw']:
55 | noisy, img = img
56 | print(type(img), len(img))
57 | img = img.type(dtype)
58 | noisy = noisy.type(dtype)
59 | else:
60 | noise_adder = AugmentNoise(style=experiment_cfg['noise'])
61 | img, _ = img
62 | img = img.type(dtype)
63 |
64 |
65 | #noisy image
66 | noisy = noise_adder.add_train_noise(img).type(dtype)
67 |
68 | # with profile(activities=[ProfilerActivity.CUDA],profile_memory=True, record_shapes=True) as prof:
69 | # with record_function("model_inference"):
70 |
71 |
72 | results = train_helper(img, noisy, dtype, config, experiment_cfg)
73 | # ic(results)
74 | denoised, clean_psnr, psnr_list, loss_list, lpips_list, ssims_list = results
75 | # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
76 | # print(prof.key_averages().table(sort_by="self_gpu_memory_usage", row_limit=10))
77 | # print(prof.key_averages().table(sort_by="gpu_memory_usage", row_limit=10))
78 | image_results_path = os.path.join(output_path, str(idx))
79 | try:
80 | print(image_results_path)
81 | os.mkdir(image_results_path)
82 | except:
83 | pass
84 | # write clean and noisy image
85 | torchvision.utils.save_image(img, os.path.join(image_results_path,"{}_gt.png".format(idx)))
86 | torchvision.utils.save_image(denoised, os.path.join(image_results_path,"{}_{:.3f}_out.png".format(idx, clean_psnr)))
87 | torchvision.utils.save_image(noisy, os.path.join(image_results_path,"{}_noisy.png".format(idx)))
88 |
89 | # torch.save(model)
90 | with open(os.path.join(image_results_path, 'metrics_psnr.pkl'), 'wb') as file:
91 | # yaml.dump(result_i['metrics'], yaml_file, default_flow_style=False)
92 | pickle.dump({'loss':loss_list, 'psnr':psnr_list, 'lpips':lpips_list, 'ssims': ssims_list}, file)
93 | running_psnr_avg += clean_psnr
94 |
95 | torch.cuda.empty_cache()
96 | #clean up
97 | # del model
98 | # model = build_model(config)
99 | # model.load_state_dict(init_state_dict)
100 | # model.type(dtype)
101 |
102 | print("#############################\n Final Average PSNR: {} SSIM:{}".format(running_psnr_avg/ len(dloader), running_ssim_avg/ len(dloader)))
103 |
104 |
105 |
106 |
107 |
108 | def train_helper( img, noisy, dtype, config, experiment_cfg):
109 |
110 | return nb2nb_aug_helper( img, noisy, dtype, config, experiment_cfg)
111 |
112 | def nb2nb_aug_helper( img, noisy, dtype, config, experiment_cfg):
113 | model = build_model(config)
114 | model.type(dtype)
115 |
116 | print("Number of params: ", sum(p.numel() for p in model.parameters() if p.requires_grad))
117 | # swa_model = AveragedModel(model)
118 | # loss
119 | # loss_fn_alex = lpips.LPIPS(net='alex')
120 | # mse = torch.nn.MSELoss().type(dtype)
121 | # optimizer
122 | if experiment_cfg['optimizer'] == 'Adam':
123 | LR = experiment_cfg['lr']
124 | optimizer = torch.optim.AdamW(
125 | model.parameters(), lr=LR)
126 |
127 | if 'lr_sched' in experiment_cfg.keys():
128 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=.99)
129 | # swa_scheduler = SWALR(optimizer, swa_lr=0.05)
130 |
131 | psnr_list = []
132 | loss_list = []
133 | lpips_list =[]
134 | ssims_list = []
135 | grad_hist = []
136 | exp_weight = .99
137 |
138 | out_avg = None
139 | # optimize single image
140 |
141 |
142 |
143 |
144 | noisy_in = noisy
145 | # WITH CSCNET did nothing somehow
146 | if 'rotate' in experiment_cfg.keys():
147 | noisy_in = torch.cat((noisy, TF.rotate(noisy, 90), TF.rotate(noisy, 180), TF.rotate(noisy, 270)))
148 |
149 | if 'flip' in experiment_cfg.keys():
150 | noisy_in = torch.cat((noisy_in, TF.hflip(noisy_in), TF.vflip(noisy_in)))
151 | #horizonal and vertical flip
152 | #we need to pad to a square if the image is not already a square
153 | H = None
154 | W = None
155 | if noisy.shape[2] != noisy.shape[3]:
156 | H = noisy.shape[2]
157 | W = noisy.shape[3]
158 | val_size = (max(H, W) + 31) // 32 * 32
159 | noisy_in = TF.pad(noisy, (0, 0, val_size-noisy.shape[3], val_size-noisy.shape[2]), padding_mode='reflect')
160 | # noisy_preaug = noisy_in
161 | t = trange(experiment_cfg['num_iter'])
162 | pll = nn.PoissonNLLLoss(log_input=False, full=True)
163 | last_net = None
164 | grad_hist = []
165 | # model.enable_warmup()
166 | # dict_checkpointC = model.apply_C.weight.data.clone().cpu()
167 | # dict_checkpointA = model.apply_A.weight.data.clone().cpu()
168 | # dict_checkpointB = model.apply_B.weight.data.clone().cpu()
169 | warmup = True
170 | warmup_counter = 0
171 | psrn_noisy_last =0.0
172 | for i in t:
173 |
174 | # if i==1000 or (i > 1000 and warmup and warmup_counter == 50 ):
175 | # dict_checkpointC = model.apply_C.weight.data.clone().cpu()
176 | # dict_checkpointA = model.apply_A.weight.data.clone().cpu()
177 | # dict_checkpointB = model.apply_B.weight.data.clone().cpu()
178 | # warmup_counter = 0
179 | # warmup = False
180 | # # optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr']*.1
181 | # model.disable_warmup()
182 | # print("Disable Warmup: {}".format(i))
183 | # elif i > 1000 and warmup_counter == 200 and not warmup:
184 | # warmup_counter = 0
185 | # warmup = True
186 | # model.enable_warmup()
187 | # print("Enable Warmup: {}".format(i))
188 | # warmup_counter += 1
189 | mask1, mask2 = generate_mask_pair(noisy_in)
190 | #g1(y) #g2(y)
191 | # noisy_sub1 = generate_subimages(noisy_in, mask1)
192 | # noisy_sub2 = generate_subimages(noisy_in, mask2)
193 |
194 |
195 |
196 | # if experiment_cfg['regularizer']:
197 | with torch.no_grad():
198 | # if out_avg is None:
199 | # if config['model_cfg']['model_type'] == 'deepcdl':
200 | # noisy_denoised = model(noisy_in, torch.Tensor([50/255]).unsqueeze(1).unsqueeze(1).unsqueeze(1).type(dtype))
201 | # else:
202 | noisy_denoised = model(noisy_in)
203 | noisy_denoised = torch.clamp(noisy_denoised, 0.0, 1.0)
204 |
205 | # else:
206 | # noisy_denoised - out_avg.clone().type(dtype)
207 |
208 | if 'cutnoise' in experiment_cfg.keys():
209 | noisy_denoised, noisy_in_aug = cutNoise(noisy_denoised.clone(), noisy_in.clone())
210 | else:
211 | noisy_in_aug = noisy_in.clone()
212 | # ic(noisy_in_aug.shape, mask1.shape, noisy_in.shape)
213 | noisy_sub1 = generate_subimages(noisy_in_aug, mask1)
214 | noisy_sub2 = generate_subimages(noisy_in_aug, mask2)
215 |
216 |
217 | #TODO Add noise to sub1?
218 |
219 | # ic(noisy_denoised.shape)
220 | noisy_sub1_denoised = generate_subimages(noisy_denoised, mask1)
221 | noisy_sub2_denoised = generate_subimages(noisy_denoised, mask2)
222 | # ic(config)
223 | # if config['model_cfg']['model_type'] == 'deepcdl':
224 | # noisy_output = model(noisy_sub1, torch.Tensor([50/255]).unsqueeze(1).unsqueeze(1).unsqueeze(1).type(dtype))
225 | # else:
226 | noisy_output = model(noisy_sub1)
227 | # print("MODEL: {}".format(i))
228 | noisy_output = torch.clamp(noisy_output, 0.0, 1.0)
229 | # if H is not None:
230 | # noisy_output = noisy_output[:,:, :H, :W]
231 | noisy_target = noisy_sub2
232 |
233 | # ic(noisy_output.shape)
234 | Lambda = experiment_cfg['LAM']
235 | # Lambda = i /experiment_cfg['num_iter'] * experiment_cfg['LAM']
236 | # ic(noisy_output.shape, noisy_target.shape)
237 | diff = noisy_output - noisy_target
238 | exp_diff = noisy_sub1_denoised - noisy_sub2_denoised
239 | # if cfg['experiment_cfg']['loss'] == 'poisson':
240 | # total_loss = pll(out, img_noisy_torch)
241 | if "l1" in experiment_cfg.keys():
242 | l1_regularization = 0.
243 | for param in model.parameters():
244 | l1_regularization += param.abs().sum()
245 | total_loss = (experiment_cfg['l1'] * l1_regularization)
246 | # else:
247 | if 'poisson_loss' in experiment_cfg.keys():
248 | loss1 = pll(noisy_output, noisy_target)
249 | loss2 = F.l1_loss(noisy_output, noisy_target)
250 | loss1 += loss2
251 | # gamma = .5
252 | # loss1 = (gamma *loss1) + ((1-gamma) * loss2)
253 | elif 'poisson_loss_only' in experiment_cfg.keys():
254 | loss1 = pll(noisy_output, noisy_target)
255 | # loss2 = F.l1_loss(noisy_output, noisy_target)
256 | # loss1 += loss2
257 | # gamma = .5
258 | # loss1 = (gamma *loss1) + ((1-gamma) * loss2)
259 | elif 'l1_loss' in experiment_cfg.keys():
260 | # loss1 = pll(noisy_output, noisy_target)
261 | loss1 = F.l1_loss(noisy_output, noisy_target)
262 | # loss1 += loss2
263 | elif 'sure_loss' in experiment_cfg.keys():
264 | n = torch.randn((noisy_output.shape), requires_grad=True).type(dtype)
265 | div = (n@noisy_output).sum()
266 | div = torch.autograd.grad(div, noisy_output, retain_graph=True)[0]
267 |
268 | loss1 = F.l1_loss(noisy_output, noisy_target) + (n @ div).mean()
269 |
270 | elif 'wsure' in experiment_cfg.keys():
271 | # fidelity_loss = F.l1_loss(noisy_output, noisy_target)
272 | fidelity_loss = torch.mean(diff**2)
273 | epsilon = 1e-3
274 | eta = noisy_sub1.clone().normal_()
275 | net_input_perturbed = noisy_sub1.clone() + (eta * epsilon)
276 | out_perturbed = model (net_input_perturbed)
277 | dx = out_perturbed - noisy_output
278 | eta_dx = torch.sum(eta * dx)
279 | MCdiv = eta_dx / epsilon
280 | div_term = 2. * (50/255) ** 2 * MCdiv / torch.numel(noisy_sub1)
281 | loss1 = fidelity_loss - (50/255) **2 + div_term
282 | elif "mse" in experiment_cfg.keys():
283 | loss1 = torch.mean(diff**2)
284 | else:
285 | loss1 = F.l1_loss(noisy_output, noisy_target)
286 | # orch.mean(diff**2)
287 | # loss1 = F.poisson_nll_loss(noisy_output, noisy_target, log_input=False)
288 | # loss1 = torch.nn.functional.l1_loss(noisy_output, noisy_target)
289 | loss2 = Lambda * torch.mean((diff - exp_diff)**2)
290 |
291 | loss = loss1 + loss2
292 | if "l1" in experiment_cfg.keys():
293 | loss += total_loss
294 | loss.backward()
295 |
296 |
297 |
298 | with torch.no_grad():
299 | # if i > 500:
300 | # out_full = swa_model(noisy).detach().cpu()
301 | # else:
302 | out_full = model(noisy).detach().cpu()
303 | if H is not None:
304 | out_full = out_full[:,:, :H, :W]
305 | if out_avg is None:
306 | out_avg = out_full.detach().cpu()
307 | else:
308 | out_avg = out_avg * exp_weight + out_full * (1 - exp_weight)
309 | out_avg = out_avg.detach().cpu()
310 | clean_psnr = psnr(out_full, img.detach().cpu(), max_val=1.0).item()
311 | noisy_psnr = psnr(out_full, noisy.detach().cpu(), max_val=1.0).item()
312 | clean_psnr_avg = psnr(out_avg, img.detach().cpu(), max_val=1.0).item()
313 |
314 |
315 | if (i+1) % 50:
316 | if noisy_psnr - psrn_noisy_last < -4 and last_net is not None:
317 | print('Falling back to previous checkpoint.')
318 |
319 | for new_param, net_param in zip(last_net, model.parameters()):
320 | net_param.data.copy_(new_param.cuda())
321 |
322 | total_loss = total_loss*0
323 | optimizer.zero_grad()
324 | torch.cuda.empty_cache()
325 | continue
326 | else:
327 | last_net = [x.detach().cpu() for x in model.parameters()]
328 | psrn_noisy_last = noisy_psnr
329 |
330 |
331 |
332 |
333 | optimizer.step()
334 | if 'param_noise_sigma' in experiment_cfg.keys():
335 | add_noise(model, experiment_cfg
336 | ['param_noise_sigma'], learning_rate=LR, dtype=dtype)
337 | # new_model_params = [p.grad.data.clone().detach().cpu() for p in model.parameters()]
338 | # if i > 500:
339 | # swa_model.update_parameters(model)
340 | # swa_scheduler.step()
341 | optimizer.zero_grad()
342 |
343 | if 'lr_sched' in experiment_cfg.keys():
344 | scheduler.step()
345 |
346 |
347 | with torch.no_grad():
348 | # if i > 500:
349 | # out_full = swa_model(noisy).detach().cpu()
350 | # else:
351 | out_full = model(noisy).detach().cpu()
352 | if H is not None:
353 | out_full = out_full[:,:, :H, :W]
354 | if out_avg is None:
355 | out_avg = out_full.detach().cpu()
356 | else:
357 | out_avg = out_avg * exp_weight + out_full * (1 - exp_weight)
358 | out_avg = out_avg.detach().cpu()
359 | clean_psnr = psnr(out_full, img.detach().cpu(), max_val=1.0).item()
360 | clean_psnr_avg = psnr(out_avg, img.detach().cpu(), max_val=1.0).item()
361 | from skimage.metrics import structural_similarity as ssim
362 | # print(out_avg.shape, img.shape)
363 | clean_ssim = ssim(out_avg.detach().cpu().numpy().squeeze(0).squeeze(0), img.detach().cpu().numpy().squeeze(0).squeeze(0))
364 | ssims_list.append(clean_ssim)
365 | # with torch.no_grad():
366 | # lpips_score = loss_fn_alex(out_avg.detach().cpu(), img.cpu()).item()
367 | t.set_description("PSNR:{:.5f} db | AVG:{:.5f} | | Loss: {:.5f} | SSIM: {:.5f}".format(clean_psnr, clean_psnr_avg, loss.item(), clean_ssim))
368 | psnr_list.append(clean_psnr)
369 | loss_list.append(loss.item())
370 | # lpips_list.append(lpips_score)
371 | # scheduler.step(loss)
372 |
373 |
374 | lpips_list = [0.0]
375 | clean_psnr = psnr(out_avg, img.detach().cpu(), max_val=1.0)
376 | # torch.save(model, '/home/cegrad/calta/sparse-dip/testmodel.pth')
377 | return out_avg, clean_psnr.item(), psnr_list, loss_list, lpips_list, ssims_list
378 |
379 |
380 |
381 | def add_noise(model, param_noise_sigma, learning_rate, dtype):
382 | for n in [x for x in model.parameters() if len(x.size()) == 4]:
383 | noise = torch.randn(n.size())*param_noise_sigma*learning_rate
384 | noise = noise.type(dtype)
385 | n.data = n.data + noise
386 |
387 |
--------------------------------------------------------------------------------