├── .gitattributes
├── .github
└── workflows
│ └── pypi-publish.yml
├── .gitignore
├── .readthedocs.yaml
├── README.md
├── docs
├── Makefile
├── README.md
├── make.bat
├── requirements.txt
└── source
│ ├── _static
│ ├── css
│ │ └── custom.css
│ ├── favicon.ico
│ ├── image
│ │ ├── example_deconv.png
│ │ ├── optic_results.png
│ │ └── psf.png
│ ├── logo.svg
│ └── pipeline_dprox.gif
│ ├── api
│ ├── algo.md
│ ├── index.md
│ ├── linalg.md
│ ├── linop.md
│ ├── primitive.md
│ ├── proxfn.md
│ └── utils.md
│ ├── citation.md
│ ├── conf.py
│ ├── index.md
│ ├── started
│ ├── index.md
│ ├── install.md
│ └── quicktour.nblink
│ └── tutorials
│ ├── computational_optics.nblink
│ ├── csmri.nblink
│ ├── deraining.nblink
│ ├── differentiable_linear_solver.nblink
│ ├── energy_system.nblink
│ ├── image_restoration.nblink
│ ├── index.md
│ ├── learn_the_basic.nblink
│ ├── linear_operator.nblink
│ └── training.nblink
├── dprox
├── __init__.py
├── algo
│ ├── __init__.py
│ ├── admm.py
│ ├── base.py
│ ├── hqs.py
│ ├── invert.py
│ ├── lp
│ │ ├── __init__.py
│ │ ├── solvers.py
│ │ └── utils.py
│ ├── opt
│ │ ├── __init__.py
│ │ ├── absorb.py
│ │ ├── equil.py
│ │ └── merge.py
│ ├── pc.py
│ ├── pgd.py
│ ├── primitives.py
│ ├── problem.py
│ ├── specialization
│ │ ├── __init__.py
│ │ ├── deq
│ │ │ ├── __init__.py
│ │ │ ├── solver.py
│ │ │ ├── training.py
│ │ │ └── utils
│ │ │ │ ├── __init__.py
│ │ │ │ ├── jacobian.py
│ │ │ │ ├── layer_utils.py
│ │ │ │ ├── optimizations.py
│ │ │ │ ├── radam.py
│ │ │ │ └── solvers.py
│ │ ├── rl
│ │ │ ├── __init__.py
│ │ │ └── solver.py
│ │ └── unroll.py
│ └── tune
│ │ ├── __init__.py
│ │ ├── dpir.py
│ │ └── learnable.py
├── contrib
│ ├── __init__.py
│ ├── csmri.py
│ ├── derain.py
│ ├── energy_system.py
│ ├── optic
│ │ ├── __init__.py
│ │ ├── common.py
│ │ ├── doe_model.py
│ │ ├── doe_model_hybrid.py
│ │ ├── unet.py
│ │ └── utils.py
│ └── restoration.py
├── linalg
│ ├── __init__.py
│ ├── custom.py
│ └── solve
│ │ ├── __init__.py
│ │ ├── solver_cg.py
│ │ ├── solver_minres.py
│ │ └── solver_plss.py
├── linop
│ ├── __init__.py
│ ├── base.py
│ ├── blackbox.py
│ ├── comp_graph.py
│ ├── constaints.py
│ ├── constant.py
│ ├── conv.py
│ ├── edge.py
│ ├── grad.py
│ ├── mul.py
│ ├── placeholder.py
│ ├── scale.py
│ ├── subsample.py
│ ├── sum.py
│ ├── variable.py
│ └── vstack.py
├── proxfn
│ ├── __init__.py
│ ├── base.py
│ ├── fast
│ │ ├── __init__.py
│ │ ├── cs.py
│ │ ├── csmri.py
│ │ ├── pr.py
│ │ ├── spi.py
│ │ └── sr.py
│ ├── nlm
│ │ ├── __init__.py
│ │ ├── nlm.py
│ │ └── patch_nlm.py
│ ├── nonneg.py
│ ├── norm.py
│ ├── pnp
│ │ ├── __init__.py
│ │ ├── denoisers
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── composite.py
│ │ │ ├── models
│ │ │ │ ├── TV_denoising.py
│ │ │ │ ├── __init__.py
│ │ │ │ ├── basicblock.py
│ │ │ │ ├── network_dncnn.py
│ │ │ │ ├── network_ffdnet.py
│ │ │ │ ├── network_unet.py
│ │ │ │ ├── qrnn
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── conv.py
│ │ │ │ │ ├── grunet.py
│ │ │ │ │ ├── layer.py
│ │ │ │ │ ├── qrnn3d.py
│ │ │ │ │ └── sync_batchnorm
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── batchnorm.py
│ │ │ │ │ │ ├── comm.py
│ │ │ │ │ │ ├── replicate.py
│ │ │ │ │ │ └── unittest.py
│ │ │ │ └── unet
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── basicblock.py
│ │ │ │ │ └── unet.py
│ │ │ └── wrapper.py
│ │ └── prior.py
│ ├── sum_square.py
│ └── unrolling
│ │ ├── __init__.py
│ │ ├── dgu.py
│ │ └── prior.py
└── utils
│ ├── __init__.py
│ ├── containar.py
│ ├── huggingface.py
│ ├── init
│ ├── __init__.py
│ ├── mosaic.py
│ └── sr.py
│ ├── io.py
│ ├── metrics.py
│ ├── misc.py
│ └── psf2otf.py
├── examples
├── README.md
├── applications
│ ├── csmri.py
│ ├── deconv.py
│ ├── demosaic.py
│ ├── joint_demosaic_deconv.py
│ └── super_resolution.py
└── papers
│ ├── README.md
│ ├── deltaprox_siggraph_2023
│ ├── computional_optics
│ │ ├── README.md
│ │ ├── e2e_optics_dprox.py
│ │ ├── e2e_optics_dprox_joint.py
│ │ ├── e2e_optics_unet.py
│ │ └── pnp_optics.py
│ ├── csmri
│ │ ├── README.md
│ │ ├── deq_tfpnp.py
│ │ ├── deq_unet.py
│ │ ├── download_datasets.py
│ │ ├── pnp_drunet.py
│ │ ├── pnp_unet.py
│ │ ├── rl_unet.py
│ │ ├── rl_unet_train.py
│ │ └── unroll_unet.py
│ └── deraining
│ │ ├── README.md
│ │ ├── derain
│ │ ├── __init__.py
│ │ ├── data.py
│ │ ├── dataset.py
│ │ ├── restormer.py
│ │ ├── unroll.py
│ │ └── unroll_share.py
│ │ ├── evaluate_PSNR_SSIM.m
│ │ ├── evaluate_PSNR_SSIM.py
│ │ ├── test_unroll.py
│ │ └── test_unroll_share.py
│ ├── dgunet_cvpr_2021
│ ├── main.py
│ ├── network.py
│ └── ops.py
│ ├── dphsir_neurcomputing_2022
│ ├── degrades
│ │ ├── __init__.py
│ │ ├── blur.py
│ │ ├── cs.py
│ │ ├── general.py
│ │ ├── inpaint.py
│ │ ├── noise.py
│ │ ├── sr.py
│ │ └── utils.py
│ ├── hsi_compress_sensing.py
│ ├── hsi_deblur.py
│ ├── hsi_inpainting.py
│ ├── hsi_misr.py
│ └── hsi_sisr.py
│ ├── dpir_tpami_2020
│ ├── rgb_demosaic.py
│ └── utils.py
│ └── tfpnp_icml_2020
│ └── csmri.py
├── notebooks
├── README.md
├── computational_optics.ipynb
├── csmri.ipynb
├── deraining.ipynb
├── differentiable_linear_solver.ipynb
├── energy_system_planning.ipynb
├── image_restoration.ipynb
├── learn_the_basic.ipynb
├── linear_operator.ipynb
├── primitive.ipynb
├── quickstart.ipynb
└── training.ipynb
├── requirements.txt
├── setup.py
└── tests
├── README.md
├── linalg
├── test_linear_solver.py
├── test_linear_solver_batch.py
├── test_linear_solver_grad.py
├── test_linear_solver_torch.py
└── test_pcg.py
├── paper
├── test_csmri.py
├── test_derain.py
├── test_energy.py
└── test_optics.py
├── problem
├── test_deraining.py
├── test_energy_system.py
├── test_inverse_problems.py
├── test_jd23.py
└── test_ml_problems.py
├── test_algorithms.py
├── test_grad.py
├── test_linop.py
├── test_linop_primitive.py
└── test_primitive.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | notebooks/** linguist-vendored
--------------------------------------------------------------------------------
/.github/workflows/pypi-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflows will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Build and publish
26 | env:
27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29 | run: |
30 | python setup.py sdist bdist_wheel
31 | twine upload dist/*
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.txt
2 | *.pth
3 | *.mat
4 | simple_cep_model_20220916
5 | data
6 | abc
7 | datasets
8 | upload.sh
9 | examples/applications/computional_optics/optics
10 | *.png
11 | # *.mat
12 | # Byte-compiled / optimized / DLL files
13 | __pycache__/
14 | *.py[cod]
15 | *$py.class
16 | # *.png
17 | # *.jpg
18 | *.tif
19 | log
20 | ckpt
21 | # C extensions
22 | *.so
23 |
24 | # Distribution / packaging
25 | .Python
26 | build/
27 | develop-eggs/
28 | dist/
29 | downloads/
30 | eggs/
31 | .eggs/
32 | lib/
33 | lib64/
34 | parts/
35 | sdist/
36 | var/
37 | wheels/
38 | pip-wheel-metadata/
39 | share/python-wheels/
40 | *.egg-info/
41 | .installed.cfg
42 | *.egg
43 | MANIFEST
44 |
45 | # PyInstaller
46 | # Usually these files are written by a python script from a template
47 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
48 | *.manifest
49 | *.spec
50 |
51 | # Installer logs
52 | pip-log.txt
53 | pip-delete-this-directory.txt
54 |
55 | # Unit test / coverage reports
56 | htmlcov/
57 | .tox/
58 | .nox/
59 | .coverage
60 | .coverage.*
61 | .cache
62 | nosetests.xml
63 | coverage.xml
64 | *.cover
65 | *.py,cover
66 | .hypothesis/
67 | .pytest_cache/
68 |
69 | # Translations
70 | *.mo
71 | *.pot
72 |
73 | # Django stuff:
74 | *.log
75 | local_settings.py
76 | db.sqlite3
77 | db.sqlite3-journal
78 |
79 | # Flask stuff:
80 | instance/
81 | .webassets-cache
82 |
83 | # Scrapy stuff:
84 | .scrapy
85 |
86 | # Sphinx documentation
87 | docs/_build/
88 |
89 | # PyBuilder
90 | target/
91 |
92 | # Jupyter Notebook
93 | .ipynb_checkpoints
94 |
95 | # IPython
96 | profile_default/
97 | ipython_config.py
98 |
99 | # pyenv
100 | .python-version
101 |
102 | # pipenv
103 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
104 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
105 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
106 | # install all needed dependencies.
107 | #Pipfile.lock
108 |
109 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
110 | __pypackages__/
111 |
112 | # Celery stuff
113 | celerybeat-schedule
114 | celerybeat.pid
115 |
116 | # SageMath parsed files
117 | *.sage.py
118 |
119 | # Environments
120 | .env
121 | .venv
122 | env/
123 | venv/
124 | ENV/
125 | env.bak/
126 | venv.bak/
127 |
128 | # Spyder project settings
129 | .spyderproject
130 | .spyproject
131 |
132 | # Rope project settings
133 | .ropeproject
134 |
135 | # mkdocs documentation
136 | /site
137 |
138 | # mypy
139 | .mypy_cache/
140 | .dmypy.json
141 | dmypy.json
142 |
143 | # Pyre type checker
144 | .pyre/
145 | .vscode
146 | .rsyncignore
147 | .ssh
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yaml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | # Set the version of Python and other tools you might need
9 | build:
10 | os: ubuntu-20.04
11 | tools:
12 | python: "3.9"
13 | # You can also specify other tool versions:
14 | # nodejs: "16"
15 | # rust: "1.55"
16 | # golang: "1.17"
17 |
18 | # Build documentation in the docs/ directory with Sphinx
19 | sphinx:
20 | configuration: docs/source/conf.py
21 |
22 | # If using Sphinx, optionally build your docs in additional formats such as PDF
23 | # formats:
24 | # - pdf
25 |
26 | # Optionally declare the Python requirements required to build your docs
27 | python:
28 | install:
29 | - requirements: docs/requirements.txt
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # Documentation
2 |
3 | Source of the documentation.
4 |
5 | - Build html.
6 |
7 | ```bash
8 | make html
9 | ```
10 |
11 | - Preview at `build/html/index.html`
12 |
13 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.https://www.sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | myst-parser
2 | sphinx-rtd-theme
3 | furo
4 | sphinx-copybutton
5 | sphinx-inline-tabs
6 | nbsphinx
7 | nbsphinx_link
8 | linkify-it-py
9 | linkify
10 | ipython
11 |
12 | torch
13 | imageio
14 | scikit_image
15 | matplotlib
16 | munch
17 | tfpnp
18 | cvxpy
19 | torchlights
20 | tensorboardX
21 | termcolor
22 | proximal
23 | opencv-python
24 | huggingface_hub
25 | torchvision
--------------------------------------------------------------------------------
/docs/source/_static/css/custom.css:
--------------------------------------------------------------------------------
1 |
2 | .sidebar-logo {
3 | display: block;
4 | margin: 0;
5 | max-width: 50%;
6 | }
7 |
8 | .nbsphinx-gallery {
9 | display: grid;
10 | grid-template-columns: repeat(auto-fill, minmax(200px, 1fr));
11 | gap: 5px;
12 | margin-top: 1em;
13 | margin-bottom: 1em;
14 | }
15 |
16 | h1 {
17 | font-size: 2em
18 | }
19 |
20 | h2 {
21 | font-size: 1.5em
22 | }
23 |
24 | h3 {
25 | font-size: 1.25em
26 | }
27 |
28 | h4 {
29 | font-size: 1.125em
30 | }
31 |
32 | h5 {
33 | font-size: 1.07em
34 | }
35 |
36 | h6 {
37 | font-size: 1em
38 | }
--------------------------------------------------------------------------------
/docs/source/_static/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/Delta-Prox/15fc489e99c4de2676dc538d9c0552338f6fb894/docs/source/_static/favicon.ico
--------------------------------------------------------------------------------
/docs/source/_static/image/example_deconv.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/Delta-Prox/15fc489e99c4de2676dc538d9c0552338f6fb894/docs/source/_static/image/example_deconv.png
--------------------------------------------------------------------------------
/docs/source/_static/image/optic_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/Delta-Prox/15fc489e99c4de2676dc538d9c0552338f6fb894/docs/source/_static/image/optic_results.png
--------------------------------------------------------------------------------
/docs/source/_static/image/psf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/Delta-Prox/15fc489e99c4de2676dc538d9c0552338f6fb894/docs/source/_static/image/psf.png
--------------------------------------------------------------------------------
/docs/source/_static/logo.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/docs/source/_static/pipeline_dprox.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/Delta-Prox/15fc489e99c4de2676dc538d9c0552338f6fb894/docs/source/_static/pipeline_dprox.gif
--------------------------------------------------------------------------------
/docs/source/api/algo.md:
--------------------------------------------------------------------------------
1 | # Proximal Algorithms
2 |
3 | Extending ∇-Prox for more proximal algorithms is straightforward. We define a new algorithm class that inherits from the base class `Algorithm`. The required methods to be implemented are `partition` and `_iter`, representing the problem partition and a single algorithm iteration.
4 |
5 | The `partition` takes a list of proxable functions and returns their splits as a list of `psi_fn` and `omega_fn`. For `_iter`, it is a single iteration of the proximal algorithm that takes an input of the state and two parameters, `rho` for the penalty strength on multipliers and `lam` for proximal operators.
6 |
7 | The state is generally a list of variables, including the auxiliary ones that an algorithm creates. ∇-Prox provides a state by returning the output of previous executions of `_iter` or the initial state provided by the `initialize` method.
8 |
9 | ```python
10 | class new_algorithm(Algorithm):
11 | def partition(cls, prox_fns: List[ProxFn]):
12 | # Perform problem partition according to the algorithm's need.
13 |
14 | def __init__(...):
15 | # Custom initialization code.
16 |
17 | def _iter(self, state, rho, lam):
18 | # Code to compute the function's proximal operator.
19 | return ...
20 |
21 | def initialize(self, x0):
22 | # Return the initial state.
23 | return ...
24 |
25 | def nparams(self):
26 | # (Optional) Return the number of hyperparameters of
27 | # this algorithm.
28 | return ...
29 |
30 | def state_split(self):
31 | # (Optional) Return the split size of the packed state.
32 | # Useful for deep equilibrium/reinforcement learning.
33 | return ...
34 | ```
35 |
36 | Implementing `partition`, `initialize`, and `_iter` is generally sufficient to evaluate the proximal algorithm for a given problem.
37 |
38 | To integrate the new algorithm with deep equilibrium learning (DEQ) and deep reinforcement learning (RL), users have to implement two additional helper methods, i.e., `params` for counting the number of hyperparameters and `state_split` for the structures of the state that is returned by `_iter`.
39 |
40 | For example, assuming `_iter` returns the state as nested arrays such as `[x,[v1,v2],[u1,u2]]`, the output of `state_split` should be `[1,[2],[2]]`. ∇-Prox exploits these properties to perform the necessary packing and unpacking for the iteration states to achieve a unified interface for the internal DEQ and RL implementations.
41 |
--------------------------------------------------------------------------------
/docs/source/api/index.md:
--------------------------------------------------------------------------------
1 | # API Documentation
2 |
3 | ```{toctree}
4 | :hidden:
5 |
6 | primitive
7 | algo
8 | linop
9 | proxfn
10 | linalg
11 | ```
12 |
13 | ∇-Prox is a domain-specific language (DSL) and compiler that transforms optimization problems into differentiable proximal solvers.
14 |
15 | ## Index
16 |
17 | - [Proximal Algorithms](algo.md)
18 | - [Proximal Functions](proxfn.md)
19 | - [Linear Operators](linop.md)
20 | - [Linear System Solver](linalg.md)
21 |
22 |
23 |
--------------------------------------------------------------------------------
/docs/source/api/linalg.md:
--------------------------------------------------------------------------------
1 | # Linear System Solver
2 |
3 |
4 | ## Unified Interface
5 |
6 | ```{eval-rst}
7 | .. autofunction:: dprox.linalg.linear_solve
8 | .. autoclass:: dprox.linalg.LinearSolveConfig
9 | ```
10 |
11 | ## Linear Solvers
12 |
13 | ```{eval-rst}
14 | .. autofunction:: dprox.linalg.solve.solver_cg.cg
15 | .. autofunction:: dprox.linalg.solve.solver_cg.pcg
16 | .. autofunction:: dprox.linalg.solve.solver_minres.minres
17 | .. autofunction:: dprox.linalg.solve.solver_plss.plss
18 | ```
--------------------------------------------------------------------------------
/docs/source/api/linop.md:
--------------------------------------------------------------------------------
1 | # Linear Operator
2 |
3 | Defining new linear operators mainly involves the definition of the forward and adjoint routines. The following code shows the template for defining them. Similar to a [proxable function](), the operator is defined as a class inheriting from the base class `LinOp`.
4 |
5 | ```python
6 | class new_linop(LinOp):
7 | def __init__(...):
8 | """ Custom initialization code.
9 | """
10 |
11 | def forward(self, inputs):
12 | """The forward operator. Compute x -> Kx
13 | """
14 |
15 | def adjoint(self, inputs):
16 | """The adjoint operator. Compute x -> K^Tx
17 | """
18 |
19 | def is_diag(self, freq):
20 | """(Optional) Check if the linear operator is diagonalizable or
21 | diagonalizable in the frequency domain.
22 | """
23 |
24 | def get_diag(self, x, freq):
25 | """(Optional) Return the diagonal/frequency diagonal matrix that
26 | matches the shape of input x.
27 | """
28 | ```
29 |
30 | By default, the linear operator is not diagonal. To introduce a diagonal linear operator, one must implement the `is_diag` and `get_diag` for checking the diagonalizability and acquiring the diagonal matrix. These methods allow ∇-Prox to construct more efficient solvers, e.g., ADMM with a closed-form matrix inverse for the least-square update.
31 |
32 | ## Sanity Check
33 |
34 | Typically, it is not always easy to correctly implement the forward and adjoint operations of the linear operator. To facilitate the testing of these operators, ∇-Prox provides an implementation of the **dot-product test** for verifying that the `forward` and `adjoint` are adjoint to each other.
35 |
36 | Basically, the idea of the dot-product test comes from the associative property of linear algebra, which gives the following equation:
37 |
38 | $$
39 | y^T(Ax) = (A^Ty)^Tx
40 | $$
41 |
42 | Here, $x$ and $y$ are randomly generated data, and $A$ and $A^T$ denote the forward and adjoint of the linear operator. ∇-Prox uses of this property and generates a large number of random data arguments to check if this equation always holds for a given precision.
43 |
44 | To use this utility, users can call the `validate(linop, tol=1e-6)` and specify the tolerance of the difference between two sides of the equation.
45 |
46 | ```python
47 | import dprox as dp
48 | from dprox.utils.examples import fspecial_gaussian
49 |
50 | x = dp.Variable()
51 | psf = fspecial_gaussian(15, 5)
52 | op = dp.conv(x, psf)
53 | assert dp.validate(op)
54 | ```
55 |
56 |
57 | ## Indices
58 |
59 | ```{eval-rst}
60 | .. autoclass:: dprox.linop.base.LinOp
61 | :members: forward, adoint, is_gram_diag, is_diag, get_diag, variables, constants, is_constant, value, offset, norm_bound, T, gram, clone
62 | .. autoclass:: dprox.linop.conv.conv
63 | :members: forward, adjoint
64 | .. autoclass:: dprox.linop.conv.conv_doe
65 | :members: forward, adjoint
66 | .. autoclass:: dprox.linop.sum.sum
67 | :members: forward, adjoint
68 | .. autoclass:: dprox.linop.sum.copy
69 | :members: forward, adjoint
70 | .. autoclass:: dprox.linop.vstack.vstack
71 | :members: forward, adjoint
72 | .. autoclass:: dprox.linop.vstack.split
73 | :members: forward, adjoint
74 | .. autoclass:: dprox.linop.variable.Variable
75 | :members: forward, adjoint
76 | .. autoclass:: dprox.linop.subsample.mosaic
77 | :members: forward, adjoint
78 | .. autoclass:: dprox.linop.scale.scale
79 | :members: forward, adjoint
80 | .. autoclass:: dprox.linop.placeholder.Placeholder
81 | :members: forward, adjoint
82 | .. autoclass:: dprox.linop.grad.grad
83 | :members: forward, adjoint
84 | ```
--------------------------------------------------------------------------------
/docs/source/api/primitive.md:
--------------------------------------------------------------------------------
1 | # Primitives
2 |
3 | ```{eval-rst}
4 | .. autofunction:: dprox.compile
5 | .. autofunction:: dprox.specialize
6 | ```
--------------------------------------------------------------------------------
/docs/source/api/proxfn.md:
--------------------------------------------------------------------------------
1 | # Proximal Functions
2 |
3 | The following code shows a template for defining a new proxable function. As previously mentioned, we define the function as a class inheriting from the base class `ProxFn`, and implement all the required methods. Then, ∇-Prox will properly handle all other necessary steps so that the new proxable function can work with operators, algorithms, and training utilities of the existing system.
4 |
5 | ```python
6 | class new_func(ProxFn):
7 | def __init__(...):
8 | """ Custom initialization code.
9 | """
10 |
11 | def _prox(self, tau, v):
12 | """ Code to compute the function's proximal operator.
13 | """
14 | return ...
15 |
16 | def _eval(self, v):
17 | """ (Optional) Code to evaluate the function.
18 | """
19 | return ...
20 |
21 | def _grad(self, v):
22 | """ (Optional) Code to compute the analytic gradient.
23 | """
24 | return ...
25 | ```
26 |
27 | More specifically, defining a new function only requires a method `_prox` to be implemented, which evaluates the proximal operator of the given function.
28 |
29 | Users can optionally implement the `_grad` function to provide a routine for computing the analytic gradient of the proxable function. This facilitates the algorithms that partially rely on the gradient evaluation, e.g., proximal gradient descent.
30 |
31 | Alternatively, users can also implement the `_eval` method that computes the forwarding results of the proxable function if it is possible. ∇-Prox takes the `_eval` routine and computes the gradient with auto-diff if `_grad` is not implemented.
32 |
33 |
34 | ```{eval-rst}
35 | .. autoclass:: dprox.proxfn.base.ProxFn
36 | ```
--------------------------------------------------------------------------------
/docs/source/api/utils.md:
--------------------------------------------------------------------------------
1 | # Utility
2 |
3 | ## Metrics
4 |
5 |
6 |
7 |
8 | ## IO
9 |
10 | ```{eval-rst}
11 | .. autofunction:: dprox.utils.io.imread_rgb
12 | .. autofunction:: dprox.utils.io.imshow
13 | .. autofunction:: dprox.utils.io.imread
14 | ```
--------------------------------------------------------------------------------
/docs/source/citation.md:
--------------------------------------------------------------------------------
1 | # Citation
2 |
3 | The following publications discuss the ideas behind ∇-Prox:
4 |
5 | > **∇-Prox: Differentiable Proximal Algorithm Modeling for Large-Scale Optimization**
6 | > Zeqiang Lai, Kaixuan Wei, Ying Fu, Philipp Härtel, and Felix Heide.
7 | > ACM Transactions on Graphics, SIGGRAPH, 2023.
8 |
9 | > **ProxImaL: Efficient Image Optimization Using Proximal Algorithms**
10 | > F. Heide, S. Diamond, M. Niessner, J. Ragan-Kelley, W. Heidrich, and G. Wetzstein.
11 | > ACM Transactions on Graphics, SIGGRAPH, 2016.
12 |
13 | ```bibtex
14 | @article{deltaprox2023,
15 | title = {∇-Prox: Differentiable Proximal Algorithm Modeling for Large-Scale Optimization},
16 | author = {Lai, Zeqiang and Wei, Kaixuan and Fu, Ying and H\"{a}rtel, Philipp and Heide, Felix},
17 | journal={ACM Transactions on Graphics (TOG)},
18 | volume = {42},
19 | number = {4},
20 | articleno = {105},
21 | pages = {1--19},
22 | year={2023},
23 | publisher = {Association for Computing Machinery},
24 | address = {New York, NY, USA},
25 | doi = {10.1145/3592144},
26 | }
27 |
28 | @article{heide2016proximal,
29 | title={Proximal: Efficient image optimization using proximal algorithms},
30 | author={Heide, Felix and Diamond, Steven and Nie{\ss}ner, Matthias and Ragan-Kelley, Jonathan and Heidrich, Wolfgang and Wetzstein, Gordon},
31 | journal={ACM Transactions on Graphics (TOG)},
32 | volume={35},
33 | number={4},
34 | pages={1--15},
35 | year={2016},
36 | publisher = {Association for Computing Machinery},
37 | address = {New York, NY, USA},
38 | doi={10.1145/2897824.2925875},
39 | }
40 | ```
41 |
--------------------------------------------------------------------------------
/docs/source/index.md:
--------------------------------------------------------------------------------
1 | ---
2 | hide-toc: true
3 | ---
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 | ```{toctree}
15 | :maxdepth: 3
16 | :hidden:
17 |
18 | Get Started
19 | tutorials/index
20 | api/index
21 | citation
22 | ```
23 |
24 |
25 | ```{toctree}
26 | :caption: Useful Links
27 | :hidden:
28 | PyPI Page
29 | GitHub Repository
30 | Project Page
31 | Paper
32 | ```
33 |
34 |
35 |
36 | 🎉 ∇-Prox is a domain-specific language (DSL) and compiler that transforms optimization problems into differentiable proximal solvers.
37 |
38 | 🎉 ∇-Prox allows for rapid prototyping of learning-based bi-level optimization problems for a diverse range of applications, by [optimized algorithm unrolling](api/primitive), [deep equilibrium learning](api/primitive), and [deep reinforcement learning](api/primitive).
39 |
40 | The library includes the following major components:
41 |
42 | - A library of differentiable [proximal algorithms](api/algo), [proximal operators](api/proxfn), and [linear operators](api/linop).
43 | - Interchangeable [specialization](api/primitive) strategies for balancing trade-offs between speed and memory.
44 | - Out-of-the-box [training utilities](api/primitive) for learning-based bi-level optimization with a few lines of code.
45 |
46 | ```{nbgallery}
47 | ```
48 |
49 |
--------------------------------------------------------------------------------
/docs/source/started/index.md:
--------------------------------------------------------------------------------
1 | # Get Started
2 |
3 | ```{toctree}
4 | :hidden:
5 |
6 | quicktour
7 | install
8 | ```
9 |
10 | ## Installation
11 |
12 | To get started with ∇-Prox, please follow the [Installation Documentation](install) for detailed instructions on how to install the library.
13 |
14 | ## Quick Tour
15 |
16 | - Take a [Quick Tour](quicktour) to get familiar with the features and functionalities of ∇-Prox.
17 |
18 | - Explore the [API Reference](../api/index) for a complete list of classes and functions.
19 |
20 | - For advanced topics and best practices, refer to the [tutorials](../tutorials/index).
21 |
22 |
23 | Happy coding with ∇-Prox! 🎉
--------------------------------------------------------------------------------
/docs/source/started/install.md:
--------------------------------------------------------------------------------
1 | # Installation
2 |
3 |
4 | ∇-Prox works with PyTorch. To install Pytorch, please follow the [PyTorch installation instructions](https://pytorch.org/get-started/locally/).
5 |
6 |
7 | **Install with pip**
8 |
9 | ```bash
10 | pip install dprox
11 | ```
12 |
13 | **Install from source**
14 |
15 | ```bash
16 | pip install git+https://github.com/princeton-computational-imaging/Delta-Prox.git
17 | ```
18 |
19 | **Editable installation**
20 |
21 | You will need an editable install if you would like to:
22 |
23 | 1. Use the main version of the source code.
24 | 2. Need to test changes in the code.
25 |
26 | To do so, clone the repository and install 🎉 Delta Prox with the following commands:
27 |
28 | ```
29 | git clone git+https://github.com/princeton-computational-imaging/Delta-Prox.git
30 | cd DeltaProx
31 | pip install -e .
32 | ```
33 |
34 | ```{caution}
35 | Note that you must keep the DeltaProx folder for editable installation if you want to keep using the library.
36 | ```
37 |
--------------------------------------------------------------------------------
/docs/source/started/quicktour.nblink:
--------------------------------------------------------------------------------
1 | {
2 | "path": "../../../notebooks/quickstart.ipynb"
3 | }
--------------------------------------------------------------------------------
/docs/source/tutorials/computational_optics.nblink:
--------------------------------------------------------------------------------
1 | {
2 | "path": "../../../notebooks/computational_optics.ipynb"
3 | }
--------------------------------------------------------------------------------
/docs/source/tutorials/csmri.nblink:
--------------------------------------------------------------------------------
1 | {
2 | "path": "../../../notebooks/csmri.ipynb"
3 | }
--------------------------------------------------------------------------------
/docs/source/tutorials/deraining.nblink:
--------------------------------------------------------------------------------
1 | {
2 | "path": "../../../notebooks/deraining.ipynb"
3 | }
--------------------------------------------------------------------------------
/docs/source/tutorials/differentiable_linear_solver.nblink:
--------------------------------------------------------------------------------
1 | {
2 | "path": "../../../notebooks/differentiable_linear_solver.ipynb"
3 | }
--------------------------------------------------------------------------------
/docs/source/tutorials/energy_system.nblink:
--------------------------------------------------------------------------------
1 | {
2 | "path": "../../../notebooks/energy_system_planning.ipynb"
3 | }
--------------------------------------------------------------------------------
/docs/source/tutorials/image_restoration.nblink:
--------------------------------------------------------------------------------
1 | {
2 | "path": "../../../notebooks/image_restoration.ipynb"
3 | }
--------------------------------------------------------------------------------
/docs/source/tutorials/index.md:
--------------------------------------------------------------------------------
1 | # Tutorials
2 |
3 | ## Basics
4 | ```{nbgallery}
5 | learn_the_basic
6 | differentiable_linear_solver
7 | linear_operator
8 | ```
9 |
10 | ## Tasks
11 |
12 | ```{nbgallery}
13 | computational_optics
14 | csmri
15 | deraining
16 | energy_system
17 | image_restoration
18 | ```
--------------------------------------------------------------------------------
/docs/source/tutorials/learn_the_basic.nblink:
--------------------------------------------------------------------------------
1 | {
2 | "path": "../../../notebooks/learn_the_basic.ipynb"
3 | }
--------------------------------------------------------------------------------
/docs/source/tutorials/linear_operator.nblink:
--------------------------------------------------------------------------------
1 | {
2 | "path": "../../../notebooks/linear_operator.ipynb"
3 | }
--------------------------------------------------------------------------------
/docs/source/tutorials/training.nblink:
--------------------------------------------------------------------------------
1 | {
2 | "path": "../../../notebooks/training.ipynb"
3 | }
--------------------------------------------------------------------------------
/dprox/__init__.py:
--------------------------------------------------------------------------------
1 | from .linop import *
2 | from .proxfn import *
3 | from .algo import *
4 | from .utils.containar import array, tensor
5 | from . import linalg
6 | from .utils.huggingface import CACHE_DIR
7 |
8 | __version__ = '0.1.3'
9 | __cache_dir__ = CACHE_DIR
10 |
--------------------------------------------------------------------------------
/dprox/algo/__init__.py:
--------------------------------------------------------------------------------
1 | from .problem import Problem
2 | from .admm import ADMM, ADMM_vxu, LinearizedADMM
3 | from .hqs import HQS
4 | from .pc import PockChambolle
5 | from .pgd import ProximalGradientDescent
6 | from .base import Algorithm
7 | from .tune.dpir import log_descent
8 | from .specialization import AutoTuneSolver, DEQSolver, UnrolledSolver
9 | from .primitives import compile, specialize, visualize, optimize, train
10 |
--------------------------------------------------------------------------------
/dprox/algo/hqs.py:
--------------------------------------------------------------------------------
1 | from .admm import ADMM
2 |
3 |
4 | class HQS(ADMM):
5 | def initialize(self, x0):
6 | x = x0
7 | z = self.K.forward(x, return_list=True)
8 | return x, z
9 |
10 | def _iter(self, state, rho, lam):
11 | x, z = state
12 | x = self.least_square.solve(z, rho)
13 | Kx = self.K.forward(x, return_list=True) # cache Kx
14 | for i, fn in enumerate(self.psi_fns):
15 | z[i] = fn.prox(Kx[i], lam=lam[fn])
16 | return x, z
17 |
18 | @property
19 | def state_split(self):
20 | return [1, [len(self.psi_fns)]]
21 |
--------------------------------------------------------------------------------
/dprox/algo/invert.py:
--------------------------------------------------------------------------------
1 | from dprox.proxfn import least_squares, ext_sum_squares, least_squares
2 | from dprox.linop import Variable
3 |
4 |
5 | def get_least_square_solver(psi_fns, omega_fns, try_diagonalize, try_freq_diagonalize, linear_solve_config):
6 | prox_fns = psi_fns + omega_fns
7 |
8 | ext_sq = [fn for fn in omega_fns if isinstance(fn, ext_sum_squares)]
9 | for fn in ext_sq:
10 | other = [f for f in prox_fns if f is not fn]
11 | if all(isinstance(fn.linop, Variable) for fn in other):
12 | return ext_sq[0].setup([f.b for f in omega_fns if f is not fn and f not in ext_sq])
13 |
14 | return least_squares(omega_fns, psi_fns, try_diagonalize, try_freq_diagonalize, linear_solve_config=linear_solve_config)
15 |
--------------------------------------------------------------------------------
/dprox/algo/lp/__init__.py:
--------------------------------------------------------------------------------
1 | from .solvers import LPSolverADMM, LPConvergenceLoss, LPProblem
--------------------------------------------------------------------------------
/dprox/algo/opt/__init__.py:
--------------------------------------------------------------------------------
1 | from . import absorb
2 | from . import merge
3 |
--------------------------------------------------------------------------------
/dprox/algo/opt/absorb.py:
--------------------------------------------------------------------------------
1 | # Absorb linear operators into proximal operators.
2 |
3 | from dprox.linop import scale, mosaic
4 | from dprox.proxfn import (sum_squares, weighted_sum_squares)
5 |
6 | WEIGHTED = {sum_squares: weighted_sum_squares}
7 |
8 |
9 | def absorb_all_linops(prox_funcs):
10 | """Repeatedy absorb lin ops.
11 | """
12 | new_proxes = []
13 | ready = prox_funcs[:]
14 | while len(ready) > 0:
15 | curr = ready.pop(0)
16 | absorbed = absorb_linop(curr)
17 | if len(absorbed) == 1 and absorbed[0] == curr:
18 | new_proxes.append(absorbed[0])
19 | else:
20 | ready += absorbed
21 | return new_proxes
22 |
23 |
24 | def absorb_linop(prox_fn):
25 | """If possible moves the top level lin op argument
26 | into the prox operator.
27 |
28 | For example, elementwise multiplication can be folded into
29 | a separable function's prox.
30 | """
31 | if isinstance(prox_fn.linop, mosaic) and isinstance(prox_fn, sum_squares):
32 | new_fn = weighted_sum_squares(prox_fn.linop.input_nodes[0], prox_fn.linop, prox_fn.offset)
33 | return [new_fn]
34 |
35 | # Fold scalar into the function.
36 | if isinstance(prox_fn.linop, scale):
37 | scalar = prox_fn.linop.scalar
38 | prox_fn.linop = prox_fn.linop.input_nodes[0]
39 | prox_fn.beta = prox_fn.beta * scalar
40 | return [prox_fn]
41 | # No change.
42 | return [prox_fn]
43 |
--------------------------------------------------------------------------------
/dprox/algo/opt/equil.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import math
4 |
5 |
6 | def equil(K, iters=50, gamma=1e-1, M=math.log(1e4)):
7 | """Computes diagonal D, E so that DKE is approximately equilibrated.
8 | """
9 | m = K.output_size
10 | n = K.input_size
11 | alpha, beta = get_alpha_beta(m, n)
12 |
13 | u = np.zeros(m)
14 | v = np.zeros(n)
15 | ubar = u.copy()
16 | vbar = v.copy()
17 |
18 | in_buf = np.zeros(n)
19 | out_buf = np.zeros(m)
20 |
21 | # Main loop.
22 | for t in range(1, iters + 1):
23 | step_size = 2 / (gamma * (t + 1))
24 | # u grad estimate.
25 | s = np.random.choice([-1, 1], size=n)
26 | K.forward(np.exp(v) * s, out_buf)
27 | u_grad = np.exp(2 * u) * np.square(out_buf) - alpha**2 + gamma * u
28 |
29 | # v grad estimate.
30 | w = np.random.choice([-1, 1], size=m)
31 | K.adjoint(np.exp(u) * w, in_buf)
32 | v_grad = np.exp(2 * v) * np.square(in_buf) - beta**2 + gamma * v
33 |
34 | u = project(u - step_size * u_grad, M)
35 | v = project(v - step_size * v_grad, M)
36 | # Update averages.
37 | ubar = 2 * u / (t + 2) + t * ubar / (t + 2)
38 | vbar = 2 * v / (t + 2) + t * vbar / (t + 2)
39 |
40 | return np.exp(ubar), np.exp(vbar)
41 |
42 |
43 | def get_alpha_beta(m, n):
44 | return (n / m)**(0.25), (m / n)**(0.25)
45 |
46 |
47 | def project(x, M):
48 | """Project x onto [-M, M]^n.
49 | """
50 | return np.minimum(M, np.maximum(x, -M, out=x), out=x)
51 |
52 | # Comparison method.
53 |
54 |
55 | def f(A, u, v, gamma, p=2):
56 | m, n = A.shape
57 | alpha, beta = get_alpha_beta(m, n)
58 | total = (1. / p) * np.exp(p * u).T.dot(np.power(np.abs(A), p)).dot(np.exp(p * v))
59 | total += -alpha**p * u.sum() - beta**p * v.sum() + (gamma / 2) * ((u * u).sum() + (v * v).sum())
60 | return np.sum(total)
61 |
62 |
63 | def get_grad(A, u, v, gamma, p=2):
64 | m, n = A.shape
65 | alpha, beta = get_alpha_beta(m, n)
66 |
67 | tmp = np.diag(np.exp(p * u)).dot((A * A)).dot(np.exp(p * v))
68 | grad_u = tmp - alpha**p + gamma * u
69 | du = -grad_u / (2 * tmp + gamma)
70 |
71 | tmp = np.diag(np.exp(p * v)).dot((A.T * A.T)).dot(np.exp(p * u))
72 | grad_v = tmp - beta**p + gamma * v
73 | dv = -grad_v / (2 * tmp + gamma)
74 |
75 | return du, dv, grad_u, grad_v
76 |
77 |
78 | def newton_equil(A, gamma, max_iters):
79 | alpha = 0.25
80 | beta = 0.5
81 | m, n = A.shape
82 | u = np.zeros(m)
83 | v = np.zeros(n)
84 | for i in range(max_iters):
85 | du, dv, grad_u, grad_v = get_grad(A, u, v, gamma)
86 | # Backtracking line search.
87 | t = 1
88 | obj = f(A, u, v, gamma)
89 | grad_term = np.sum(alpha * (grad_u.dot(du) + grad_v.dot(dv)))
90 | while True:
91 | new_obj = f(A, u + t * du, v + t * dv, gamma)
92 | if new_obj > obj + t * grad_term:
93 | t = beta * t
94 | else:
95 | u = u + t * du
96 | v = v + t * dv
97 | break
98 | return np.exp(u), np.exp(v)
99 |
--------------------------------------------------------------------------------
/dprox/algo/opt/merge.py:
--------------------------------------------------------------------------------
1 | # Merge proximal operators together.
2 |
3 | # from proximal.prox_fns import sum_squares, zero_prox
4 | import numpy as np
5 |
6 |
7 | def merge_all(prox_fns):
8 | """Merge as many prox functions as possible.
9 | """
10 | while True:
11 | merged = []
12 | new_prox_fns = []
13 | no_merges = True
14 | for i in range(len(prox_fns)):
15 | for j in range(i + 1, len(prox_fns)):
16 | if prox_fns[i] not in merged and prox_fns[j] not in merged and \
17 | can_merge(prox_fns[i], prox_fns[j]):
18 | no_merges = False
19 | merged += [prox_fns[i], prox_fns[j]]
20 | new_prox_fns.append(merge_fns(prox_fns[i], prox_fns[j]))
21 | if no_merges:
22 | break
23 | prox_fns = new_prox_fns + [fn for fn in prox_fns if fn not in merged]
24 |
25 | return prox_fns
26 |
27 |
28 | def can_merge(lh_prox, rh_prox):
29 | """Can lh_prox and rh_prox be merged into a single function?
30 | """
31 | # Lin ops must be the same.
32 | if lh_prox.lin_op == rh_prox.lin_op:
33 | if type(lh_prox) == zero_prox or type(rh_prox) == zero_prox:
34 | return True
35 | elif type(lh_prox) == sum_squares or type(rh_prox) == sum_squares:
36 | return True
37 |
38 | return False
39 |
40 |
41 | def merge_fns(lh_prox, rh_prox):
42 | """Merge the two functions into a single function.
43 | """
44 | assert can_merge(lh_prox, rh_prox)
45 | new_c = lh_prox.c + rh_prox.c
46 | new_gamma = lh_prox.gamma + rh_prox.gamma
47 | new_d = lh_prox.d + rh_prox.d
48 | args = [lh_prox, rh_prox]
49 | arg_types = [type(lh_prox), type(rh_prox)]
50 | # Merge a linear term into the other proxable function.
51 | if zero_prox in arg_types:
52 | to_copy = args[1 - arg_types.index(zero_prox)]
53 | return to_copy.copy(c=new_c, gamma=new_gamma, d=new_d)
54 | # Merge a sum squares term into the other proxable function.
55 | elif sum_squares in arg_types:
56 | idx = arg_types.index(sum_squares)
57 | sq_fn = args[idx]
58 | to_copy = args[1 - idx]
59 | coeff = sq_fn.alpha * sq_fn.beta
60 | return to_copy.copy(c=new_c - 2 * coeff * sq_fn.b,
61 | gamma=new_gamma + coeff * sq_fn.beta,
62 | d=new_d + np.square(sq_fn.b).sum())
63 | else:
64 | raise ValueError("Unknown merge strategy.")
65 |
--------------------------------------------------------------------------------
/dprox/algo/pc.py:
--------------------------------------------------------------------------------
1 | from ..linop import adjoint
2 | from .admm import ADMM
3 | from .base import expand
4 |
5 |
6 | class PockChambolle(ADMM):
7 | def initialize(self, x0):
8 | x = x0
9 | xbar = x.clone()
10 | z = self.K.forward(x, return_list=True)
11 | return x, z, xbar
12 |
13 | def _iter(self, state, rho, lam):
14 | x, z, xbar = state
15 |
16 | # update z
17 | Kxbar = self.K.forward(xbar, return_list=True)
18 | for i, fn in enumerate(self.psi_fns):
19 | r = expand(lam[fn])
20 | z[i] = z[i] + r * Kxbar[i]
21 | z[i] = z[i] - r * fn.prox(z[i], lam=r)
22 |
23 | # update x
24 | # Ktz = self.K.adjoint(z)
25 | Ktz = [adjoint(fn.linop, z[i]) for i, fn in enumerate(self.psi_fns)]
26 | x_next = [x - Ktz[i] for i in range(len(Ktz))]
27 | if len(self.omega_fns) > 0:
28 | x_next = self.least_square.solve(x_next, rho)
29 | else:
30 | x_next = sum(x_next)
31 |
32 | # update xbar
33 | xbar = x_next + x_next - x
34 | x = x_next
35 |
36 | return x, z, xbar
37 |
38 | @property
39 | def state_split(self):
40 | return [1, [len(self.psi_fns)], 1]
41 |
--------------------------------------------------------------------------------
/dprox/algo/pgd.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from dprox.proxfn import ProxFn
4 |
5 | from .base import Algorithm, expand
6 |
7 |
8 | class ProximalGradientDescent(Algorithm):
9 | @classmethod
10 | def partition(cls, prox_fns: List[ProxFn]):
11 | if len(prox_fns) != 2:
12 | raise ValueError('Proximal gradient descent only supports \
13 | two proximal functions for now.')
14 |
15 | omega_fns = []
16 | for fn in prox_fns:
17 | if hasattr(fn, 'grad'):
18 | omega_fns.append(fn)
19 |
20 | psi_fns = [fn for fn in prox_fns if fn not in omega_fns]
21 |
22 | if len(omega_fns) == 0:
23 | raise ValueError('Proximal gradient descent requires \
24 | at least one proximal function is differentiable.')
25 |
26 | return psi_fns, omega_fns
27 |
28 | def __init__(
29 | self,
30 | psi_fns: List[ProxFn],
31 | omega_fns: List[ProxFn],
32 | *args,
33 | **kwargs
34 | ):
35 | super().__init__(psi_fns, omega_fns)
36 | self.diff_fn = omega_fns[0]
37 | self.prox_fn = psi_fns[0]
38 |
39 | def _iter(self, state, rho, lam):
40 | x = state[0]
41 | v = x - expand(rho) * self.diff_fn.grad(x)
42 | x = self.prox_fn.prox(v, lam[self.prox_fn])
43 | return [x]
44 |
45 | def initialize(self, x0):
46 | return [x0]
47 |
48 | @property
49 | def state_split(self):
50 | return [1]
51 |
52 | @property
53 | def nparams(self):
54 | return len(self.psi_fns) + 1
55 |
--------------------------------------------------------------------------------
/dprox/algo/specialization/__init__.py:
--------------------------------------------------------------------------------
1 | from .deq import DEQSolver, train_deq
2 | from .unroll import UnrolledSolver, build_unrolled_solver
3 | from .rl import AutoTuneSolver
--------------------------------------------------------------------------------
/dprox/algo/specialization/deq/__init__.py:
--------------------------------------------------------------------------------
1 | from .solver import DEQSolver
2 | from .training import train_deq
--------------------------------------------------------------------------------
/dprox/algo/specialization/deq/solver.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 |
7 | from dprox.algo.base import Algorithm, move, auto_convert_to_tensor
8 | from dprox.utils import to_torch_tensor
9 |
10 | from .utils.solvers import anderson
11 |
12 |
13 | class DEQ(nn.Module):
14 | def __init__(self, fn, f_thres=40, b_thres=40):
15 | super().__init__()
16 | self.fn = fn
17 |
18 | self.f_thres = f_thres
19 | self.b_thres = b_thres
20 |
21 | self.solver = anderson
22 |
23 | def forward(self, x, *args, **kwargs):
24 | f_thres = kwargs.get('f_thres', self.f_thres)
25 | b_thres = kwargs.get('b_thres', self.b_thres)
26 |
27 | z0 = x
28 |
29 | # Forward pass
30 | with torch.no_grad():
31 | out = self.solver(lambda z: self.fn(z, x, *args), z0, threshold=f_thres)
32 | z_star = out['result'] # See step 2 above
33 | new_z_star = z_star
34 |
35 | # (Prepare for) Backward pass, see step 3 above
36 | if self.training:
37 | new_z_star = self.fn(z_star.requires_grad_(), x, *args)
38 |
39 | # Jacobian-related computations, see additional step above. For instance:
40 | # jac_loss = jac_loss_estimate(new_z_star, z_star, vecs=1)
41 |
42 | def backward_hook(grad):
43 | if self.hook is not None:
44 | self.hook.remove()
45 | torch.cuda.synchronize() # To avoid infinite recursion
46 | # Compute the fixed point of yJ + grad, where J=J_f is the Jacobian of f at z_star
47 | out = self.solver(lambda y: torch.autograd.grad(new_z_star, z_star, y, retain_graph=True)[0] + grad,
48 | torch.zeros_like(grad), threshold=b_thres)
49 | new_grad = out['result']
50 | return new_grad
51 |
52 | self.hook = new_z_star.register_hook(backward_hook)
53 |
54 | return new_z_star
55 |
56 |
57 | class DEQSolver(nn.Module):
58 | def __init__(self, solver: Algorithm, learned_params=False, rhos=None, lams=None):
59 | super().__init__()
60 | self.internal = solver
61 |
62 | def fn(z, x, *args):
63 | state = solver.unpack(z)
64 | state = solver.iter(state, *args)
65 | return solver.pack(state)
66 | self.solver = DEQ(fn)
67 |
68 | self.learned_params = learned_params
69 | if learned_params:
70 | self.r = nn.parameter.Parameter(torch.tensor(1.))
71 | self.l = nn.parameter.Parameter(torch.tensor(1.))
72 |
73 | self.rhos = rhos
74 | self.lams = lams
75 |
76 | @auto_convert_to_tensor(['x0', 'rhos', 'lams'], batchify=['x0'])
77 | def solve(
78 | self,
79 | x0: Union[torch.Tensor, np.ndarray] = None,
80 | rhos: Union[float, torch.Tensor, np.ndarray] = None,
81 | lams: Union[float, torch.Tensor, np.ndarray, dict] = None,
82 | **kwargs
83 | ):
84 | x0, rhos, lams, _ = self.internal.defaults(x0, rhos, lams, 1)
85 | x0, rhos, lams = move(x0, rhos, lams, device=self.internal.device)
86 |
87 | lam = {k: to_torch_tensor(v)[..., 0].to(x0.device) for k, v in lams.items()}
88 | rho = to_torch_tensor(rhos)[..., 0].to(x0.device)
89 |
90 | if self.learned_params:
91 | rho = self.r * rho
92 | lam = {k: v * self.l for k, v in lam.items()}
93 |
94 | state = self.internal.initialize(x0)
95 | state = self.internal.pack(state)
96 | state = self.solver(state, rho, lam)
97 | state = self.internal.unpack(state)
98 | return state[0]
99 |
100 | def forward(
101 | self,
102 | **kwargs
103 | ):
104 | return self.solve(**kwargs)
105 |
106 | def load(self, state_dict, strict=True):
107 | self.load_state_dict(state_dict['solver'], strict=strict)
108 | self.rhos = state_dict.get('rhos')
109 | self.lams = state_dict.get('lams')
110 |
--------------------------------------------------------------------------------
/dprox/algo/specialization/deq/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/Delta-Prox/15fc489e99c4de2676dc538d9c0552338f6fb894/dprox/algo/specialization/deq/utils/__init__.py
--------------------------------------------------------------------------------
/dprox/algo/specialization/deq/utils/jacobian.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | import numpy as np
5 |
6 |
7 | def jac_loss_estimate(f0, z0, vecs=2, create_graph=True):
8 | """Estimating tr(J^TJ)=tr(JJ^T) via Hutchinson estimator
9 |
10 | Args:
11 | f0 (torch.Tensor): Output of the function f (whose J is to be analyzed)
12 | z0 (torch.Tensor): Input to the function f
13 | vecs (int, optional): Number of random Gaussian vectors to use. Defaults to 2.
14 | create_graph (bool, optional): Whether to create backward graph (e.g., to train on this loss).
15 | Defaults to True.
16 |
17 | Returns:
18 | torch.Tensor: A 1x1 torch tensor that encodes the (shape-normalized) jacobian loss
19 | """
20 | vecs = vecs
21 | result = 0
22 | for i in range(vecs):
23 | v = torch.randn(*z0.shape).to(z0)
24 | vJ = torch.autograd.grad(f0, z0, v, retain_graph=True, create_graph=create_graph)[0]
25 | result += vJ.norm()**2
26 | return result / vecs / np.prod(z0.shape)
27 |
28 | def power_method(f0, z0, n_iters=200):
29 | """Estimating the spectral radius of J using power method
30 |
31 | Args:
32 | f0 (torch.Tensor): Output of the function f (whose J is to be analyzed)
33 | z0 (torch.Tensor): Input to the function f
34 | n_iters (int, optional): Number of power method iterations. Defaults to 200.
35 |
36 | Returns:
37 | tuple: (largest eigenvector, largest (abs.) eigenvalue)
38 | """
39 | evector = torch.randn_like(z0)
40 | bsz = evector.shape[0]
41 | for i in range(n_iters):
42 | vTJ = torch.autograd.grad(f0, z0, evector, retain_graph=(i < n_iters-1), create_graph=False)[0]
43 | evalue = (vTJ * evector).reshape(bsz, -1).sum(1, keepdim=True) / (evector * evector).reshape(bsz, -1).sum(1, keepdim=True)
44 | evector = (vTJ.reshape(bsz, -1) / vTJ.reshape(bsz, -1).norm(dim=1, keepdim=True)).reshape_as(z0)
45 | return (evector, torch.abs(evalue))
--------------------------------------------------------------------------------
/dprox/algo/specialization/deq/utils/layer_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 |
5 |
6 | def list2vec(z1_list):
7 | """Convert list of tensors to a vector"""
8 | bsz = z1_list[0].size(0)
9 | return torch.cat([elem.reshape(bsz, -1, 1) for elem in z1_list], dim=1)
10 |
11 |
12 | def vec2list(z1, cutoffs):
13 | """Convert a vector back to a list, via the cutoffs specified"""
14 | bsz = z1.shape[0]
15 | z1_list = []
16 | start_idx, end_idx = 0, cutoffs[0][0] * cutoffs[0][1] * cutoffs[0][2]
17 | for i in range(len(cutoffs)):
18 | z1_list.append(z1[:, start_idx:end_idx].view(bsz, *cutoffs[i]))
19 | if i < len(cutoffs)-1:
20 | start_idx = end_idx
21 | end_idx += cutoffs[i + 1][0] * cutoffs[i + 1][1] * cutoffs[i + 1][2]
22 | return z1_list
23 |
24 |
25 | def conv3x3(in_planes, out_planes, stride=1, bias=False):
26 | """3x3 convolution with padding"""
27 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=bias)
28 |
29 | def conv5x5(in_planes, out_planes, stride=1, bias=False):
30 | """5x5 convolution with padding"""
31 | return nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride, padding=2, bias=bias)
32 |
33 |
34 | def norm_diff(new, old, show_list=False):
35 | if show_list:
36 | return [(new[i] - old[i]).norm().item() for i in range(len(new))]
37 | return np.sqrt(sum((new[i] - old[i]).norm().item()**2 for i in range(len(new))))
--------------------------------------------------------------------------------
/dprox/algo/specialization/rl/__init__.py:
--------------------------------------------------------------------------------
1 | from .solver import *
--------------------------------------------------------------------------------
/dprox/algo/specialization/unroll.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from functools import partial
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | from ..base import auto_convert_to_tensor, move
8 |
9 |
10 | def clone(x, nums, share):
11 | return [x if share else copy.deepcopy(x) for _ in range(nums)]
12 |
13 |
14 | def build_unrolled_solver(solver, share=True, **kwargs):
15 | if share == True:
16 | solver.solve = partial(solver.solve, **kwargs)
17 | return solver
18 | return UnrolledSolver(solver, share=share, **kwargs)
19 |
20 |
21 | class UnrolledSolver(nn.Module):
22 | def __init__(self, solver, max_iter, share=False, learned_params=False):
23 | super().__init__()
24 | if share == False:
25 | self.solvers = nn.ModuleList(clone(solver, max_iter, share=share))
26 | else:
27 | self.solver = solver
28 | self.solvers = [self.solver for _ in range(max_iter)]
29 |
30 | self.max_iter = max_iter
31 | self.share = share
32 |
33 | self.learned_params = learned_params
34 | if learned_params:
35 | self.rhos = nn.parameter.Parameter(torch.ones(max_iter))
36 | self.lams = {}
37 | for fn in solver.psi_fns:
38 | lam = nn.parameter.Parameter(torch.ones(max_iter))
39 | setattr(self, str(fn), lam)
40 | self.lams[fn] = lam
41 |
42 | @auto_convert_to_tensor(['x0', 'rhos', 'lams'], batchify=['x0'])
43 | def solve(self, x0=None, rhos=None, lams=None, max_iter=None):
44 | x0, rhos, lams = move(x0, rhos, lams, device=self.solvers[0].device)
45 |
46 | if self.learned_params:
47 | rhos, lams = self.rhos, self.lams
48 |
49 | max_iter = self.max_iter if max_iter is None else max_iter
50 |
51 | state = self.solvers[0].initialize(x0)
52 |
53 | for i in range(max_iter):
54 | rho = rhos[..., i:i + 1]
55 | lam = {self.solvers[i].psi_fns[0]: v[..., i:i + 1] for k, v in lams.items()}
56 | state = self.solvers[i].iters(state, rho, lam, 1, False)
57 |
58 | return state[0]
59 |
--------------------------------------------------------------------------------
/dprox/algo/tune/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/Delta-Prox/15fc489e99c4de2676dc538d9c0552338f6fb894/dprox/algo/tune/__init__.py
--------------------------------------------------------------------------------
/dprox/algo/tune/dpir.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def get_rho_sigma_admm(sigma=2.55 / 255, iter_num=15, modelSigma1=49.0, modelSigma2=2.55, w=1.0, lam=0.23):
6 | modelSigmaS = np.logspace(np.log10(modelSigma1), np.log10(modelSigma2), iter_num).astype(np.float32)
7 | modelSigmaS_lin = np.linspace(modelSigma1, modelSigma2, iter_num).astype(np.float32)
8 | sigmas = (modelSigmaS * w + modelSigmaS_lin * (1 - w)) / 255.
9 | rhos = list(map(lambda x: lam * (sigma**2) / (x**2), sigmas))
10 | return rhos, sigmas
11 |
12 |
13 | def log_descent(upper, lower, iter=24, sigma=0.255 / 255,
14 | w=1.0, lam=0.23, sqrt=False):
15 | """
16 | generate a list of rhos and sigmas based on given parameters using
17 | logarithmic descent.
18 |
19 | :param upper: The upper bound of the range of modelSigmaS values to be generated using a logarithmic
20 | scale
21 | :param lower: The lower bound of the range of values for modelSigmaS
22 | :param iter: The number of iterations or steps in the descent algorithm, defaults to 24 (optional)
23 | :param sigma: The standard deviation of the noise in the image
24 | :param w: The parameter w is a weight used to balance the logarithmic and linear scales when
25 | generating the sequence of modelSigmaS values. It is used to calculate the sigmas values, which are
26 | the squared values of the modelSigmaS values divided by 255
27 | :param lam: lam is a hyperparameter that controls the strength of the regularization term in the
28 | optimization problem.
29 | :return: two lists: `rhos` and `sigmas`.
30 | """
31 | modelSigmaS = np.logspace(np.log10(upper), np.log10(lower), iter).astype(np.float32)
32 | modelSigmaS_lin = np.linspace(upper, lower, iter).astype(np.float32)
33 | sigmas = (modelSigmaS * w + modelSigmaS_lin * (1 - w)) / 255.
34 | rhos = list(map(lambda x: lam * (sigma**2) / (x**2), sigmas))
35 | if not sqrt:
36 | sigmas = list(sigmas**2)
37 | rhos = torch.tensor(rhos).float()
38 | sigmas = torch.tensor(sigmas).float()
39 | return rhos, sigmas
40 |
41 |
42 | def f(params): return [np.sqrt(p) * 255 for p in params]
43 | # this can be 70.94
44 |
45 |
46 | def log_descent2(upper, lower, iter=24, sigma=0.255 / 255, w=1.0, lam=0.23):
47 | modelSigmaS = np.logspace(np.log10(upper), np.log10(lower), iter).astype(np.float32)
48 | modelSigmaS_lin = np.linspace(upper, lower, iter).astype(np.float32)
49 | sigmas = (modelSigmaS * w + modelSigmaS_lin * (1 - w)) / 255.
50 | sigmas = list((sigmas**2) / (sigma**2))
51 | rhos = list(map(lambda x: lam * (sigma**2) / (x**2), sigmas))
52 | sigmas = list(map(lambda x: 1 / x, sigmas))
53 | return rhos, sigmas
54 |
55 |
56 | def log_descent_origin(upper, lower, iter=24,
57 | sigma=0.255 / 255, w=1.0, lam=0.23):
58 | modelSigmaS = np.logspace(np.log10(upper), np.log10(lower), iter).astype(np.float32)
59 | modelSigmaS_lin = np.linspace(upper, lower, iter).astype(np.float32)
60 | sigmas = (modelSigmaS * w + modelSigmaS_lin * (1 - w)) / 255.
61 | rhos = list(map(lambda x: lam * (sigma**2) / (x**2), sigmas))
62 | return rhos
63 |
--------------------------------------------------------------------------------
/dprox/algo/tune/learnable.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class LearnableParamProvider(nn.Module):
6 | def __init__(self, steps, default_value=0.5):
7 | super().__init__()
8 | # for step in steps:
9 | # self.register_buffer(f'')
10 | # self.params = nn.Parameter(torch.ones([len(steps)]) * default_value)
11 |
12 | def forward(self, step):
13 | # return self.params[step]
14 | if step == 0: return torch.tensor(0.3442)
15 | elif step == 6: return torch.tensor(0.6111)
16 | else: return torch.tensor(0.3168)
--------------------------------------------------------------------------------
/dprox/contrib/__init__.py:
--------------------------------------------------------------------------------
1 | from .restoration import *
2 | from . import csmri, derain, optic
--------------------------------------------------------------------------------
/dprox/contrib/derain.py:
--------------------------------------------------------------------------------
1 | # Learnable Linear Operator from
2 | # https://github.com/MC-E/Deep-Generalized-Unfolding-Networks-for-Image-Restoration/blob/main/Deraining/DGUNet.py
3 | # Deep Generalized Unfolding Networks for Image Restoration
4 |
5 | import torch.nn as nn
6 |
7 |
8 | class ResBlock(nn.Module):
9 | def __init__(
10 | self, conv, n_feats, kernel_size,
11 | bias=True, bn=False, act=nn.PReLU(), res_scale=1
12 | ):
13 | super(ResBlock, self).__init__()
14 | m = []
15 | for i in range(2):
16 | if i == 0:
17 | m.append(conv(n_feats, 64, kernel_size, bias=bias))
18 | else:
19 | m.append(conv(64, n_feats, kernel_size, bias=bias))
20 | if bn:
21 | m.append(nn.BatchNorm2d(n_feats))
22 | if i == 0:
23 | m.append(act)
24 |
25 | self.body = nn.Sequential(*m)
26 | self.res_scale = res_scale
27 |
28 | def forward(self, x):
29 | res = self.body(x).mul(self.res_scale)
30 | res += x
31 | return res
32 |
33 |
34 | def default_conv(in_channels, out_channels, kernel_size, stride=1, bias=True):
35 | return nn.Conv2d(
36 | in_channels, out_channels, kernel_size,
37 | padding=(kernel_size // 2), stride=stride, bias=bias)
38 |
39 |
40 | class LearnableDegOp(nn.Module):
41 | def __init__(self, diag=False):
42 | super().__init__()
43 | self.phi_0 = ResBlock(default_conv, 3, 3)
44 | self.phi_1 = ResBlock(default_conv, 3, 3)
45 | self.phi_6 = ResBlock(default_conv, 3, 3)
46 | self.phit_0 = ResBlock(default_conv, 3, 3)
47 | self.phit_1 = ResBlock(default_conv, 3, 3)
48 | self.phit_6 = ResBlock(default_conv, 3, 3)
49 |
50 | if diag:
51 | self.phid_0 = ResBlock(default_conv, 3, 3)
52 | self.phid_1 = ResBlock(default_conv, 3, 3)
53 | self.phid_6 = ResBlock(default_conv, 3, 3)
54 |
55 | self.max_step = 5
56 | self.step = 0
57 |
58 | def forward(self, x, step=None):
59 | if step is None: step = self.step
60 | if step == 0:
61 | return self.phi_0(x)
62 | elif step == self.max_step + 1:
63 | return self.phi_6(x)
64 | else:
65 | return self.phi_1(x)
66 |
67 | def adjoint(self, x, step=None):
68 | if step is None: step = self.step
69 | if step == 0:
70 | return self.phit_0(x)
71 | elif step == self.max_step + 1:
72 | return self.phit_6(x)
73 | else:
74 | return self.phit_1(x)
75 |
76 | def diag(self, x, step=None):
77 | if step is None: step = self.step
78 | if step == 0:
79 | return self.phid_0(x)
80 | elif step == self.max_step + 1:
81 | return self.phid_6(x)
82 | else:
83 | return self.phid_1(x)
84 |
--------------------------------------------------------------------------------
/dprox/contrib/energy_system.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import io
3 |
4 | from dprox.utils.huggingface import load_path
5 |
6 |
7 | def load_simple_cep_model():
8 | model_components = io.loadmat(load_path("energy_system/simple_cep_model_20220916/esm_instance.mat"))
9 | n_con, n_var = model_components["A"].shape
10 | print("Number of linear constraints (w/o bound constraints):", n_con)
11 | print("Number of decision variables:", n_var)
12 |
13 | A = model_components["A"].astype(np.float64)
14 | b = model_components["rhs"].astype(np.float64)
15 | types = model_components["sense"]
16 |
17 | A_ub = A[types == '<']
18 | b_ub = b[types == '<'][:, 0]
19 | n1 = sum(types == '<')
20 | print('n1, A_ub, b_ub:', n1, A_ub.shape, b_ub.shape)
21 |
22 | A_eq = A[types == '=']
23 | b_eq = b[types == '='][:, 0]
24 | n2 = sum(types == '=')
25 | print('n2, A_eq, b_eq:', n2, A_eq.shape, b_eq.shape)
26 | assert n1 + n2 == n_con
27 |
28 | c = model_components["obj"][:, 0]
29 | print('c:', c.shape)
30 |
31 | return c, A_ub, A_eq, b_ub, b_eq
32 |
--------------------------------------------------------------------------------
/dprox/contrib/optic/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import *
2 | from .doe_model import *
3 | # from .doe_model2 import *
4 | from .unet import U_Net
--------------------------------------------------------------------------------
/dprox/contrib/restoration.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.misc
3 | import scipy.ndimage
4 | import torch
5 |
6 | from dprox.utils import to_ndarray, to_torch_tensor
7 |
8 | samples = {
9 | "face": scipy.misc.face(),
10 | "ascent": scipy.misc.ascent(),
11 | }
12 |
13 |
14 | def sample(name="face", return_tensor=True):
15 | s = samples[name].copy().astype("float32") / 255
16 | if return_tensor:
17 | s = to_torch_tensor(s, batch=True).float()
18 | return s
19 |
20 |
21 | def point_spread_function(ksize, sigma):
22 | return np.expand_dims(fspecial_gaussian(ksize, sigma), axis=2).astype("float32")
23 |
24 |
25 | def blurring(img, psf):
26 | device = img.device
27 | img = to_ndarray(img, debatch=True)
28 | psf = to_ndarray(psf)
29 | b = scipy.ndimage.convolve(img, psf, mode="wrap")
30 | b = to_torch_tensor(b, batch=True).to(device)
31 | return b
32 |
33 |
34 | def fspecial_gaussian(hsize, sigma):
35 | hsize = [hsize, hsize]
36 | siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
37 | std = sigma
38 | [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
39 | arg = -(x * x + y * y) / (2 * std * std)
40 | h = np.exp(arg)
41 | h[h < scipy.finfo(float).eps * h.max()] = 0
42 | sumh = h.sum()
43 | if sumh != 0:
44 | h = h / sumh
45 | return h
46 |
47 |
48 | def downsampling(img, psf, sf):
49 | import cv2
50 |
51 | device = img.device
52 | img = to_ndarray(img, debatch=True)
53 | psf = to_ndarray(psf)
54 | blurred = scipy.ndimage.filters.convolve(img, psf, mode="wrap")
55 | downed = blurred[0::sf, 0::sf, ...]
56 | x0 = cv2.resize(downed, (downed.shape[1] * sf, downed.shape[0] * sf), interpolation=cv2.INTER_CUBIC)
57 |
58 | x0 = to_torch_tensor(x0, batch=True).to(device).float()
59 | downed = to_torch_tensor(downed, batch=True).to(device).float()
60 | return downed, x0
61 |
62 |
63 | def masks_CFA_Bayer(shape):
64 | pattern = "RGGB"
65 | channels = dict((channel, np.zeros(shape)) for channel in "RGB")
66 | for channel, (y, x) in zip(pattern, [(0, 0), (0, 1), (1, 0), (1, 1)]):
67 | channels[channel][y::2, x::2] = 1
68 | return tuple(channels[c].astype(bool) for c in "RGB")
69 |
70 |
71 | def mosaicing(img):
72 | device = img.device
73 | img = to_ndarray(img, debatch=True)
74 | shape = img.shape[:2]
75 | R_m, G_m, B_m = masks_CFA_Bayer(shape)
76 | mask = np.concatenate((R_m[..., None], G_m[..., None], B_m[..., None]), axis=-1)
77 | b = mask * img
78 | b = to_torch_tensor(b, batch=True).to(device)
79 | return b
80 |
81 |
82 | def mosaicing_np(img):
83 | shape = img.shape[:2]
84 | R_m, G_m, B_m = masks_CFA_Bayer(shape)
85 | mask = np.concatenate((R_m[..., None], G_m[..., None], B_m[..., None]), axis=-1)
86 | b = mask * img
87 | return b
88 |
89 |
90 | def mosaicing_torch(img):
91 | shape = img.shape[-2:]
92 | R_m, G_m, B_m = masks_CFA_Bayer(shape)
93 | mask = np.concatenate((R_m[..., None], G_m[..., None], B_m[..., None]), axis=-1)
94 | mask = torch.from_numpy(mask)
95 | mask = mask.permute(2, 0, 1).unsqueeze(0).to(img.device)
96 | b = mask * img
97 | return b
98 |
--------------------------------------------------------------------------------
/dprox/linalg/__init__.py:
--------------------------------------------------------------------------------
1 | from . import solve
2 | from .custom import LinearSolve, LinearSolveConfig, linear_solve
--------------------------------------------------------------------------------
/dprox/linalg/custom.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from functools import partial
3 |
4 | import torch
5 |
6 | from .solve import SOLVERS
7 |
8 |
9 | @dataclass
10 | class LinearSolveConfig:
11 | """Defines default configuration parameters for solving linear equations.
12 |
13 | Args:
14 | rtol (float): The relative tolerance level for convergence, default to 1e-6.
15 | max_iters (int): The maximum number of iterations allowed for convergence.
16 | verbose (bool): whether to print progress updates during the solving process.
17 | solver_type (str): The type of solver to use (e.g. conjugate gradient).
18 | solver_kwargs (dict): additional keyword arguments to pass to the solver function
19 | """
20 |
21 | rtol: float = 1e-6
22 | max_iters: int = 100
23 | verbose: bool = False
24 | solver_type: str = "cg"
25 | solver_kwargs: dict = field(default_factory=dict)
26 | use_analytic_grad: bool = True
27 |
28 |
29 | def _build_solver(config: LinearSolveConfig):
30 | solve_fn = SOLVERS[config.solver_type]
31 | solve_fn = partial(solve_fn, rtol=config.rtol, max_iters=config.max_iters, verbose=config.verbose, **config.solver_kwargs)
32 | return solve_fn
33 |
34 |
35 | def _trainable_parameters(module):
36 | return [p for p in module.parameters() if p.requires_grad]
37 |
38 |
39 | class LinearSolve(torch.autograd.Function):
40 | @staticmethod
41 | def forward(ctx, A, b, config, *Aparams):
42 | ctx.A = A
43 | ctx.linear_solver = _build_solver(config)
44 | x = ctx.linear_solver(A, b)
45 | ctx.save_for_backward(x, *Aparams)
46 | return x
47 |
48 | @staticmethod
49 | def backward(ctx, grad_x):
50 | grad_B = ctx.linear_solver(ctx.A.T, grad_x)
51 |
52 | x = ctx.saved_tensors[0]
53 | x = x.detach().clone()
54 |
55 | A = ctx.A.clone()
56 | with torch.enable_grad():
57 | loss = -A(x)
58 | grad_Aparams = torch.autograd.grad(
59 | (loss,), _trainable_parameters(A), grad_outputs=(grad_B,), create_graph=torch.is_grad_enabled(), allow_unused=True
60 | )
61 |
62 | return (None, grad_B, None, *grad_Aparams)
63 |
64 |
65 | def linear_solve(A: torch.nn.Module, b: torch.Tensor, config: LinearSolveConfig = LinearSolveConfig()):
66 | """Solves a linear system of equations with analytic gradient.
67 |
68 | Args:
69 | A (torch.nn.Module): A is a torch.nn.Module object, it should be callable as A(x) for forward operator
70 | of the linear operator.
71 | b (torch.Tensor): b is a tensor representing the right-hand side of the linear system of equations Ax = b.
72 | config (LinearSolveConfig): `config` is an instance of the `LinearSolveConfig` class, which
73 | contains various configuration options for the linear solver. These options include the maximum
74 | number of iterations, the tolerance level for convergence, and the method used to solve the linear system.
75 |
76 | Returns:
77 | The solution of Ax = b.
78 | """
79 | if config.use_analytic_grad:
80 | return LinearSolve.apply(A, b, config, *_trainable_parameters(A))
81 | solver = _build_solver(config)
82 | return solver(A, b)
83 |
84 |
85 | def pcg(A: torch.nn.Module, b: torch.Tensor, rtol: float = 1e-6, max_iters: int = 100, verbose: bool = False, **kwargs):
86 | config = LinearSolveConfig(rtol=rtol, max_iters=max_iters, verbose=verbose, solver_kwargs=kwargs, solver_type="pcg")
87 | return LinearSolve.apply(A, b, config, *_trainable_parameters(A))
88 |
--------------------------------------------------------------------------------
/dprox/linalg/solve/__init__.py:
--------------------------------------------------------------------------------
1 | from .solver_cg import cg, cg2, pcg
2 | from .solver_plss import plss, plssw
3 | from .solver_minres import minres
4 |
5 |
6 | __all__ = available_solvers = [
7 | 'cg',
8 | 'cg2',
9 | 'pcg',
10 | 'plss',
11 | 'plssw',
12 | 'minres',
13 | ]
14 |
15 | SOLVERS = {
16 | 'cg': cg,
17 | 'cg2': cg2,
18 | 'pcg': pcg,
19 | 'plss': plss,
20 | 'plssw': plssw,
21 | 'minres': minres,
22 | }
23 |
--------------------------------------------------------------------------------
/dprox/linop/__init__.py:
--------------------------------------------------------------------------------
1 | from .blackbox import LinOpFactory, BlackBox
2 | from .conv import conv, conv_doe
3 | from .constant import Constant
4 | from .comp_graph import CompGraph, est_CompGraph_norm, eval, adjoint, gram, validate
5 | from .scale import scale
6 | from .subsample import mosaic
7 | from .sum import sum, copy
8 | from .variable import Variable
9 | from .base import LinOp
10 | from .vstack import vstack, split
11 | from .placeholder import Placeholder
12 | from .grad import grad
13 | from .mul import mul_color, mul_elementwise
--------------------------------------------------------------------------------
/dprox/linop/blackbox.py:
--------------------------------------------------------------------------------
1 | from .base import LinOp
2 |
3 |
4 | def LinOpFactory(forward, adjoint, diag=None, norm_bound=None):
5 | """Returns a function to generate a custom LinOp.
6 |
7 | Parameters
8 | ----------
9 | input_shape : tuple
10 | The dimensions of the input.
11 | output_shape : tuple
12 | The dimensions of the output.
13 | forward : function
14 | Applies the operator to an input array and writes to an output.
15 | adjoint : function
16 | Applies the adjoint operator to an input array and writes to an output.
17 | norm_bound : float, optional
18 | An upper bound on the spectral norm of the operator.
19 | """
20 | def get_black_box(*args):
21 | return BlackBox(*args, forward=forward, adjoint=adjoint, diag=diag, norm_bound=norm_bound)
22 | return get_black_box
23 |
24 |
25 | class BlackBox(LinOp):
26 | """A black-box lin op specified by the user.
27 | """
28 |
29 | def __init__(self, *args, forward=None, adjoint=None, diag=None, norm_bound=None):
30 | self._forward = forward
31 | self._adjoint = adjoint
32 | self._norm_bound = norm_bound
33 | self._diag = diag
34 | super(BlackBox, self).__init__(args)
35 |
36 | def forward(self, *inputs):
37 | """The forward operator.
38 |
39 | Reads from inputs and writes to outputs.
40 | """
41 | if len(inputs) == 1:
42 | return self._forward(inputs[0], step=self.step)
43 | return self._forward(*inputs, step=self.step)
44 |
45 | def adjoint(self, *inputs):
46 | """The adjoint operator.
47 |
48 | Reads from inputs and writes to outputs.
49 | """
50 | if len(inputs) == 1:
51 | return self._adjoint(inputs[0], step=self.step)
52 | return self._adjoint(*inputs, step=self.step)
53 |
54 | def norm_bound(self, input_mags):
55 | """Gives an upper bound on the magnitudes of the outputs given inputs.
56 |
57 | Parameters
58 | ----------
59 | input_mags : list
60 | List of magnitudes of inputs.
61 |
62 | Returns
63 | -------
64 | float
65 | Magnitude of outputs.
66 | """
67 | if self._norm_bound is None:
68 | return super(BlackBox, self).norm_bound(input_mags)
69 | else:
70 | return self._norm_bound * input_mags[0]
71 |
72 | def is_gram_diag(self, freq=False):
73 | """Is the lin op's Gram matrix diagonal (in the frequency domain)?
74 | """
75 | return self._diag != None
76 |
77 | def get_diag(self, x, freq=False):
78 | return self._diag(x, self.step)
79 |
--------------------------------------------------------------------------------
/dprox/linop/constaints.py:
--------------------------------------------------------------------------------
1 |
2 | class matmul:
3 | def __init__(self, var, A):
4 | self.A = A
5 | self.var = var
6 |
7 | def __eq__(self, other):
8 | return equality(self, other)
9 |
10 | def __le__(self, other):
11 | return less(self, other)
12 |
13 |
14 | class equality:
15 | def __init__(self, left: matmul, right):
16 | self.left = left
17 | self.right = right
18 |
19 |
20 | class less:
21 | def __init__(self, left: matmul, right):
22 | self.left = left
23 | self.right = right
24 |
--------------------------------------------------------------------------------
/dprox/linop/constant.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from .base import LinOp
5 |
6 |
7 | class Constant(LinOp):
8 | """A constant.
9 | """
10 |
11 | def __init__(self, value):
12 | super(Constant, self).__init__([])
13 | if value is not None and not isinstance(value, torch.Tensor):
14 | value = torch.tensor(value)
15 | self._value = value
16 |
17 | # ---------------------------------------------------------------------------- #
18 | # Computation #
19 | # ---------------------------------------------------------------------------- #
20 |
21 | def forward(self, *value, **kwargs):
22 | """The forward operator.
23 |
24 | Reads from inputs and writes to outputs.
25 | """
26 | return self.value
27 |
28 | def adjoint(self, value, **kwargs):
29 | """The adjoint operator.
30 |
31 | Reads from inputs and writes to outputs.
32 | """
33 | return self.value*0
34 |
35 | # ---------------------------------------------------------------------------- #
36 | # Diagonal #
37 | # ---------------------------------------------------------------------------- #
38 |
39 | def is_diag(self, freq=False):
40 | """Is the lin op diagonal (in the frequency domain)?
41 | """
42 | return True
43 |
44 | def get_diag(self, freq=False):
45 | """Returns the diagonal representation (A^TA)^(1/2).
46 |
47 | Parameters
48 | ----------
49 | freq : bool
50 | Is the diagonal representation in the frequency domain?
51 | Returns
52 | -------
53 | dict of variable to ndarray
54 | The diagonal operator acting on each variable.
55 | """
56 | return {}
57 |
58 | # ---------------------------------------------------------------------------- #
59 | # Property #
60 | # ---------------------------------------------------------------------------- #
61 |
62 | @property
63 | def variables(self):
64 | return []
65 |
66 | @property
67 | def constants(self):
68 | return [self]
69 |
70 | @property
71 | def value(self):
72 | return self._value.to(self.device)
73 |
74 |
75 | def norm_bound(self, input_mags):
76 | """Gives an upper bound on the magnitudes of the outputs given inputs.
77 |
78 | Parameters
79 | ----------
80 | input_mags : list
81 | List of magnitudes of inputs.
82 |
83 | Returns
84 | -------
85 | float
86 | Magnitude of outputs.
87 | """
88 | return 0.0
89 |
90 | # ---------------------------------------------------------------------------- #
91 | # Python Magic #
92 | # ---------------------------------------------------------------------------- #
93 |
94 | def __repr__(self):
95 | if self._value is not None:
96 | return 'Constant(value=somevalue)'
97 | return 'Constant(value=None)'
--------------------------------------------------------------------------------
/dprox/linop/edge.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class Edge(object):
5 | """The edge between two lin ops.
6 | """
7 |
8 | def __init__(self, start, end):
9 | self.start = start
10 | self.end = end
11 | self.data = None
12 | self.mag = None # Used to get norm bounds.
13 |
14 | @property
15 | def size(self):
16 | return np.prod(self.data.shape)
17 |
18 | def __repr__(self) -> str:
19 | return f'Edge(id={id(self)}, start={self.start}, end={self.end}, data={self.data is not None})'
--------------------------------------------------------------------------------
/dprox/linop/grad.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from dprox.utils.misc import to_torch_tensor
4 |
5 | from .conv import conv
6 |
7 |
8 | class grad(conv):
9 | """
10 | gradient operation. can be defined for different dimensions.
11 | default is n-d gradient.
12 | """
13 |
14 | def __init__(self, arg, dim=1):
15 | if dim not in [0,1,2]:
16 | raise ValueError('dim must be 0(Height) or 1(Width) or 2 (Channel)')
17 |
18 | D = to_torch_tensor([1, -1])
19 | for _ in range(3-1):
20 | D = D.unsqueeze(0)
21 | D = D.transpose(dim, -1)
22 |
23 | super(grad, self).__init__(arg, kernel=D)
24 |
25 | def get_dims(self):
26 | """Return the dimensinonality of the gradient
27 | """
28 | return self.dims
29 |
30 | def norm_bound(self, input_mags):
31 | """Gives an upper bound on the magnitudes of the outputs given inputs.
32 |
33 | Parameters
34 | ----------
35 | input_mags : list
36 | List of magnitudes of inputs.
37 |
38 | Returns
39 | -------
40 | float
41 | Magnitude of outputs.
42 | """
43 | # 1D gradient operator has spectral norm = 2.
44 | # ND gradient is permutation of stacked grad in axis 0, axis 1, etc.
45 | # so norm is 2*sqrt(dims)
46 | return 2 * np.sqrt(self.dims) * input_mags[0]
47 |
--------------------------------------------------------------------------------
/dprox/linop/mul.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 |
7 | from dprox.utils.misc import to_torch_tensor
8 |
9 | from .base import LinOp
10 | from .placeholder import Placeholder
11 |
12 |
13 | class mul_color(LinOp):
14 | def __init__(
15 | self,
16 | arg: LinOp,
17 | srf: Union[Placeholder, torch.Tensor, np.array],
18 | ):
19 | super().__init__([arg])
20 | # srf: [C, C_2]
21 | self._srf = srf
22 |
23 | if isinstance(srf, Placeholder):
24 | def on_change(val):
25 | self.srf = nn.parameter.Parameter(val)
26 | self._srf.change(on_change)
27 | else:
28 | self.srf = nn.parameter.Parameter(to_torch_tensor(srf, batch=True))
29 |
30 | def forward(self, x, **kwargs):
31 | return self.apply(x, self.srf)
32 |
33 | def adjoint(self, x, **kwargs):
34 | return self.apply(x, self.srf.T)
35 |
36 | def apply(self, x, srf):
37 | N, C, H, W = x.shape
38 | x = x.reshape(N, C, H * W) # N,C,HW
39 | out = srf.T @ x # N,C2,HW
40 | out = out.reshape(N, -1, H, W)
41 | return out
42 |
43 |
44 | class mul_elementwise(LinOp):
45 | def __init__(
46 | self,
47 | arg: LinOp,
48 | w: Union[Placeholder, torch.Tensor, np.array],
49 | ):
50 | super().__init__([arg])
51 | self._w = w
52 |
53 | if isinstance(w, Placeholder):
54 | def on_change(val):
55 | self.w = nn.parameter.Parameter(val)
56 | self._w.change(on_change)
57 | else:
58 | self.w = nn.parameter.Parameter(to_torch_tensor(w, batch=True))
59 |
60 | def forward(self, x, **kwargs):
61 | w = self.w.to(x.device)
62 | return w * x
63 |
64 | def adjoint(self, x, **kwargs):
65 | return self.forward(x)
66 |
67 | def is_diag(self, freq=False):
68 | return not freq and self.input_nodes[0].is_diag(freq)
69 |
70 | def get_diag(self, x, freq=False):
71 | if not freq:
72 | return self.w.to(x.device)
73 | return None
74 |
--------------------------------------------------------------------------------
/dprox/linop/placeholder.py:
--------------------------------------------------------------------------------
1 | from .constant import Constant
2 |
3 |
4 | class Placeholder(Constant):
5 | def __init__(self, default=None):
6 | super().__init__(default)
7 | self.watchers = []
8 |
9 | @property
10 | def value(self):
11 | return self._value.to(self.device)
12 |
13 | @value.setter
14 | def value(self, val):
15 | """Assign a value to the variable.
16 | """
17 | self._value = val
18 | for watcher in self.watchers:
19 | watcher(val)
20 |
21 | def change(self, fn):
22 | self.watchers.append(fn)
23 |
--------------------------------------------------------------------------------
/dprox/linop/scale.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from .base import LinOp
5 |
6 |
7 | class scale(LinOp):
8 | """Multiplication scale*X with a fixed scalar.
9 | """
10 |
11 | def __init__(self, scalar, arg):
12 | assert np.isscalar(scalar)
13 | self.scalar = scalar
14 | super(scale, self).__init__([arg])
15 |
16 | # ---------------------------------------------------------------------------- #
17 | # Computation #
18 | # ---------------------------------------------------------------------------- #
19 |
20 | def forward(self, input, **kwargs):
21 | """The forward operator.
22 |
23 | Reads from inputs and writes to outputs.
24 | """
25 | return input * self.scalar
26 |
27 | def adjoint(self, input, **kwargs):
28 | """The adjoint operator.
29 |
30 | Reads from inputs and writes to outputs.
31 | """
32 | return self.forward(input)
33 |
34 | # ---------------------------------------------------------------------------- #
35 | # Diagonal #
36 | # ---------------------------------------------------------------------------- #
37 |
38 | def is_gram_diag(self, freq=False):
39 | """Is the lin Gram diagonal (in the frequency domain)?
40 | """
41 | return self.input_nodes[0].is_gram_diag(freq)
42 |
43 | def is_diag(self, freq=False):
44 | """Is the lin op diagonal (in the frequency domain)?
45 | """
46 | return self.input_nodes[0].is_diag(freq)
47 |
48 | def get_diag(self, ref, freq=False):
49 | """Returns the diagonal representation (A^TA)^(1/2).
50 |
51 | Parameters
52 | ----------
53 | freq : bool
54 | Is the diagonal representation in the frequency domain?
55 | Returns
56 | -------
57 | dict of variable to ndarray
58 | The diagonal operator acting on each variable.
59 | """
60 | var_diags = self.input_nodes[0].get_diag(ref, freq) * self.scalar
61 | return var_diags * torch.conj(var_diags)
62 |
63 | # ---------------------------------------------------------------------------- #
64 | # Property #
65 | # ---------------------------------------------------------------------------- #
66 |
67 | def norm_bound(self, input_mags):
68 | """Gives an upper bound on the magnitudes of the outputs given inputs.
69 |
70 | Parameters
71 | ----------
72 | input_mags : list
73 | List of magnitudes of inputs.
74 |
75 | Returns
76 | -------
77 | float
78 | Magnitude of outputs.
79 | """
80 | return abs(self.scalar) * input_mags[0]
81 |
--------------------------------------------------------------------------------
/dprox/linop/subsample.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from dprox.utils.misc import to_nn_parameter, to_torch_tensor
4 |
5 | from .base import LinOp
6 |
7 |
8 | class mosaic(LinOp):
9 |
10 | def __init__(self, arg):
11 | super(mosaic, self).__init__([arg])
12 | self.cache = {}
13 |
14 | # ---------------------------------------------------------------------------- #
15 | # Computation #
16 | # ---------------------------------------------------------------------------- #
17 |
18 | def forward(self, input, **kwargs):
19 | """The forward operator.
20 |
21 | Reads from inputs and writes to outputs.
22 | """
23 | mask = self._mask(input.shape).to(input.device)
24 | return mask * input
25 |
26 | def adjoint(self, input, **kwargs):
27 | """The adjoint operator.
28 |
29 | Reads from inputs and writes to outputs.
30 | """
31 | return self.forward(input)
32 |
33 | @staticmethod
34 | def masks_CFA_Bayer(shape):
35 | pattern = 'RGGB'
36 | channels = dict((channel, np.zeros(shape)) for channel in 'RGB')
37 | for channel, (y, x) in zip(pattern, [(0, 0), (0, 1), (1, 0), (1, 1)]):
38 | channels[channel][y::2, x::2] = 1
39 | return tuple(channels[c].astype(bool) for c in 'RGB')
40 |
41 | def _mask(self, shape):
42 | if shape not in self.cache:
43 | shape = shape[-2:]
44 | R_m, G_m, B_m = self.masks_CFA_Bayer(shape)
45 | mask = np.concatenate((R_m[..., None], G_m[..., None], B_m[..., None]), axis=-1)
46 | self.cache[shape] = to_nn_parameter(to_torch_tensor(mask.astype('float32'), batch=True))
47 | return self.cache[shape]
48 |
49 | # ---------------------------------------------------------------------------- #
50 | # Diagonal #
51 | # ---------------------------------------------------------------------------- #
52 |
53 | def is_gram_diag(self, freq=False):
54 | """Is the lin op's Gram matrix diagonal (in the frequency domain)?
55 | """
56 | return self.is_self_diag(freq) and self.input_nodes[0].is_diag(freq)
57 |
58 | def is_self_diag(self, freq=False):
59 | return not freq
60 |
61 | def get_diag(self, x, freq=False):
62 | """Returns the diagonal representation (A^TA)^(1/2).
63 |
64 | Parameters
65 | ----------
66 | freq : bool
67 | Is the diagonal representation in the frequency domain?
68 | Returns
69 | -------
70 | dict of variable to ndarray
71 | The diagonal operator acting on each variable.
72 | """
73 | assert not freq
74 | # var_diags = self.input_nodes[0].get_diag(freq)
75 | # selection = self.get_selection()
76 | # self_diag = np.zeros(self.input_nodes[0].shape)
77 | # self_diag[selection] = 1
78 | # for var in var_diags.keys():
79 | # var_diags[var] = var_diags[var] * self_diag.ravel()
80 | return self._mask(x.shape).to(self.device)
81 |
82 | # ---------------------------------------------------------------------------- #
83 | # Property #
84 | # ---------------------------------------------------------------------------- #
85 |
86 | def norm_bound(self, input_mags):
87 | """Gives an upper bound on the magnitudes of the outputs given inputs.
88 |
89 | Parameters
90 | ----------
91 | input_mags : list
92 | List of magnitudes of inputs.
93 |
94 | Returns
95 | -------
96 | float
97 | Magnitude of outputs.
98 | """
99 | return input_mags[0]
100 |
--------------------------------------------------------------------------------
/dprox/linop/sum.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .base import LinOp
4 |
5 |
6 | class sum(LinOp):
7 | """Sums its inputs.
8 | """
9 |
10 | def __init__(self, input_nodes):
11 | super(sum, self).__init__(input_nodes)
12 |
13 | def forward(self, *inputs, **kwargs):
14 | """ Just sum all the inputs, all inputs should have the same shape
15 | """
16 | output = torch.zeros_like(inputs[0])
17 | for input in inputs:
18 | output += input.to(output.device)
19 | return output
20 |
21 | def adjoint(self, input, **kwargs):
22 | """ The adjoint of sum spread of the input to all its child
23 | """
24 | outputs = LinOp.MultOutput()
25 | for _ in self.input_nodes:
26 | outputs.append(input)
27 | if len(outputs) > 1:
28 | return outputs
29 | return outputs[0]
30 |
31 | def is_diag(self, freq=False):
32 | """Is the lin op diagonal (in the frequency domain)?
33 | """
34 | return all([arg.is_diag(freq) for arg in self.input_nodes])
35 |
36 | def is_gram_diag(self, freq=False):
37 | """Is the lin op diagonal (in the frequency domain)?
38 | """
39 | return all([arg.is_gram_diag(freq) for arg in self.input_nodes])
40 |
41 | def get_diag(self, ref, freq=False):
42 | """Returns the diagonal representation (A^TA)^(1/2).
43 |
44 | Parameters
45 | ----------
46 | freq : bool
47 | Is the diagonal representation in the frequency domain?
48 | Returns
49 | -------
50 | dict of variable to ndarray
51 | The diagonal operator acting on each variable.
52 | """
53 | # var_diags = {var: torch.zeros(var.size) for var in self.variables()}
54 | # for arg in self.input_nodes:
55 | # arg_diags = arg.get_diag(shape, freq)
56 | # for var, diag in arg_diags.items():
57 | # var_diags[var] = var_diags[var] + diag
58 | # return var_diags.values()[0]
59 | return self.input_nodes[0].get_diag(ref, freq)
60 |
61 | def norm_bound(self, input_mags):
62 | """Gives an upper bound on the magnitudes of the outputs given inputs.
63 |
64 | Parameters
65 | ----------
66 | input_mags : list
67 | List of magnitudes of inputs.
68 |
69 | Returns
70 | -------
71 | float
72 | Magnitude of outputs.
73 | """
74 | return torch.sum(input_mags)
75 |
76 |
77 | class copy(sum):
78 |
79 | def __init__(self, arg):
80 | super(copy, self).__init__([arg])
81 |
82 | def forward(self, inputs, **kwargs):
83 | """The forward operator.
84 |
85 | Reads from inputs and writes to outputs.
86 | """
87 | return super(copy, self).adjoint(inputs, **kwargs)
88 |
89 | def adjoint(self, *inputs, **kwargs):
90 | """The adjoint operator.
91 |
92 | Reads from inputs and writes to outputs.
93 | """
94 | return super(copy, self).forward(*inputs, **kwargs)
95 |
96 | def norm_bound(self, input_mags):
97 | """Gives an upper bound on the magnitudes of the outputs given inputs.
98 |
99 | Parameters
100 | ----------
101 | input_mags : list
102 | List of magnitudes of inputs.
103 |
104 | Returns
105 | -------
106 | float
107 | Magnitude of outputs.
108 | """
109 | return input_mags[0]
110 |
--------------------------------------------------------------------------------
/dprox/linop/variable.py:
--------------------------------------------------------------------------------
1 | import uuid
2 |
3 | import torch
4 |
5 | from .base import LinOp
6 |
7 |
8 | class Variable(LinOp):
9 | """A variable.
10 | """
11 |
12 | def __init__(self, shape=None, value=None, name=None):
13 | super(Variable, self).__init__([])
14 | self.uuid = uuid.uuid1()
15 | self._value = value
16 | self.shape = shape
17 | self.varname = name
18 | self.initval = None
19 |
20 | # ---------------------------------------------------------------------------- #
21 | # Computation #
22 | # ---------------------------------------------------------------------------- #
23 |
24 | def forward(self, inputs, **kwargs):
25 | """The forward operator.
26 |
27 | Reads from inputs and writes to outputs.
28 | """
29 | return inputs
30 |
31 | def adjoint(self, inputs, **kwargs):
32 | """The adjoint operator.
33 |
34 | Reads from inputs and writes to outputs.
35 | """
36 | return inputs
37 |
38 | # ---------------------------------------------------------------------------- #
39 | # Diagonal #
40 | # ---------------------------------------------------------------------------- #
41 |
42 | def is_diag(self, freq=False):
43 | """Is the lin op diagonal (in the frequency domain)?
44 | """
45 | return True
46 |
47 | def get_diag(self, ref, freq=False):
48 | """Returns the diagonal representation (A^TA)^(1/2).
49 |
50 | Parameters
51 | ----------
52 | freq : bool
53 | Is the diagonal representation in the frequency domain?
54 | Returns
55 | -------
56 | dict of variable to ndarray
57 | The diagonal operator acting on each variable.
58 | """
59 | return torch.ones(ref.shape)
60 |
61 |
62 | # ---------------------------------------------------------------------------- #
63 | # Property #
64 | # ---------------------------------------------------------------------------- #
65 |
66 |
67 | @property
68 | def variables(self):
69 | return [self]
70 |
71 | @property
72 | def value(self):
73 | return self._value.to(self.device)
74 |
75 | @value.setter
76 | def value(self, val):
77 | """Assign a value to the variable.
78 | """
79 | self._value = val
80 |
81 | def norm_bound(self, input_mags):
82 | """Gives an upper bound on the magnitudes of the outputs given inputs.
83 |
84 | Parameters
85 | ----------
86 | input_mags : list
87 | List of magnitudes of inputs.
88 |
89 | Returns
90 | -------
91 | float
92 | Magnitude of outputs.
93 | """
94 | return 1.0
95 |
96 | # ---------------------------------------------------------------------------- #
97 | # Python Magic #
98 | # ---------------------------------------------------------------------------- #
99 |
100 | def __repr__(self):
101 | return f'Variable(id={self.uuid}, shape={self.shape}, value={"None" if self._value is None else "somevalue"})'
--------------------------------------------------------------------------------
/dprox/proxfn/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import ProxFn
2 | from .sum_square import sum_squares, weighted_sum_squares, least_squares, ext_sum_squares
3 | from .pnp import deep_prior
4 | from .nonneg import nonneg
5 | from .norm import norm1, norm2
6 | from .fast import *
7 | from .nlm import patch_nlm
8 | from .unrolling import unrolled_prior
--------------------------------------------------------------------------------
/dprox/proxfn/base.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from dprox.linop import LinOp, Placeholder, CompGraph
4 | from dprox.utils import to_torch_tensor
5 | import torch.nn as nn
6 |
7 |
8 | def exists(x):
9 | return x is not None
10 |
11 |
12 | def prox_scaled(prox, alpha):
13 | def _prox(v, lam):
14 | return prox(v, lam * alpha)
15 | return _prox
16 |
17 |
18 | def prox_affine(prox, beta):
19 | def _prox(v, lam):
20 | return 1. / beta * prox(beta * v, beta * beta * lam)
21 | return _prox
22 |
23 |
24 | def prox_translated(prox, b):
25 | def _prox(v, lam):
26 | return prox(v - b, lam) + b
27 | return _prox
28 |
29 |
30 | class ProxFn(nn.Module):
31 | """ The abstract class for the proximal operator.
32 | f(x) = argmin_x f(x) + 1/(2*lam) * ||x-v||_2^2
33 | """
34 |
35 | def __init__(self, linop: LinOp, alpha=1, beta=1):
36 | super().__init__()
37 | self.linop = linop
38 | self.alpha = alpha
39 | self.beta = beta
40 | self.step = 0
41 | self.dag = CompGraph(linop, zero_out_constant=True)
42 |
43 | @property
44 | def offset(self):
45 | return -self.linop.offset
46 |
47 | def unwrap(self, value):
48 | if isinstance(value, Placeholder):
49 | return value.value
50 | return to_torch_tensor(value, batch=True).to(self.linop.device)
51 |
52 | def eval(self, v):
53 | return NotImplementedError
54 |
55 | def prox(self, v, lam):
56 | """ v: [B,C,H,W], lam: [B]
57 | """
58 | if len(lam.shape) == 1: lam = lam.view(lam.shape[0], 1, 1, 1)
59 |
60 | fn = self._prox
61 | fn = prox_scaled(fn, self.alpha)
62 | fn = prox_affine(fn, self.beta)
63 | fn = prox_translated(fn, self.offset)
64 | return fn(v, lam)
65 |
66 | def convex_conjugate_prox(self, v, lam):
67 | # use Moreau’s identity
68 | return v - self.prox(v / lam, lam)
69 |
70 | def _prox(self, v, lam):
71 | return NotImplementedError
72 |
73 | # def grad(self, x):
74 | # x_ = x.detach().requires_grad_(True)
75 | # self.eval(x_).backward()
76 | # return x_.grad
77 |
78 | def __mul__(self, other):
79 | if np.isscalar(other) and other > 0:
80 | self.alpha = other
81 | return self
82 | return TypeError("Can only multiply by a positive scalar.")
83 |
84 | def __rmul__(self, other):
85 | """Called for Number * ProxFn.
86 | """
87 | return self * other
88 |
89 | def __add__(self, other):
90 | """ProxFn + ProxFn(s).
91 | """
92 | if isinstance(other, ProxFn):
93 | return [self, other]
94 | elif type(other) == list:
95 | return [self] + other
96 | else:
97 | return NotImplemented
98 |
99 | def __radd__(self, other):
100 | """Called for list + ProxFn.
101 | """
102 | if type(other) == list:
103 | return other + [self]
104 | else:
105 | return NotImplemented
106 |
107 | def __str__(self):
108 | return f'{self.__class__.__name__}'
109 |
--------------------------------------------------------------------------------
/dprox/proxfn/fast/__init__.py:
--------------------------------------------------------------------------------
1 | from .sr import sisr, misr
2 | from .cs import compress_sensing
3 | from .csmri import csmri
--------------------------------------------------------------------------------
/dprox/proxfn/fast/cs.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from ..sum_square import ext_sum_squares
4 |
5 |
6 | class compress_sensing(ext_sum_squares):
7 | def __init__(self, linop, mask, y):
8 | super().__init__(linop, y)
9 | self.y = self.to_parameter(y)
10 | self.mask = self.to_parameter(mask)
11 |
12 | def _reload(self, shape):
13 | mask = self.mask.value.float()
14 |
15 | self.phi = torch.sum(mask**2, dim=1, keepdim=True)
16 | def A(x): return torch.sum(x*mask, dim=1, keepdim=True)
17 | def At(x): return x*mask
18 |
19 | self.A, self.At = A, At
20 |
21 | def _prox(self, v, lam):
22 | y, A, At, phi = self.y.value, self.A, self.At, self.phi
23 | I = self.I
24 |
25 | rhs = At((I*y-A(v))/(phi+I*lam))
26 | v = (v + rhs)/I
27 | return v
28 |
--------------------------------------------------------------------------------
/dprox/proxfn/fast/csmri.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.fft
3 |
4 | from ..sum_square import ext_sum_squares
5 | from dprox.utils import fft2, ifft2
6 |
7 |
8 | class csmri(ext_sum_squares):
9 | def __init__(self, linop, mask, y):
10 | super().__init__(linop)
11 | self.mask = mask
12 | self.y = y
13 |
14 | def _prox(self, v, lam, num_psi):
15 | if len(lam.shape) == 1:
16 | lam = lam.view(lam.shape[0], 1, 1, 1)
17 | y = self.unwrap(self.y)
18 | mask = self.unwrap(self.mask).bool()
19 |
20 | z = fft2(v)
21 | temp = ((lam * z.clone()) + y) / (1 + lam * num_psi)
22 | z[mask] = temp[mask]
23 | z = ifft2(z)
24 |
25 | return z
26 |
--------------------------------------------------------------------------------
/dprox/proxfn/fast/pr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .. import ProxFn
3 | from ..sum_square import ext_sum_squares
4 |
5 | from dprox.utils.misc import to_torch_tensor
6 |
7 |
8 | class phase_ret(ProxFn):
9 | def __init__(self, linop):
10 | super().__init__(linop)
11 |
12 | def prox(self, v, lam):
13 | # Az = cdp_forward(z, mask) # Az
14 | # y_hat = torch.abs(Az) # |Az|
15 | # meas_err = y_hat - y0
16 | # gradient_forward = torch.stack((meas_err/y_hat*Az[...,0], meas_err/y_hat*Az[...,1]), -1)
17 | # gradient = cdp_backward(gradient_forward, mask)
18 | # z = z - _tau * (gradient + _mu * (z - (x + u)))
19 | pass
20 |
21 |
22 | def cdp_forward(data, mask):
23 | """
24 | Compute the forward model of cdp.
25 |
26 | Args:
27 | data (torch.Tensor): Image_data (batch_size*1*hight*weight). complex
28 | mask (torch.Tensor): mask (batch_size*sampling_rate*hight*weight).complex
29 |
30 | Returns:
31 | forward_data (torch.Tensor): the complex field of forward data (batch_size*sampling_rate*hight*weight) complex
32 | """
33 | sampling_rate = mask.shape[1]
34 | x = data.repeat(1, sampling_rate, 1, 1)
35 | masked_data = x * mask
36 | forward_data = torch.fft.fft2(masked_data, norm='ortho')
37 | return forward_data
38 |
39 |
40 | def cdp_backward(data, mask):
41 | """
42 | Compute the backward model of cdp (the inverse operator of forward model).
43 |
44 | Args:
45 | data (torch.Tensor): Field_data (batch_size*sampling_rate*hight*weight).
46 | mask (torch.Tensor): mask (batch_size*sampling_rate*hight*weight)
47 |
48 | Returns:
49 | backward_data (torch.Tensor): the complex field of backward data (batch_size*1*hight*weight)
50 | """
51 | Ifft_data = torch.fft.ifft2(data, norm='ortho')
52 | backward_data = Ifft_data * torch.conj(mask)
53 | return backward_data.mean(1, keepdim=True)
54 |
--------------------------------------------------------------------------------
/dprox/proxfn/fast/spi.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from ..sum_square import ext_sum_squares
4 |
5 |
6 | class spi(ext_sum_squares):
7 | def __init__(self, linop, K, y):
8 | super().__init__(linop, y)
9 | self.K = self.to_parameter(K)
10 | self.x0 = self.to_parameter(y)
11 |
12 | def _prox(self, v, lam):
13 | assert self.I == torch.ones_like(self.I), \
14 | 'spi ext_sum_square only support I=1'
15 |
16 | K = self.K.value * 10
17 | K1 = self.x0.value * (K ** 2)
18 |
19 | out = spi_inverse(v, K1, K, lam)
20 | return out
21 |
22 |
23 | def kron(a, b):
24 | """
25 | Kronecker product of matrices a and b with leading batch dimensions.
26 | Batch dimensions are broadcast. The number of them mush
27 | :type a: torch.Tensor
28 | :type b: torch.Tensor
29 | :rtype: torch.Tensor
30 | """
31 | siz1 = torch.Size(torch.tensor(a.shape[-2:]) * torch.tensor(b.shape[-2:]))
32 | res = a.unsqueeze(-1).unsqueeze(-3) * b.unsqueeze(-2).unsqueeze(-4)
33 | siz0 = res.shape[:-4]
34 |
35 | return res.reshape(siz0 + siz1)
36 |
37 |
38 | def spi_forward(x, K, alpha, q):
39 | ones = torch.ones(1, 1, K, K).to(x.device)
40 | theta = alpha * kron(x, ones) / (K**2)
41 | y = torch.poisson(theta)
42 | ob = (y >= torch.ones_like(y) * q).float()
43 |
44 | return ob
45 |
46 |
47 | def spi_inverse(ztilde, K1, K, mu):
48 | """
49 | Proximal operator "Prox\_{\frac{1}{\mu} D}" for single photon imaging
50 | assert alpha == K and q == 1
51 | """
52 | z = torch.zeros_like(ztilde)
53 |
54 | K0 = K**2 - K1
55 | indices_0 = (K1 == 0)
56 |
57 | z[indices_0] = ztilde[indices_0] - (K0 / mu)[indices_0]
58 | def func(y): return K1 / (torch.exp(y) - 1) - mu * y - K0 + mu * ztilde
59 |
60 | indices_1 = torch.logical_not(indices_0)
61 |
62 | # differentiable binary search
63 | bmin = 1e-5 * torch.ones_like(ztilde)
64 | bmax = 1.1 * torch.ones_like(ztilde)
65 |
66 | bave = (bmin + bmax) / 2.0
67 |
68 | for i in range(10):
69 | tmp = func(bave)
70 | indices_pos = torch.logical_and(tmp > 0, indices_1)
71 | indices_neg = torch.logical_and(tmp < 0, indices_1)
72 | indices_zero = torch.logical_and(tmp == 0, indices_1)
73 | indices_0 = torch.logical_or(indices_0, indices_zero)
74 | indices_1 = torch.logical_not(indices_0)
75 |
76 | bmin[indices_pos] = bave[indices_pos]
77 | bmax[indices_neg] = bave[indices_neg]
78 | bave[indices_1] = (bmin[indices_1] + bmax[indices_1]) / 2.0
79 |
80 | z[K1 != 0] = bave[K1 != 0]
81 | return torch.clamp(z, 0.0, 1.0)
82 |
--------------------------------------------------------------------------------
/dprox/proxfn/nlm/__init__.py:
--------------------------------------------------------------------------------
1 | from .patch_nlm import patch_nlm
--------------------------------------------------------------------------------
/dprox/proxfn/nlm/patch_nlm.py:
--------------------------------------------------------------------------------
1 | from .nlm import NonLocalMeansFast
2 | from .. import ProxFn
3 |
4 |
5 | class patch_nlm(ProxFn):
6 | def __init__(self, linop):
7 | super().__init__(linop)
8 | self.denoiser = NonLocalMeansFast()
9 |
10 | def _prox(self, v, lam):
11 | lam = lam.sqrt()
12 | out = self.denoiser(v, lam)
13 | return out
14 |
--------------------------------------------------------------------------------
/dprox/proxfn/nonneg.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from . import ProxFn
4 |
5 |
6 | class nonneg(ProxFn):
7 | def __init__(self, linop=None):
8 | super().__init__(linop)
9 |
10 | def _prox(self, v, lam):
11 | return torch.maximum(v, torch.zeros_like(v))
12 |
--------------------------------------------------------------------------------
/dprox/proxfn/norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from . import ProxFn
4 |
5 |
6 | def soft_threshold(v, lam):
7 | """ ref: https://www.tensorflow.org/probability/api_docs/python/tfp/math/soft_threshold
8 |
9 | argmin_x lam * |x|_1 + 0.5 * (x-v)^2
10 | """
11 | return torch.sign(v) * torch.maximum(torch.abs(v) - lam, torch.zeros_like(v))
12 |
13 |
14 | class norm1(ProxFn):
15 | def __init__(self, linop=None):
16 | super().__init__(linop)
17 |
18 | def _prox(self, v, lam):
19 | return soft_threshold(v, lam)
20 |
21 |
22 | class norm2(ProxFn):
23 | def __init__(self, linop=None):
24 | super().__init__(linop)
25 |
26 | def _prox(self, v, lam):
27 | return v / (1 + 2 * lam)
28 |
--------------------------------------------------------------------------------
/dprox/proxfn/pnp/__init__.py:
--------------------------------------------------------------------------------
1 | from .prior import deep_prior
--------------------------------------------------------------------------------
/dprox/proxfn/pnp/denoisers/__init__.py:
--------------------------------------------------------------------------------
1 | from .wrapper import (TVDenoiser,
2 | FFDNet3DDenoiser, FFDNetDenoiser, FFDNetColorDenoiser,
3 | IRCNNDenoiser, DRUNetDenoiser,
4 | QRNN3DDenoiser, GRUNetDenoiser, GRUNetTVDenoiser,
5 | UNetDenoiser)
6 |
7 | from .composite import Augment, DeepTVDenoiser
--------------------------------------------------------------------------------
/dprox/proxfn/pnp/denoisers/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import torch
3 | import torch.nn as nn
4 |
5 | class Denoiser(nn.Module):
6 | def denoise(self, input: torch.Tensor, sigma: torch.Tensor):
7 | """ x: [NCHW] , sigma: a single number tensor"""
8 | sigma = sigma.view(-1, 1, 1, 1)
9 | output = self._denoise(input, sigma)
10 | return output
11 |
12 | @abc.abstractmethod
13 | def _denoise(self, x, sigma):
14 | raise NotImplementedError
15 |
16 |
17 | class Denoiser2D(Denoiser):
18 | def denoise(self, input: torch.Tensor, sigma: torch.Tensor):
19 | """ x: [NCHW] , sigma: a single number tensor"""
20 | sigma = sigma.view(-1, 1, 1, 1)
21 | outs = []
22 | for band in input.split(1, dim=1):
23 | band = self._denoise(band, sigma)
24 | outs.append(band)
25 | return torch.cat(outs, dim=1)
26 |
--------------------------------------------------------------------------------
/dprox/proxfn/pnp/denoisers/composite.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class Augment(nn.Module):
7 | def __init__(self, base_denoiser):
8 | super().__init__()
9 | self.base_denoiser = base_denoiser
10 | self.iter = 0
11 |
12 | def denoise(self, x: torch.Tensor, sigma: torch.Tensor):
13 | iter = self.iter
14 |
15 | x = self.augment(x, iter % 8)
16 |
17 | x = self.base_denoiser.denoise(x, sigma)
18 |
19 | if iter % 8 == 3 or iter % 8 == 5:
20 | x = self.augment(x, 8 - iter % 8)
21 | else:
22 | x = self.augment(x, iter % 8)
23 |
24 | self.iter += 1
25 | return x
26 |
27 | def reset(self):
28 | self.iter = 0
29 |
30 | @staticmethod
31 | def augment(img, mode=0):
32 | if mode == 0:
33 | return img
34 | elif mode == 1:
35 | return img.rot90(1, [2, 3]).flip([2])
36 | elif mode == 2:
37 | return img.flip([2])
38 | elif mode == 3:
39 | return img.rot90(3, [2, 3])
40 | elif mode == 4:
41 | return img.rot90(2, [2, 3]).flip([2])
42 | elif mode == 5:
43 | return img.rot90(1, [2, 3])
44 | elif mode == 6:
45 | return img.rot90(2, [2, 3])
46 | elif mode == 7:
47 | return img.rot90(3, [2, 3]).flip([2])
48 |
49 |
50 | class DeepTVDenoiser:
51 | def __init__(self, deep_denoise, tv_denoising,
52 | deep_hypara_list=[40., 20., 10., 5.], tv_hypara_list=[10, 0.01]):
53 | self.deep_hypara_list = deep_hypara_list
54 | self.tv_hypara_list = tv_hypara_list
55 | self.tv_denoising = tv_denoising
56 | self.deep_denoise = deep_denoise
57 |
58 | def denoise(self, x):
59 | import cvxpy as cp
60 | # x: 1,31,512,512
61 | deep_num = len(self.deep_hypara_list)
62 | tv_num = len(self.tv_hypara_list)
63 | deep_list = [self.deep_denoise(x, torch.tensor(level/255.).to(x.device)) for level in self.deep_hypara_list]
64 | deep_list = [tmp.squeeze().permute(1, 2, 0) for tmp in deep_list]
65 |
66 | tv_list = [self.tv_denoising(x.squeeze().permute(1, 2, 0), level, 5).clamp(0, 1) for level in self.tv_hypara_list]
67 |
68 | ffdnet_mat = np.stack(
69 | [x_ele[:, :, :].cpu().numpy().reshape(-1).astype(np.float64) for x_ele in deep_list],
70 | axis=0)
71 | tv_mat = np.stack(
72 | [x_ele[:, :, :].cpu().numpy().reshape(-1).astype(np.float64) for x_ele in tv_list],
73 | axis=0)
74 | w = cp.Variable(deep_num + tv_num)
75 | P = np.zeros((deep_num + tv_num, deep_num + tv_num))
76 | P[:deep_num, :deep_num] = ffdnet_mat @ ffdnet_mat.T
77 | P[:deep_num, deep_num:] = -ffdnet_mat @ tv_mat.T
78 | P[deep_num:, :deep_num] = -tv_mat @ ffdnet_mat.T
79 | P[deep_num:, deep_num:] = tv_mat @ tv_mat.T
80 | one_vector_ffdnet = np.ones((1, deep_num))
81 | one_vector_tv = np.ones((1, tv_num))
82 | objective = cp.quad_form(w, P)
83 | problem = cp.Problem(
84 | cp.Minimize(objective),
85 | [one_vector_ffdnet @ w[:deep_num] == 1,
86 | one_vector_tv @ w[deep_num:] == 1,
87 | w >= 0])
88 | problem.solve()
89 | w_value = w.value
90 | x_ffdnet, x_tv = 0, 0
91 |
92 | for idx in range(deep_num):
93 | x_ffdnet += w_value[idx] * deep_list[idx]
94 | for idx in range(tv_num):
95 | x_tv += w_value[idx + deep_num] * tv_list[idx]
96 | v = 0.5 * (x_ffdnet + x_tv)
97 | v = v.permute(2, 0, 1).unsqueeze(0)
98 | return v
99 |
--------------------------------------------------------------------------------
/dprox/proxfn/pnp/denoisers/models/TV_denoising.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | def TV_denoising(y0, lamda, iteration=100):
5 | device = y0.device
6 | w, h, b = y0.shape
7 | zh = torch.zeros([w, h-1, b], device=device, dtype=torch.float32)
8 | zv = torch.zeros([w-1, h, b], device=device, dtype=torch.float32)
9 | alpha = 5
10 | for it in range(iteration):
11 | x0h = y0 - dht_3d(zh)
12 | x0v = y0 - dvt_3d(zv)
13 | x0 = (x0h + x0v) / 2
14 | zh = clip(zh + 1/alpha*dh(x0), lamda/2)
15 | zv = clip(zv + 1/alpha*dv(x0), lamda/2)
16 | return x0
17 |
18 | def TV_denoising3d(y0, lamda, iteration=100):
19 | device = y0.device
20 | # z = torch.zeros(y0.shape - [1, 1, 1], device=device, dtype=torch.float32)
21 | w, h, b = y0.shape
22 | zh = torch.zeros([w, h-1, b], device=device, dtype=torch.float32)
23 | zv = torch.zeros([w-1, h, b], device=device, dtype=torch.float32)
24 | zt = torch.zeros([w, h, b-1], device=device, dtype=torch.float32)
25 | alpha = 5
26 | for it in range(iteration):
27 | x0h = y0 - dht_3d(zh)
28 | x0v = y0 - dvt_3d(zv)
29 | x0t = y0 - dtt_3d(zt)
30 | x0 = (x0h + x0v + x0t) / 3
31 | zh = clip(zh + 1/alpha*dh(x0), lamda/2)
32 | zv = clip(zv + 1/alpha*dv(x0), lamda/2)
33 | zt = clip(zt + 1/alpha*dt(x0), lamda/2)
34 | return x0
35 |
36 | def clip(x, thres):
37 | return torch.clamp(x, min=-thres, max=thres)
38 |
39 | def dht_3d(x):
40 | return torch.cat([-x[:,0:1,:], x[:,:-1,:]-x[:,1:,:], x[:,-1:,:]], 1)
41 |
42 | def dvt_3d(x):
43 | return torch.cat([-x[0:1,:,:], x[:-1,:,:]-x[1:,:,:], x[-1:,:,:]], 0)
44 |
45 | def dtt_3d(x):
46 | return torch.cat([-x[:,:,0:1], x[:,:,:-1]-x[:,:,1:], x[:,:,-1:]], 2)
47 |
48 | def dh(x):
49 | return x[:,1:,:]-x[:,:-1,:]
50 |
51 | def dv(x):
52 | return x[1:,:,:] - x[:-1,:,:]
53 |
54 | def dt(x):
55 | return x[:,:,1:] - x[:,:,:-1]
--------------------------------------------------------------------------------
/dprox/proxfn/pnp/denoisers/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/Delta-Prox/15fc489e99c4de2676dc538d9c0552338f6fb894/dprox/proxfn/pnp/denoisers/models/__init__.py
--------------------------------------------------------------------------------
/dprox/proxfn/pnp/denoisers/models/qrnn/__init__.py:
--------------------------------------------------------------------------------
1 | from .qrnn3d import QRNNREDC3D
2 | from .grunet import GRUnet
3 |
4 | """Define commonly used architecture"""
5 |
6 |
7 | def qrnn3d():
8 | net = QRNNREDC3D(1, 16, 5, [1, 3], has_ad=True, bn=True)
9 | net.use_2dconv = False
10 | net.bandwise = False
11 | return net
12 |
13 |
14 | def qrnn2d():
15 | net = QRNNREDC3D(1, 16, 5, [1, 3], has_ad=True, is_2d=True)
16 | net.use_2dconv = False
17 | net.bandwise = False
18 | return net
19 |
20 |
21 | def qrnn3d_masked():
22 | net = QRNNREDC3D(2, 16, 5, [1, 3], has_ad=True)
23 | net.use_2dconv = False
24 | net.bandwise = False
25 | return net
26 |
27 |
28 | def grunet_masked():
29 | return GRUnet(in_ch=2, out_ch=1, use_noise_map=True)
30 |
31 |
32 | def grunet_masked_nobn():
33 | return GRUnet(in_ch=2, out_ch=1, use_noise_map=True, bn=False)
34 |
35 |
36 | def grunet():
37 | return GRUnet(in_ch=1, out_ch=1, use_noise_map=False)
38 |
39 |
40 | def grunet_nobn():
41 | return GRUnet(in_ch=1, out_ch=1, use_noise_map=False, bn=False)
42 |
--------------------------------------------------------------------------------
/dprox/proxfn/pnp/denoisers/models/qrnn/conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional
4 | from .sync_batchnorm import SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
5 |
6 | BatchNorm3d = SynchronizedBatchNorm3d
7 |
8 |
9 | class BNReLUConv3d(nn.Sequential):
10 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False):
11 | super(BNReLUConv3d, self).__init__()
12 | self.add_module('bn', BatchNorm3d(in_channels))
13 | self.add_module('relu', nn.ReLU(inplace=inplace))
14 | self.add_module('conv', nn.Conv3d(in_channels, channels, k, s, p, bias=False))
15 |
16 |
17 | class BNReLUDeConv3d(nn.Sequential):
18 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False):
19 | super(BNReLUDeConv3d, self).__init__()
20 | self.add_module('bn', BatchNorm3d(in_channels))
21 | self.add_module('relu', nn.ReLU(inplace=inplace))
22 | self.add_module('deconv', nn.ConvTranspose3d(in_channels, channels, k, s, p, bias=False))
23 |
24 |
25 | class BNReLUUpsampleConv3d(nn.Sequential):
26 | def __init__(self, in_channels, channels, k=3, s=1, p=1, upsample=(1,2,2), inplace=False):
27 | super(BNReLUUpsampleConv3d, self).__init__()
28 | self.add_module('bn', BatchNorm3d(in_channels))
29 | self.add_module('relu', nn.ReLU(inplace=inplace))
30 | self.add_module('upsample_conv', UpsampleConv3d(in_channels, channels, k, s, p, bias=False, upsample=upsample))
31 |
32 |
33 | class UpsampleConv3d(torch.nn.Module):
34 | """UpsampleConvLayer
35 | Upsamples the input and then does a convolution. This method gives better results
36 | compared to ConvTranspose2d.
37 | ref: http://distill.pub/2016/deconv-checkerboard/
38 | """
39 |
40 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=True, upsample=None):
41 | super(UpsampleConv3d, self).__init__()
42 | self.upsample = upsample
43 | if upsample:
44 | self.upsample_layer = torch.nn.Upsample(scale_factor=upsample, mode='trilinear', align_corners=True)
45 |
46 | self.conv3d = torch.nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
47 |
48 | def forward(self, x):
49 | x_in = x
50 | if self.upsample:
51 | x_in = self.upsample_layer(x_in)
52 | out = self.conv3d(x_in)
53 | return out
54 |
55 |
56 | class BasicConv3d(nn.Sequential):
57 | def __init__(self, in_channels, channels, k=3, s=1, p=1, bias=False, bn=True):
58 | super(BasicConv3d, self).__init__()
59 | if bn:
60 | self.add_module('bn', BatchNorm3d(in_channels))
61 | self.add_module('conv', nn.Conv3d(in_channels, channels, k, s, p, bias=bias))
62 |
63 |
64 | class BasicDeConv3d(nn.Sequential):
65 | def __init__(self, in_channels, channels, k=3, s=1, p=1, bias=False, bn=True):
66 | super(BasicDeConv3d, self).__init__()
67 | if bn:
68 | self.add_module('bn', BatchNorm3d(in_channels))
69 | self.add_module('deconv', nn.ConvTranspose3d(in_channels, channels, k, s, p, bias=bias))
70 |
71 |
72 | class BasicUpsampleConv3d(nn.Sequential):
73 | def __init__(self, in_channels, channels, k=3, s=1, p=1, upsample=(1,2,2), bn=True):
74 | super(BasicUpsampleConv3d, self).__init__()
75 | if bn:
76 | self.add_module('bn', BatchNorm3d(in_channels))
77 | self.add_module('upsample_conv', UpsampleConv3d(in_channels, channels, k, s, p, bias=False, upsample=upsample))
78 |
--------------------------------------------------------------------------------
/dprox/proxfn/pnp/denoisers/models/qrnn/sync_batchnorm/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .replicate import DataParallelWithCallback, patch_replication_callback
13 |
--------------------------------------------------------------------------------
/dprox/proxfn/pnp/denoisers/models/qrnn/sync_batchnorm/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 |
31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32 |
33 | Note that, as all modules are isomorphism, we assign each sub-module with a context
34 | (shared among multiple copies of this module on different devices).
35 | Through this context, different copies can share some information.
36 |
37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38 | of any slave copies.
39 | """
40 | master_copy = modules[0]
41 | nr_modules = len(list(master_copy.modules()))
42 | ctxs = [CallbackContext() for _ in range(nr_modules)]
43 |
44 | for i, module in enumerate(modules):
45 | for j, m in enumerate(module.modules()):
46 | if hasattr(m, '__data_parallel_replicate__'):
47 | m.__data_parallel_replicate__(ctxs[j], i)
48 |
49 |
50 | class DataParallelWithCallback(DataParallel):
51 | """
52 | Data Parallel with a replication callback.
53 |
54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55 | original `replicate` function.
56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57 |
58 | Examples:
59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61 | # sync_bn.__data_parallel_replicate__ will be invoked.
62 | """
63 |
64 | def replicate(self, module, device_ids):
65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66 | execute_replication_callbacks(modules)
67 | return modules
68 |
69 |
70 | def patch_replication_callback(data_parallel):
71 | """
72 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
73 | Useful when you have customized `DataParallel` implementation.
74 |
75 | Examples:
76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78 | > patch_replication_callback(sync_bn)
79 | # this is equivalent to
80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82 | """
83 |
84 | assert isinstance(data_parallel, DataParallel)
85 |
86 | old_replicate = data_parallel.replicate
87 |
88 | @functools.wraps(old_replicate)
89 | def new_replicate(module, device_ids):
90 | modules = old_replicate(module, device_ids)
91 | execute_replication_callbacks(modules)
92 | return modules
93 |
94 | data_parallel.replicate = new_replicate
95 |
--------------------------------------------------------------------------------
/dprox/proxfn/pnp/denoisers/models/qrnn/sync_batchnorm/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 |
13 | import numpy as np
14 | from torch.autograd import Variable
15 |
16 |
17 | def as_numpy(v):
18 | if isinstance(v, Variable):
19 | v = v.data
20 | return v.cpu().numpy()
21 |
22 |
23 | class TorchTestCase(unittest.TestCase):
24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25 | npa, npb = as_numpy(a), as_numpy(b)
26 | self.assertTrue(
27 | np.allclose(npa, npb, atol=atol),
28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
29 | )
30 |
--------------------------------------------------------------------------------
/dprox/proxfn/pnp/denoisers/models/unet/__init__.py:
--------------------------------------------------------------------------------
1 | from .unet import UNet
--------------------------------------------------------------------------------
/dprox/proxfn/pnp/prior.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from dprox.utils import hf, safe_sqrt
7 |
8 | from ..base import ProxFn
9 | from .denoisers import (DRUNetDenoiser, FFDNetColorDenoiser, FFDNetDenoiser,
10 | GRUNetDenoiser, IRCNNDenoiser, UNetDenoiser)
11 | from .denoisers.composite import Augment
12 |
13 |
14 | def get_denoiser(type):
15 | if type == 'ffdnet':
16 | model_path = hf.load_path('pnp_denoisers/ffdnet_gray.pth', repo_type='models')
17 | return FFDNetDenoiser(model_path)
18 | if type == 'ffdnet_color':
19 | model_path = hf.load_path('pnp_denoisers/ffdnet_color.pth', repo_type='models')
20 | return FFDNetColorDenoiser(model_path)
21 | if type == 'drunet_color':
22 | model_path = hf.load_path('pnp_denoisers/drunet_color.pth', repo_type='models')
23 | return DRUNetDenoiser(3, model_path)
24 | if type == 'drunet':
25 | model_path = hf.load_path('pnp_denoisers/drunet_gray.pth', repo_type='models')
26 | return DRUNetDenoiser(1, model_path)
27 | if type == 'ircnn':
28 | model_path = hf.load_path('pnp_denoisers/ircnn_gray.pth', repo_type='models')
29 | return IRCNNDenoiser(1, model_path)
30 | if type == 'grunet':
31 | model_path = hf.load_path('pnp_denoisers/unet_qrnn3d.pth', repo_type='models')
32 | return GRUNetDenoiser(model_path)
33 | if type == 'unet':
34 | model_path = hf.load_path('pnp_denoisers/unet-nm.pt', repo_type='models')
35 | return UNetDenoiser(model_path)
36 |
37 |
38 | def clone(x, nums, shared):
39 | return [x if shared else copy.deepcopy(x) for _ in range(nums)]
40 |
41 |
42 | class deep_prior(ProxFn):
43 | def __init__(self, linop, denoiser='ffdnet', x8=False, clamp=False, trainable=False, unroll_step=None, sqrt=False):
44 | super().__init__(linop)
45 | self.name = denoiser
46 |
47 | if isinstance(denoiser, str):
48 | self.denoiser = get_denoiser(denoiser)
49 | else:
50 | self.denoiser = denoiser
51 |
52 | self.x8 = x8
53 | self.clamp = clamp
54 | self.sqrt = sqrt
55 | if x8:
56 | self.denoiser = Augment(self.denoiser)
57 |
58 | if not trainable:
59 | self.denoiser.eval()
60 | self.denoiser.requires_grad_(False)
61 |
62 | self.unroll = unroll_step is not None
63 | if unroll_step is not None:
64 | self.denoisers = nn.ModuleList(clone(self.denoiser, unroll_step, shared=False))
65 |
66 | def _reload(self, shape=None):
67 | if self.x8:
68 | self.denoiser.reset()
69 |
70 | def eval(self, v):
71 | raise NotImplementedError('deep prior cannot be explictly evaluated')
72 |
73 | def _prox(self, v: torch.Tensor, lam: torch.Tensor):
74 | """ v: [N, C, H, W] or [N, H, W]
75 | lam: [1]
76 | """
77 | sigma = safe_sqrt(lam) if self.sqrt else lam
78 | if self.clamp: v = v.clamp(0, 1)
79 | if torch.is_complex(v): v = v.real
80 | if len(v.shape) == 3: input = v.unsqueeze(1)
81 | else: input = v
82 | if self.unroll: out = self.denoisers[self.step].denoise(input, sigma) # sigma = lam.sqrt()
83 | else: out = self.denoiser.denoise(input, sigma)
84 | out = out.type_as(v)
85 | out = out.reshape(*v.shape)
86 | return out
87 |
88 | def __repr__(self):
89 | return f'deep_prior(denoiser="{self.name}", unroll={self.unroll})'
90 |
--------------------------------------------------------------------------------
/dprox/proxfn/unrolling/__init__.py:
--------------------------------------------------------------------------------
1 | from .prior import unrolled_prior
--------------------------------------------------------------------------------
/dprox/proxfn/unrolling/prior.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .. import ProxFn
4 | from .dgu import Denoiser
5 |
6 |
7 | class unrolled_prior(ProxFn):
8 | def __init__(self, linop, denoiser=None):
9 | super().__init__(linop)
10 | if denoiser is not None:
11 | self.denoiser = denoiser()
12 | else:
13 | self.denoiser = Denoiser()
14 |
15 | def eval(self, v):
16 | raise NotImplementedError('deep prior cannot be explictly evaluated')
17 |
18 | def _prox(self, v: torch.Tensor, lam: torch.Tensor=None):
19 | """ v: [N, C, H, W]
20 | lam: [1]
21 | """
22 | out = self.denoiser(v, self.step)
23 | return out
--------------------------------------------------------------------------------
/dprox/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .misc import to_ndarray, to_torch_tensor, fft2, ifft2, outlier_correct, crop_center_region, safe_sqrt
2 | from .io import imshow, imread_rgb, imread, filter_ckpt, is_image_file, list_image_files
3 | from .metrics import psnr, sam, ssim, mpsnr, mpsnr_max, mssim
4 | from . import huggingface as hf
--------------------------------------------------------------------------------
/dprox/utils/containar.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def is_dp_array(x):
6 | """
7 | check if an object has an attribute 'is_dp_array' with a value of True.
8 |
9 | :param x: The input parameter to the function is x, which is expected to be an object
10 | :return: a boolean value. It checks if the input `x` has an
11 | attribute called `is_dp_array` and if its value is `True`.
12 | """
13 | return hasattr(x, 'is_dp_array') and x.is_dp_array == True
14 |
15 |
16 | def array(*args, **kwargs):
17 | """
18 | create a numpy array and sets an attribute to indicate that it is a dp_array.
19 | :return: a numpy array with an additional attribute `is_dp_array`
20 | set to `True`.
21 | """
22 | out = np.array(*args, **kwargs)
23 | out.is_dp_array = True
24 | return out
25 |
26 |
27 | def is_dp_tensor(x):
28 | """
29 | check if the input object has an attribute 'is_dp_tensor' set to True, indicating that
30 | it is a delta-prox tensor.
31 |
32 | :param x: The input variable that is being checked for whether it is a delta-prox tensor
33 | or not
34 | :return: a boolean value `True` if the input `x` has an
35 | attribute `is_dp_tensor` that is also `True`, and `False` otherwise.
36 | """
37 | return hasattr(x, 'is_dp_tensor') and x.is_dp_tensor == True
38 |
39 |
40 | def tensor(*args, **kwargs):
41 | """
42 | create a PyTorch tensor and set a flag to indicate that it is a
43 | delta-prox tensor.
44 | :return: A PyTorch tensor with an additional attribute `is_dp_tensor` set to `True`.
45 | """
46 | out = torch.tensor(*args, **kwargs)
47 | out.is_dp_tensor = True
48 | return out
49 |
--------------------------------------------------------------------------------
/dprox/utils/huggingface.py:
--------------------------------------------------------------------------------
1 | import os
2 | import urllib.request
3 | import huggingface_hub
4 |
5 | from tqdm import tqdm
6 |
7 | CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache/dprox")
8 |
9 |
10 | class DownloadProgressBar(tqdm):
11 | # The DownloadProgressBar class is a subclass of the tqdm class in Python used to
12 | # display progress bars for downloading files.
13 | def update_to(self, b=1, bsize=1, tsize=None):
14 | if tsize is not None:
15 | self.total = tsize
16 | self.update(b * bsize - self.n)
17 |
18 |
19 | def download_url(url: str, output_path: str) -> None:
20 | """
21 | download a file from a given URL and save it to a specified output path while
22 | displaying a progress bar.
23 |
24 | Args:
25 | url (str): The URL of the file to be downloaded
26 | output_path (str): output_path is a string representing the file path where the downloaded file
27 | will be saved. It should include the file name and extension
28 | """
29 | with DownloadProgressBar(
30 | unit="B", unit_scale=True, miniters=1, desc=url.split("/")[-1]
31 | ) as t:
32 | urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to)
33 |
34 |
35 | def load_path(base_path: str, repo_type="datasets", user_id="delta-prox") -> str:
36 | """
37 | check if a file exists in a specific directory and download it from a URL if it
38 | doesn't exist.
39 |
40 | Args:
41 | base_path (str): The base path is a string that represents the path to a file or directory that the
42 | function is trying to locate or download. It is used to construct the full path to the file or
43 | directory by appending it to the DPROX_DIR path
44 |
45 | Return:
46 | a string which is the path to the file specified by the input parameter `base_path`.
47 | """
48 | if os.path.exists(base_path):
49 | return base_path
50 |
51 | save_path = os.path.join(CACHE_DIR, base_path)
52 | if not os.path.exists(save_path):
53 | base_url = "https://huggingface.co"
54 | if repo_type == "datasets":
55 | base_url += "/" + repo_type
56 | repo_id = base_path.split("/")[0]
57 | path = os.path.join(*(base_path.split("/")[1:]))
58 | url = f"{base_url}/{user_id}/{repo_id}/resolve/main/{path}"
59 | print(f"{base_path} not found")
60 | print("Try to download from huggingface: ", url)
61 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
62 | if "\\" in url: ## fix url for windows
63 | url = url.replace("\\", "/")
64 | download_url(url, save_path)
65 | print("Downloaded to ", save_path)
66 | return save_path
67 |
68 |
69 | def load_image(path, user_id="delta-prox"):
70 | import imageio
71 | import numpy as np
72 |
73 | path = load_path(path, user_id=user_id)
74 | img = imageio.imread(path)
75 | return np.float32(img) / 255
76 |
77 |
78 | def load_checkpoint(path, user_id="delta-prox"):
79 | import torch
80 |
81 | ckpt_path = load_path(path, repo_type="models", user_id=user_id)
82 | return torch.load(ckpt_path)
83 |
84 |
85 | def download_dataset(path, user_id="delta-prox", local_dir=None, force_download=False):
86 | if local_dir is None:
87 | local_dir = os.path.join(CACHE_DIR, path)
88 | if os.path.exists(local_dir) and not force_download:
89 | return local_dir
90 | huggingface_hub.snapshot_download(
91 | repo_id=f"{user_id}/{path}", local_dir=local_dir, repo_type="dataset"
92 | )
93 | return local_dir
94 |
--------------------------------------------------------------------------------
/dprox/utils/init/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/Delta-Prox/15fc489e99c4de2676dc538d9c0552338f6fb894/dprox/utils/init/__init__.py
--------------------------------------------------------------------------------
/dprox/utils/init/sr.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def shift_pixel(x, sf, upper_left=True):
5 | """shift pixel for super-resolution with different scale factors
6 | Args:
7 | x: WxHxC or WxH
8 | sf: scale factor
9 | upper_left: shift direction
10 | """
11 | from scipy.interpolate import interp2d
12 | h, w = x.shape[:2]
13 | shift = (sf-1)*0.5
14 | xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
15 | if upper_left:
16 | x1 = xv + shift
17 | y1 = yv + shift
18 | else:
19 | x1 = xv - shift
20 | y1 = yv - shift
21 |
22 | x1 = np.clip(x1, 0, w-1)
23 | y1 = np.clip(y1, 0, h-1)
24 |
25 | if x.ndim == 2:
26 | x = interp2d(xv, yv, x)(x1, y1)
27 | if x.ndim == 3:
28 | for i in range(x.shape[-1]):
29 | x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
30 |
31 | return x
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | # Welcome to the $\nabla$-Prox examples
2 |
3 | The examples of using $\nabla$-Prox for various applications.
4 |
5 | | Application |
6 | | -- |
7 | | [Computational Optics](papers/deltaprox_siggraph_2023/computional_optics) |
8 | | [Compressive-MRI](papers/deltaprox_siggraph_2023/csmri) |
9 | | [Image Deraining](papers/deltaprox_siggraph_2023/deraining) |
10 | | [Image Deconvolution](applications/deconv.py) |
11 | | [Image Demosaicing](applications/demosaic.py) |
12 | | [Image Super-Resolution](applications/super_resolution.py) |
13 | | [Joint Image Deconvolution and Demosaicing](applications/joint_demosaic_deconv.py) |
14 |
15 |
16 |
17 |
18 | > Contributions are welcome.
--------------------------------------------------------------------------------
/examples/applications/csmri.py:
--------------------------------------------------------------------------------
1 | from dprox import *
2 | from dprox.utils import *
3 | from dprox import contrib
4 |
5 | x0, y0, gt, mask = contrib.csmri.sample()
6 |
7 | x = Variable()
8 | y = Placeholder()
9 | data_term = csmri(x, mask, y)
10 | reg_term = deep_prior(x, denoiser='unet')
11 | prob = Problem(data_term + reg_term)
12 |
13 | y.value = y0
14 | max_iter = 24
15 | rhos, sigmas = log_descent(30, 20, max_iter)
16 | prob.solve(
17 | method='admm',
18 | device='cuda',
19 | x0=x0, rhos=rhos, lams={reg_term: sigmas}, max_iter=max_iter, pbar=True
20 | )
21 | out = x.value.real
22 |
23 | print(psnr(out, gt)) # 43
24 | imshow(out)
--------------------------------------------------------------------------------
/examples/applications/deconv.py:
--------------------------------------------------------------------------------
1 | from dprox import *
2 | from dprox.utils import *
3 | from dprox.contrib import *
4 |
5 | img = sample()
6 | psf = point_spread_function(15, 5)
7 | b = blurring(img, psf)
8 |
9 | x = Variable()
10 | data_term = sum_squares(conv(x, psf) - b)
11 | reg_term = deep_prior(x, denoiser='ffdnet_color')
12 | prob = Problem(data_term + reg_term)
13 | prob.solve(method='admm', x0=b)
14 |
15 | print(psnr(x.value, img)) # 35
16 | imshow(x.value)
--------------------------------------------------------------------------------
/examples/applications/demosaic.py:
--------------------------------------------------------------------------------
1 | from dprox import *
2 | from dprox.utils import *
3 | from dprox.contrib import *
4 |
5 | img = sample('face')
6 | offset = mosaicing(img)
7 |
8 | x = Variable()
9 | data_term = sum_squares(mosaic(x), offset)
10 | reg_term = deep_prior(x, denoiser='ffdnet_color')
11 | prob = Problem(data_term + reg_term, merge=False)
12 |
13 | max_iter = 30
14 | rhos, sigmas = log_descent(35, 5, max_iter, sqrt=True)
15 | out = prob.solve(method='admm', x0=offset, pbar=True,
16 | rhos=rhos, lams={reg_term: sigmas})
17 |
18 | print(psnr(out, img)) # 39
19 | imshow(out)
--------------------------------------------------------------------------------
/examples/applications/joint_demosaic_deconv.py:
--------------------------------------------------------------------------------
1 | from dprox import *
2 | from dprox.utils import *
3 | from dprox.contrib import *
4 |
5 | img = sample()
6 | img = to_torch_tensor(img, batch=True).float()
7 | psf = point_spread_function(15, 5)
8 | offset = blurring(img, psf)
9 | offset = mosaicing(offset)
10 |
11 | x = Variable()
12 | data_term = sum_squares(mosaic(conv(x, psf)), offset)
13 | reg_term = deep_prior(x, denoiser='ffdnet_color')
14 | prob = Problem(data_term + reg_term, absorb=True)
15 |
16 | out = prob.solve(method='admm', x0=offset, max_iter=24, pbar=True)
17 |
18 |
19 | imshow(out)
20 | print(psnr(out, img)) # 25.9
21 |
--------------------------------------------------------------------------------
/examples/applications/super_resolution.py:
--------------------------------------------------------------------------------
1 | from dprox import *
2 | from dprox.utils import *
3 | from dprox.contrib import *
4 |
5 | img = sample('face')
6 | psf = point_spread_function(5, 3)
7 | y, x0 = downsampling(img, psf, sf=2)
8 |
9 | x = Variable()
10 | data_term = sisr(x, y, kernel=psf, sf=2)
11 | reg_term = deep_prior(x, denoiser='ffdnet_color')
12 | prob = Problem(data_term + reg_term)
13 |
14 | max_iter = 24
15 | rhos, sigmas = log_descent(35, 35, max_iter)
16 |
17 | out = prob.solve(method='admm', x0=x0, rhos=rhos, lams={reg_term: sigmas}, max_iter=24, pbar=True)
18 |
19 | print(psnr(out, img)) # 32.9
20 | imshow(out)
--------------------------------------------------------------------------------
/examples/papers/README.md:
--------------------------------------------------------------------------------
1 | # Papers with $\nabla$-Prox
2 |
3 | A list of papers that are implemented / replicated in $\nabla$-Prox.
4 |
5 | > If you want to add your project to the list, contact the maintainers or file a pull request.
6 |
7 | | Paper | Code |
8 | | ---- | --- |
9 | |[∇-Prox: Differentiable Proximal Algorithm Modeling for Large-Scale Optimization](https://dl.acm.org/doi/abs/10.1145/3592144). SIGGRAPH'23 | [Link](deltaprox_siggraph_2023/) |
10 | | [Deep Generalized Unfolding Networks for Image Restoration](https://arxiv.org/abs/2204.13348). CVPR'21 | [Link](deltaprox_siggraph_2023/) |
11 | | [Deep Plug-and-Play Prior for Hyperspectral Image Restoration](http://arxiv.org/abs/2209.08240). Neurocomputing'22 | [Link](dphsir_neurcomputing_2022/) |
12 | | [Plug-and-Play Image Restoration with Deep Denoiser Prior](http://arxiv.org/abs/2209.08240). TPAMI'20 | [Link](dpir_tpami_2020/) |
13 | | [Tuning-free Plug-and-Play Proximal Algorithm for Inverse Imaging Problems](http://arxiv.org/abs/2209.08240). ICML'20 | [Link](tfpnp_icml_2020//) |
--------------------------------------------------------------------------------
/examples/papers/deltaprox_siggraph_2023/computional_optics/README.md:
--------------------------------------------------------------------------------
1 | # End-to-End Computational Optics
2 |
3 | The application of Delta-Prox for end-to-end computational optics.
4 |
5 | ## Prepare Data and Checkpoints
6 |
7 | - Download [BSD500](https://huggingface.co/datasets/delta-prox/BSD500) for training, and [CBSD68](https://huggingface.co/datasets/delta-prox/CBSD68) and [Urban100](https://huggingface.co/datasets/delta-prox/Urban100) for evaluation.
8 | - The pretrained models are hosted at [Huggingface](https://huggingface.co/delta-prox/computational_optics).
9 |
10 | ## Training
11 |
12 | - Training Delta-Prox
13 |
14 | ```bash
15 | python e2e_optics_dprox.py train
16 | ```
17 |
18 | - Training DeepOpticsUnet
19 |
20 | ```bash
21 | python e2e_optics_unet.py train
22 | ```
23 |
24 | ## Evaluation
25 |
26 | - Evaluate Delta-Prox
27 |
28 | ```bash
29 | python e2e_optics_dprox.py test
30 | ```
31 |
32 | - Evaluate DeepOpticsUnet
33 |
34 | ```bash
35 | python e2e_optics_unet.py test
36 | ```
37 |
38 | - Evaluate DPIR
39 |
40 | ```bash
41 | python pnp_optics.py
42 | ```
--------------------------------------------------------------------------------
/examples/papers/deltaprox_siggraph_2023/computional_optics/pnp_optics.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torchlight as tl
5 |
6 | from dprox import *
7 | from dprox import Variable
8 | from dprox.contrib.optic import (DOEModelConfig, build_baseline_profile,
9 | build_doe_model, img_psf_conv,
10 | load_sample_img)
11 | from dprox.linop.conv import conv_doe
12 | from dprox.utils import *
13 |
14 | # -------------------- Define Model --------------------- #
15 |
16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17 | config = DOEModelConfig()
18 | rgb_collim_model = build_doe_model(config).to(device)
19 | circular = config.circular
20 | fresnel_phase_c = build_baseline_profile(rgb_collim_model)
21 |
22 | # -------------------- Define Solver --------------------- #
23 |
24 | x = Variable()
25 | y = Placeholder()
26 | PSF = Placeholder()
27 | data_term = sum_squares(conv_doe(x, PSF, circular=circular), y)
28 | reg_term = deep_prior(x, denoiser='ffdnet_color')
29 | solver = compile(data_term + reg_term, method='admm')
30 | solver.eval()
31 |
32 | # -------------------- Define Forward --------------------- #
33 |
34 | sigma = 7.65 / 255
35 | max_iter = 10
36 | rhos, sigmas = log_descent(49, 7.65, max_iter, sigma=max(0.255 / 255, sigma))
37 |
38 |
39 | def step_fn(gt):
40 | gt = gt.to(device).float()
41 | psf = rgb_collim_model.get_psf(fresnel_phase_c)
42 | inp = img_psf_conv(gt, psf, circular=circular)
43 | inp = inp + torch.randn(*inp.shape, device=inp.device) * sigma
44 | y.value = inp
45 | PSF.value = psf
46 | out = solver.solve(x0=inp,
47 | rhos=rhos,
48 | lams={reg_term: sigmas},
49 | max_iter=max_iter)
50 | return gt, inp, out
51 |
52 | # -------------------- Evalution --------------------- #
53 |
54 |
55 | tl.metrics.set_data_format('chw')
56 |
57 | # datasets = ['McMaster', 'Kodak24']
58 | datasets = ['Urban100']
59 |
60 | for dataset in datasets:
61 | root = 'data/test/' + dataset
62 | logger = tl.logging.Logger('saved/pnp/results' + dataset, name=dataset)
63 | tracker = tl.trainer.util.MetricTracker()
64 |
65 | timer = tl.utils.Timer()
66 | timer.tic()
67 | for idx, name in enumerate(os.listdir(root)):
68 | gt = load_sample_img(os.path.join(root, name))
69 |
70 | torch.manual_seed(idx)
71 | torch.cuda.manual_seed(idx)
72 | gt, inp, pred = step_fn(gt)
73 | pred = pred.clamp(0, 1)
74 |
75 | psnr = tl.metrics.psnr(pred, gt)
76 | ssim = tl.metrics.ssim(pred, gt)
77 |
78 | tracker.update('psnr', psnr)
79 | tracker.update('ssim', ssim)
80 |
81 | logger.info('{} PSNR {} SSIM {}'.format(name, psnr, ssim))
82 | logger.save_img(name, pred)
83 |
84 | logger.info('averge results')
85 | logger.info(tracker.summary())
86 | print(timer.toc() / len(os.listdir(root)))
87 |
--------------------------------------------------------------------------------
/examples/papers/deltaprox_siggraph_2023/csmri/README.md:
--------------------------------------------------------------------------------
1 | # Compressed-MRI
2 |
3 | The application of Delta-Prox for end-to-end compressed-MRI.
4 |
5 | ## Prepare Data and Checkpoints
6 |
7 | You download the testing datasests from huggingface, [MICCAI_2020](https://huggingface.co/datasets/delta-prox/MICCAI_2020), [Medical_7_2020](https://huggingface.co/datasets/delta-prox/Medical_7_2020), [examples](https://huggingface.co/datasets/delta-prox/examples)
8 |
9 | To train the rl solver, you need to download the training dataset, [Image128](https://huggingface.co/datasets/delta-prox/Image128), as well.
10 |
11 | The checkpoints is hosted at [aaronb/CSMRI](https://huggingface.co/aaronb/CSMRI).
12 |
13 | ## Evaluation & training
14 |
15 | All scripts can be directly executed, e.g.,
16 |
17 | ```bash
18 | python deq_tfpnp.py
19 | ```
20 |
21 | For more configuration args. Use
22 | ```bash
23 | python deq_tfpnp.py --help
24 | ```
25 |
26 |
--------------------------------------------------------------------------------
/examples/papers/deltaprox_siggraph_2023/csmri/download_datasets.py:
--------------------------------------------------------------------------------
1 | import dprox.utils.huggingface as hf
2 |
3 | hf.download_dataset('Medical_7_2020', local_dir='data/Medical7_2020')
4 | hf.download_dataset('MICCAI_2020', local_dir='data/MICCAI_2020')
5 |
--------------------------------------------------------------------------------
/examples/papers/deltaprox_siggraph_2023/csmri/pnp_drunet.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from tfpnp.utils.metric import psnr_qrnn3d
5 | from torch.utils.data import DataLoader
6 | from torchlight.logging import Logger
7 |
8 | from dprox import *
9 | from dprox.algo.tune import *
10 | from dprox.utils import *
11 | from dprox.contrib.csmri import CustomADMM, EvalDataset
12 |
13 |
14 | def main():
15 |
16 | x = Variable()
17 | y = Placeholder()
18 | mask = Placeholder()
19 | data_term = csmri(x, mask, y)
20 | reg_term = deep_prior(x, denoiser='drunet')
21 | solver = CustomADMM([reg_term], [data_term]).cuda()
22 |
23 | def step_fn(batch):
24 | y.value = batch['y0'].cuda()
25 | mask.value = batch['mask'].cuda()
26 | target = batch['gt'].cuda()
27 | x0 = batch['x0'].cuda()
28 | max_iter = 24
29 |
30 | # Medical7_2020/radial_128_2/5
31 | # rhos, sigmas = log_descent(50, 40, max_iter)
32 | # rhos, _ = log_descent(125, 0.1, max_iter)
33 |
34 | # Medical7_2020/radial_128_2/10
35 | # rhos, sigmas = log_descent(40, 40, max_iter)
36 | # rhos, _ = log_descent(0.1, 0.1, max_iter)
37 |
38 | # Medical7_2020/radial_128_4/5
39 | rhos, sigmas = log_descent(80, 40, max_iter)
40 | rhos, _ = log_descent(10, 0.1, max_iter)
41 |
42 | # Medical7_2020/radial_128_4/15
43 | # rhos, sigmas = log_descent(80, 55, max_iter)
44 | # rhos, _ = log_descent(1, 0.1, max_iter)
45 |
46 | # MICCAI_2020/radial_128_4/5
47 | # rhos, sigmas = log_descent(60, 40, max_iter)
48 | # rhos, _ = log_descent(4, 0.1, max_iter)
49 |
50 | # MICCAI_2020/radial_128_8/5
51 | # rhos, sigmas = log_descent(60, 40, max_iter)
52 | # rhos, _ = log_descent(1, 0.1, max_iter)
53 |
54 | # rhos, sigmas = log_descent(50, 1, max_iter)
55 | pred = solver.solve(x0=x0, rhos=rhos, lams={reg_term: sigmas}, max_iter=max_iter).real
56 | return target, pred
57 |
58 | valid_datasets = {
59 | # 'Medical7_2020/radial_128_2/5': EvalDataset('data/csmri/Medical7_2020/radial_128_2/5'),
60 | 'Medical7_2020/radial_128_4/5': EvalDataset('data/csmri/Medical7_2020/radial_128_4/5'),
61 | # 'Medical7_2020/radial_128_2/10': EvalDataset('data/csmri/Medical7_2020/radial_128_2/10'),
62 | # 'Medical7_2020/radial_128_8/15': EvalDataset('data/csmri/Medical7_2020/radial_128_8/15'),
63 | # 'Medical7_2020/radial_128_4/15': EvalDataset('data/csmri/Medical7_2020/radial_128_4/15'),
64 | # 'MICCAI_2020/radial_128_4/5': EvalDataset('data/csmri/MICCAI_2020/radial_128_4/5'),
65 | # 'MICCAI_2020/radial_128_8/15': EvalDataset('data/csmri/MICCAI_2020/radial_128_8/15'),
66 | }
67 |
68 | save_root = 'abc/pnp_drunet'
69 |
70 | for name, valid_dataset in valid_datasets.items():
71 | total_psnr = 0
72 | test_loader = DataLoader(valid_dataset)
73 |
74 | save_dir = os.path.join(save_root, name)
75 | os.makedirs(save_dir, exist_ok=True)
76 | logger = Logger(save_dir, name=name)
77 |
78 | for idx, batch in enumerate(test_loader):
79 | with torch.no_grad():
80 | target, pred = step_fn(batch)
81 | psnr = psnr_qrnn3d(target.squeeze(0).cpu().numpy(),
82 | pred.squeeze(0).cpu().numpy(),
83 | data_range=1)
84 | total_psnr += psnr
85 | logger.save_img(f'{idx}.png', pred)
86 |
87 | logger.info('{} avg psnr= {}'.format(name, total_psnr / len(test_loader)))
88 |
89 |
90 | main()
91 |
--------------------------------------------------------------------------------
/examples/papers/deltaprox_siggraph_2023/csmri/pnp_unet.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from tfpnp.utils.metric import psnr_qrnn3d
5 | from torch.utils.data import DataLoader
6 | from torchlight.logging import Logger
7 |
8 | from dprox import *
9 | from dprox.algo.tune import *
10 | from dprox.utils import *
11 | from dprox.contrib.csmri import CustomADMM, EvalDataset
12 |
13 |
14 | def main():
15 | x = Variable()
16 | y = Placeholder()
17 | mask = Placeholder()
18 | data_term = csmri(x, mask, y)
19 | reg_term = deep_prior(x, denoiser='unet')
20 | solver = CustomADMM([reg_term], [data_term]).cuda()
21 |
22 | def step_fn(batch):
23 | y.value = batch['y0'].cuda()
24 | mask.value = batch['mask'].cuda()
25 | target = batch['gt'].cuda()
26 | x0 = batch['x0'].cuda()
27 | max_iter = 24
28 |
29 | # Medical7_2020/radial_128_2/5
30 | rhos, sigmas = log_descent(50, 40, max_iter)
31 | rhos, _ = log_descent(125, 0.1, max_iter)
32 |
33 | # Medical7_2020/radial_128_2/10
34 | rhos, sigmas = log_descent(40, 40, max_iter)
35 | rhos, _ = log_descent(0.1, 0.1, max_iter)
36 |
37 | # Medical7_2020/radial_128_4/5
38 | rhos, sigmas = log_descent(70, 40, max_iter)
39 | rhos, _ = log_descent(120, 0.1, max_iter)
40 |
41 | # Medical7_2020/radial_128_4/15
42 | # rhos, sigmas = log_descent(50, 40, max_iter)
43 | # rhos, _ = log_descent(0.1, 0.1, max_iter)
44 |
45 | # # MICCAI_2020/radial_128_4/5
46 | # rhos, sigmas = log_descent(60, 40, max_iter)
47 | # rhos, _ = log_descent(4, 0.1, max_iter)
48 |
49 | # # MICCAI_2020/radial_128_8/5
50 | # rhos, sigmas = log_descent(60, 40, max_iter)
51 | # rhos, _ = log_descent(1, 0.1, max_iter)
52 |
53 | # rhos, sigmas = log_descent(50, 1, max_iter)
54 | pred = solver.solve(x0=x0, rhos=rhos, lams={reg_term: sigmas}, max_iter=max_iter).real
55 | return target, pred
56 |
57 | valid_datasets = {
58 | 'Medical7_2020/radial_128_4/5': EvalDataset('data/Medical7_2020/radial_128_4/5'),
59 | # 'Medical7_2020/radial_128_4/15': EvalDataset('data/Medical7_2020/radial_128_4/15'),
60 | # 'MICCAI_2020/radial_128_4/5': EvalDataset('data/MICCAI_2020/radial_128_4/5'),
61 | # 'MICCAI_2020/radial_128_8/5': EvalDataset('data/MICCAI_2020/radial_128_8/5'),
62 | }
63 |
64 | save_root = 'abc/pnp_unet'
65 |
66 | for name, valid_dataset in valid_datasets.items():
67 | total_psnr = 0
68 | test_loader = DataLoader(valid_dataset)
69 |
70 | save_dir = os.path.join(save_root, name)
71 | os.makedirs(save_dir, exist_ok=True)
72 | logger = Logger(save_dir, name=name)
73 |
74 | for batch in test_loader:
75 | with torch.no_grad():
76 | target, pred = step_fn(batch)
77 | psnr = psnr_qrnn3d(target.squeeze(0).cpu().numpy(),
78 | pred.squeeze(0).cpu().numpy(),
79 | data_range=1)
80 | total_psnr += psnr
81 | logger.save_img(f'{psnr:0.2f}.png', pred)
82 |
83 | logger.info('{} avg psnr= {}'.format(name, total_psnr / len(test_loader)))
84 |
85 |
86 | main()
87 |
--------------------------------------------------------------------------------
/examples/papers/deltaprox_siggraph_2023/csmri/rl_unet_train.py:
--------------------------------------------------------------------------------
1 | from scipy.io import loadmat
2 | from pathlib import Path
3 | from tfpnp.utils.noise import GaussianModelD
4 |
5 | from dprox import *
6 | from dprox.algo.tune import *
7 | from dprox.utils import *
8 | from dprox.contrib.csmri import (CustomADMM, CustomEnv,
9 | EvalDataset, TrainDataset)
10 |
11 |
12 | def build_solver():
13 | x = Variable()
14 | y = Placeholder()
15 | mask = Placeholder()
16 | data_term = csmri(x, mask, y)
17 | reg_term = deep_prior(x, denoiser='unet')
18 | solver = CustomADMM([reg_term], [data_term])
19 | return solver, {'y': y, 'mask': mask}
20 |
21 |
22 | def main():
23 | solver, placeholders = build_solver()
24 |
25 | # dataset
26 | examples_root = Path(hf.download_dataset('examples', force_download=False))
27 | train_root = Path(hf.download_dataset('Images128', force_download=False))
28 | val_root = Path(hf.download_dataset('Medical_7_2020', force_download=False))
29 |
30 | mask_dir = examples_root / 'csmri/masks'
31 | sigma_ns = [5, 10, 15]
32 | sampling_masks = ['radial_128_2', 'radial_128_4', 'radial_128_8']
33 |
34 | noise_model = GaussianModelD(sigma_ns)
35 | masks = [loadmat(str(mask_dir / f'{sampling_mask}.mat')).get('mask')
36 | for sampling_mask in sampling_masks]
37 | dataset = TrainDataset(train_root, fns=None, masks=masks, noise_model=noise_model)
38 |
39 | valid_datasets = {
40 | 'Medical_7_2020/radial_128_2/15': EvalDataset(val_root / 'radial_128_2/15'),
41 | 'Medical_7_2020/radial_128_4/15': EvalDataset(val_root / 'radial_128_4/15'),
42 | 'Medical_7_2020/radial_128_8/15': EvalDataset(val_root / 'radial_128_8/15'),
43 | }
44 |
45 | # train
46 |
47 | tf_solver = AutoTuneSolver(solver, policy='resnet')
48 |
49 | training_cfg = dict(
50 | rmsize=480,
51 | max_episode_step=30,
52 | train_steps=15000,
53 | warmup=20,
54 | save_freq=1000,
55 | validate_interval=10,
56 | episode_train_times=10,
57 | env_batch=48,
58 | loop_penalty=0.05,
59 | discount=0.99,
60 | lambda_e=0.2,
61 | tau=0.001,
62 | action_pack=1,
63 | log_dir='rl_unet',
64 | custom_env=CustomEnv,
65 | )
66 | tf_solver.train(dataset, valid_datasets, placeholders, **training_cfg)
67 | ckpt_path = 'ckpt/tfpnp_unet/actor_best.pkl'
68 | tf_solver.eval(ckpt_path, valid_datasets, placeholders, custom_env=CustomEnv)
69 |
70 |
71 | if __name__ == '__main__':
72 | main()
73 |
--------------------------------------------------------------------------------
/examples/papers/deltaprox_siggraph_2023/deraining/README.md:
--------------------------------------------------------------------------------
1 | # Image Deraining
2 |
3 | The application of Delta-Prox for image-deraining.
4 |
5 | ## Prepare Data and Checkpoints
6 |
7 | - Download test data from [Google drive](https://drive.google.com/file/d/1P_-RAvltEoEhfT-9GrWRdpEi6NSswTs8/view?usp=sharing
8 | ). Put it into the `datasets/test` folder. For example, use `datasets/test/Rain100H`.
9 |
10 | - Download the checkpoints from [Huggingface](https://huggingface.co/delta-prox/image_deraining).
11 |
12 | ## Evaluation
13 |
14 | - Unrolled Proximal Gradient Descent with shared parameters and restormer as initializer.
15 |
16 | ```bash
17 | python test_unroll_share.py
18 | python evaluate_PSNR_SSIM.py
19 | ```
20 |
21 | - Unrolled Proximal Gradient Descent with unshared parameters
22 |
23 | ```bash
24 | python test_unroll.py
25 | python evaluate_PSNR_SSIM.py
26 | ```
27 |
28 | > To obtain the paper results, please use the matlab script `evaluate_PSNR_SSIM.m`.
29 |
30 | ## Acknowledgement
31 |
32 | - [Restormer](https://github.com/swz30/Restormer)
33 | - [DGUNet](https://github.com/MC-E/Deep-Generalized-Unfolding-Networks-for-Image-Restoration)
--------------------------------------------------------------------------------
/examples/papers/deltaprox_siggraph_2023/deraining/derain/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/Delta-Prox/15fc489e99c4de2676dc538d9c0552338f6fb894/examples/papers/deltaprox_siggraph_2023/deraining/derain/__init__.py
--------------------------------------------------------------------------------
/examples/papers/deltaprox_siggraph_2023/deraining/derain/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .dataset import DataLoaderTrain, DataLoaderVal, DataLoaderTest, DataLoaderTest_fine
3 |
4 | def get_training_data(rgb_dir, img_options):
5 | assert os.path.exists(rgb_dir)
6 | return DataLoaderTrain(rgb_dir, img_options)
7 |
8 | def get_validation_data(rgb_dir, img_options):
9 | assert os.path.exists(rgb_dir)
10 | return DataLoaderVal(rgb_dir, img_options)
11 |
12 | def get_test_data(rgb_dir, img_options):
13 | assert os.path.exists(rgb_dir)
14 | return DataLoaderTest(rgb_dir, img_options)
15 |
16 | def get_test_data_fine(rgb_dir, img_options):
17 | assert os.path.exists(rgb_dir)
18 | return DataLoaderTest_fine(rgb_dir, img_options)
--------------------------------------------------------------------------------
/examples/papers/deltaprox_siggraph_2023/deraining/test_unroll.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import fire
4 |
5 | import imageio
6 | import torch
7 | from torch.utils.data import DataLoader
8 | from tqdm import tqdm
9 |
10 | from dprox import *
11 | from dprox.utils import *
12 |
13 | from derain.data import get_test_data
14 | from derain.unroll import LearnableDegOp, unrolled_prior
15 |
16 |
17 | def build_solver():
18 | # custom linop
19 | A = LearnableDegOp().cuda()
20 | def forward_fn(input, step): return A.forward(input, step)
21 | def adjoint_fn(input, step): return A.adjoint(input, step)
22 | raining = LinOpFactory(forward_fn, adjoint_fn)
23 |
24 | # build solver
25 | x = Variable()
26 | b = Placeholder()
27 | data_term = sum_squares(raining(x), b)
28 | reg_term = unrolled_prior(x)
29 | obj = data_term + reg_term
30 | solver = compile(obj, method='pgd', device='cuda')
31 |
32 | # load parameters
33 | ckpt = torch.load('derain_pdg_unroll.pth')
34 | A.load_state_dict(ckpt['linop'])
35 | reg_term.load_state_dict(ckpt['prior'])
36 | rhos = ckpt['rhos']
37 |
38 | return solver, rhos, b
39 |
40 |
41 | @torch.no_grad()
42 | def main(
43 | dataset='Rain100H',
44 | result_dir='results/dprox_pdg_unroll',
45 | data_dir='datasets/test/',
46 | ):
47 | solver, rhos, b = build_solver()
48 |
49 | rgb_dir_test = os.path.join(data_dir, dataset, 'input')
50 | test_dataset = get_test_data(rgb_dir_test, img_options={})
51 | test_loader = DataLoader(dataset=test_dataset, batch_size=1,
52 | shuffle=False, num_workers=4,
53 | drop_last=False, pin_memory=True)
54 |
55 | result_dir = os.path.join(result_dir, dataset)
56 | os.makedirs(result_dir, exist_ok=True)
57 |
58 | for data in tqdm(test_loader):
59 | input = data[0].cuda()
60 | filenames = data[1]
61 |
62 | b.value = input
63 | output = solver.solve(x0=input, rhos=rhos, max_iter=7)
64 | output = output + input
65 |
66 | output = torch.clamp(output, 0, 1)
67 | output = output.permute(0, 2, 3, 1).cpu().detach().numpy() * 255
68 |
69 | for batch in range(len(output)):
70 | restored_img = output[batch].astype('uint8')
71 | imageio.imsave(
72 | os.path.join(result_dir, filenames[batch] + '.png'),
73 | restored_img
74 | )
75 |
76 |
77 | if __name__ == '__main__':
78 | fire.Fire(main)
79 |
--------------------------------------------------------------------------------
/examples/papers/deltaprox_siggraph_2023/deraining/test_unroll_share.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import fire
4 | import imageio
5 | import torch
6 | from torch.utils.data import DataLoader
7 | from tqdm import tqdm
8 |
9 | from dprox import *
10 | from dprox.utils import *
11 |
12 | from derain.data import get_test_data
13 | from derain.restormer import Restormer
14 | from derain.unroll_share import LearnableDegOp, unrolled_prior
15 |
16 |
17 | def build_solver():
18 | # custom linop
19 | A = LearnableDegOp().cuda()
20 | def forward_fn(input, step): return A.forward(input, step)
21 | def adjoint_fn(input, step): return A.adjoint(input, step)
22 | raining = LinOpFactory(forward_fn, adjoint_fn)
23 |
24 | # build solver
25 | x = Variable()
26 | b = Placeholder()
27 | data_term = sum_squares(raining(x), b)
28 | reg_term = unrolled_prior(x)
29 | obj = data_term + reg_term
30 | solver = compile(obj, method='pgd', device='cuda')
31 |
32 | # load parameters
33 | ckpt = torch.load('derain_pdg_unroll_share.pth')
34 | A.load_state_dict(ckpt['linop'])
35 | reg_term.load_state_dict(ckpt['prior'])
36 | rhos = ckpt['rhos']
37 |
38 | return solver, rhos, b
39 |
40 |
41 | @torch.no_grad()
42 | def main(
43 | dataset='Rain100H',
44 | result_dir='results/dprox_unroll_share',
45 | data_dir='datasets/test/',
46 | ):
47 | solver, rhos, b = build_solver()
48 |
49 | rgb_dir_test = os.path.join(data_dir, dataset, 'input')
50 | test_dataset = get_test_data(rgb_dir_test, img_options={})
51 | test_loader = DataLoader(dataset=test_dataset, batch_size=1,
52 | shuffle=False, num_workers=4,
53 | drop_last=False, pin_memory=True)
54 |
55 | result_dir = os.path.join(result_dir, dataset)
56 | os.makedirs(result_dir, exist_ok=True)
57 |
58 | restormer = Restormer().cuda()
59 | restormer.load_state_dict(torch.load('restormer.pth')['params'])
60 |
61 | for data in tqdm(test_loader):
62 | input = data[0].cuda()
63 | filenames = data[1]
64 |
65 | def init_hook(iter, state, rho, lam):
66 | if iter == 5:
67 | state[0] = restormer(input)
68 |
69 | b.value = input
70 | output = solver.solve(x0=input, rhos=rhos, max_iter=7, callback=init_hook)
71 | output = output + input
72 |
73 | output = torch.clamp(output, 0, 1)
74 | output = output.permute(0, 2, 3, 1).cpu().detach().numpy() * 255
75 |
76 | for batch in range(len(output)):
77 | restored_img = output[batch].astype('uint8')
78 | imageio.imsave(
79 | os.path.join(result_dir, filenames[batch] + '.png'),
80 | restored_img
81 | )
82 |
83 |
84 | if __name__ == '__main__':
85 | fire.Fire(main)
86 |
--------------------------------------------------------------------------------
/examples/papers/dgunet_cvpr_2021/main.py:
--------------------------------------------------------------------------------
1 | from dprox import *
2 | from dprox.utils import *
3 |
4 | from ops import dgu_linop, dgu_prior
5 |
6 |
7 | def test():
8 | b = hf.load_image('examples/derain/derain_input.png')
9 | gt = hf.load_image('examples/derain/derain_target.png')
10 | imshow(gt, b)
11 |
12 | # custom linop
13 | A = dgu_linop().cuda()
14 | def forward_fn(input, step): return A.forward(input, step)
15 | def adjoint_fn(input, step): return A.adjoint(input, step)
16 | raining = LinOpFactory(forward_fn, adjoint_fn)
17 |
18 | # build solver
19 | x = Variable()
20 | data_term = sum_squares(raining(x), b)
21 | reg_term = dgu_prior(x)
22 | obj = data_term + reg_term
23 | solver = compile(obj, method='pgd')
24 |
25 | # load parameters
26 | ckpt = hf.load_checkpoint('image_deraining/derain_pdg.pth')
27 | A.load_state_dict(ckpt['linop'])
28 | reg_term.load_state_dict(ckpt['prior'])
29 | rhos = ckpt['rhos']
30 |
31 | out = solver.solve(x0=b, rhos=rhos, max_iter=7)
32 | out = to_ndarray(out, debatch=True) + b
33 | print(psnr(out, gt)) # 35.92
34 | imshow(gt, out)
35 | assert psnr(out, gt) - 35.92 < 0.1
36 |
37 |
38 | if __name__ == '__main__':
39 | test()
40 |
--------------------------------------------------------------------------------
/examples/papers/dgunet_cvpr_2021/ops.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from dprox import ProxFn
5 | from network import Denoiser, ResBlock, default_conv
6 |
7 |
8 | class dgu_prior(ProxFn):
9 | def __init__(self, linop, denoiser=None):
10 | super().__init__(linop)
11 | if denoiser is not None:
12 | self.denoiser = denoiser()
13 | else:
14 | self.denoiser = Denoiser()
15 |
16 | def eval(self, v):
17 | raise NotImplementedError('deep prior cannot be explictly evaluated')
18 |
19 | def _prox(self, v: torch.Tensor, lam: torch.Tensor = None):
20 | """ v: [N, C, H, W]
21 | lam: [1]
22 | """
23 | out = self.denoiser(v, self.step)
24 | return out
25 |
26 |
27 | class dgu_linop(nn.Module):
28 | def __init__(self, diag=False):
29 | super().__init__()
30 | self.phi_0 = ResBlock(default_conv, 3, 3)
31 | self.phi_1 = ResBlock(default_conv, 3, 3)
32 | self.phi_6 = ResBlock(default_conv, 3, 3)
33 | self.phit_0 = ResBlock(default_conv, 3, 3)
34 | self.phit_1 = ResBlock(default_conv, 3, 3)
35 | self.phit_6 = ResBlock(default_conv, 3, 3)
36 |
37 | if diag:
38 | self.phid_0 = ResBlock(default_conv, 3, 3)
39 | self.phid_1 = ResBlock(default_conv, 3, 3)
40 | self.phid_6 = ResBlock(default_conv, 3, 3)
41 |
42 | self.max_step = 5
43 | self.step = 0
44 |
45 | def forward(self, x, step=None):
46 | if step is None: step = self.step
47 | if step == 0:
48 | return self.phi_0(x)
49 | elif step == self.max_step + 1:
50 | return self.phi_6(x)
51 | else:
52 | return self.phi_1(x)
53 |
54 | def adjoint(self, x, step=None):
55 | if step is None: step = self.step
56 | if step == 0:
57 | return self.phit_0(x)
58 | elif step == self.max_step + 1:
59 | return self.phit_6(x)
60 | else:
61 | return self.phit_1(x)
62 |
63 | def diag(self, x, step=None):
64 | if step is None: step = self.step
65 | if step == 0:
66 | return self.phid_0(x)
67 | elif step == self.max_step + 1:
68 | return self.phid_6(x)
69 | else:
70 | return self.phid_1(x)
71 |
--------------------------------------------------------------------------------
/examples/papers/dphsir_neurcomputing_2022/degrades/__init__.py:
--------------------------------------------------------------------------------
1 | from .general import AffineTransform, PerspectiveTransform, HSI2RGB
2 | from .blur import GaussianBlur, UniformBlur
3 | from .sr import GaussianDownsample, BiCubicDownsample, UniformDownsample
4 | from . import cs
5 | from .noise import GaussianNoise
6 |
--------------------------------------------------------------------------------
/examples/papers/dphsir_neurcomputing_2022/degrades/blur.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import ndimage
3 |
4 | from .utils import fspecial_gaussian
5 |
6 |
7 | class AbstractBlur:
8 | def __init__(self, kernel):
9 | self.kernel = kernel
10 |
11 | def __call__(self, img):
12 | # img_L = np.fft.ifftn(np.fft.fftn(img) * np.fft.fftn(np.expand_dims(self.k, axis=2), img.shape)).real
13 | img_L = ndimage.filters.convolve(img, np.expand_dims(self.kernel, axis=2), mode='wrap')
14 | return img_L
15 |
16 |
17 | class GaussianBlur(AbstractBlur):
18 | def __init__(self, ksize=8, sigma=3):
19 | k = fspecial_gaussian(ksize, sigma)
20 | super().__init__(k)
21 |
22 |
23 | class UniformBlur(AbstractBlur):
24 | def __init__(self, ksize):
25 | k = np.ones((ksize, ksize)) / (ksize*ksize)
26 | super().__init__(k)
27 |
--------------------------------------------------------------------------------
/examples/papers/dphsir_neurcomputing_2022/degrades/cs.py:
--------------------------------------------------------------------------------
1 | import hdf5storage
2 | import os
3 | import numpy as np
4 |
5 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
6 |
7 |
8 | class CASSI(object):
9 | """ Only work when img size = [512, 512, 31]
10 | """
11 | cache = {}
12 |
13 | def __init__(self, size=None):
14 | if size:
15 | size = tuple(size)
16 | if size not in CASSI.cache:
17 | H, W, C = size
18 | m = np.random.choice([0, 1], size=(H,W), p=[1./2, 1./2])
19 | ms = [np.roll(m, i, axis=0) for i in range(C)]
20 | mask = np.stack(ms, axis=2)
21 | CASSI.cache[size] = mask.astype('float32')
22 | self.mask = CASSI.cache[size]
23 | else:
24 | self.mask = hdf5storage.loadmat(os.path.join(CURRENT_DIR, 'kernels', 'cs_mask_cassi.mat'))['mask']
25 |
26 | def __call__(self, img):
27 | return np.sum(img * self.mask, axis=2)
28 |
--------------------------------------------------------------------------------
/examples/papers/dphsir_neurcomputing_2022/degrades/general.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import cv2
4 | import numpy as np
5 | from scipy.io import loadmat
6 |
7 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
8 |
9 |
10 | class AffineTransform(object):
11 | def __call__(self, x):
12 | srcTri = np.array([[0, 0], [x.shape[1] - 1, 0], [0, x.shape[0] - 1]]).astype(np.float32)
13 | dstTri = np.array([[0, x.shape[1]*0.05], [x.shape[1]*0.99, x.shape[0]*0], [x.shape[1]*0.05, x.shape[0]*0.99]]).astype(np.float32)
14 | warp_mat = cv2.getAffineTransform(srcTri, dstTri)
15 | warp_dst = cv2.warpAffine(x, warp_mat, (x.shape[1], x.shape[0]))
16 | return warp_dst
17 |
18 |
19 | class PerspectiveTransform(object):
20 | def __init__(self, shift):
21 | self.shift = shift
22 |
23 | def __call__(self, img):
24 | rows, cols, _ = img.shape
25 | pts1 = np.float32([[0, 0], [rows, 0], [self.shift, cols], [rows-self.shift, cols]])
26 | pts2 = np.float32([[0, 0], [rows, 0], [0, cols], [rows, cols]])
27 |
28 | M = cv2.getPerspectiveTransform(pts1, pts2)
29 | dst = cv2.warpPerspective(img, M, (rows, cols))
30 | return dst
31 |
32 |
33 | class HSI2RGB(object):
34 | def __init__(self, srf=None):
35 | if srf is None:
36 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
37 | self.srf = loadmat(os.path.join(CURRENT_DIR, 'kernels', 'misr_spe_p.mat'))['P'] # (3,31)
38 | else:
39 | self.srf = srf
40 |
41 | def __call__(self, img):
42 | return img @ self.srf.transpose()
43 |
--------------------------------------------------------------------------------
/examples/papers/dphsir_neurcomputing_2022/degrades/inpaint.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class RandomMask:
5 | def __init__(self, ratio=0.2):
6 | """ ratio: ratio to keep
7 | """
8 | self.ratio = ratio
9 |
10 | def __call__(self, img):
11 | mask = (np.random.rand(*img.shape) > (1-self.ratio)).astype('float')
12 | img_L = img * mask
13 | return img_L, mask
14 |
15 |
16 | class StripeMask:
17 | """ Input: [W,H,B] """
18 |
19 | def __init__(self, bandwise=False):
20 | self.bandwise = bandwise
21 |
22 | def __call__(self, img):
23 | mask = np.ones_like(img)
24 | for i in range(0, img.shape[1]-5, 10):
25 | mask[:, i+3] = 0
26 | mask[:, i+4] = 0
27 | mask[:, i+5] = 0
28 | img_L = img * mask
29 | return img_L, mask
30 |
31 |
32 | class RandomStripe:
33 | """ Input: [W,H,B] """
34 |
35 | def __init__(self, num_bands=4, bandwise=True, ratio=0.2):
36 | self.num_bands = num_bands # how many bands will be added stripe noise
37 | self.bandwise = bandwise # if the location of stripe noise are the same
38 | self.ratio = ratio
39 |
40 | def __call__(self, img):
41 | mask = np.ones_like(img)
42 | w, h, b = img.shape
43 | # random select 4 band to add stripe (actually dead line)
44 | # start_band = np.random.choice(b-self.num_bands, replace=False)
45 | for i in range(b):
46 | stripes = np.random.choice(h, int(h*self.ratio), replace=False)
47 | for j in stripes:
48 | mask[:, j:j+4, i] = 0
49 | img_L = img * mask
50 | return img_L, mask
51 |
52 |
53 | class FastHyStripe:
54 | """ Input: [W,H,B] """
55 |
56 | def __init__(self, num_bands=15, bandwise=False):
57 | self.num_bands = num_bands # how many bands will be added stripe noise
58 | self.bandwise = bandwise # if the location of stripe noise are the same
59 |
60 | def __call__(self, img):
61 | import time
62 | np.random.seed(int(time.time()))
63 | mask = np.ones_like(img)
64 | w, h, b = img.shape
65 | # random select 4 band to add stripe (actually dead line)
66 |
67 | start_band = 10
68 | # start_band = 0
69 | for i in range(start_band, start_band+self.num_bands):
70 | stripes = np.random.choice(h, 20, replace=False)
71 | for k, j in enumerate(stripes):
72 | t = np.random.rand()
73 | if k == 4:
74 | mask[:, j:j+30, i] = 0
75 | elif k == 10:
76 | mask[:, j:j+15, i] = 0
77 | elif t > 0.6:
78 | mask[:, j:j+4, i] = 0
79 | else:
80 | mask[:, j:j+2, i] = 0
81 | if self.bandwise:
82 | break
83 | if self.bandwise:
84 | mask[:, :, start_band:start_band+self.num_bands] = np.expand_dims(mask[:, :, start_band], axis=-1)
85 | img_L = img * mask
86 | return img_L, mask
87 |
--------------------------------------------------------------------------------
/examples/papers/dphsir_neurcomputing_2022/degrades/noise.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | class GaussianNoise:
4 | def __init__(self, sigma):
5 | np.random.seed(seed=0) # for reproducibility
6 | self.sigma = sigma
7 |
8 | def __call__(self, img):
9 | img_L = img + np.random.normal(0, self.sigma, img.shape)
10 | return img_L
--------------------------------------------------------------------------------
/examples/papers/dphsir_neurcomputing_2022/degrades/sr.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import abc
4 |
5 | import hdf5storage
6 | import numpy as np
7 |
8 | from .utils import imresize_np
9 | from .blur import AbstractBlur, GaussianBlur, UniformBlur
10 |
11 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
12 |
13 |
14 | class AbstractDownsample(abc.ABC):
15 | def __init__(self, sf, kernel):
16 | self.sf = sf
17 | self.kernel = kernel
18 |
19 |
20 | class ClassicalDownsample(AbstractDownsample):
21 | def __init__(self, sf, blur: AbstractBlur):
22 | super().__init__(sf, blur.kernel)
23 | self.blur = blur
24 |
25 | def __call__(self, img):
26 | """ input: [w,h,c]
27 | data range: both (0,255), (0,1) are ok
28 | """
29 | img = self.blur(img)
30 | img = img[0::self.sf, 0::self.sf, ...]
31 | return img
32 |
33 |
34 | class GaussianDownsample(ClassicalDownsample):
35 | def __init__(self, sf, ksize=8, sigma=3):
36 | blur = GaussianBlur(ksize, sigma)
37 | super().__init__(sf, blur)
38 |
39 |
40 | class UniformDownsample(ClassicalDownsample):
41 | def __init__(self, sf):
42 | blur = UniformBlur(sf)
43 | super().__init__(sf, blur)
44 |
45 |
46 | class BiCubicDownsample(AbstractDownsample):
47 | kernel_path = os.path.join(CURRENT_DIR, 'kernels', 'kernels_bicubicx234.mat')
48 | valid_sfs = [2, 3, 4]
49 |
50 | def __init__(self, sf):
51 | if sf not in self.valid_sfs:
52 | raise ValueError(f'Invalid scale factor, choose from {self.valid_sfs}')
53 | self.sf = sf
54 | self.kernels = hdf5storage.loadmat(self.kernel_path)['kernels']
55 | self.kernel = self.kernels[0, sf-2].astype(np.float64)
56 |
57 | def __call__(self, img):
58 | """ input: [w,h,c]
59 | data range: both (0,255), (0,1) are ok
60 | """
61 | img_L = imresize_np(img, 1/self.sf)
62 | return img_L
63 |
--------------------------------------------------------------------------------
/examples/papers/dphsir_neurcomputing_2022/hsi_compress_sensing.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import numpy as np
3 | from scipy.io import loadmat
4 | import torch
5 |
6 | from dprox import *
7 | from dprox.utils import *
8 |
9 | from degrades.cs import CASSI
10 |
11 | # ------------------------------------- #
12 | # Prepare Data #
13 | # ------------------------------------- #
14 |
15 | I = loadmat('Lehavim_0910-1717.mat')['gt']
16 |
17 | down = CASSI()
18 | mask = down.mask.astype('float32')
19 | y = down(I).astype('float32')
20 |
21 | x0 = np.expand_dims(y, axis=-1) * mask
22 |
23 | print(I.shape, y.shape, x0.shape)
24 | imshow(I[:, :, 20], y, x0[:, :, 20])
25 |
26 | # %
27 | # ----------------------------------- #
28 | # Define and Solve #
29 | # ----------------------------------- #
30 |
31 | x = Variable()
32 | data_term = compress_sensing(x, mask, y)
33 | reg_term = deep_prior(x, denoiser='grunet')
34 |
35 | solver = compile(data_term+reg_term)
36 | solver.to(torch.device('cuda'))
37 |
38 | iter_num = 24
39 | rhos, sigmas = log_descent(50, 45, iter_num)
40 | x_pred = solver.solve(x0,
41 | rhos=rhos,
42 | weights={reg_term: sigmas},
43 | max_iter=iter_num,
44 | pbar=True)
45 |
46 | out = to_ndarray(x_pred, debatch=True)
47 |
48 | print(mpsnr(out, I)) # mpsnr: 39.18
49 | imshow(out[:, :, 20])
50 | # %%
51 |
--------------------------------------------------------------------------------
/examples/papers/dphsir_neurcomputing_2022/hsi_deblur.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import torch
3 | from scipy.io import loadmat
4 |
5 | from dprox import *
6 | from dprox.utils import *
7 |
8 | from degrades.blur import GaussianBlur
9 |
10 | # ------------------------------------- #
11 | # Prepare Data #
12 | # ------------------------------------- #
13 |
14 |
15 | I = loadmat('Lehavim_0910-1717.mat')['gt']
16 |
17 | blur = GaussianBlur()
18 | b = blur(I)
19 |
20 | print(I.shape, b.shape)
21 | imshow(I[:, :, 20], b[:, :, 20])
22 |
23 | # %%
24 | # ----------------------------------- #
25 | # Define and Solve #
26 | # ----------------------------------- #
27 |
28 | x = Variable()
29 | data_term = sum_squares(conv(x, blur.kernel), b)
30 | reg_term = deep_prior(x, denoiser='grunet')
31 |
32 | device = torch.device('cuda')
33 | solver = compile(data_term + reg_term)
34 |
35 | iter_num = 24
36 | rhos, sigmas = log_descent(35, 10, iter_num)
37 | x_pred = solver.solve(to_torch_tensor(b, batch=True).to(device),
38 | rhos=1e-10,
39 | weights={reg_term: sigmas}, # 54.97 if reg_term: 0.23
40 | max_iter=iter_num,
41 | eps=0,
42 | pbar=True)
43 |
44 | out = to_ndarray(x_pred, debatch=True)
45 | print(mpsnr(out, I)) # 54.22/55.10, 78.31 if rho = 1e-10
46 | imshow(out[:, :, 20])
47 |
48 | # two deep prior: 53.00
49 |
50 | # %%
51 |
--------------------------------------------------------------------------------
/examples/papers/dphsir_neurcomputing_2022/hsi_inpainting.py:
--------------------------------------------------------------------------------
1 | # %%
2 | from scipy.io import loadmat
3 |
4 | from dprox import *
5 | from dprox.utils import *
6 |
7 | from degrades.inpaint import FastHyStripe
8 |
9 | # ------------------------------------- #
10 | # Prepare Data #
11 | # ------------------------------------- #
12 |
13 | I = loadmat('Lehavim_0910-1717.mat')['gt']
14 |
15 | degrade = FastHyStripe()
16 | b, mask = degrade(I)
17 | mask = mask.astype('float')
18 |
19 | imshow(I[:,:,20], b[:,:,20])
20 |
21 | #%%
22 | # ----------------------------------- #
23 | # Define and Solve #
24 | # ----------------------------------- #
25 |
26 | S = MulElementwise(mask)
27 |
28 | x = Variable()
29 | data_term = sum_squares(S, b)
30 | deep_reg = deep_prior(x, denoiser='grunet')
31 | problem = Problem(data_term+deep_reg)
32 |
33 | rhos, sigmas = log_descent(5, 4, iter=24, lam=0.6)
34 | x_pred = problem.solve(solver='admm',
35 | x0=b,
36 | weights={deep_reg: sigmas},
37 | rhos=rhos,
38 | max_iter=24,
39 | pbar=True)
40 |
41 | out = to_ndarray(x_pred, debatch=True)
42 |
43 | print(mpsnr(out, I)) # 74.92/74.88
44 | imshow(out[:,:,20])
45 |
46 | # %%
47 |
--------------------------------------------------------------------------------
/examples/papers/dphsir_neurcomputing_2022/hsi_misr.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import cv2
3 | from scipy.io import loadmat
4 | from dprox import *
5 | from dprox.utils import *
6 |
7 | from degrades.sr import GaussianDownsample
8 | from degrades.general import HSI2RGB
9 |
10 | # ------------------------------------- #
11 | # Prepare Data #
12 | # ------------------------------------- #
13 |
14 | I = loadmat('Lehavim_0910-1717.mat')['gt'].astype('float32')
15 |
16 | sf = 2
17 | down = GaussianDownsample(sf=sf)
18 | y = down(I)
19 |
20 | x0 = cv2.resize(y, (y.shape[1] * sf, y.shape[0] * sf), interpolation=cv2.INTER_CUBIC)
21 |
22 | srf = HSI2RGB().srf.T # 31*3
23 |
24 | T = mul_color(srf)
25 | z = to_ndarray(T(to_torch_tensor(I, batch=True)), debatch=True)
26 |
27 | print(I.shape, y.shape, x0.shape, z.shape)
28 | imshow(I[:, :, 20], y[:, :, 20], x0[:, :, 20], z)
29 |
30 | # %%
31 | # ----------------------------------- #
32 | # Define and Solve #
33 | # ----------------------------------- #
34 |
35 |
36 | x = Variable()
37 | data_term = sisr(x, down.kernel, sf, y)
38 | rgb_fidelity = misr(x, z, srf)
39 | reg = deep_prior(x, denoiser='grunet')
40 | problem = Problem(rgb_fidelity + data_term + reg)
41 |
42 | rhos, sigmas = log_descent(35, 10, 1)
43 | x_pred = problem.solve(solver='admm',
44 | x0=x0,
45 | rhos=rhos,
46 | weights={reg: sigmas},
47 | max_iter=1,
48 | pbar=True)
49 |
50 | out = to_ndarray(x_pred, debatch=True)
51 | print(mpsnr(out, I)) # 59.11, 59.74
52 | imshow(out[:, :, 20])
53 | # %%
54 |
--------------------------------------------------------------------------------
/examples/papers/dphsir_neurcomputing_2022/hsi_sisr.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import cv2
3 | import torch
4 | from scipy.io import loadmat
5 |
6 | from dprox import *
7 | from dprox.utils import *
8 |
9 | from degrades import GaussianDownsample
10 |
11 | # ------------------------------------- #
12 | # Prepare Data #
13 | # ------------------------------------- #
14 |
15 | I = loadmat('Lehavim_0910-1717.mat')['gt'].astype('float32')
16 |
17 | sf = 2
18 | down = GaussianDownsample(sf=sf)
19 | y = down(I)
20 |
21 | x0 = cv2.resize(y, (y.shape[1]*sf, y.shape[0]*sf), interpolation=cv2.INTER_CUBIC)
22 |
23 | print(I.shape, y.shape, x0.shape)
24 | imshow(I[:, :, 20], y[:, :, 20], x0[:, :, 20])
25 |
26 |
27 | #%%
28 | # ----------------------------------- #
29 | # Define and Solve #
30 | # ----------------------------------- #
31 |
32 | x = Variable()
33 | data_term = sisr(x, down.kernel, sf, y)
34 | deep_reg = deep_prior(x, denoiser='grunet')
35 |
36 | solver = ADMM([data_term, deep_reg], partition=True)
37 | solver.to(torch.device('cuda'))
38 |
39 | rhos, sigmas = log_descent(35, 10, 24)
40 | x_pred = solver.solve(x0,
41 | rhos=rhos,
42 | weights={deep_reg: sigmas},
43 | max_iter=24,
44 | pbar=True)
45 |
46 | out = to_ndarray(x_pred, debatch=True)
47 | print(mpsnr(out, I)) # 47.9153/47.5494
48 | imshow(out[:, :, 20])
49 | # %%
50 |
--------------------------------------------------------------------------------
/examples/papers/dpir_tpami_2020/rgb_demosaic.py:
--------------------------------------------------------------------------------
1 | # Demosaic Ref
2 | # - https://sporco.readthedocs.io/en/latest/examples/ppp/ppp_admm_dmsc.html
3 | # - https://github.com/cszn/DPIR/blob/master/main_dpir_demosaick.py
4 |
5 | import cv2
6 | import scipy.misc
7 |
8 | from dprox import *
9 | from dprox.utils import *
10 |
11 | from utils import dm_matlab, mosaic_CFA_Bayer, tensor2uint, uint2tensor4
12 |
13 | # Prepare GT and Input
14 | # I = imread_rgb('/media/exthdd/laizeqiang/lzq/projects/hyper-pnp/OpenProx/examples/experiments/rgb/data/Kodak24/kodim02.png')
15 | I = scipy.misc.face()
16 | CFA, CFA4, b, mask = mosaic_CFA_Bayer(I)
17 |
18 | x0 = dm_matlab(uint2tensor4(CFA4))
19 | x0 = tensor2uint(x0)
20 | x0 = cv2.cvtColor(CFA, cv2.COLOR_BAYER_BG2RGB_EA) # essential for drunet, wo 14, w 41.72
21 | # x0 = b
22 |
23 | imshow(b)
24 | I = I.astype('float32') / 255
25 | x0 = x0.astype('float32') / 255
26 | b = b.astype('float32') / 255
27 |
28 | print(b.shape, mask.shape)
29 |
30 | # Define and Solve the Problem
31 | x = Variable()
32 |
33 | data_term = sum_squares(mosaic(x), b)
34 | deep_reg = deep_prior(x, denoiser='ffdnet_color', x8=True)
35 | problem = Problem(data_term + deep_reg)
36 |
37 |
38 | hi = 32
39 | low = 2
40 | iter_num = 40
41 | rhos, sigmas = log_descent(hi, low, iter_num)
42 | x_pred = problem.solve(solver='hqs',
43 | x0=x0,
44 | weights={deep_reg: sigmas},
45 | rhos=rhos,
46 | max_iter=iter_num,
47 | pbar=False)
48 | out = to_ndarray(x_pred, debatch=True).clip(0, 1)
49 | out[mask] = b[mask]
50 | psnr_ = psnr(out, I)
51 |
52 |
53 | # best: 32 2 44.766043261756685
54 | print(psnr_)
55 | imshow(out)
56 |
--------------------------------------------------------------------------------
/notebooks/README.md:
--------------------------------------------------------------------------------
1 | # Welcome to the $\nabla$-Prox tutorials
2 |
3 | Get familiar with the concepts and modules of $\nabla$-Prox. We provide several step-by-step tutorials for the framework and its applications.
4 |
5 | To locally run the notebooks, please follow the [installation tutorial](https://github.com/princeton-computational-imaging/Delta-Prox/tree/main#installation) to install the necessary dependencies. You can also run the tutorial notebooks in Colab.
6 |
7 | ### Framework
8 |
9 | | Tutorial | Colab |
10 | | -- | -- |
11 | | [Learning the Basics](learn_the_basic.ipynb) | [](https://colab.research.google.com/github/princeton-computational-imaging/Delta-Prox/blob/main/notebooks/learn_the_basic.ipynb)|
12 | | [Differentiable Linear Solver](differentiable_linear_solver.ipynb) | [](https://colab.research.google.com/github/princeton-computational-imaging/Delta-Prox/blob/main/notebooks/differentiable_linear_solver.ipynb)|
13 | | [Linear Operator](linear_operator.ipynb) | [](https://colab.research.google.com/github/princeton-computational-imaging/Delta-Prox/blob/main/notebooks/linear_operator.ipynb)|
14 | | [Primitive](primitive.ipynb) | [](https://colab.research.google.com/github/princeton-computational-imaging/Delta-Prox/blob/main/notebooks/primitive.ipynb)|
15 |
16 |
17 |
18 | ### Applications
19 |
20 | | Tutorial | Colab |
21 | | -- | -- |
22 | | [Compressive-MRI](csmri.ipynb) | [](https://colab.research.google.com/github/princeton-computational-imaging/Delta-Prox/blob/main/notebooks/csmri.ipynb) |
23 | | [Computational Optics ](computational_optics.ipynb) | [](https://colab.research.google.com/github/princeton-computational-imaging/Delta-Prox/blob/main/notebooks/computational_optics.ipynb) |
24 | | [Image Deraining](deraining.ipynb) | [](https://colab.research.google.com/github/princeton-computational-imaging/Delta-Prox/blob/main/notebooks/deraining.ipynb) |
25 | | [Energy System Planning](energy_system_planning.ipynb) | [](https://colab.research.google.com/github/princeton-computational-imaging/Delta-Prox/blob/main/notebooks/energy_system_planning.ipynb) |
26 | | [Image Restoration](image_restoration.ipynb) | [](https://colab.research.google.com/github/princeton-computational-imaging/Delta-Prox/blob/main/notebooks/image_restoration.ipynb) |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | cvxpy
3 | torchlights
4 | termcolor
5 | tfpnp
6 | matplotlib
7 | tensorboardX
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | deps = [
4 | 'imageio',
5 | 'scikit_image',
6 | 'matplotlib',
7 | 'munch',
8 | 'tfpnp',
9 | 'cvxpy',
10 | 'torchlights',
11 | 'tensorboardX',
12 | 'termcolor',
13 | 'proximal',
14 | 'opencv-python',
15 | 'huggingface_hub',
16 | ]
17 |
18 | setup(
19 | name='dprox',
20 | description='A domain-specific language and compiler that transforms optimization problems into differentiable proximal solvers.',
21 | url='https://github.com/princeton-computational-imaging/Delta-Prox',
22 | author='Zeqiang Lai',
23 | author_email='laizeqiang@outlook.com',
24 | packages=find_packages(),
25 | version='0.2.1',
26 | include_package_data=True,
27 | install_requires=deps,
28 | )
29 |
--------------------------------------------------------------------------------
/tests/README.md:
--------------------------------------------------------------------------------
1 | # Unit Tests
2 |
3 | ```bash
4 | pip install pytest
5 | ```
6 |
7 | Use `-s` to enable printing to screen during the tests.
8 |
9 | ```bash
10 | # test all
11 | pytest -s tests
12 |
13 | # test one file
14 | pytest -s tests/test_algorithms.py
15 |
16 | # test one function
17 | pytest -s tests/test_algorithms.py::test_admm
18 | ```
--------------------------------------------------------------------------------
/tests/linalg/test_linear_solver.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import dprox as dp
3 | import numpy as np
4 | import dprox.utils
5 |
6 | dprox.utils.misc.seed_everything(2023)
7 |
8 | P = np.random.rand(5,5)
9 | I = np.eye(5)
10 | mu = 0.01
11 | A = P.T @ P + mu * I
12 | x = np.random.rand(5)
13 | offset = A @ x
14 | rtol = 1e-8
15 |
16 |
17 | class MatrixLinOp(dp.LinOp):
18 | def __init__(self, A):
19 | super().__init__()
20 | self.A = torch.nn.parameter.Parameter(A)
21 |
22 | def forward(self, x):
23 | return self.A @ x
24 |
25 | def adjoint(self, x):
26 | return self.A.T @ x
27 |
28 |
29 | def test_gmres_scipy():
30 | from scipy.sparse.linalg import gmres, LinearOperator
31 | A_op = LinearOperator(shape=(5,5), matvec=lambda b: A@b)
32 |
33 | xhat, _ = gmres(A_op, offset)
34 |
35 | print('gmres scipy')
36 | print(x)
37 | print(xhat)
38 |
39 | print(np.mean(np.abs(xhat-x)))
40 | assert np.allclose(xhat, x, rtol=rtol)
41 |
42 |
43 | def test_cg_scipy():
44 | from scipy.sparse.linalg import cg, LinearOperator
45 | A_op = LinearOperator(shape=(5,5), matvec=lambda b: A@b)
46 |
47 | xhat, _ = cg(A_op, offset)
48 |
49 | print('cg scipy')
50 | print(x)
51 | print(xhat)
52 |
53 | print(np.mean(np.abs(xhat-x)))
54 | assert np.allclose(xhat, x, rtol=rtol)
55 |
56 |
57 | def test_cg():
58 | A2 = torch.from_numpy(A)
59 | K = lambda x: A2@x
60 | x2 = torch.from_numpy(x)
61 | b2 = torch.from_numpy(offset)
62 |
63 | xhat1 = dp.linalg.solve.cg(K, b2)
64 | xhat2 = dp.linalg.solve.cg2(K, b2)
65 | xhat3 = dp.linalg.solve.pcg(K, b2)
66 |
67 | print('conjugate_gradient')
68 | print(torch.mean(torch.abs(xhat1-x2)).item())
69 | print(xhat1.numpy())
70 | assert torch.allclose(xhat1, x2, rtol=rtol)
71 |
72 | print('conjugate_gradient2')
73 | print(torch.mean(torch.abs(xhat2-x2)).item())
74 | print(xhat2.numpy())
75 | # assert torch.allclose(xhat2, x2, rtol=rtol)
76 |
77 | print('PCG')
78 | print(torch.mean(torch.abs(xhat3-x2)).item())
79 | print(xhat3.numpy())
80 | assert torch.allclose(xhat3, x2, rtol=rtol)
81 |
82 |
83 | def test_plss():
84 | A2 = torch.from_numpy(A)
85 | # K = lambda x: A2@x
86 | K = MatrixLinOp(A2)
87 | x2 = torch.from_numpy(x)
88 | b2 = torch.from_numpy(offset)
89 |
90 | xhat1 = dp.linalg.solve.plss(K, b2)
91 |
92 | print('PLSS')
93 | print(torch.mean(torch.abs(xhat1-x2)).item())
94 | print(xhat1.detach().numpy())
95 | assert torch.allclose(xhat1, x2, rtol=rtol)
96 |
97 |
98 | def test_minres():
99 | A2 = torch.from_numpy(A)
100 | # K = lambda x: A2@x
101 | K = MatrixLinOp(A2)
102 | x2 = torch.from_numpy(x)
103 | b2 = torch.from_numpy(offset)
104 |
105 | with torch.no_grad():
106 | xhat1 = dp.linalg.solve.minres(K, b2)
107 |
108 | print('MINRES')
109 | print(torch.mean(torch.abs(xhat1-x2)).item())
110 | print(xhat1.detach().numpy())
111 | assert torch.allclose(xhat1, x2, rtol=rtol)
112 |
--------------------------------------------------------------------------------
/tests/linalg/test_linear_solver_batch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import dprox as dp
3 | import numpy as np
4 |
5 |
6 | P = np.random.rand(5, 5)
7 | I = np.eye(5)
8 | mu = 0.01
9 | A = P.T @ P + mu * I
10 | x = np.random.rand(5)
11 | offset = A @ x
12 | rtol = 1e-8
13 |
14 |
15 | class MatrixLinOp(dp.LinOp):
16 | def __init__(self, A):
17 | super().__init__()
18 | self.A = torch.nn.parameter.Parameter(A)
19 |
20 | def forward(self, x):
21 | return self.A @ x
22 |
23 | def adjoint(self, x):
24 | return self.A.T @ x
25 |
26 |
27 | def test_cg():
28 | A2 = torch.from_numpy(A)
29 | def K(x): return A2 @ x
30 | x2 = torch.from_numpy(x)
31 | b2 = torch.from_numpy(offset)
32 |
33 | xhat1 = dp.linalg.solve.cg(K, b2)
34 | xhat3 = dp.linalg.solve.pcg(K, b2)
35 |
36 | print('conjugate_gradient')
37 | print(torch.mean(torch.abs(xhat1 - x2)).item())
38 | print(xhat1.numpy())
39 | assert torch.allclose(xhat1, x2, rtol=rtol)
40 |
41 | # print('conjugate_gradient2')
42 | # print(torch.mean(torch.abs(xhat2-x2)).item())
43 | # print(xhat2.numpy())
44 | # assert torch.allclose(xhat2, x2, rtol=rtol)
45 |
46 | print('PCG')
47 | print(torch.mean(torch.abs(xhat3 - x2)).item())
48 | print(xhat3.numpy())
49 | assert torch.allclose(xhat3, x2, rtol=rtol)
50 |
--------------------------------------------------------------------------------
/tests/linalg/test_linear_solver_grad.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from functools import partial
3 |
4 |
5 | atol, rtol = 1e-6, 1e-3
6 | allclose = partial(torch.allclose, atol=atol, rtol=rtol)
7 |
8 |
9 | def tol(input, other):
10 | tmp = torch.abs(input-other) - atol - rtol * torch.abs(other)
11 | return torch.max(tmp).item()
12 |
13 |
14 | def auto_diff(seed):
15 | torch.manual_seed(seed)
16 | theta = torch.randn((32,32), requires_grad=True)
17 | K = theta * 2
18 | x = torch.randn((32))
19 | b = K @ x
20 | b = b.clone().detach().requires_grad_(True)
21 |
22 | xhat = torch.linalg.solve(K, b)
23 |
24 | loss = xhat.mean()
25 | loss.backward()
26 |
27 | return theta.grad, b.grad
28 |
29 |
30 | def differentiate_dloss_db(seed):
31 | torch.manual_seed(seed)
32 | theta = torch.randn((32,32), requires_grad=True)
33 | K = theta * 2
34 | x = torch.randn((32))
35 | b = K @ x
36 | b = b.clone().detach().requires_grad_(True)
37 |
38 | xhat = torch.linalg.solve(K, b)
39 | xhat.retain_grad()
40 |
41 | loss = xhat.mean()
42 | loss.backward()
43 |
44 | db_matrix = torch.inverse(K.T) @ xhat.grad
45 | db_matrix_free = torch.linalg.solve(K.T, xhat.grad)
46 |
47 | return db_matrix, db_matrix_free
48 |
49 |
50 | def matrix_differentiate_dloss_dtheta(seed):
51 | torch.manual_seed(seed)
52 | theta = torch.randn((32,32), requires_grad=True)
53 | K = theta * 2
54 | x = torch.randn((32))
55 | b = K @ x
56 | b = b.clone().detach().requires_grad_(True)
57 |
58 | xhat = torch.linalg.solve(K, b)
59 | xhat.retain_grad()
60 |
61 | loss = xhat.mean()
62 | loss.backward()
63 |
64 |
65 | def Kmat(theta):
66 | return theta * 2
67 |
68 | dK_dtheta = torch.autograd.functional.jacobian(Kmat, theta)
69 | dxhat_dtheta = -torch.inverse(K) @ dK_dtheta @ xhat
70 |
71 | dloss_dtheta = dxhat_dtheta @ xhat.grad
72 |
73 | return dloss_dtheta
74 |
75 |
76 | def matrix_free_differentiate_dloss_dtheta(seed):
77 | torch.manual_seed(seed)
78 | theta = torch.randn((32,32), requires_grad=True)
79 | K = theta * 2
80 | x = torch.randn((32))
81 | b = K @ x
82 | b = b.clone().detach().requires_grad_(True)
83 |
84 | xhat = torch.linalg.solve(K, b)
85 | xhat.retain_grad()
86 |
87 | loss = xhat.mean()
88 | loss.backward()
89 |
90 | def linop(theta):
91 | return theta*2 @ xhat
92 |
93 | db_dtheta = -torch.autograd.functional.jacobian(linop, theta).permute(1,2,0).unsqueeze(-1)
94 | dxhat_dtheta = torch.linalg.solve(K, db_dtheta).squeeze(-1)
95 |
96 | dloss_dtheta = dxhat_dtheta @ xhat.grad
97 |
98 | return dloss_dtheta
99 |
100 |
101 | def test_db():
102 | for seed in range(20):
103 | _, db_ref = auto_diff(seed)
104 | db_matrix, db_matrix_free = differentiate_dloss_db(seed)
105 | print('b', seed, torch.mean(torch.abs(db_ref- db_matrix)))
106 | print('b free', seed, torch.mean(torch.abs(db_ref- db_matrix_free)))
107 | assert allclose(db_matrix, db_ref)
108 | assert allclose(db_matrix_free, db_ref)
109 |
110 |
111 | def test_theta():
112 | for seed in range(20):
113 | dtheta_ref, _ = auto_diff(seed)
114 | dtheta_matrix = matrix_differentiate_dloss_dtheta(seed)
115 | dtheta_matrix_free = matrix_free_differentiate_dloss_dtheta(seed)
116 |
117 | print('theta', seed, tol(dtheta_matrix, dtheta_ref))
118 | print('theta free', seed, tol(dtheta_matrix_free, dtheta_ref),
119 | dtheta_ref.abs().max().item(),
120 | (dtheta_matrix_free-dtheta_ref).abs().max().item())
121 |
122 | assert allclose(dtheta_matrix, dtheta_ref)
123 | assert allclose(dtheta_matrix_free, dtheta_ref)
124 |
125 |
--------------------------------------------------------------------------------
/tests/linalg/test_pcg.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/Delta-Prox/15fc489e99c4de2676dc538d9c0552338f6fb894/tests/linalg/test_pcg.py
--------------------------------------------------------------------------------
/tests/paper/test_derain.py:
--------------------------------------------------------------------------------
1 | # def run_derain():
2 | # pass
3 |
4 |
5 | # def test_Rain100H():
6 | # psnr, ssim = run_derain()
7 | # assert abs(psnr - 31.08) < 0.01
8 | # assert abs(ssim - 0.897) < 0.001
9 |
10 |
11 | # def test_Rain100H_with_init():
12 | # psnr, ssim = run_derain()
13 | # assert abs(psnr - 31.62) < 0.01
14 | # assert abs(ssim - 0.905) < 0.001
15 |
16 |
17 | # def test_Test1200():
18 | # psnr, ssim = run_derain()
19 | # assert abs(psnr - 32.95) < 0.01
20 | # assert abs(ssim - 0.913) < 0.001
21 |
22 |
23 | # def test_Test1200_with_init():
24 | # psnr, ssim = run_derain()
25 | # assert abs(psnr - 33.25) < 0.01
26 | # assert abs(ssim - 0.926) < 0.001
27 |
--------------------------------------------------------------------------------
/tests/paper/test_energy.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/Delta-Prox/15fc489e99c4de2676dc538d9c0552338f6fb894/tests/paper/test_energy.py
--------------------------------------------------------------------------------
/tests/paper/test_optics.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torchlight as tl
6 | import torchlight.nn as tlnn
7 |
8 | from dprox import *
9 | from dprox import Variable
10 | from dprox.linop.conv import conv_doe
11 | from dprox.utils import *
12 | from dprox.contrib.optic.utils import load_sample_img
13 | from dprox.contrib.optic.doe_model import img_psf_conv, build_doe_model, DOEModelConfig
14 |
15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16 |
17 |
18 | @torch.no_grad()
19 | def run_optics(dataset):
20 | config = DOEModelConfig()
21 | rgb_collim_model = build_doe_model().to(device)
22 | circular = config.circular
23 |
24 | # -------------------- Define Model --------------------- #
25 | x = Variable()
26 | y = Placeholder()
27 | PSF = Placeholder()
28 | data_term = sum_squares(conv_doe(x, PSF, circular=circular), y)
29 | reg_term = deep_prior(x, denoiser='ffdnet_color')
30 | solver = compile(data_term + reg_term, method='admm')
31 | solver.eval()
32 |
33 | # ---------------- Setup Hyperparameter ----------------- #
34 | max_iter = 10
35 | sigma = 7.65 / 255.
36 | rhos, sigmas = log_descent(49, 7.65, max_iter, sigma=max(0.255 / 255, sigma))
37 | rhos = torch.tensor(rhos, device=device).float()
38 | sigmas = torch.tensor(sigmas, device=device).float()
39 | rgb_collim_model.rhos = nn.parameter.Parameter(rhos)
40 | rgb_collim_model.sigmas = nn.parameter.Parameter(sigmas)
41 |
42 | # ---------------- Forward Model ------------------------ #
43 | def step_fn(gt):
44 | gt = gt.to(device).float()
45 | psf = rgb_collim_model.get_psf()
46 | inp = img_psf_conv(gt, psf, circular=circular)
47 | inp = inp + torch.randn(*inp.shape, device=inp.device) * sigma
48 | y.value = inp
49 | PSF.value = psf
50 |
51 | timer = tl.utils.Timer()
52 | timer.tic()
53 |
54 | out = solver.solve(x0=inp,
55 | rhos=rgb_collim_model.rhos,
56 | lams={reg_term: rgb_collim_model.sigmas.sqrt()},
57 | max_iter=max_iter)
58 |
59 | return gt, inp, out
60 |
61 | print('Model size (M)', tlnn.benchmark.model_size(rgb_collim_model))
62 | print('Trainable model size (M)', tlnn.benchmark.trainable_model_size(rgb_collim_model))
63 |
64 | # --------------- Load and Run ------------------ #
65 | ckpt = hf.load_checkpoint('computational_optics/joint_dprox.pth')
66 | rgb_collim_model.load_state_dict(ckpt['model'])
67 |
68 | tl.metrics.set_data_format('chw')
69 |
70 | root = dataset
71 | root = hf.download_dataset(root)
72 | tracker = tl.trainer.util.MetricTracker()
73 |
74 | timer = tl.utils.Timer()
75 | timer.tic()
76 | for idx, name in enumerate(list_image_files(root)):
77 | gt = load_sample_img(os.path.join(root, name))
78 |
79 | torch.manual_seed(idx)
80 | torch.cuda.manual_seed(idx)
81 | gt, inp, pred = step_fn(gt)
82 | pred = pred.clamp(0, 1)
83 |
84 | psnr = tl.metrics.psnr(pred, gt)
85 | ssim = tl.metrics.ssim(pred, gt)
86 |
87 | tracker.update('psnr', psnr)
88 | tracker.update('ssim', ssim)
89 |
90 | print('{} PSNR {} SSIM {}'.format(name, psnr, ssim))
91 |
92 | print('averge results')
93 | print(tracker.summary())
94 | print(timer.toc() / len(list_image_files(root)))
95 | return tracker['psnr'], tracker['ssim']
96 |
97 |
98 | def test_ubran100():
99 | psnr, ssim = run_optics('Urban100')
100 | assert abs(psnr - 30.83) < 0.01
101 | assert abs(ssim - 0.944) < 0.001
102 |
103 |
104 | def test_cbsd68():
105 | psnr, ssim = run_optics('CBSD68')
106 | assert abs(psnr - 32.01) < 0.01
107 | assert abs(ssim - 0.942) < 0.001
108 |
--------------------------------------------------------------------------------
/tests/problem/test_deraining.py:
--------------------------------------------------------------------------------
1 | from dprox import *
2 | from dprox.utils import *
3 | from dprox.contrib.derain import LearnableDegOp
4 |
5 | def test():
6 | b = hf.load_image('examples/derain/derain_input.png')
7 | gt = hf.load_image('examples/derain/derain_target.png')
8 | imshow(gt, b)
9 |
10 | # custom linop
11 | A = LearnableDegOp().cuda()
12 | def forward_fn(input, step): return A.forward(input, step)
13 | def adjoint_fn(input, step): return A.adjoint(input, step)
14 | raining = LinOpFactory(forward_fn, adjoint_fn)
15 |
16 | # build solver
17 | x = Variable()
18 | data_term = sum_squares(raining(x), b)
19 | reg_term = unrolled_prior(x)
20 | obj = data_term + reg_term
21 | solver = compile(obj, method='pgd')
22 |
23 | # load parameters
24 | ckpt = hf.load_checkpoint('image_deraining/derain_pdg.pth')
25 | A.load_state_dict(ckpt['linop'])
26 | reg_term.load_state_dict(ckpt['prior'])
27 | rhos = ckpt['rhos']
28 |
29 | out = solver.solve(x0=b, rhos=rhos, max_iter=7)
30 | out = to_ndarray(out, debatch=True) + b
31 | print(psnr(out, gt)) # 35.92
32 | imshow(gt, out)
33 | assert psnr(out, gt) - 35.92 < 0.1
34 |
--------------------------------------------------------------------------------
/tests/problem/test_energy_system.py:
--------------------------------------------------------------------------------
1 | import scipy.io
2 |
3 | import dprox as dp
4 | from dprox.utils.huggingface import load_path
5 | from dprox.contrib.energy_system import load_simple_cep_model
6 |
7 |
8 | def test():
9 | c, A_ub, A_eq, b_ub, b_eq = load_simple_cep_model()
10 |
11 | x = dp.Variable()
12 | prob = dp.Problem(c @ x, [A_ub @ x <= b_ub, A_eq @ x == b_eq])
13 | out = prob.solve(method='admm', adapt_params=True)
14 | print(out)
15 |
16 | # reference solution
17 | solution = scipy.io.loadmat(load_path("energy_system/simple_cep_model_20220916/esm_instance_solution.mat"))
18 |
19 | # pcg with custom backward
20 | # 83434.40576028604 170.29735095665677 0.0041508882261022985 200.8266268913533 0.569342163225932 29.714674732313096
21 | # tensor(-170.2974, device='cuda:0', dtype=torch.float64)
22 |
23 | # raw pcg
24 | # 83433.69636112433 174.7402594703876 0.004143616818155279 200.82662689427482 0.5693421632258844 29.691228047449876
25 | # tensor(-174.7403, device='cuda:0', dtype=torch.float64)
26 |
27 | # Obj: 8.35e+04, res_z: 0.00e+00, res_primal: 1.95e+02, reƒs_dual: 9.19e-04, eps_primal: 2.00e+02, eps_dual: 1.00e-03, rho: 1.45e+01
28 | # tensor(-99.6752, device='cuda:0', dtype=torch.float64)
29 |
--------------------------------------------------------------------------------
/tests/problem/test_inverse_problems.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from dprox import *
4 | from dprox.utils import *
5 | from dprox import contrib
6 |
7 |
8 | def test_csmri():
9 | x0, y0, gt, mask = contrib.csmri.sample()
10 |
11 | x = Variable()
12 | y = Placeholder()
13 | data_term = csmri(x, mask, y)
14 | reg_term = deep_prior(x, denoiser='unet')
15 | prob = Problem(data_term + reg_term)
16 |
17 | y.value = y0
18 | max_iter = 24
19 | rhos, sigmas = log_descent(30, 20, max_iter)
20 | prob.solve(
21 | method='admm',
22 | device='cuda',
23 | x0=x0, rhos=rhos, lams={reg_term: sigmas}, max_iter=max_iter, pbar=True
24 | )
25 | out = x.value.real
26 |
27 | print(psnr(out, gt))
28 | assert abs(psnr(out, gt) - 43.1) < 0.1
29 |
30 |
31 | def test_deconv():
32 | img = contrib.sample('face')
33 | psf = contrib.point_spread_function(15, 5)
34 | b = contrib.blurring(img, psf)
35 |
36 | x = Variable()
37 | data_term = sum_squares(conv(x, psf) - b)
38 | reg_term = deep_prior(x, denoiser='ffdnet_color')
39 | reg2 = nonneg(x)
40 | prob = Problem(data_term + reg_term + reg2)
41 |
42 | out = prob.solve(method='admm', x0=b, pbar=True)
43 |
44 | print(psnr(out, img))
45 | assert abs(psnr(out, img) - 34.5) < 0.1
46 |
47 |
48 | def test_deconv2():
49 | img = contrib.sample('face')
50 | psf = contrib.point_spread_function(ksize=15, sigma=5)
51 | # TODO: this still has bug
52 | y = contrib.blurring(img, psf) + np.random.randn(*img.shape).astype('float32') * 5 / 255.0
53 | y.squeeze(0)
54 | print(img.shape, y.shape)
55 |
56 | x = Variable()
57 | data_term = sum_squares(conv(x, psf) - y)
58 | prior_term = deep_prior(x, 'ffdnet_color')
59 | reg_term = nonneg(x)
60 | objective = data_term + prior_term + reg_term
61 | p = Problem(objective)
62 | out = p.solve(method='admm', x0=y, pbar=True)
63 | print(psnr(out, img))
64 |
--------------------------------------------------------------------------------
/tests/problem/test_jd23.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from dprox import *
4 | from dprox.utils import *
5 | from dprox.linalg import LinearSolveConfig
6 | from dprox.contrib import *
7 |
8 |
9 | def test_jd2():
10 | img = sample('face')
11 | img = to_torch_tensor(img, batch=True).float()
12 | psf = point_spread_function(15, 5)
13 | b = blurring(img, psf)
14 | b = mosaicing(b)
15 |
16 | x = Variable()
17 | data_term = sum_squares(mosaic(conv(x, psf)) - b)
18 | reg_term = deep_prior(x, denoiser='ffdnet_color')
19 | prob = Problem(data_term + reg_term, linear_solve_config=LinearSolveConfig(max_iters=50))
20 |
21 | max_iter = 5
22 | rhos, sigmas = log_descent(35, 30, max_iter)
23 | out = prob.solve(method='admm', x0=b, rhos=rhos, lams={reg_term: sigmas}, max_iter=max_iter, pbar=True)
24 |
25 | print(psnr(out, img)) # 25.25
26 |
27 | assert abs(psnr(out, img) - 25.10) < 0.1
28 |
29 |
30 | def load(name='face', return_tensor=True):
31 | import imageio
32 | s = imageio.imread(name).copy().astype('float32') / 255
33 | s = s[:768,:,:]
34 | if return_tensor:
35 | s = to_torch_tensor(s, batch=True).float()
36 | return s
37 |
38 |
39 | def test_jd2_batched():
40 | psf = point_spread_function(15, 5)
41 |
42 | # img = load('tests/face2.png')
43 | img = sample('face')
44 | img1 = to_torch_tensor(img, batch=True).float()
45 | b = blurring(img1, psf)
46 | b1 = mosaicing(b)
47 |
48 | img = sample('face')
49 | img2 = to_torch_tensor(img, batch=True).float()
50 | b = blurring(img2, psf)
51 | b2 = mosaicing(b)
52 | b = torch.cat([b1, b2], dim=0)
53 | print(b.shape)
54 |
55 | x = Variable()
56 | data_term = sum_squares(mosaic(conv(x, psf)) - b)
57 | reg_term = deep_prior(x, denoiser='ffdnet_color')
58 | prob = Problem(data_term + reg_term, linear_solve_config=LinearSolveConfig(max_iters=50))
59 |
60 | max_iter = 5
61 | rhos, sigmas = log_descent(35, 30, max_iter)
62 | out = prob.solve(method='admm', x0=b, rhos=rhos, lams={reg_term: sigmas}, max_iter=max_iter, pbar=True)
63 |
64 | print(psnr(out[0:1], img1)) # 29.689
65 | print(psnr(out[1].unsqueeze(0), img2)) # 29.689
66 |
67 | # assert abs(psnr(out[0:1], img1) - 29.92) < 0.1
68 | assert abs(psnr(out[0:1], img1) - 25.10) < 0.1
69 | assert abs(psnr(out[1].unsqueeze(0), img2) - 25.10) < 0.1
--------------------------------------------------------------------------------
/tests/problem/test_ml_problems.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import dprox as dp
3 |
4 |
5 | def test_lsq():
6 | x = dp.Variable((3,3))
7 | rhs = np.array([[1, 2, 3],[4,5,6],[7,8,9]])
8 | prob = dp.Problem(dp.sum_squares(2*x - rhs))
9 | prob.solve('admm', x0=np.zeros((3,3)))
10 | print(x.value)
11 |
12 | assert (x.value.cpu().numpy() == rhs / 2).all()
13 |
14 |
15 | def test_lsq1():
16 | x = dp.Variable((3,3))
17 | rhs = np.array([[1, 2, 3],[4,5,6],[7,8,9]])
18 | prob = dp.Problem(dp.sum_squares(2*x, rhs))
19 | prob.solve('admm', x0=np.zeros((3,3)))
20 | print(x.value)
21 |
22 | assert (x.value.cpu().numpy() == rhs / 2).all()
23 |
24 |
25 | def test_lsq2():
26 | x = dp.Variable((3,3,1))
27 | rhs = np.array([[[1, 2, 3],[4,5,6],[7,8,9]]])
28 | kernel = np.array([[1,1],[1,1]]) / 4
29 | prob = dp.Problem(dp.sum_squares(dp.conv(x, kernel) - rhs))
30 | prob.solve('admm', x0=np.zeros((3,3,1)))
31 | out = dp.eval(dp.conv(x, kernel)-rhs, x.value, zero_out_constant=False)
32 | print(x.value)
33 | print(out)
34 | assert (out < 1e-5).all()
35 |
36 |
37 | def test_lsq3():
38 | x = dp.Variable((3))
39 | rhs = np.array([1, 2, 3])
40 | prob = dp.Problem(dp.sum_squares(2*x - rhs))
41 | prob.solve('admm', x0=np.zeros(3))
42 | print(x.value)
43 |
44 | assert (x.value.cpu().numpy() == rhs / 2).all()
45 |
--------------------------------------------------------------------------------
/tests/test_algorithms.py:
--------------------------------------------------------------------------------
1 | from dprox import *
2 | from dprox.utils import *
3 | from dprox import contrib
4 |
5 |
6 | def test_admm():
7 | img = contrib.sample('face')
8 | psf = contrib.point_spread_function(15, 5)
9 | b = contrib.blurring(img, psf)
10 |
11 | x = Variable()
12 | data_term = sum_squares(conv(x, psf) - b)
13 | reg_term = deep_prior(x, denoiser='ffdnet_color')
14 | reg2 = nonneg(x)
15 | prob = Problem(data_term + reg_term + reg2)
16 |
17 | out = prob.solve(method='admm', x0=b, pbar=True)
18 |
19 | print(psnr(out, img))
20 | assert abs(psnr(out, img) - 34.51) < 0.1
21 |
22 |
23 | def test_ladmm():
24 | img = contrib.sample('face')
25 | psf = contrib.point_spread_function(15, 5)
26 | b = contrib.blurring(img, psf)
27 |
28 | x = Variable()
29 | data_term = sum_squares(conv(x, psf) - b)
30 | reg_term = deep_prior(x, denoiser='ffdnet_color')
31 | reg2 = nonneg(x)
32 | prob = Problem(data_term + reg_term + reg2)
33 |
34 | out = prob.solve(method='ladmm', x0=b, pbar=True)
35 |
36 | print(psnr(out, img))
37 | assert abs(psnr(out, img) - 34.51) < 0.1
38 |
39 |
40 | def test_admm_vxu():
41 | img = contrib.sample('face')
42 | psf = contrib.point_spread_function(15, 5)
43 | b = contrib.blurring(img, psf)
44 |
45 | x = Variable()
46 | data_term = sum_squares(conv(x, psf) - b)
47 | reg_term = deep_prior(x, denoiser='ffdnet_color')
48 | reg2 = nonneg(x)
49 | prob = Problem(data_term + reg_term + reg2)
50 |
51 | out = prob.solve(method='admm_vxu', x0=b, pbar=True)
52 |
53 | print(psnr(out, img))
54 | assert abs(psnr(out, img) - 34.50) < 0.1
55 |
56 |
57 | def test_hqs():
58 | img = contrib.sample('face')
59 | psf = contrib.point_spread_function(15, 5)
60 | b = contrib.blurring(img, psf)
61 |
62 | x = Variable()
63 | data_term = sum_squares(conv(x, psf) - b)
64 | reg_term = deep_prior(x, denoiser='ffdnet_color')
65 | reg2 = nonneg(x)
66 | prob = Problem(data_term + reg_term + reg2)
67 |
68 | out = prob.solve(method='hqs', x0=b, pbar=True)
69 |
70 | print(psnr(out, img))
71 | assert abs(psnr(out, img) - 34.08) < 0.1
72 |
73 |
74 | def test_pc():
75 | img = contrib.sample('face')
76 | psf = contrib.point_spread_function(15, 5)
77 | b = contrib.blurring(img, psf)
78 |
79 | x = Variable()
80 | data_term = sum_squares(conv(x, psf) - b)
81 | reg_term = deep_prior(x, denoiser='ffdnet_color')
82 | reg2 = nonneg(x)
83 | prob = Problem(data_term + reg_term + reg2)
84 |
85 | out = prob.solve(method='pc', x0=b, pbar=True)
86 |
87 | print(psnr(out, img))
88 | assert abs(psnr(out, img) - 29.87) < 0.1
89 |
90 |
91 | def test_pgd():
92 | img = contrib.sample('face')
93 | psf = contrib.point_spread_function(15, 5)
94 | b = contrib.blurring(img, psf)
95 |
96 | x = Variable()
97 | data_term = sum_squares(conv(x, psf) - b)
98 | reg_term = deep_prior(x, denoiser='ffdnet_color')
99 | prob = Problem(data_term + reg_term)
100 |
101 | out = prob.solve(method='pgd', x0=b, pbar=True)
102 |
103 | print(psnr(out, img))
104 | assert abs(psnr(out, img) - 21.44) < 0.1
105 |
--------------------------------------------------------------------------------
/tests/test_grad.py:
--------------------------------------------------------------------------------
1 | from dprox import *
2 | from dprox.contrib.optic import *
3 | from dprox.utils import *
4 |
5 |
6 | def test_deep_prior():
7 | device = torch.device("cuda")
8 | x = Variable()
9 | reg_term = deep_prior(x, denoiser="ffdnet_color").to(device)
10 |
11 | sigma = torch.nn.Parameter(torch.tensor(0.1).to(device))
12 | inp = torch.randn((1, 3, 128, 128), device=device)
13 | x.value = inp
14 | y = reg_term.prox(inp, sigma)
15 |
16 | loss = torch.nn.functional.mse_loss(inp, y)
17 | loss.backward()
18 | print(sigma.grad)
19 |
--------------------------------------------------------------------------------
/tests/test_linop.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import dprox as dp
3 |
4 | from dprox.contrib import fspecial_gaussian
5 | from dprox.utils import to_torch_tensor
6 |
7 | from scipy.misc import face
8 |
9 |
10 | def test_conv():
11 | x = dp.Variable()
12 | psf = fspecial_gaussian(15, 5)
13 | op = dp.conv(x, psf)
14 |
15 | K = dp.CompGraph(op)
16 | assert K.sanity_check()
17 |
18 | img = to_torch_tensor(face(), batch=True)
19 | out = K.forward(img)
20 |
21 | print(img.shape)
22 | print(out.shape)
23 | assert img.shape == out.shape
24 |
25 |
26 | def test_constant():
27 | x = dp.Variable()
28 | y = 3 * (x - torch.tensor([2, 2, 2]))
29 |
30 | input = torch.tensor([1, 2, 3])
31 | print(dp.eval(y, input, zero_out_constant=False))
32 |
33 | x.value = torch.tensor([1, 2, 3])
34 |
35 | print(y.variables)
36 | print(y.constants)
37 | print(y.value)
38 | print(y.offset)
39 | assert torch.allclose(y.value, 3 * (torch.tensor([1, 2, 3]) - torch.tensor([2, 2, 2])))
40 | assert torch.allclose(y.offset, - 3 * torch.tensor([2, 2, 2]))
41 |
42 |
43 | def test_grad():
44 | x = dp.Variable()
45 | K = dp.CompGraph(dp.grad(x, dim=1) + dp.grad(x, dim=2))
46 | assert K.sanity_check()
47 | img = to_torch_tensor(face(), batch=True)
48 | print(img.shape)
49 | outputs = K.forward(img)
50 |
51 |
52 | def test_mosaic():
53 | x = dp.Variable()
54 | K = dp.CompGraph(dp.mosaic(x))
55 | assert K.sanity_check()
56 | img = to_torch_tensor(face(), batch=True)
57 | print(img.shape)
58 | outputs = K.forward(img)
59 |
60 |
61 | def test_sum():
62 | x1 = dp.Variable()
63 | x2 = dp.Variable()
64 |
65 | K = dp.CompGraph(x1 + x2)
66 |
67 | v1 = torch.randn((4, 4), requires_grad=True)
68 | v2 = torch.randn((4, 4))
69 |
70 | outputs = K.forward(v1, v2)
71 |
72 | print(outputs)
73 | assert torch.allclose(outputs, v1 + v2)
74 |
75 | loss = torch.mean(outputs)
76 | loss.backward()
77 | print(v1.grad)
78 | assert torch.allclose(v1.grad, torch.full_like(v1.grad, 1 / 16))
79 |
80 |
81 | def test_variable():
82 | x = dp.Variable()
83 |
84 | print(x.forward(torch.tensor([2, 2, 2])))
85 | print(x.adjoint(torch.tensor([2, 2, 2])))
86 |
87 | K = dp.CompGraph(x)
88 |
89 | out = K.adjoint(torch.tensor([2, 2, 2]))
90 | print(out)
91 |
92 |
93 | def test_vstack():
94 | x = dp.Variable()
95 | K = dp.CompGraph(dp.vstack([dp.mosaic(x), dp.grad(x)]))
96 |
97 | img = to_torch_tensor(face(), batch=True)
98 | print(img.shape)
99 |
100 | outputs = K.forward(img)
101 | inputs = K.adjoint(outputs)
102 | print(inputs.shape)
103 | K.sanity_check()
104 |
105 |
106 | def test_complex():
107 | from dprox import contrib
108 |
109 | img = contrib.sample('face')
110 | psf = contrib.point_spread_function(15, 5)
111 | b = contrib.blurring(img, psf)
112 |
113 | x = dp.Variable()
114 | data_term = dp.sum_squares(dp.conv(x, psf) - b)
115 | reg_term = dp.deep_prior(x, denoiser='ffdnet_color')
116 | reg2 = dp.nonneg(x)
117 | K = dp.CompGraph(dp.vstack([fn.linop for fn in [data_term, reg_term, reg2]]))
118 | K.forward(b)
119 |
--------------------------------------------------------------------------------
/tests/test_linop_primitive.py:
--------------------------------------------------------------------------------
1 | import dprox as dp
2 | import torch
3 |
4 | def test_eval():
5 | op = dp.Variable()
6 | out = dp.eval(op, torch.zeros(64,64))
7 | print(out)
8 | assert torch.allclose(out, torch.zeros(64,64))
--------------------------------------------------------------------------------
/tests/test_primitive.py:
--------------------------------------------------------------------------------
1 | import dprox as dp
2 |
3 |
4 | # def test():
5 | # dp.compile()
6 | # dp.specialize()
7 | # dp.optimize()
8 | # dp.visualize()
9 | # dp.validate()
--------------------------------------------------------------------------------