├── .gitattributes ├── .github └── workflows │ └── pypi-publish.yml ├── .gitignore ├── .readthedocs.yaml ├── README.md ├── docs ├── Makefile ├── README.md ├── make.bat ├── requirements.txt └── source │ ├── _static │ ├── css │ │ └── custom.css │ ├── favicon.ico │ ├── image │ │ ├── example_deconv.png │ │ ├── optic_results.png │ │ └── psf.png │ ├── logo.svg │ └── pipeline_dprox.gif │ ├── api │ ├── algo.md │ ├── index.md │ ├── linalg.md │ ├── linop.md │ ├── primitive.md │ ├── proxfn.md │ └── utils.md │ ├── citation.md │ ├── conf.py │ ├── index.md │ ├── started │ ├── index.md │ ├── install.md │ └── quicktour.nblink │ └── tutorials │ ├── computational_optics.nblink │ ├── csmri.nblink │ ├── deraining.nblink │ ├── differentiable_linear_solver.nblink │ ├── energy_system.nblink │ ├── image_restoration.nblink │ ├── index.md │ ├── learn_the_basic.nblink │ ├── linear_operator.nblink │ └── training.nblink ├── dprox ├── __init__.py ├── algo │ ├── __init__.py │ ├── admm.py │ ├── base.py │ ├── hqs.py │ ├── invert.py │ ├── lp │ │ ├── __init__.py │ │ ├── solvers.py │ │ └── utils.py │ ├── opt │ │ ├── __init__.py │ │ ├── absorb.py │ │ ├── equil.py │ │ └── merge.py │ ├── pc.py │ ├── pgd.py │ ├── primitives.py │ ├── problem.py │ ├── specialization │ │ ├── __init__.py │ │ ├── deq │ │ │ ├── __init__.py │ │ │ ├── solver.py │ │ │ ├── training.py │ │ │ └── utils │ │ │ │ ├── __init__.py │ │ │ │ ├── jacobian.py │ │ │ │ ├── layer_utils.py │ │ │ │ ├── optimizations.py │ │ │ │ ├── radam.py │ │ │ │ └── solvers.py │ │ ├── rl │ │ │ ├── __init__.py │ │ │ └── solver.py │ │ └── unroll.py │ └── tune │ │ ├── __init__.py │ │ ├── dpir.py │ │ └── learnable.py ├── contrib │ ├── __init__.py │ ├── csmri.py │ ├── derain.py │ ├── energy_system.py │ ├── optic │ │ ├── __init__.py │ │ ├── common.py │ │ ├── doe_model.py │ │ ├── doe_model_hybrid.py │ │ ├── unet.py │ │ └── utils.py │ └── restoration.py ├── linalg │ ├── __init__.py │ ├── custom.py │ └── solve │ │ ├── __init__.py │ │ ├── solver_cg.py │ │ ├── solver_minres.py │ │ └── solver_plss.py ├── linop │ ├── __init__.py │ ├── base.py │ ├── blackbox.py │ ├── comp_graph.py │ ├── constaints.py │ ├── constant.py │ ├── conv.py │ ├── edge.py │ ├── grad.py │ ├── mul.py │ ├── placeholder.py │ ├── scale.py │ ├── subsample.py │ ├── sum.py │ ├── variable.py │ └── vstack.py ├── proxfn │ ├── __init__.py │ ├── base.py │ ├── fast │ │ ├── __init__.py │ │ ├── cs.py │ │ ├── csmri.py │ │ ├── pr.py │ │ ├── spi.py │ │ └── sr.py │ ├── nlm │ │ ├── __init__.py │ │ ├── nlm.py │ │ └── patch_nlm.py │ ├── nonneg.py │ ├── norm.py │ ├── pnp │ │ ├── __init__.py │ │ ├── denoisers │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── composite.py │ │ │ ├── models │ │ │ │ ├── TV_denoising.py │ │ │ │ ├── __init__.py │ │ │ │ ├── basicblock.py │ │ │ │ ├── network_dncnn.py │ │ │ │ ├── network_ffdnet.py │ │ │ │ ├── network_unet.py │ │ │ │ ├── qrnn │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── conv.py │ │ │ │ │ ├── grunet.py │ │ │ │ │ ├── layer.py │ │ │ │ │ ├── qrnn3d.py │ │ │ │ │ └── sync_batchnorm │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── batchnorm.py │ │ │ │ │ │ ├── comm.py │ │ │ │ │ │ ├── replicate.py │ │ │ │ │ │ └── unittest.py │ │ │ │ └── unet │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── basicblock.py │ │ │ │ │ └── unet.py │ │ │ └── wrapper.py │ │ └── prior.py │ ├── sum_square.py │ └── unrolling │ │ ├── __init__.py │ │ ├── dgu.py │ │ └── prior.py └── utils │ ├── __init__.py │ ├── containar.py │ ├── huggingface.py │ ├── init │ ├── __init__.py │ ├── mosaic.py │ └── sr.py │ ├── io.py │ ├── metrics.py │ ├── misc.py │ └── psf2otf.py ├── examples ├── README.md ├── applications │ ├── csmri.py │ ├── deconv.py │ ├── demosaic.py │ ├── joint_demosaic_deconv.py │ └── super_resolution.py └── papers │ ├── README.md │ ├── deltaprox_siggraph_2023 │ ├── computional_optics │ │ ├── README.md │ │ ├── e2e_optics_dprox.py │ │ ├── e2e_optics_dprox_joint.py │ │ ├── e2e_optics_unet.py │ │ └── pnp_optics.py │ ├── csmri │ │ ├── README.md │ │ ├── deq_tfpnp.py │ │ ├── deq_unet.py │ │ ├── download_datasets.py │ │ ├── pnp_drunet.py │ │ ├── pnp_unet.py │ │ ├── rl_unet.py │ │ ├── rl_unet_train.py │ │ └── unroll_unet.py │ └── deraining │ │ ├── README.md │ │ ├── derain │ │ ├── __init__.py │ │ ├── data.py │ │ ├── dataset.py │ │ ├── restormer.py │ │ ├── unroll.py │ │ └── unroll_share.py │ │ ├── evaluate_PSNR_SSIM.m │ │ ├── evaluate_PSNR_SSIM.py │ │ ├── test_unroll.py │ │ └── test_unroll_share.py │ ├── dgunet_cvpr_2021 │ ├── main.py │ ├── network.py │ └── ops.py │ ├── dphsir_neurcomputing_2022 │ ├── degrades │ │ ├── __init__.py │ │ ├── blur.py │ │ ├── cs.py │ │ ├── general.py │ │ ├── inpaint.py │ │ ├── noise.py │ │ ├── sr.py │ │ └── utils.py │ ├── hsi_compress_sensing.py │ ├── hsi_deblur.py │ ├── hsi_inpainting.py │ ├── hsi_misr.py │ └── hsi_sisr.py │ ├── dpir_tpami_2020 │ ├── rgb_demosaic.py │ └── utils.py │ └── tfpnp_icml_2020 │ └── csmri.py ├── notebooks ├── README.md ├── computational_optics.ipynb ├── csmri.ipynb ├── deraining.ipynb ├── differentiable_linear_solver.ipynb ├── energy_system_planning.ipynb ├── image_restoration.ipynb ├── learn_the_basic.ipynb ├── linear_operator.ipynb ├── primitive.ipynb ├── quickstart.ipynb └── training.ipynb ├── requirements.txt ├── setup.py └── tests ├── README.md ├── linalg ├── test_linear_solver.py ├── test_linear_solver_batch.py ├── test_linear_solver_grad.py ├── test_linear_solver_torch.py └── test_pcg.py ├── paper ├── test_csmri.py ├── test_derain.py ├── test_energy.py └── test_optics.py ├── problem ├── test_deraining.py ├── test_energy_system.py ├── test_inverse_problems.py ├── test_jd23.py └── test_ml_problems.py ├── test_algorithms.py ├── test_grad.py ├── test_linop.py ├── test_linop_primitive.py └── test_primitive.py /.gitattributes: -------------------------------------------------------------------------------- 1 | notebooks/** linguist-vendored -------------------------------------------------------------------------------- /.github/workflows/pypi-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.txt 2 | *.pth 3 | *.mat 4 | simple_cep_model_20220916 5 | data 6 | abc 7 | datasets 8 | upload.sh 9 | examples/applications/computional_optics/optics 10 | *.png 11 | # *.mat 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | # *.png 17 | # *.jpg 18 | *.tif 19 | log 20 | ckpt 21 | # C extensions 22 | *.so 23 | 24 | # Distribution / packaging 25 | .Python 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | downloads/ 30 | eggs/ 31 | .eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | wheels/ 38 | pip-wheel-metadata/ 39 | share/python-wheels/ 40 | *.egg-info/ 41 | .installed.cfg 42 | *.egg 43 | MANIFEST 44 | 45 | # PyInstaller 46 | # Usually these files are written by a python script from a template 47 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 48 | *.manifest 49 | *.spec 50 | 51 | # Installer logs 52 | pip-log.txt 53 | pip-delete-this-directory.txt 54 | 55 | # Unit test / coverage reports 56 | htmlcov/ 57 | .tox/ 58 | .nox/ 59 | .coverage 60 | .coverage.* 61 | .cache 62 | nosetests.xml 63 | coverage.xml 64 | *.cover 65 | *.py,cover 66 | .hypothesis/ 67 | .pytest_cache/ 68 | 69 | # Translations 70 | *.mo 71 | *.pot 72 | 73 | # Django stuff: 74 | *.log 75 | local_settings.py 76 | db.sqlite3 77 | db.sqlite3-journal 78 | 79 | # Flask stuff: 80 | instance/ 81 | .webassets-cache 82 | 83 | # Scrapy stuff: 84 | .scrapy 85 | 86 | # Sphinx documentation 87 | docs/_build/ 88 | 89 | # PyBuilder 90 | target/ 91 | 92 | # Jupyter Notebook 93 | .ipynb_checkpoints 94 | 95 | # IPython 96 | profile_default/ 97 | ipython_config.py 98 | 99 | # pyenv 100 | .python-version 101 | 102 | # pipenv 103 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 104 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 105 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 106 | # install all needed dependencies. 107 | #Pipfile.lock 108 | 109 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 110 | __pypackages__/ 111 | 112 | # Celery stuff 113 | celerybeat-schedule 114 | celerybeat.pid 115 | 116 | # SageMath parsed files 117 | *.sage.py 118 | 119 | # Environments 120 | .env 121 | .venv 122 | env/ 123 | venv/ 124 | ENV/ 125 | env.bak/ 126 | venv.bak/ 127 | 128 | # Spyder project settings 129 | .spyderproject 130 | .spyproject 131 | 132 | # Rope project settings 133 | .ropeproject 134 | 135 | # mkdocs documentation 136 | /site 137 | 138 | # mypy 139 | .mypy_cache/ 140 | .dmypy.json 141 | dmypy.json 142 | 143 | # Pyre type checker 144 | .pyre/ 145 | .vscode 146 | .rsyncignore 147 | .ssh -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-20.04 11 | tools: 12 | python: "3.9" 13 | # You can also specify other tool versions: 14 | # nodejs: "16" 15 | # rust: "1.55" 16 | # golang: "1.17" 17 | 18 | # Build documentation in the docs/ directory with Sphinx 19 | sphinx: 20 | configuration: docs/source/conf.py 21 | 22 | # If using Sphinx, optionally build your docs in additional formats such as PDF 23 | # formats: 24 | # - pdf 25 | 26 | # Optionally declare the Python requirements required to build your docs 27 | python: 28 | install: 29 | - requirements: docs/requirements.txt -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Documentation 2 | 3 | Source of the documentation. 4 | 5 | - Build html. 6 | 7 | ```bash 8 | make html 9 | ``` 10 | 11 | - Preview at `build/html/index.html` 12 | 13 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | myst-parser 2 | sphinx-rtd-theme 3 | furo 4 | sphinx-copybutton 5 | sphinx-inline-tabs 6 | nbsphinx 7 | nbsphinx_link 8 | linkify-it-py 9 | linkify 10 | ipython 11 | 12 | torch 13 | imageio 14 | scikit_image 15 | matplotlib 16 | munch 17 | tfpnp 18 | cvxpy 19 | torchlights 20 | tensorboardX 21 | termcolor 22 | proximal 23 | opencv-python 24 | huggingface_hub 25 | torchvision -------------------------------------------------------------------------------- /docs/source/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | 2 | .sidebar-logo { 3 | display: block; 4 | margin: 0; 5 | max-width: 50%; 6 | } 7 | 8 | .nbsphinx-gallery { 9 | display: grid; 10 | grid-template-columns: repeat(auto-fill, minmax(200px, 1fr)); 11 | gap: 5px; 12 | margin-top: 1em; 13 | margin-bottom: 1em; 14 | } 15 | 16 | h1 { 17 | font-size: 2em 18 | } 19 | 20 | h2 { 21 | font-size: 1.5em 22 | } 23 | 24 | h3 { 25 | font-size: 1.25em 26 | } 27 | 28 | h4 { 29 | font-size: 1.125em 30 | } 31 | 32 | h5 { 33 | font-size: 1.07em 34 | } 35 | 36 | h6 { 37 | font-size: 1em 38 | } -------------------------------------------------------------------------------- /docs/source/_static/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Delta-Prox/15fc489e99c4de2676dc538d9c0552338f6fb894/docs/source/_static/favicon.ico -------------------------------------------------------------------------------- /docs/source/_static/image/example_deconv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Delta-Prox/15fc489e99c4de2676dc538d9c0552338f6fb894/docs/source/_static/image/example_deconv.png -------------------------------------------------------------------------------- /docs/source/_static/image/optic_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Delta-Prox/15fc489e99c4de2676dc538d9c0552338f6fb894/docs/source/_static/image/optic_results.png -------------------------------------------------------------------------------- /docs/source/_static/image/psf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Delta-Prox/15fc489e99c4de2676dc538d9c0552338f6fb894/docs/source/_static/image/psf.png -------------------------------------------------------------------------------- /docs/source/_static/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | DeltaProx-Logo 4 | 12 | 13 | 14 | 15 | 16 | - 17 | 18 | 19 | Prox 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /docs/source/_static/pipeline_dprox.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Delta-Prox/15fc489e99c4de2676dc538d9c0552338f6fb894/docs/source/_static/pipeline_dprox.gif -------------------------------------------------------------------------------- /docs/source/api/algo.md: -------------------------------------------------------------------------------- 1 | # Proximal Algorithms 2 | 3 | Extending ∇-Prox for more proximal algorithms is straightforward. We define a new algorithm class that inherits from the base class `Algorithm`. The required methods to be implemented are `partition` and `_iter`, representing the problem partition and a single algorithm iteration. 4 | 5 | The `partition` takes a list of proxable functions and returns their splits as a list of `psi_fn` and `omega_fn`. For `_iter`, it is a single iteration of the proximal algorithm that takes an input of the state and two parameters, `rho` for the penalty strength on multipliers and `lam` for proximal operators. 6 | 7 | The state is generally a list of variables, including the auxiliary ones that an algorithm creates. ∇-Prox provides a state by returning the output of previous executions of `_iter` or the initial state provided by the `initialize` method. 8 | 9 | ```python 10 | class new_algorithm(Algorithm): 11 | def partition(cls, prox_fns: List[ProxFn]): 12 | # Perform problem partition according to the algorithm's need. 13 | 14 | def __init__(...): 15 | # Custom initialization code. 16 | 17 | def _iter(self, state, rho, lam): 18 | # Code to compute the function's proximal operator. 19 | return ... 20 | 21 | def initialize(self, x0): 22 | # Return the initial state. 23 | return ... 24 | 25 | def nparams(self): 26 | # (Optional) Return the number of hyperparameters of 27 | # this algorithm. 28 | return ... 29 | 30 | def state_split(self): 31 | # (Optional) Return the split size of the packed state. 32 | # Useful for deep equilibrium/reinforcement learning. 33 | return ... 34 | ``` 35 | 36 | Implementing `partition`, `initialize`, and `_iter` is generally sufficient to evaluate the proximal algorithm for a given problem. 37 | 38 | To integrate the new algorithm with deep equilibrium learning (DEQ) and deep reinforcement learning (RL), users have to implement two additional helper methods, i.e., `params` for counting the number of hyperparameters and `state_split` for the structures of the state that is returned by `_iter`. 39 | 40 | For example, assuming `_iter` returns the state as nested arrays such as `[x,[v1,v2],[u1,u2]]`, the output of `state_split` should be `[1,[2],[2]]`. ∇-Prox exploits these properties to perform the necessary packing and unpacking for the iteration states to achieve a unified interface for the internal DEQ and RL implementations. 41 | -------------------------------------------------------------------------------- /docs/source/api/index.md: -------------------------------------------------------------------------------- 1 | # API Documentation 2 | 3 | ```{toctree} 4 | :hidden: 5 | 6 | primitive 7 | algo 8 | linop 9 | proxfn 10 | linalg 11 | ``` 12 | 13 | ∇-Prox is a domain-specific language (DSL) and compiler that transforms optimization problems into differentiable proximal solvers. 14 | 15 | ## Index 16 | 17 | - [Proximal Algorithms](algo.md) 18 | - [Proximal Functions](proxfn.md) 19 | - [Linear Operators](linop.md) 20 | - [Linear System Solver](linalg.md) 21 | 22 | 23 | -------------------------------------------------------------------------------- /docs/source/api/linalg.md: -------------------------------------------------------------------------------- 1 | # Linear System Solver 2 | 3 | 4 | ## Unified Interface 5 | 6 | ```{eval-rst} 7 | .. autofunction:: dprox.linalg.linear_solve 8 | .. autoclass:: dprox.linalg.LinearSolveConfig 9 | ``` 10 | 11 | ## Linear Solvers 12 | 13 | ```{eval-rst} 14 | .. autofunction:: dprox.linalg.solve.solver_cg.cg 15 | .. autofunction:: dprox.linalg.solve.solver_cg.pcg 16 | .. autofunction:: dprox.linalg.solve.solver_minres.minres 17 | .. autofunction:: dprox.linalg.solve.solver_plss.plss 18 | ``` -------------------------------------------------------------------------------- /docs/source/api/linop.md: -------------------------------------------------------------------------------- 1 | # Linear Operator 2 | 3 | Defining new linear operators mainly involves the definition of the forward and adjoint routines. The following code shows the template for defining them. Similar to a [proxable function](), the operator is defined as a class inheriting from the base class `LinOp`. 4 | 5 | ```python 6 | class new_linop(LinOp): 7 | def __init__(...): 8 | """ Custom initialization code. 9 | """ 10 | 11 | def forward(self, inputs): 12 | """The forward operator. Compute x -> Kx 13 | """ 14 | 15 | def adjoint(self, inputs): 16 | """The adjoint operator. Compute x -> K^Tx 17 | """ 18 | 19 | def is_diag(self, freq): 20 | """(Optional) Check if the linear operator is diagonalizable or 21 | diagonalizable in the frequency domain. 22 | """ 23 | 24 | def get_diag(self, x, freq): 25 | """(Optional) Return the diagonal/frequency diagonal matrix that 26 | matches the shape of input x. 27 | """ 28 | ``` 29 | 30 | By default, the linear operator is not diagonal. To introduce a diagonal linear operator, one must implement the `is_diag` and `get_diag` for checking the diagonalizability and acquiring the diagonal matrix. These methods allow ∇-Prox to construct more efficient solvers, e.g., ADMM with a closed-form matrix inverse for the least-square update. 31 | 32 | ## Sanity Check 33 | 34 | Typically, it is not always easy to correctly implement the forward and adjoint operations of the linear operator. To facilitate the testing of these operators, ∇-Prox provides an implementation of the **dot-product test** for verifying that the `forward` and `adjoint` are adjoint to each other. 35 | 36 | Basically, the idea of the dot-product test comes from the associative property of linear algebra, which gives the following equation: 37 | 38 | $$ 39 | y^T(Ax) = (A^Ty)^Tx 40 | $$ 41 | 42 | Here, $x$ and $y$ are randomly generated data, and $A$ and $A^T$ denote the forward and adjoint of the linear operator. ∇-Prox uses of this property and generates a large number of random data arguments to check if this equation always holds for a given precision. 43 | 44 | To use this utility, users can call the `validate(linop, tol=1e-6)` and specify the tolerance of the difference between two sides of the equation. 45 | 46 | ```python 47 | import dprox as dp 48 | from dprox.utils.examples import fspecial_gaussian 49 | 50 | x = dp.Variable() 51 | psf = fspecial_gaussian(15, 5) 52 | op = dp.conv(x, psf) 53 | assert dp.validate(op) 54 | ``` 55 | 56 | 57 | ## Indices 58 | 59 | ```{eval-rst} 60 | .. autoclass:: dprox.linop.base.LinOp 61 | :members: forward, adoint, is_gram_diag, is_diag, get_diag, variables, constants, is_constant, value, offset, norm_bound, T, gram, clone 62 | .. autoclass:: dprox.linop.conv.conv 63 | :members: forward, adjoint 64 | .. autoclass:: dprox.linop.conv.conv_doe 65 | :members: forward, adjoint 66 | .. autoclass:: dprox.linop.sum.sum 67 | :members: forward, adjoint 68 | .. autoclass:: dprox.linop.sum.copy 69 | :members: forward, adjoint 70 | .. autoclass:: dprox.linop.vstack.vstack 71 | :members: forward, adjoint 72 | .. autoclass:: dprox.linop.vstack.split 73 | :members: forward, adjoint 74 | .. autoclass:: dprox.linop.variable.Variable 75 | :members: forward, adjoint 76 | .. autoclass:: dprox.linop.subsample.mosaic 77 | :members: forward, adjoint 78 | .. autoclass:: dprox.linop.scale.scale 79 | :members: forward, adjoint 80 | .. autoclass:: dprox.linop.placeholder.Placeholder 81 | :members: forward, adjoint 82 | .. autoclass:: dprox.linop.grad.grad 83 | :members: forward, adjoint 84 | ``` -------------------------------------------------------------------------------- /docs/source/api/primitive.md: -------------------------------------------------------------------------------- 1 | # Primitives 2 | 3 | ```{eval-rst} 4 | .. autofunction:: dprox.compile 5 | .. autofunction:: dprox.specialize 6 | ``` -------------------------------------------------------------------------------- /docs/source/api/proxfn.md: -------------------------------------------------------------------------------- 1 | # Proximal Functions 2 | 3 | The following code shows a template for defining a new proxable function. As previously mentioned, we define the function as a class inheriting from the base class `ProxFn`, and implement all the required methods. Then, ∇-Prox will properly handle all other necessary steps so that the new proxable function can work with operators, algorithms, and training utilities of the existing system. 4 | 5 | ```python 6 | class new_func(ProxFn): 7 | def __init__(...): 8 | """ Custom initialization code. 9 | """ 10 | 11 | def _prox(self, tau, v): 12 | """ Code to compute the function's proximal operator. 13 | """ 14 | return ... 15 | 16 | def _eval(self, v): 17 | """ (Optional) Code to evaluate the function. 18 | """ 19 | return ... 20 | 21 | def _grad(self, v): 22 | """ (Optional) Code to compute the analytic gradient. 23 | """ 24 | return ... 25 | ``` 26 | 27 | More specifically, defining a new function only requires a method `_prox` to be implemented, which evaluates the proximal operator of the given function. 28 | 29 | Users can optionally implement the `_grad` function to provide a routine for computing the analytic gradient of the proxable function. This facilitates the algorithms that partially rely on the gradient evaluation, e.g., proximal gradient descent. 30 | 31 | Alternatively, users can also implement the `_eval` method that computes the forwarding results of the proxable function if it is possible. ∇-Prox takes the `_eval` routine and computes the gradient with auto-diff if `_grad` is not implemented. 32 | 33 | 34 | ```{eval-rst} 35 | .. autoclass:: dprox.proxfn.base.ProxFn 36 | ``` -------------------------------------------------------------------------------- /docs/source/api/utils.md: -------------------------------------------------------------------------------- 1 | # Utility 2 | 3 | ## Metrics 4 | 5 | 6 | 7 | 8 | ## IO 9 | 10 | ```{eval-rst} 11 | .. autofunction:: dprox.utils.io.imread_rgb 12 | .. autofunction:: dprox.utils.io.imshow 13 | .. autofunction:: dprox.utils.io.imread 14 | ``` -------------------------------------------------------------------------------- /docs/source/citation.md: -------------------------------------------------------------------------------- 1 | # Citation 2 | 3 | The following publications discuss the ideas behind ∇-Prox: 4 | 5 | > **∇-Prox: Differentiable Proximal Algorithm Modeling for Large-Scale Optimization**
6 | > Zeqiang Lai, Kaixuan Wei, Ying Fu, Philipp Härtel, and Felix Heide.
7 | > ACM Transactions on Graphics, SIGGRAPH, 2023. 8 | 9 | > **ProxImaL: Efficient Image Optimization Using Proximal Algorithms**
10 | > F. Heide, S. Diamond, M. Niessner, J. Ragan-Kelley, W. Heidrich, and G. Wetzstein.
11 | > ACM Transactions on Graphics, SIGGRAPH, 2016. 12 | 13 | ```bibtex 14 | @article{deltaprox2023, 15 | title = {∇-Prox: Differentiable Proximal Algorithm Modeling for Large-Scale Optimization}, 16 | author = {Lai, Zeqiang and Wei, Kaixuan and Fu, Ying and H\"{a}rtel, Philipp and Heide, Felix}, 17 | journal={ACM Transactions on Graphics (TOG)}, 18 | volume = {42}, 19 | number = {4}, 20 | articleno = {105}, 21 | pages = {1--19}, 22 | year={2023}, 23 | publisher = {Association for Computing Machinery}, 24 | address = {New York, NY, USA}, 25 | doi = {10.1145/3592144}, 26 | } 27 | 28 | @article{heide2016proximal, 29 | title={Proximal: Efficient image optimization using proximal algorithms}, 30 | author={Heide, Felix and Diamond, Steven and Nie{\ss}ner, Matthias and Ragan-Kelley, Jonathan and Heidrich, Wolfgang and Wetzstein, Gordon}, 31 | journal={ACM Transactions on Graphics (TOG)}, 32 | volume={35}, 33 | number={4}, 34 | pages={1--15}, 35 | year={2016}, 36 | publisher = {Association for Computing Machinery}, 37 | address = {New York, NY, USA}, 38 | doi={10.1145/2897824.2925875}, 39 | } 40 | ``` 41 | -------------------------------------------------------------------------------- /docs/source/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | hide-toc: true 3 | --- 4 | 5 | 6 |
7 |

8 | 9 | 10 | 11 |

12 | 13 | 14 | ```{toctree} 15 | :maxdepth: 3 16 | :hidden: 17 | 18 | Get Started 19 | tutorials/index 20 | api/index 21 | citation 22 | ``` 23 | 24 | 25 | ```{toctree} 26 | :caption: Useful Links 27 | :hidden: 28 | PyPI Page 29 | GitHub Repository 30 | Project Page 31 | Paper 32 | ``` 33 | 34 |
35 | 36 | 🎉 ∇-Prox is a domain-specific language (DSL) and compiler that transforms optimization problems into differentiable proximal solvers. 37 |
38 | 🎉 ∇-Prox allows for rapid prototyping of learning-based bi-level optimization problems for a diverse range of applications, by [optimized algorithm unrolling](api/primitive), [deep equilibrium learning](api/primitive), and [deep reinforcement learning](api/primitive). 39 | 40 | The library includes the following major components: 41 | 42 | - A library of differentiable [proximal algorithms](api/algo), [proximal operators](api/proxfn), and [linear operators](api/linop). 43 | - Interchangeable [specialization](api/primitive) strategies for balancing trade-offs between speed and memory. 44 | - Out-of-the-box [training utilities](api/primitive) for learning-based bi-level optimization with a few lines of code. 45 | 46 | ```{nbgallery} 47 | ``` 48 | 49 | -------------------------------------------------------------------------------- /docs/source/started/index.md: -------------------------------------------------------------------------------- 1 | # Get Started 2 | 3 | ```{toctree} 4 | :hidden: 5 | 6 | quicktour 7 | install 8 | ``` 9 | 10 | ## Installation 11 | 12 | To get started with ∇-Prox, please follow the [Installation Documentation](install) for detailed instructions on how to install the library. 13 | 14 | ## Quick Tour 15 | 16 | - Take a [Quick Tour](quicktour) to get familiar with the features and functionalities of ∇-Prox. 17 | 18 | - Explore the [API Reference](../api/index) for a complete list of classes and functions. 19 | 20 | - For advanced topics and best practices, refer to the [tutorials](../tutorials/index). 21 | 22 | 23 | Happy coding with ∇-Prox! 🎉 -------------------------------------------------------------------------------- /docs/source/started/install.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | 4 | ∇-Prox works with PyTorch. To install Pytorch, please follow the [PyTorch installation instructions](https://pytorch.org/get-started/locally/). 5 | 6 | 7 | **Install with pip** 8 | 9 | ```bash 10 | pip install dprox 11 | ``` 12 | 13 | **Install from source** 14 | 15 | ```bash 16 | pip install git+https://github.com/princeton-computational-imaging/Delta-Prox.git 17 | ``` 18 | 19 | **Editable installation** 20 | 21 | You will need an editable install if you would like to: 22 | 23 | 1. Use the main version of the source code. 24 | 2. Need to test changes in the code. 25 | 26 | To do so, clone the repository and install 🎉 Delta Prox with the following commands: 27 | 28 | ``` 29 | git clone git+https://github.com/princeton-computational-imaging/Delta-Prox.git 30 | cd DeltaProx 31 | pip install -e . 32 | ``` 33 | 34 | ```{caution} 35 | Note that you must keep the DeltaProx folder for editable installation if you want to keep using the library. 36 | ``` 37 | -------------------------------------------------------------------------------- /docs/source/started/quicktour.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path": "../../../notebooks/quickstart.ipynb" 3 | } -------------------------------------------------------------------------------- /docs/source/tutorials/computational_optics.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path": "../../../notebooks/computational_optics.ipynb" 3 | } -------------------------------------------------------------------------------- /docs/source/tutorials/csmri.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path": "../../../notebooks/csmri.ipynb" 3 | } -------------------------------------------------------------------------------- /docs/source/tutorials/deraining.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path": "../../../notebooks/deraining.ipynb" 3 | } -------------------------------------------------------------------------------- /docs/source/tutorials/differentiable_linear_solver.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path": "../../../notebooks/differentiable_linear_solver.ipynb" 3 | } -------------------------------------------------------------------------------- /docs/source/tutorials/energy_system.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path": "../../../notebooks/energy_system_planning.ipynb" 3 | } -------------------------------------------------------------------------------- /docs/source/tutorials/image_restoration.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path": "../../../notebooks/image_restoration.ipynb" 3 | } -------------------------------------------------------------------------------- /docs/source/tutorials/index.md: -------------------------------------------------------------------------------- 1 | # Tutorials 2 | 3 | ## Basics 4 | ```{nbgallery} 5 | learn_the_basic 6 | differentiable_linear_solver 7 | linear_operator 8 | ``` 9 | 10 | ## Tasks 11 | 12 | ```{nbgallery} 13 | computational_optics 14 | csmri 15 | deraining 16 | energy_system 17 | image_restoration 18 | ``` -------------------------------------------------------------------------------- /docs/source/tutorials/learn_the_basic.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path": "../../../notebooks/learn_the_basic.ipynb" 3 | } -------------------------------------------------------------------------------- /docs/source/tutorials/linear_operator.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path": "../../../notebooks/linear_operator.ipynb" 3 | } -------------------------------------------------------------------------------- /docs/source/tutorials/training.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path": "../../../notebooks/training.ipynb" 3 | } -------------------------------------------------------------------------------- /dprox/__init__.py: -------------------------------------------------------------------------------- 1 | from .linop import * 2 | from .proxfn import * 3 | from .algo import * 4 | from .utils.containar import array, tensor 5 | from . import linalg 6 | from .utils.huggingface import CACHE_DIR 7 | 8 | __version__ = '0.1.3' 9 | __cache_dir__ = CACHE_DIR 10 | -------------------------------------------------------------------------------- /dprox/algo/__init__.py: -------------------------------------------------------------------------------- 1 | from .problem import Problem 2 | from .admm import ADMM, ADMM_vxu, LinearizedADMM 3 | from .hqs import HQS 4 | from .pc import PockChambolle 5 | from .pgd import ProximalGradientDescent 6 | from .base import Algorithm 7 | from .tune.dpir import log_descent 8 | from .specialization import AutoTuneSolver, DEQSolver, UnrolledSolver 9 | from .primitives import compile, specialize, visualize, optimize, train 10 | -------------------------------------------------------------------------------- /dprox/algo/hqs.py: -------------------------------------------------------------------------------- 1 | from .admm import ADMM 2 | 3 | 4 | class HQS(ADMM): 5 | def initialize(self, x0): 6 | x = x0 7 | z = self.K.forward(x, return_list=True) 8 | return x, z 9 | 10 | def _iter(self, state, rho, lam): 11 | x, z = state 12 | x = self.least_square.solve(z, rho) 13 | Kx = self.K.forward(x, return_list=True) # cache Kx 14 | for i, fn in enumerate(self.psi_fns): 15 | z[i] = fn.prox(Kx[i], lam=lam[fn]) 16 | return x, z 17 | 18 | @property 19 | def state_split(self): 20 | return [1, [len(self.psi_fns)]] 21 | -------------------------------------------------------------------------------- /dprox/algo/invert.py: -------------------------------------------------------------------------------- 1 | from dprox.proxfn import least_squares, ext_sum_squares, least_squares 2 | from dprox.linop import Variable 3 | 4 | 5 | def get_least_square_solver(psi_fns, omega_fns, try_diagonalize, try_freq_diagonalize, linear_solve_config): 6 | prox_fns = psi_fns + omega_fns 7 | 8 | ext_sq = [fn for fn in omega_fns if isinstance(fn, ext_sum_squares)] 9 | for fn in ext_sq: 10 | other = [f for f in prox_fns if f is not fn] 11 | if all(isinstance(fn.linop, Variable) for fn in other): 12 | return ext_sq[0].setup([f.b for f in omega_fns if f is not fn and f not in ext_sq]) 13 | 14 | return least_squares(omega_fns, psi_fns, try_diagonalize, try_freq_diagonalize, linear_solve_config=linear_solve_config) 15 | -------------------------------------------------------------------------------- /dprox/algo/lp/__init__.py: -------------------------------------------------------------------------------- 1 | from .solvers import LPSolverADMM, LPConvergenceLoss, LPProblem -------------------------------------------------------------------------------- /dprox/algo/opt/__init__.py: -------------------------------------------------------------------------------- 1 | from . import absorb 2 | from . import merge 3 | -------------------------------------------------------------------------------- /dprox/algo/opt/absorb.py: -------------------------------------------------------------------------------- 1 | # Absorb linear operators into proximal operators. 2 | 3 | from dprox.linop import scale, mosaic 4 | from dprox.proxfn import (sum_squares, weighted_sum_squares) 5 | 6 | WEIGHTED = {sum_squares: weighted_sum_squares} 7 | 8 | 9 | def absorb_all_linops(prox_funcs): 10 | """Repeatedy absorb lin ops. 11 | """ 12 | new_proxes = [] 13 | ready = prox_funcs[:] 14 | while len(ready) > 0: 15 | curr = ready.pop(0) 16 | absorbed = absorb_linop(curr) 17 | if len(absorbed) == 1 and absorbed[0] == curr: 18 | new_proxes.append(absorbed[0]) 19 | else: 20 | ready += absorbed 21 | return new_proxes 22 | 23 | 24 | def absorb_linop(prox_fn): 25 | """If possible moves the top level lin op argument 26 | into the prox operator. 27 | 28 | For example, elementwise multiplication can be folded into 29 | a separable function's prox. 30 | """ 31 | if isinstance(prox_fn.linop, mosaic) and isinstance(prox_fn, sum_squares): 32 | new_fn = weighted_sum_squares(prox_fn.linop.input_nodes[0], prox_fn.linop, prox_fn.offset) 33 | return [new_fn] 34 | 35 | # Fold scalar into the function. 36 | if isinstance(prox_fn.linop, scale): 37 | scalar = prox_fn.linop.scalar 38 | prox_fn.linop = prox_fn.linop.input_nodes[0] 39 | prox_fn.beta = prox_fn.beta * scalar 40 | return [prox_fn] 41 | # No change. 42 | return [prox_fn] 43 | -------------------------------------------------------------------------------- /dprox/algo/opt/equil.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import math 4 | 5 | 6 | def equil(K, iters=50, gamma=1e-1, M=math.log(1e4)): 7 | """Computes diagonal D, E so that DKE is approximately equilibrated. 8 | """ 9 | m = K.output_size 10 | n = K.input_size 11 | alpha, beta = get_alpha_beta(m, n) 12 | 13 | u = np.zeros(m) 14 | v = np.zeros(n) 15 | ubar = u.copy() 16 | vbar = v.copy() 17 | 18 | in_buf = np.zeros(n) 19 | out_buf = np.zeros(m) 20 | 21 | # Main loop. 22 | for t in range(1, iters + 1): 23 | step_size = 2 / (gamma * (t + 1)) 24 | # u grad estimate. 25 | s = np.random.choice([-1, 1], size=n) 26 | K.forward(np.exp(v) * s, out_buf) 27 | u_grad = np.exp(2 * u) * np.square(out_buf) - alpha**2 + gamma * u 28 | 29 | # v grad estimate. 30 | w = np.random.choice([-1, 1], size=m) 31 | K.adjoint(np.exp(u) * w, in_buf) 32 | v_grad = np.exp(2 * v) * np.square(in_buf) - beta**2 + gamma * v 33 | 34 | u = project(u - step_size * u_grad, M) 35 | v = project(v - step_size * v_grad, M) 36 | # Update averages. 37 | ubar = 2 * u / (t + 2) + t * ubar / (t + 2) 38 | vbar = 2 * v / (t + 2) + t * vbar / (t + 2) 39 | 40 | return np.exp(ubar), np.exp(vbar) 41 | 42 | 43 | def get_alpha_beta(m, n): 44 | return (n / m)**(0.25), (m / n)**(0.25) 45 | 46 | 47 | def project(x, M): 48 | """Project x onto [-M, M]^n. 49 | """ 50 | return np.minimum(M, np.maximum(x, -M, out=x), out=x) 51 | 52 | # Comparison method. 53 | 54 | 55 | def f(A, u, v, gamma, p=2): 56 | m, n = A.shape 57 | alpha, beta = get_alpha_beta(m, n) 58 | total = (1. / p) * np.exp(p * u).T.dot(np.power(np.abs(A), p)).dot(np.exp(p * v)) 59 | total += -alpha**p * u.sum() - beta**p * v.sum() + (gamma / 2) * ((u * u).sum() + (v * v).sum()) 60 | return np.sum(total) 61 | 62 | 63 | def get_grad(A, u, v, gamma, p=2): 64 | m, n = A.shape 65 | alpha, beta = get_alpha_beta(m, n) 66 | 67 | tmp = np.diag(np.exp(p * u)).dot((A * A)).dot(np.exp(p * v)) 68 | grad_u = tmp - alpha**p + gamma * u 69 | du = -grad_u / (2 * tmp + gamma) 70 | 71 | tmp = np.diag(np.exp(p * v)).dot((A.T * A.T)).dot(np.exp(p * u)) 72 | grad_v = tmp - beta**p + gamma * v 73 | dv = -grad_v / (2 * tmp + gamma) 74 | 75 | return du, dv, grad_u, grad_v 76 | 77 | 78 | def newton_equil(A, gamma, max_iters): 79 | alpha = 0.25 80 | beta = 0.5 81 | m, n = A.shape 82 | u = np.zeros(m) 83 | v = np.zeros(n) 84 | for i in range(max_iters): 85 | du, dv, grad_u, grad_v = get_grad(A, u, v, gamma) 86 | # Backtracking line search. 87 | t = 1 88 | obj = f(A, u, v, gamma) 89 | grad_term = np.sum(alpha * (grad_u.dot(du) + grad_v.dot(dv))) 90 | while True: 91 | new_obj = f(A, u + t * du, v + t * dv, gamma) 92 | if new_obj > obj + t * grad_term: 93 | t = beta * t 94 | else: 95 | u = u + t * du 96 | v = v + t * dv 97 | break 98 | return np.exp(u), np.exp(v) 99 | -------------------------------------------------------------------------------- /dprox/algo/opt/merge.py: -------------------------------------------------------------------------------- 1 | # Merge proximal operators together. 2 | 3 | # from proximal.prox_fns import sum_squares, zero_prox 4 | import numpy as np 5 | 6 | 7 | def merge_all(prox_fns): 8 | """Merge as many prox functions as possible. 9 | """ 10 | while True: 11 | merged = [] 12 | new_prox_fns = [] 13 | no_merges = True 14 | for i in range(len(prox_fns)): 15 | for j in range(i + 1, len(prox_fns)): 16 | if prox_fns[i] not in merged and prox_fns[j] not in merged and \ 17 | can_merge(prox_fns[i], prox_fns[j]): 18 | no_merges = False 19 | merged += [prox_fns[i], prox_fns[j]] 20 | new_prox_fns.append(merge_fns(prox_fns[i], prox_fns[j])) 21 | if no_merges: 22 | break 23 | prox_fns = new_prox_fns + [fn for fn in prox_fns if fn not in merged] 24 | 25 | return prox_fns 26 | 27 | 28 | def can_merge(lh_prox, rh_prox): 29 | """Can lh_prox and rh_prox be merged into a single function? 30 | """ 31 | # Lin ops must be the same. 32 | if lh_prox.lin_op == rh_prox.lin_op: 33 | if type(lh_prox) == zero_prox or type(rh_prox) == zero_prox: 34 | return True 35 | elif type(lh_prox) == sum_squares or type(rh_prox) == sum_squares: 36 | return True 37 | 38 | return False 39 | 40 | 41 | def merge_fns(lh_prox, rh_prox): 42 | """Merge the two functions into a single function. 43 | """ 44 | assert can_merge(lh_prox, rh_prox) 45 | new_c = lh_prox.c + rh_prox.c 46 | new_gamma = lh_prox.gamma + rh_prox.gamma 47 | new_d = lh_prox.d + rh_prox.d 48 | args = [lh_prox, rh_prox] 49 | arg_types = [type(lh_prox), type(rh_prox)] 50 | # Merge a linear term into the other proxable function. 51 | if zero_prox in arg_types: 52 | to_copy = args[1 - arg_types.index(zero_prox)] 53 | return to_copy.copy(c=new_c, gamma=new_gamma, d=new_d) 54 | # Merge a sum squares term into the other proxable function. 55 | elif sum_squares in arg_types: 56 | idx = arg_types.index(sum_squares) 57 | sq_fn = args[idx] 58 | to_copy = args[1 - idx] 59 | coeff = sq_fn.alpha * sq_fn.beta 60 | return to_copy.copy(c=new_c - 2 * coeff * sq_fn.b, 61 | gamma=new_gamma + coeff * sq_fn.beta, 62 | d=new_d + np.square(sq_fn.b).sum()) 63 | else: 64 | raise ValueError("Unknown merge strategy.") 65 | -------------------------------------------------------------------------------- /dprox/algo/pc.py: -------------------------------------------------------------------------------- 1 | from ..linop import adjoint 2 | from .admm import ADMM 3 | from .base import expand 4 | 5 | 6 | class PockChambolle(ADMM): 7 | def initialize(self, x0): 8 | x = x0 9 | xbar = x.clone() 10 | z = self.K.forward(x, return_list=True) 11 | return x, z, xbar 12 | 13 | def _iter(self, state, rho, lam): 14 | x, z, xbar = state 15 | 16 | # update z 17 | Kxbar = self.K.forward(xbar, return_list=True) 18 | for i, fn in enumerate(self.psi_fns): 19 | r = expand(lam[fn]) 20 | z[i] = z[i] + r * Kxbar[i] 21 | z[i] = z[i] - r * fn.prox(z[i], lam=r) 22 | 23 | # update x 24 | # Ktz = self.K.adjoint(z) 25 | Ktz = [adjoint(fn.linop, z[i]) for i, fn in enumerate(self.psi_fns)] 26 | x_next = [x - Ktz[i] for i in range(len(Ktz))] 27 | if len(self.omega_fns) > 0: 28 | x_next = self.least_square.solve(x_next, rho) 29 | else: 30 | x_next = sum(x_next) 31 | 32 | # update xbar 33 | xbar = x_next + x_next - x 34 | x = x_next 35 | 36 | return x, z, xbar 37 | 38 | @property 39 | def state_split(self): 40 | return [1, [len(self.psi_fns)], 1] 41 | -------------------------------------------------------------------------------- /dprox/algo/pgd.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from dprox.proxfn import ProxFn 4 | 5 | from .base import Algorithm, expand 6 | 7 | 8 | class ProximalGradientDescent(Algorithm): 9 | @classmethod 10 | def partition(cls, prox_fns: List[ProxFn]): 11 | if len(prox_fns) != 2: 12 | raise ValueError('Proximal gradient descent only supports \ 13 | two proximal functions for now.') 14 | 15 | omega_fns = [] 16 | for fn in prox_fns: 17 | if hasattr(fn, 'grad'): 18 | omega_fns.append(fn) 19 | 20 | psi_fns = [fn for fn in prox_fns if fn not in omega_fns] 21 | 22 | if len(omega_fns) == 0: 23 | raise ValueError('Proximal gradient descent requires \ 24 | at least one proximal function is differentiable.') 25 | 26 | return psi_fns, omega_fns 27 | 28 | def __init__( 29 | self, 30 | psi_fns: List[ProxFn], 31 | omega_fns: List[ProxFn], 32 | *args, 33 | **kwargs 34 | ): 35 | super().__init__(psi_fns, omega_fns) 36 | self.diff_fn = omega_fns[0] 37 | self.prox_fn = psi_fns[0] 38 | 39 | def _iter(self, state, rho, lam): 40 | x = state[0] 41 | v = x - expand(rho) * self.diff_fn.grad(x) 42 | x = self.prox_fn.prox(v, lam[self.prox_fn]) 43 | return [x] 44 | 45 | def initialize(self, x0): 46 | return [x0] 47 | 48 | @property 49 | def state_split(self): 50 | return [1] 51 | 52 | @property 53 | def nparams(self): 54 | return len(self.psi_fns) + 1 55 | -------------------------------------------------------------------------------- /dprox/algo/specialization/__init__.py: -------------------------------------------------------------------------------- 1 | from .deq import DEQSolver, train_deq 2 | from .unroll import UnrolledSolver, build_unrolled_solver 3 | from .rl import AutoTuneSolver -------------------------------------------------------------------------------- /dprox/algo/specialization/deq/__init__.py: -------------------------------------------------------------------------------- 1 | from .solver import DEQSolver 2 | from .training import train_deq -------------------------------------------------------------------------------- /dprox/algo/specialization/deq/solver.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | from dprox.algo.base import Algorithm, move, auto_convert_to_tensor 8 | from dprox.utils import to_torch_tensor 9 | 10 | from .utils.solvers import anderson 11 | 12 | 13 | class DEQ(nn.Module): 14 | def __init__(self, fn, f_thres=40, b_thres=40): 15 | super().__init__() 16 | self.fn = fn 17 | 18 | self.f_thres = f_thres 19 | self.b_thres = b_thres 20 | 21 | self.solver = anderson 22 | 23 | def forward(self, x, *args, **kwargs): 24 | f_thres = kwargs.get('f_thres', self.f_thres) 25 | b_thres = kwargs.get('b_thres', self.b_thres) 26 | 27 | z0 = x 28 | 29 | # Forward pass 30 | with torch.no_grad(): 31 | out = self.solver(lambda z: self.fn(z, x, *args), z0, threshold=f_thres) 32 | z_star = out['result'] # See step 2 above 33 | new_z_star = z_star 34 | 35 | # (Prepare for) Backward pass, see step 3 above 36 | if self.training: 37 | new_z_star = self.fn(z_star.requires_grad_(), x, *args) 38 | 39 | # Jacobian-related computations, see additional step above. For instance: 40 | # jac_loss = jac_loss_estimate(new_z_star, z_star, vecs=1) 41 | 42 | def backward_hook(grad): 43 | if self.hook is not None: 44 | self.hook.remove() 45 | torch.cuda.synchronize() # To avoid infinite recursion 46 | # Compute the fixed point of yJ + grad, where J=J_f is the Jacobian of f at z_star 47 | out = self.solver(lambda y: torch.autograd.grad(new_z_star, z_star, y, retain_graph=True)[0] + grad, 48 | torch.zeros_like(grad), threshold=b_thres) 49 | new_grad = out['result'] 50 | return new_grad 51 | 52 | self.hook = new_z_star.register_hook(backward_hook) 53 | 54 | return new_z_star 55 | 56 | 57 | class DEQSolver(nn.Module): 58 | def __init__(self, solver: Algorithm, learned_params=False, rhos=None, lams=None): 59 | super().__init__() 60 | self.internal = solver 61 | 62 | def fn(z, x, *args): 63 | state = solver.unpack(z) 64 | state = solver.iter(state, *args) 65 | return solver.pack(state) 66 | self.solver = DEQ(fn) 67 | 68 | self.learned_params = learned_params 69 | if learned_params: 70 | self.r = nn.parameter.Parameter(torch.tensor(1.)) 71 | self.l = nn.parameter.Parameter(torch.tensor(1.)) 72 | 73 | self.rhos = rhos 74 | self.lams = lams 75 | 76 | @auto_convert_to_tensor(['x0', 'rhos', 'lams'], batchify=['x0']) 77 | def solve( 78 | self, 79 | x0: Union[torch.Tensor, np.ndarray] = None, 80 | rhos: Union[float, torch.Tensor, np.ndarray] = None, 81 | lams: Union[float, torch.Tensor, np.ndarray, dict] = None, 82 | **kwargs 83 | ): 84 | x0, rhos, lams, _ = self.internal.defaults(x0, rhos, lams, 1) 85 | x0, rhos, lams = move(x0, rhos, lams, device=self.internal.device) 86 | 87 | lam = {k: to_torch_tensor(v)[..., 0].to(x0.device) for k, v in lams.items()} 88 | rho = to_torch_tensor(rhos)[..., 0].to(x0.device) 89 | 90 | if self.learned_params: 91 | rho = self.r * rho 92 | lam = {k: v * self.l for k, v in lam.items()} 93 | 94 | state = self.internal.initialize(x0) 95 | state = self.internal.pack(state) 96 | state = self.solver(state, rho, lam) 97 | state = self.internal.unpack(state) 98 | return state[0] 99 | 100 | def forward( 101 | self, 102 | **kwargs 103 | ): 104 | return self.solve(**kwargs) 105 | 106 | def load(self, state_dict, strict=True): 107 | self.load_state_dict(state_dict['solver'], strict=strict) 108 | self.rhos = state_dict.get('rhos') 109 | self.lams = state_dict.get('lams') 110 | -------------------------------------------------------------------------------- /dprox/algo/specialization/deq/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Delta-Prox/15fc489e99c4de2676dc538d9c0552338f6fb894/dprox/algo/specialization/deq/utils/__init__.py -------------------------------------------------------------------------------- /dprox/algo/specialization/deq/utils/jacobian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | 7 | def jac_loss_estimate(f0, z0, vecs=2, create_graph=True): 8 | """Estimating tr(J^TJ)=tr(JJ^T) via Hutchinson estimator 9 | 10 | Args: 11 | f0 (torch.Tensor): Output of the function f (whose J is to be analyzed) 12 | z0 (torch.Tensor): Input to the function f 13 | vecs (int, optional): Number of random Gaussian vectors to use. Defaults to 2. 14 | create_graph (bool, optional): Whether to create backward graph (e.g., to train on this loss). 15 | Defaults to True. 16 | 17 | Returns: 18 | torch.Tensor: A 1x1 torch tensor that encodes the (shape-normalized) jacobian loss 19 | """ 20 | vecs = vecs 21 | result = 0 22 | for i in range(vecs): 23 | v = torch.randn(*z0.shape).to(z0) 24 | vJ = torch.autograd.grad(f0, z0, v, retain_graph=True, create_graph=create_graph)[0] 25 | result += vJ.norm()**2 26 | return result / vecs / np.prod(z0.shape) 27 | 28 | def power_method(f0, z0, n_iters=200): 29 | """Estimating the spectral radius of J using power method 30 | 31 | Args: 32 | f0 (torch.Tensor): Output of the function f (whose J is to be analyzed) 33 | z0 (torch.Tensor): Input to the function f 34 | n_iters (int, optional): Number of power method iterations. Defaults to 200. 35 | 36 | Returns: 37 | tuple: (largest eigenvector, largest (abs.) eigenvalue) 38 | """ 39 | evector = torch.randn_like(z0) 40 | bsz = evector.shape[0] 41 | for i in range(n_iters): 42 | vTJ = torch.autograd.grad(f0, z0, evector, retain_graph=(i < n_iters-1), create_graph=False)[0] 43 | evalue = (vTJ * evector).reshape(bsz, -1).sum(1, keepdim=True) / (evector * evector).reshape(bsz, -1).sum(1, keepdim=True) 44 | evector = (vTJ.reshape(bsz, -1) / vTJ.reshape(bsz, -1).norm(dim=1, keepdim=True)).reshape_as(z0) 45 | return (evector, torch.abs(evalue)) -------------------------------------------------------------------------------- /dprox/algo/specialization/deq/utils/layer_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | 6 | def list2vec(z1_list): 7 | """Convert list of tensors to a vector""" 8 | bsz = z1_list[0].size(0) 9 | return torch.cat([elem.reshape(bsz, -1, 1) for elem in z1_list], dim=1) 10 | 11 | 12 | def vec2list(z1, cutoffs): 13 | """Convert a vector back to a list, via the cutoffs specified""" 14 | bsz = z1.shape[0] 15 | z1_list = [] 16 | start_idx, end_idx = 0, cutoffs[0][0] * cutoffs[0][1] * cutoffs[0][2] 17 | for i in range(len(cutoffs)): 18 | z1_list.append(z1[:, start_idx:end_idx].view(bsz, *cutoffs[i])) 19 | if i < len(cutoffs)-1: 20 | start_idx = end_idx 21 | end_idx += cutoffs[i + 1][0] * cutoffs[i + 1][1] * cutoffs[i + 1][2] 22 | return z1_list 23 | 24 | 25 | def conv3x3(in_planes, out_planes, stride=1, bias=False): 26 | """3x3 convolution with padding""" 27 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=bias) 28 | 29 | def conv5x5(in_planes, out_planes, stride=1, bias=False): 30 | """5x5 convolution with padding""" 31 | return nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride, padding=2, bias=bias) 32 | 33 | 34 | def norm_diff(new, old, show_list=False): 35 | if show_list: 36 | return [(new[i] - old[i]).norm().item() for i in range(len(new))] 37 | return np.sqrt(sum((new[i] - old[i]).norm().item()**2 for i in range(len(new)))) -------------------------------------------------------------------------------- /dprox/algo/specialization/rl/__init__.py: -------------------------------------------------------------------------------- 1 | from .solver import * -------------------------------------------------------------------------------- /dprox/algo/specialization/unroll.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from functools import partial 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from ..base import auto_convert_to_tensor, move 8 | 9 | 10 | def clone(x, nums, share): 11 | return [x if share else copy.deepcopy(x) for _ in range(nums)] 12 | 13 | 14 | def build_unrolled_solver(solver, share=True, **kwargs): 15 | if share == True: 16 | solver.solve = partial(solver.solve, **kwargs) 17 | return solver 18 | return UnrolledSolver(solver, share=share, **kwargs) 19 | 20 | 21 | class UnrolledSolver(nn.Module): 22 | def __init__(self, solver, max_iter, share=False, learned_params=False): 23 | super().__init__() 24 | if share == False: 25 | self.solvers = nn.ModuleList(clone(solver, max_iter, share=share)) 26 | else: 27 | self.solver = solver 28 | self.solvers = [self.solver for _ in range(max_iter)] 29 | 30 | self.max_iter = max_iter 31 | self.share = share 32 | 33 | self.learned_params = learned_params 34 | if learned_params: 35 | self.rhos = nn.parameter.Parameter(torch.ones(max_iter)) 36 | self.lams = {} 37 | for fn in solver.psi_fns: 38 | lam = nn.parameter.Parameter(torch.ones(max_iter)) 39 | setattr(self, str(fn), lam) 40 | self.lams[fn] = lam 41 | 42 | @auto_convert_to_tensor(['x0', 'rhos', 'lams'], batchify=['x0']) 43 | def solve(self, x0=None, rhos=None, lams=None, max_iter=None): 44 | x0, rhos, lams = move(x0, rhos, lams, device=self.solvers[0].device) 45 | 46 | if self.learned_params: 47 | rhos, lams = self.rhos, self.lams 48 | 49 | max_iter = self.max_iter if max_iter is None else max_iter 50 | 51 | state = self.solvers[0].initialize(x0) 52 | 53 | for i in range(max_iter): 54 | rho = rhos[..., i:i + 1] 55 | lam = {self.solvers[i].psi_fns[0]: v[..., i:i + 1] for k, v in lams.items()} 56 | state = self.solvers[i].iters(state, rho, lam, 1, False) 57 | 58 | return state[0] 59 | -------------------------------------------------------------------------------- /dprox/algo/tune/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Delta-Prox/15fc489e99c4de2676dc538d9c0552338f6fb894/dprox/algo/tune/__init__.py -------------------------------------------------------------------------------- /dprox/algo/tune/dpir.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def get_rho_sigma_admm(sigma=2.55 / 255, iter_num=15, modelSigma1=49.0, modelSigma2=2.55, w=1.0, lam=0.23): 6 | modelSigmaS = np.logspace(np.log10(modelSigma1), np.log10(modelSigma2), iter_num).astype(np.float32) 7 | modelSigmaS_lin = np.linspace(modelSigma1, modelSigma2, iter_num).astype(np.float32) 8 | sigmas = (modelSigmaS * w + modelSigmaS_lin * (1 - w)) / 255. 9 | rhos = list(map(lambda x: lam * (sigma**2) / (x**2), sigmas)) 10 | return rhos, sigmas 11 | 12 | 13 | def log_descent(upper, lower, iter=24, sigma=0.255 / 255, 14 | w=1.0, lam=0.23, sqrt=False): 15 | """ 16 | generate a list of rhos and sigmas based on given parameters using 17 | logarithmic descent. 18 | 19 | :param upper: The upper bound of the range of modelSigmaS values to be generated using a logarithmic 20 | scale 21 | :param lower: The lower bound of the range of values for modelSigmaS 22 | :param iter: The number of iterations or steps in the descent algorithm, defaults to 24 (optional) 23 | :param sigma: The standard deviation of the noise in the image 24 | :param w: The parameter w is a weight used to balance the logarithmic and linear scales when 25 | generating the sequence of modelSigmaS values. It is used to calculate the sigmas values, which are 26 | the squared values of the modelSigmaS values divided by 255 27 | :param lam: lam is a hyperparameter that controls the strength of the regularization term in the 28 | optimization problem. 29 | :return: two lists: `rhos` and `sigmas`. 30 | """ 31 | modelSigmaS = np.logspace(np.log10(upper), np.log10(lower), iter).astype(np.float32) 32 | modelSigmaS_lin = np.linspace(upper, lower, iter).astype(np.float32) 33 | sigmas = (modelSigmaS * w + modelSigmaS_lin * (1 - w)) / 255. 34 | rhos = list(map(lambda x: lam * (sigma**2) / (x**2), sigmas)) 35 | if not sqrt: 36 | sigmas = list(sigmas**2) 37 | rhos = torch.tensor(rhos).float() 38 | sigmas = torch.tensor(sigmas).float() 39 | return rhos, sigmas 40 | 41 | 42 | def f(params): return [np.sqrt(p) * 255 for p in params] 43 | # this can be 70.94 44 | 45 | 46 | def log_descent2(upper, lower, iter=24, sigma=0.255 / 255, w=1.0, lam=0.23): 47 | modelSigmaS = np.logspace(np.log10(upper), np.log10(lower), iter).astype(np.float32) 48 | modelSigmaS_lin = np.linspace(upper, lower, iter).astype(np.float32) 49 | sigmas = (modelSigmaS * w + modelSigmaS_lin * (1 - w)) / 255. 50 | sigmas = list((sigmas**2) / (sigma**2)) 51 | rhos = list(map(lambda x: lam * (sigma**2) / (x**2), sigmas)) 52 | sigmas = list(map(lambda x: 1 / x, sigmas)) 53 | return rhos, sigmas 54 | 55 | 56 | def log_descent_origin(upper, lower, iter=24, 57 | sigma=0.255 / 255, w=1.0, lam=0.23): 58 | modelSigmaS = np.logspace(np.log10(upper), np.log10(lower), iter).astype(np.float32) 59 | modelSigmaS_lin = np.linspace(upper, lower, iter).astype(np.float32) 60 | sigmas = (modelSigmaS * w + modelSigmaS_lin * (1 - w)) / 255. 61 | rhos = list(map(lambda x: lam * (sigma**2) / (x**2), sigmas)) 62 | return rhos 63 | -------------------------------------------------------------------------------- /dprox/algo/tune/learnable.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class LearnableParamProvider(nn.Module): 6 | def __init__(self, steps, default_value=0.5): 7 | super().__init__() 8 | # for step in steps: 9 | # self.register_buffer(f'') 10 | # self.params = nn.Parameter(torch.ones([len(steps)]) * default_value) 11 | 12 | def forward(self, step): 13 | # return self.params[step] 14 | if step == 0: return torch.tensor(0.3442) 15 | elif step == 6: return torch.tensor(0.6111) 16 | else: return torch.tensor(0.3168) -------------------------------------------------------------------------------- /dprox/contrib/__init__.py: -------------------------------------------------------------------------------- 1 | from .restoration import * 2 | from . import csmri, derain, optic -------------------------------------------------------------------------------- /dprox/contrib/derain.py: -------------------------------------------------------------------------------- 1 | # Learnable Linear Operator from 2 | # https://github.com/MC-E/Deep-Generalized-Unfolding-Networks-for-Image-Restoration/blob/main/Deraining/DGUNet.py 3 | # Deep Generalized Unfolding Networks for Image Restoration 4 | 5 | import torch.nn as nn 6 | 7 | 8 | class ResBlock(nn.Module): 9 | def __init__( 10 | self, conv, n_feats, kernel_size, 11 | bias=True, bn=False, act=nn.PReLU(), res_scale=1 12 | ): 13 | super(ResBlock, self).__init__() 14 | m = [] 15 | for i in range(2): 16 | if i == 0: 17 | m.append(conv(n_feats, 64, kernel_size, bias=bias)) 18 | else: 19 | m.append(conv(64, n_feats, kernel_size, bias=bias)) 20 | if bn: 21 | m.append(nn.BatchNorm2d(n_feats)) 22 | if i == 0: 23 | m.append(act) 24 | 25 | self.body = nn.Sequential(*m) 26 | self.res_scale = res_scale 27 | 28 | def forward(self, x): 29 | res = self.body(x).mul(self.res_scale) 30 | res += x 31 | return res 32 | 33 | 34 | def default_conv(in_channels, out_channels, kernel_size, stride=1, bias=True): 35 | return nn.Conv2d( 36 | in_channels, out_channels, kernel_size, 37 | padding=(kernel_size // 2), stride=stride, bias=bias) 38 | 39 | 40 | class LearnableDegOp(nn.Module): 41 | def __init__(self, diag=False): 42 | super().__init__() 43 | self.phi_0 = ResBlock(default_conv, 3, 3) 44 | self.phi_1 = ResBlock(default_conv, 3, 3) 45 | self.phi_6 = ResBlock(default_conv, 3, 3) 46 | self.phit_0 = ResBlock(default_conv, 3, 3) 47 | self.phit_1 = ResBlock(default_conv, 3, 3) 48 | self.phit_6 = ResBlock(default_conv, 3, 3) 49 | 50 | if diag: 51 | self.phid_0 = ResBlock(default_conv, 3, 3) 52 | self.phid_1 = ResBlock(default_conv, 3, 3) 53 | self.phid_6 = ResBlock(default_conv, 3, 3) 54 | 55 | self.max_step = 5 56 | self.step = 0 57 | 58 | def forward(self, x, step=None): 59 | if step is None: step = self.step 60 | if step == 0: 61 | return self.phi_0(x) 62 | elif step == self.max_step + 1: 63 | return self.phi_6(x) 64 | else: 65 | return self.phi_1(x) 66 | 67 | def adjoint(self, x, step=None): 68 | if step is None: step = self.step 69 | if step == 0: 70 | return self.phit_0(x) 71 | elif step == self.max_step + 1: 72 | return self.phit_6(x) 73 | else: 74 | return self.phit_1(x) 75 | 76 | def diag(self, x, step=None): 77 | if step is None: step = self.step 78 | if step == 0: 79 | return self.phid_0(x) 80 | elif step == self.max_step + 1: 81 | return self.phid_6(x) 82 | else: 83 | return self.phid_1(x) 84 | -------------------------------------------------------------------------------- /dprox/contrib/energy_system.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import io 3 | 4 | from dprox.utils.huggingface import load_path 5 | 6 | 7 | def load_simple_cep_model(): 8 | model_components = io.loadmat(load_path("energy_system/simple_cep_model_20220916/esm_instance.mat")) 9 | n_con, n_var = model_components["A"].shape 10 | print("Number of linear constraints (w/o bound constraints):", n_con) 11 | print("Number of decision variables:", n_var) 12 | 13 | A = model_components["A"].astype(np.float64) 14 | b = model_components["rhs"].astype(np.float64) 15 | types = model_components["sense"] 16 | 17 | A_ub = A[types == '<'] 18 | b_ub = b[types == '<'][:, 0] 19 | n1 = sum(types == '<') 20 | print('n1, A_ub, b_ub:', n1, A_ub.shape, b_ub.shape) 21 | 22 | A_eq = A[types == '='] 23 | b_eq = b[types == '='][:, 0] 24 | n2 = sum(types == '=') 25 | print('n2, A_eq, b_eq:', n2, A_eq.shape, b_eq.shape) 26 | assert n1 + n2 == n_con 27 | 28 | c = model_components["obj"][:, 0] 29 | print('c:', c.shape) 30 | 31 | return c, A_ub, A_eq, b_ub, b_eq 32 | -------------------------------------------------------------------------------- /dprox/contrib/optic/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .doe_model import * 3 | # from .doe_model2 import * 4 | from .unet import U_Net -------------------------------------------------------------------------------- /dprox/contrib/restoration.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.misc 3 | import scipy.ndimage 4 | import torch 5 | 6 | from dprox.utils import to_ndarray, to_torch_tensor 7 | 8 | samples = { 9 | "face": scipy.misc.face(), 10 | "ascent": scipy.misc.ascent(), 11 | } 12 | 13 | 14 | def sample(name="face", return_tensor=True): 15 | s = samples[name].copy().astype("float32") / 255 16 | if return_tensor: 17 | s = to_torch_tensor(s, batch=True).float() 18 | return s 19 | 20 | 21 | def point_spread_function(ksize, sigma): 22 | return np.expand_dims(fspecial_gaussian(ksize, sigma), axis=2).astype("float32") 23 | 24 | 25 | def blurring(img, psf): 26 | device = img.device 27 | img = to_ndarray(img, debatch=True) 28 | psf = to_ndarray(psf) 29 | b = scipy.ndimage.convolve(img, psf, mode="wrap") 30 | b = to_torch_tensor(b, batch=True).to(device) 31 | return b 32 | 33 | 34 | def fspecial_gaussian(hsize, sigma): 35 | hsize = [hsize, hsize] 36 | siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] 37 | std = sigma 38 | [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) 39 | arg = -(x * x + y * y) / (2 * std * std) 40 | h = np.exp(arg) 41 | h[h < scipy.finfo(float).eps * h.max()] = 0 42 | sumh = h.sum() 43 | if sumh != 0: 44 | h = h / sumh 45 | return h 46 | 47 | 48 | def downsampling(img, psf, sf): 49 | import cv2 50 | 51 | device = img.device 52 | img = to_ndarray(img, debatch=True) 53 | psf = to_ndarray(psf) 54 | blurred = scipy.ndimage.filters.convolve(img, psf, mode="wrap") 55 | downed = blurred[0::sf, 0::sf, ...] 56 | x0 = cv2.resize(downed, (downed.shape[1] * sf, downed.shape[0] * sf), interpolation=cv2.INTER_CUBIC) 57 | 58 | x0 = to_torch_tensor(x0, batch=True).to(device).float() 59 | downed = to_torch_tensor(downed, batch=True).to(device).float() 60 | return downed, x0 61 | 62 | 63 | def masks_CFA_Bayer(shape): 64 | pattern = "RGGB" 65 | channels = dict((channel, np.zeros(shape)) for channel in "RGB") 66 | for channel, (y, x) in zip(pattern, [(0, 0), (0, 1), (1, 0), (1, 1)]): 67 | channels[channel][y::2, x::2] = 1 68 | return tuple(channels[c].astype(bool) for c in "RGB") 69 | 70 | 71 | def mosaicing(img): 72 | device = img.device 73 | img = to_ndarray(img, debatch=True) 74 | shape = img.shape[:2] 75 | R_m, G_m, B_m = masks_CFA_Bayer(shape) 76 | mask = np.concatenate((R_m[..., None], G_m[..., None], B_m[..., None]), axis=-1) 77 | b = mask * img 78 | b = to_torch_tensor(b, batch=True).to(device) 79 | return b 80 | 81 | 82 | def mosaicing_np(img): 83 | shape = img.shape[:2] 84 | R_m, G_m, B_m = masks_CFA_Bayer(shape) 85 | mask = np.concatenate((R_m[..., None], G_m[..., None], B_m[..., None]), axis=-1) 86 | b = mask * img 87 | return b 88 | 89 | 90 | def mosaicing_torch(img): 91 | shape = img.shape[-2:] 92 | R_m, G_m, B_m = masks_CFA_Bayer(shape) 93 | mask = np.concatenate((R_m[..., None], G_m[..., None], B_m[..., None]), axis=-1) 94 | mask = torch.from_numpy(mask) 95 | mask = mask.permute(2, 0, 1).unsqueeze(0).to(img.device) 96 | b = mask * img 97 | return b 98 | -------------------------------------------------------------------------------- /dprox/linalg/__init__.py: -------------------------------------------------------------------------------- 1 | from . import solve 2 | from .custom import LinearSolve, LinearSolveConfig, linear_solve -------------------------------------------------------------------------------- /dprox/linalg/custom.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from functools import partial 3 | 4 | import torch 5 | 6 | from .solve import SOLVERS 7 | 8 | 9 | @dataclass 10 | class LinearSolveConfig: 11 | """Defines default configuration parameters for solving linear equations. 12 | 13 | Args: 14 | rtol (float): The relative tolerance level for convergence, default to 1e-6. 15 | max_iters (int): The maximum number of iterations allowed for convergence. 16 | verbose (bool): whether to print progress updates during the solving process. 17 | solver_type (str): The type of solver to use (e.g. conjugate gradient). 18 | solver_kwargs (dict): additional keyword arguments to pass to the solver function 19 | """ 20 | 21 | rtol: float = 1e-6 22 | max_iters: int = 100 23 | verbose: bool = False 24 | solver_type: str = "cg" 25 | solver_kwargs: dict = field(default_factory=dict) 26 | use_analytic_grad: bool = True 27 | 28 | 29 | def _build_solver(config: LinearSolveConfig): 30 | solve_fn = SOLVERS[config.solver_type] 31 | solve_fn = partial(solve_fn, rtol=config.rtol, max_iters=config.max_iters, verbose=config.verbose, **config.solver_kwargs) 32 | return solve_fn 33 | 34 | 35 | def _trainable_parameters(module): 36 | return [p for p in module.parameters() if p.requires_grad] 37 | 38 | 39 | class LinearSolve(torch.autograd.Function): 40 | @staticmethod 41 | def forward(ctx, A, b, config, *Aparams): 42 | ctx.A = A 43 | ctx.linear_solver = _build_solver(config) 44 | x = ctx.linear_solver(A, b) 45 | ctx.save_for_backward(x, *Aparams) 46 | return x 47 | 48 | @staticmethod 49 | def backward(ctx, grad_x): 50 | grad_B = ctx.linear_solver(ctx.A.T, grad_x) 51 | 52 | x = ctx.saved_tensors[0] 53 | x = x.detach().clone() 54 | 55 | A = ctx.A.clone() 56 | with torch.enable_grad(): 57 | loss = -A(x) 58 | grad_Aparams = torch.autograd.grad( 59 | (loss,), _trainable_parameters(A), grad_outputs=(grad_B,), create_graph=torch.is_grad_enabled(), allow_unused=True 60 | ) 61 | 62 | return (None, grad_B, None, *grad_Aparams) 63 | 64 | 65 | def linear_solve(A: torch.nn.Module, b: torch.Tensor, config: LinearSolveConfig = LinearSolveConfig()): 66 | """Solves a linear system of equations with analytic gradient. 67 | 68 | Args: 69 | A (torch.nn.Module): A is a torch.nn.Module object, it should be callable as A(x) for forward operator 70 | of the linear operator. 71 | b (torch.Tensor): b is a tensor representing the right-hand side of the linear system of equations Ax = b. 72 | config (LinearSolveConfig): `config` is an instance of the `LinearSolveConfig` class, which 73 | contains various configuration options for the linear solver. These options include the maximum 74 | number of iterations, the tolerance level for convergence, and the method used to solve the linear system. 75 | 76 | Returns: 77 | The solution of Ax = b. 78 | """ 79 | if config.use_analytic_grad: 80 | return LinearSolve.apply(A, b, config, *_trainable_parameters(A)) 81 | solver = _build_solver(config) 82 | return solver(A, b) 83 | 84 | 85 | def pcg(A: torch.nn.Module, b: torch.Tensor, rtol: float = 1e-6, max_iters: int = 100, verbose: bool = False, **kwargs): 86 | config = LinearSolveConfig(rtol=rtol, max_iters=max_iters, verbose=verbose, solver_kwargs=kwargs, solver_type="pcg") 87 | return LinearSolve.apply(A, b, config, *_trainable_parameters(A)) 88 | -------------------------------------------------------------------------------- /dprox/linalg/solve/__init__.py: -------------------------------------------------------------------------------- 1 | from .solver_cg import cg, cg2, pcg 2 | from .solver_plss import plss, plssw 3 | from .solver_minres import minres 4 | 5 | 6 | __all__ = available_solvers = [ 7 | 'cg', 8 | 'cg2', 9 | 'pcg', 10 | 'plss', 11 | 'plssw', 12 | 'minres', 13 | ] 14 | 15 | SOLVERS = { 16 | 'cg': cg, 17 | 'cg2': cg2, 18 | 'pcg': pcg, 19 | 'plss': plss, 20 | 'plssw': plssw, 21 | 'minres': minres, 22 | } 23 | -------------------------------------------------------------------------------- /dprox/linop/__init__.py: -------------------------------------------------------------------------------- 1 | from .blackbox import LinOpFactory, BlackBox 2 | from .conv import conv, conv_doe 3 | from .constant import Constant 4 | from .comp_graph import CompGraph, est_CompGraph_norm, eval, adjoint, gram, validate 5 | from .scale import scale 6 | from .subsample import mosaic 7 | from .sum import sum, copy 8 | from .variable import Variable 9 | from .base import LinOp 10 | from .vstack import vstack, split 11 | from .placeholder import Placeholder 12 | from .grad import grad 13 | from .mul import mul_color, mul_elementwise -------------------------------------------------------------------------------- /dprox/linop/blackbox.py: -------------------------------------------------------------------------------- 1 | from .base import LinOp 2 | 3 | 4 | def LinOpFactory(forward, adjoint, diag=None, norm_bound=None): 5 | """Returns a function to generate a custom LinOp. 6 | 7 | Parameters 8 | ---------- 9 | input_shape : tuple 10 | The dimensions of the input. 11 | output_shape : tuple 12 | The dimensions of the output. 13 | forward : function 14 | Applies the operator to an input array and writes to an output. 15 | adjoint : function 16 | Applies the adjoint operator to an input array and writes to an output. 17 | norm_bound : float, optional 18 | An upper bound on the spectral norm of the operator. 19 | """ 20 | def get_black_box(*args): 21 | return BlackBox(*args, forward=forward, adjoint=adjoint, diag=diag, norm_bound=norm_bound) 22 | return get_black_box 23 | 24 | 25 | class BlackBox(LinOp): 26 | """A black-box lin op specified by the user. 27 | """ 28 | 29 | def __init__(self, *args, forward=None, adjoint=None, diag=None, norm_bound=None): 30 | self._forward = forward 31 | self._adjoint = adjoint 32 | self._norm_bound = norm_bound 33 | self._diag = diag 34 | super(BlackBox, self).__init__(args) 35 | 36 | def forward(self, *inputs): 37 | """The forward operator. 38 | 39 | Reads from inputs and writes to outputs. 40 | """ 41 | if len(inputs) == 1: 42 | return self._forward(inputs[0], step=self.step) 43 | return self._forward(*inputs, step=self.step) 44 | 45 | def adjoint(self, *inputs): 46 | """The adjoint operator. 47 | 48 | Reads from inputs and writes to outputs. 49 | """ 50 | if len(inputs) == 1: 51 | return self._adjoint(inputs[0], step=self.step) 52 | return self._adjoint(*inputs, step=self.step) 53 | 54 | def norm_bound(self, input_mags): 55 | """Gives an upper bound on the magnitudes of the outputs given inputs. 56 | 57 | Parameters 58 | ---------- 59 | input_mags : list 60 | List of magnitudes of inputs. 61 | 62 | Returns 63 | ------- 64 | float 65 | Magnitude of outputs. 66 | """ 67 | if self._norm_bound is None: 68 | return super(BlackBox, self).norm_bound(input_mags) 69 | else: 70 | return self._norm_bound * input_mags[0] 71 | 72 | def is_gram_diag(self, freq=False): 73 | """Is the lin op's Gram matrix diagonal (in the frequency domain)? 74 | """ 75 | return self._diag != None 76 | 77 | def get_diag(self, x, freq=False): 78 | return self._diag(x, self.step) 79 | -------------------------------------------------------------------------------- /dprox/linop/constaints.py: -------------------------------------------------------------------------------- 1 | 2 | class matmul: 3 | def __init__(self, var, A): 4 | self.A = A 5 | self.var = var 6 | 7 | def __eq__(self, other): 8 | return equality(self, other) 9 | 10 | def __le__(self, other): 11 | return less(self, other) 12 | 13 | 14 | class equality: 15 | def __init__(self, left: matmul, right): 16 | self.left = left 17 | self.right = right 18 | 19 | 20 | class less: 21 | def __init__(self, left: matmul, right): 22 | self.left = left 23 | self.right = right 24 | -------------------------------------------------------------------------------- /dprox/linop/constant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from .base import LinOp 5 | 6 | 7 | class Constant(LinOp): 8 | """A constant. 9 | """ 10 | 11 | def __init__(self, value): 12 | super(Constant, self).__init__([]) 13 | if value is not None and not isinstance(value, torch.Tensor): 14 | value = torch.tensor(value) 15 | self._value = value 16 | 17 | # ---------------------------------------------------------------------------- # 18 | # Computation # 19 | # ---------------------------------------------------------------------------- # 20 | 21 | def forward(self, *value, **kwargs): 22 | """The forward operator. 23 | 24 | Reads from inputs and writes to outputs. 25 | """ 26 | return self.value 27 | 28 | def adjoint(self, value, **kwargs): 29 | """The adjoint operator. 30 | 31 | Reads from inputs and writes to outputs. 32 | """ 33 | return self.value*0 34 | 35 | # ---------------------------------------------------------------------------- # 36 | # Diagonal # 37 | # ---------------------------------------------------------------------------- # 38 | 39 | def is_diag(self, freq=False): 40 | """Is the lin op diagonal (in the frequency domain)? 41 | """ 42 | return True 43 | 44 | def get_diag(self, freq=False): 45 | """Returns the diagonal representation (A^TA)^(1/2). 46 | 47 | Parameters 48 | ---------- 49 | freq : bool 50 | Is the diagonal representation in the frequency domain? 51 | Returns 52 | ------- 53 | dict of variable to ndarray 54 | The diagonal operator acting on each variable. 55 | """ 56 | return {} 57 | 58 | # ---------------------------------------------------------------------------- # 59 | # Property # 60 | # ---------------------------------------------------------------------------- # 61 | 62 | @property 63 | def variables(self): 64 | return [] 65 | 66 | @property 67 | def constants(self): 68 | return [self] 69 | 70 | @property 71 | def value(self): 72 | return self._value.to(self.device) 73 | 74 | 75 | def norm_bound(self, input_mags): 76 | """Gives an upper bound on the magnitudes of the outputs given inputs. 77 | 78 | Parameters 79 | ---------- 80 | input_mags : list 81 | List of magnitudes of inputs. 82 | 83 | Returns 84 | ------- 85 | float 86 | Magnitude of outputs. 87 | """ 88 | return 0.0 89 | 90 | # ---------------------------------------------------------------------------- # 91 | # Python Magic # 92 | # ---------------------------------------------------------------------------- # 93 | 94 | def __repr__(self): 95 | if self._value is not None: 96 | return 'Constant(value=somevalue)' 97 | return 'Constant(value=None)' -------------------------------------------------------------------------------- /dprox/linop/edge.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Edge(object): 5 | """The edge between two lin ops. 6 | """ 7 | 8 | def __init__(self, start, end): 9 | self.start = start 10 | self.end = end 11 | self.data = None 12 | self.mag = None # Used to get norm bounds. 13 | 14 | @property 15 | def size(self): 16 | return np.prod(self.data.shape) 17 | 18 | def __repr__(self) -> str: 19 | return f'Edge(id={id(self)}, start={self.start}, end={self.end}, data={self.data is not None})' -------------------------------------------------------------------------------- /dprox/linop/grad.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from dprox.utils.misc import to_torch_tensor 4 | 5 | from .conv import conv 6 | 7 | 8 | class grad(conv): 9 | """ 10 | gradient operation. can be defined for different dimensions. 11 | default is n-d gradient. 12 | """ 13 | 14 | def __init__(self, arg, dim=1): 15 | if dim not in [0,1,2]: 16 | raise ValueError('dim must be 0(Height) or 1(Width) or 2 (Channel)') 17 | 18 | D = to_torch_tensor([1, -1]) 19 | for _ in range(3-1): 20 | D = D.unsqueeze(0) 21 | D = D.transpose(dim, -1) 22 | 23 | super(grad, self).__init__(arg, kernel=D) 24 | 25 | def get_dims(self): 26 | """Return the dimensinonality of the gradient 27 | """ 28 | return self.dims 29 | 30 | def norm_bound(self, input_mags): 31 | """Gives an upper bound on the magnitudes of the outputs given inputs. 32 | 33 | Parameters 34 | ---------- 35 | input_mags : list 36 | List of magnitudes of inputs. 37 | 38 | Returns 39 | ------- 40 | float 41 | Magnitude of outputs. 42 | """ 43 | # 1D gradient operator has spectral norm = 2. 44 | # ND gradient is permutation of stacked grad in axis 0, axis 1, etc. 45 | # so norm is 2*sqrt(dims) 46 | return 2 * np.sqrt(self.dims) * input_mags[0] 47 | -------------------------------------------------------------------------------- /dprox/linop/mul.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | from dprox.utils.misc import to_torch_tensor 8 | 9 | from .base import LinOp 10 | from .placeholder import Placeholder 11 | 12 | 13 | class mul_color(LinOp): 14 | def __init__( 15 | self, 16 | arg: LinOp, 17 | srf: Union[Placeholder, torch.Tensor, np.array], 18 | ): 19 | super().__init__([arg]) 20 | # srf: [C, C_2] 21 | self._srf = srf 22 | 23 | if isinstance(srf, Placeholder): 24 | def on_change(val): 25 | self.srf = nn.parameter.Parameter(val) 26 | self._srf.change(on_change) 27 | else: 28 | self.srf = nn.parameter.Parameter(to_torch_tensor(srf, batch=True)) 29 | 30 | def forward(self, x, **kwargs): 31 | return self.apply(x, self.srf) 32 | 33 | def adjoint(self, x, **kwargs): 34 | return self.apply(x, self.srf.T) 35 | 36 | def apply(self, x, srf): 37 | N, C, H, W = x.shape 38 | x = x.reshape(N, C, H * W) # N,C,HW 39 | out = srf.T @ x # N,C2,HW 40 | out = out.reshape(N, -1, H, W) 41 | return out 42 | 43 | 44 | class mul_elementwise(LinOp): 45 | def __init__( 46 | self, 47 | arg: LinOp, 48 | w: Union[Placeholder, torch.Tensor, np.array], 49 | ): 50 | super().__init__([arg]) 51 | self._w = w 52 | 53 | if isinstance(w, Placeholder): 54 | def on_change(val): 55 | self.w = nn.parameter.Parameter(val) 56 | self._w.change(on_change) 57 | else: 58 | self.w = nn.parameter.Parameter(to_torch_tensor(w, batch=True)) 59 | 60 | def forward(self, x, **kwargs): 61 | w = self.w.to(x.device) 62 | return w * x 63 | 64 | def adjoint(self, x, **kwargs): 65 | return self.forward(x) 66 | 67 | def is_diag(self, freq=False): 68 | return not freq and self.input_nodes[0].is_diag(freq) 69 | 70 | def get_diag(self, x, freq=False): 71 | if not freq: 72 | return self.w.to(x.device) 73 | return None 74 | -------------------------------------------------------------------------------- /dprox/linop/placeholder.py: -------------------------------------------------------------------------------- 1 | from .constant import Constant 2 | 3 | 4 | class Placeholder(Constant): 5 | def __init__(self, default=None): 6 | super().__init__(default) 7 | self.watchers = [] 8 | 9 | @property 10 | def value(self): 11 | return self._value.to(self.device) 12 | 13 | @value.setter 14 | def value(self, val): 15 | """Assign a value to the variable. 16 | """ 17 | self._value = val 18 | for watcher in self.watchers: 19 | watcher(val) 20 | 21 | def change(self, fn): 22 | self.watchers.append(fn) 23 | -------------------------------------------------------------------------------- /dprox/linop/scale.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from .base import LinOp 5 | 6 | 7 | class scale(LinOp): 8 | """Multiplication scale*X with a fixed scalar. 9 | """ 10 | 11 | def __init__(self, scalar, arg): 12 | assert np.isscalar(scalar) 13 | self.scalar = scalar 14 | super(scale, self).__init__([arg]) 15 | 16 | # ---------------------------------------------------------------------------- # 17 | # Computation # 18 | # ---------------------------------------------------------------------------- # 19 | 20 | def forward(self, input, **kwargs): 21 | """The forward operator. 22 | 23 | Reads from inputs and writes to outputs. 24 | """ 25 | return input * self.scalar 26 | 27 | def adjoint(self, input, **kwargs): 28 | """The adjoint operator. 29 | 30 | Reads from inputs and writes to outputs. 31 | """ 32 | return self.forward(input) 33 | 34 | # ---------------------------------------------------------------------------- # 35 | # Diagonal # 36 | # ---------------------------------------------------------------------------- # 37 | 38 | def is_gram_diag(self, freq=False): 39 | """Is the lin Gram diagonal (in the frequency domain)? 40 | """ 41 | return self.input_nodes[0].is_gram_diag(freq) 42 | 43 | def is_diag(self, freq=False): 44 | """Is the lin op diagonal (in the frequency domain)? 45 | """ 46 | return self.input_nodes[0].is_diag(freq) 47 | 48 | def get_diag(self, ref, freq=False): 49 | """Returns the diagonal representation (A^TA)^(1/2). 50 | 51 | Parameters 52 | ---------- 53 | freq : bool 54 | Is the diagonal representation in the frequency domain? 55 | Returns 56 | ------- 57 | dict of variable to ndarray 58 | The diagonal operator acting on each variable. 59 | """ 60 | var_diags = self.input_nodes[0].get_diag(ref, freq) * self.scalar 61 | return var_diags * torch.conj(var_diags) 62 | 63 | # ---------------------------------------------------------------------------- # 64 | # Property # 65 | # ---------------------------------------------------------------------------- # 66 | 67 | def norm_bound(self, input_mags): 68 | """Gives an upper bound on the magnitudes of the outputs given inputs. 69 | 70 | Parameters 71 | ---------- 72 | input_mags : list 73 | List of magnitudes of inputs. 74 | 75 | Returns 76 | ------- 77 | float 78 | Magnitude of outputs. 79 | """ 80 | return abs(self.scalar) * input_mags[0] 81 | -------------------------------------------------------------------------------- /dprox/linop/subsample.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from dprox.utils.misc import to_nn_parameter, to_torch_tensor 4 | 5 | from .base import LinOp 6 | 7 | 8 | class mosaic(LinOp): 9 | 10 | def __init__(self, arg): 11 | super(mosaic, self).__init__([arg]) 12 | self.cache = {} 13 | 14 | # ---------------------------------------------------------------------------- # 15 | # Computation # 16 | # ---------------------------------------------------------------------------- # 17 | 18 | def forward(self, input, **kwargs): 19 | """The forward operator. 20 | 21 | Reads from inputs and writes to outputs. 22 | """ 23 | mask = self._mask(input.shape).to(input.device) 24 | return mask * input 25 | 26 | def adjoint(self, input, **kwargs): 27 | """The adjoint operator. 28 | 29 | Reads from inputs and writes to outputs. 30 | """ 31 | return self.forward(input) 32 | 33 | @staticmethod 34 | def masks_CFA_Bayer(shape): 35 | pattern = 'RGGB' 36 | channels = dict((channel, np.zeros(shape)) for channel in 'RGB') 37 | for channel, (y, x) in zip(pattern, [(0, 0), (0, 1), (1, 0), (1, 1)]): 38 | channels[channel][y::2, x::2] = 1 39 | return tuple(channels[c].astype(bool) for c in 'RGB') 40 | 41 | def _mask(self, shape): 42 | if shape not in self.cache: 43 | shape = shape[-2:] 44 | R_m, G_m, B_m = self.masks_CFA_Bayer(shape) 45 | mask = np.concatenate((R_m[..., None], G_m[..., None], B_m[..., None]), axis=-1) 46 | self.cache[shape] = to_nn_parameter(to_torch_tensor(mask.astype('float32'), batch=True)) 47 | return self.cache[shape] 48 | 49 | # ---------------------------------------------------------------------------- # 50 | # Diagonal # 51 | # ---------------------------------------------------------------------------- # 52 | 53 | def is_gram_diag(self, freq=False): 54 | """Is the lin op's Gram matrix diagonal (in the frequency domain)? 55 | """ 56 | return self.is_self_diag(freq) and self.input_nodes[0].is_diag(freq) 57 | 58 | def is_self_diag(self, freq=False): 59 | return not freq 60 | 61 | def get_diag(self, x, freq=False): 62 | """Returns the diagonal representation (A^TA)^(1/2). 63 | 64 | Parameters 65 | ---------- 66 | freq : bool 67 | Is the diagonal representation in the frequency domain? 68 | Returns 69 | ------- 70 | dict of variable to ndarray 71 | The diagonal operator acting on each variable. 72 | """ 73 | assert not freq 74 | # var_diags = self.input_nodes[0].get_diag(freq) 75 | # selection = self.get_selection() 76 | # self_diag = np.zeros(self.input_nodes[0].shape) 77 | # self_diag[selection] = 1 78 | # for var in var_diags.keys(): 79 | # var_diags[var] = var_diags[var] * self_diag.ravel() 80 | return self._mask(x.shape).to(self.device) 81 | 82 | # ---------------------------------------------------------------------------- # 83 | # Property # 84 | # ---------------------------------------------------------------------------- # 85 | 86 | def norm_bound(self, input_mags): 87 | """Gives an upper bound on the magnitudes of the outputs given inputs. 88 | 89 | Parameters 90 | ---------- 91 | input_mags : list 92 | List of magnitudes of inputs. 93 | 94 | Returns 95 | ------- 96 | float 97 | Magnitude of outputs. 98 | """ 99 | return input_mags[0] 100 | -------------------------------------------------------------------------------- /dprox/linop/sum.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .base import LinOp 4 | 5 | 6 | class sum(LinOp): 7 | """Sums its inputs. 8 | """ 9 | 10 | def __init__(self, input_nodes): 11 | super(sum, self).__init__(input_nodes) 12 | 13 | def forward(self, *inputs, **kwargs): 14 | """ Just sum all the inputs, all inputs should have the same shape 15 | """ 16 | output = torch.zeros_like(inputs[0]) 17 | for input in inputs: 18 | output += input.to(output.device) 19 | return output 20 | 21 | def adjoint(self, input, **kwargs): 22 | """ The adjoint of sum spread of the input to all its child 23 | """ 24 | outputs = LinOp.MultOutput() 25 | for _ in self.input_nodes: 26 | outputs.append(input) 27 | if len(outputs) > 1: 28 | return outputs 29 | return outputs[0] 30 | 31 | def is_diag(self, freq=False): 32 | """Is the lin op diagonal (in the frequency domain)? 33 | """ 34 | return all([arg.is_diag(freq) for arg in self.input_nodes]) 35 | 36 | def is_gram_diag(self, freq=False): 37 | """Is the lin op diagonal (in the frequency domain)? 38 | """ 39 | return all([arg.is_gram_diag(freq) for arg in self.input_nodes]) 40 | 41 | def get_diag(self, ref, freq=False): 42 | """Returns the diagonal representation (A^TA)^(1/2). 43 | 44 | Parameters 45 | ---------- 46 | freq : bool 47 | Is the diagonal representation in the frequency domain? 48 | Returns 49 | ------- 50 | dict of variable to ndarray 51 | The diagonal operator acting on each variable. 52 | """ 53 | # var_diags = {var: torch.zeros(var.size) for var in self.variables()} 54 | # for arg in self.input_nodes: 55 | # arg_diags = arg.get_diag(shape, freq) 56 | # for var, diag in arg_diags.items(): 57 | # var_diags[var] = var_diags[var] + diag 58 | # return var_diags.values()[0] 59 | return self.input_nodes[0].get_diag(ref, freq) 60 | 61 | def norm_bound(self, input_mags): 62 | """Gives an upper bound on the magnitudes of the outputs given inputs. 63 | 64 | Parameters 65 | ---------- 66 | input_mags : list 67 | List of magnitudes of inputs. 68 | 69 | Returns 70 | ------- 71 | float 72 | Magnitude of outputs. 73 | """ 74 | return torch.sum(input_mags) 75 | 76 | 77 | class copy(sum): 78 | 79 | def __init__(self, arg): 80 | super(copy, self).__init__([arg]) 81 | 82 | def forward(self, inputs, **kwargs): 83 | """The forward operator. 84 | 85 | Reads from inputs and writes to outputs. 86 | """ 87 | return super(copy, self).adjoint(inputs, **kwargs) 88 | 89 | def adjoint(self, *inputs, **kwargs): 90 | """The adjoint operator. 91 | 92 | Reads from inputs and writes to outputs. 93 | """ 94 | return super(copy, self).forward(*inputs, **kwargs) 95 | 96 | def norm_bound(self, input_mags): 97 | """Gives an upper bound on the magnitudes of the outputs given inputs. 98 | 99 | Parameters 100 | ---------- 101 | input_mags : list 102 | List of magnitudes of inputs. 103 | 104 | Returns 105 | ------- 106 | float 107 | Magnitude of outputs. 108 | """ 109 | return input_mags[0] 110 | -------------------------------------------------------------------------------- /dprox/linop/variable.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | import torch 4 | 5 | from .base import LinOp 6 | 7 | 8 | class Variable(LinOp): 9 | """A variable. 10 | """ 11 | 12 | def __init__(self, shape=None, value=None, name=None): 13 | super(Variable, self).__init__([]) 14 | self.uuid = uuid.uuid1() 15 | self._value = value 16 | self.shape = shape 17 | self.varname = name 18 | self.initval = None 19 | 20 | # ---------------------------------------------------------------------------- # 21 | # Computation # 22 | # ---------------------------------------------------------------------------- # 23 | 24 | def forward(self, inputs, **kwargs): 25 | """The forward operator. 26 | 27 | Reads from inputs and writes to outputs. 28 | """ 29 | return inputs 30 | 31 | def adjoint(self, inputs, **kwargs): 32 | """The adjoint operator. 33 | 34 | Reads from inputs and writes to outputs. 35 | """ 36 | return inputs 37 | 38 | # ---------------------------------------------------------------------------- # 39 | # Diagonal # 40 | # ---------------------------------------------------------------------------- # 41 | 42 | def is_diag(self, freq=False): 43 | """Is the lin op diagonal (in the frequency domain)? 44 | """ 45 | return True 46 | 47 | def get_diag(self, ref, freq=False): 48 | """Returns the diagonal representation (A^TA)^(1/2). 49 | 50 | Parameters 51 | ---------- 52 | freq : bool 53 | Is the diagonal representation in the frequency domain? 54 | Returns 55 | ------- 56 | dict of variable to ndarray 57 | The diagonal operator acting on each variable. 58 | """ 59 | return torch.ones(ref.shape) 60 | 61 | 62 | # ---------------------------------------------------------------------------- # 63 | # Property # 64 | # ---------------------------------------------------------------------------- # 65 | 66 | 67 | @property 68 | def variables(self): 69 | return [self] 70 | 71 | @property 72 | def value(self): 73 | return self._value.to(self.device) 74 | 75 | @value.setter 76 | def value(self, val): 77 | """Assign a value to the variable. 78 | """ 79 | self._value = val 80 | 81 | def norm_bound(self, input_mags): 82 | """Gives an upper bound on the magnitudes of the outputs given inputs. 83 | 84 | Parameters 85 | ---------- 86 | input_mags : list 87 | List of magnitudes of inputs. 88 | 89 | Returns 90 | ------- 91 | float 92 | Magnitude of outputs. 93 | """ 94 | return 1.0 95 | 96 | # ---------------------------------------------------------------------------- # 97 | # Python Magic # 98 | # ---------------------------------------------------------------------------- # 99 | 100 | def __repr__(self): 101 | return f'Variable(id={self.uuid}, shape={self.shape}, value={"None" if self._value is None else "somevalue"})' -------------------------------------------------------------------------------- /dprox/proxfn/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ProxFn 2 | from .sum_square import sum_squares, weighted_sum_squares, least_squares, ext_sum_squares 3 | from .pnp import deep_prior 4 | from .nonneg import nonneg 5 | from .norm import norm1, norm2 6 | from .fast import * 7 | from .nlm import patch_nlm 8 | from .unrolling import unrolled_prior -------------------------------------------------------------------------------- /dprox/proxfn/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from dprox.linop import LinOp, Placeholder, CompGraph 4 | from dprox.utils import to_torch_tensor 5 | import torch.nn as nn 6 | 7 | 8 | def exists(x): 9 | return x is not None 10 | 11 | 12 | def prox_scaled(prox, alpha): 13 | def _prox(v, lam): 14 | return prox(v, lam * alpha) 15 | return _prox 16 | 17 | 18 | def prox_affine(prox, beta): 19 | def _prox(v, lam): 20 | return 1. / beta * prox(beta * v, beta * beta * lam) 21 | return _prox 22 | 23 | 24 | def prox_translated(prox, b): 25 | def _prox(v, lam): 26 | return prox(v - b, lam) + b 27 | return _prox 28 | 29 | 30 | class ProxFn(nn.Module): 31 | """ The abstract class for the proximal operator. 32 | f(x) = argmin_x f(x) + 1/(2*lam) * ||x-v||_2^2 33 | """ 34 | 35 | def __init__(self, linop: LinOp, alpha=1, beta=1): 36 | super().__init__() 37 | self.linop = linop 38 | self.alpha = alpha 39 | self.beta = beta 40 | self.step = 0 41 | self.dag = CompGraph(linop, zero_out_constant=True) 42 | 43 | @property 44 | def offset(self): 45 | return -self.linop.offset 46 | 47 | def unwrap(self, value): 48 | if isinstance(value, Placeholder): 49 | return value.value 50 | return to_torch_tensor(value, batch=True).to(self.linop.device) 51 | 52 | def eval(self, v): 53 | return NotImplementedError 54 | 55 | def prox(self, v, lam): 56 | """ v: [B,C,H,W], lam: [B] 57 | """ 58 | if len(lam.shape) == 1: lam = lam.view(lam.shape[0], 1, 1, 1) 59 | 60 | fn = self._prox 61 | fn = prox_scaled(fn, self.alpha) 62 | fn = prox_affine(fn, self.beta) 63 | fn = prox_translated(fn, self.offset) 64 | return fn(v, lam) 65 | 66 | def convex_conjugate_prox(self, v, lam): 67 | # use Moreau’s identity 68 | return v - self.prox(v / lam, lam) 69 | 70 | def _prox(self, v, lam): 71 | return NotImplementedError 72 | 73 | # def grad(self, x): 74 | # x_ = x.detach().requires_grad_(True) 75 | # self.eval(x_).backward() 76 | # return x_.grad 77 | 78 | def __mul__(self, other): 79 | if np.isscalar(other) and other > 0: 80 | self.alpha = other 81 | return self 82 | return TypeError("Can only multiply by a positive scalar.") 83 | 84 | def __rmul__(self, other): 85 | """Called for Number * ProxFn. 86 | """ 87 | return self * other 88 | 89 | def __add__(self, other): 90 | """ProxFn + ProxFn(s). 91 | """ 92 | if isinstance(other, ProxFn): 93 | return [self, other] 94 | elif type(other) == list: 95 | return [self] + other 96 | else: 97 | return NotImplemented 98 | 99 | def __radd__(self, other): 100 | """Called for list + ProxFn. 101 | """ 102 | if type(other) == list: 103 | return other + [self] 104 | else: 105 | return NotImplemented 106 | 107 | def __str__(self): 108 | return f'{self.__class__.__name__}' 109 | -------------------------------------------------------------------------------- /dprox/proxfn/fast/__init__.py: -------------------------------------------------------------------------------- 1 | from .sr import sisr, misr 2 | from .cs import compress_sensing 3 | from .csmri import csmri -------------------------------------------------------------------------------- /dprox/proxfn/fast/cs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..sum_square import ext_sum_squares 4 | 5 | 6 | class compress_sensing(ext_sum_squares): 7 | def __init__(self, linop, mask, y): 8 | super().__init__(linop, y) 9 | self.y = self.to_parameter(y) 10 | self.mask = self.to_parameter(mask) 11 | 12 | def _reload(self, shape): 13 | mask = self.mask.value.float() 14 | 15 | self.phi = torch.sum(mask**2, dim=1, keepdim=True) 16 | def A(x): return torch.sum(x*mask, dim=1, keepdim=True) 17 | def At(x): return x*mask 18 | 19 | self.A, self.At = A, At 20 | 21 | def _prox(self, v, lam): 22 | y, A, At, phi = self.y.value, self.A, self.At, self.phi 23 | I = self.I 24 | 25 | rhs = At((I*y-A(v))/(phi+I*lam)) 26 | v = (v + rhs)/I 27 | return v 28 | -------------------------------------------------------------------------------- /dprox/proxfn/fast/csmri.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.fft 3 | 4 | from ..sum_square import ext_sum_squares 5 | from dprox.utils import fft2, ifft2 6 | 7 | 8 | class csmri(ext_sum_squares): 9 | def __init__(self, linop, mask, y): 10 | super().__init__(linop) 11 | self.mask = mask 12 | self.y = y 13 | 14 | def _prox(self, v, lam, num_psi): 15 | if len(lam.shape) == 1: 16 | lam = lam.view(lam.shape[0], 1, 1, 1) 17 | y = self.unwrap(self.y) 18 | mask = self.unwrap(self.mask).bool() 19 | 20 | z = fft2(v) 21 | temp = ((lam * z.clone()) + y) / (1 + lam * num_psi) 22 | z[mask] = temp[mask] 23 | z = ifft2(z) 24 | 25 | return z 26 | -------------------------------------------------------------------------------- /dprox/proxfn/fast/pr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .. import ProxFn 3 | from ..sum_square import ext_sum_squares 4 | 5 | from dprox.utils.misc import to_torch_tensor 6 | 7 | 8 | class phase_ret(ProxFn): 9 | def __init__(self, linop): 10 | super().__init__(linop) 11 | 12 | def prox(self, v, lam): 13 | # Az = cdp_forward(z, mask) # Az 14 | # y_hat = torch.abs(Az) # |Az| 15 | # meas_err = y_hat - y0 16 | # gradient_forward = torch.stack((meas_err/y_hat*Az[...,0], meas_err/y_hat*Az[...,1]), -1) 17 | # gradient = cdp_backward(gradient_forward, mask) 18 | # z = z - _tau * (gradient + _mu * (z - (x + u))) 19 | pass 20 | 21 | 22 | def cdp_forward(data, mask): 23 | """ 24 | Compute the forward model of cdp. 25 | 26 | Args: 27 | data (torch.Tensor): Image_data (batch_size*1*hight*weight). complex 28 | mask (torch.Tensor): mask (batch_size*sampling_rate*hight*weight).complex 29 | 30 | Returns: 31 | forward_data (torch.Tensor): the complex field of forward data (batch_size*sampling_rate*hight*weight) complex 32 | """ 33 | sampling_rate = mask.shape[1] 34 | x = data.repeat(1, sampling_rate, 1, 1) 35 | masked_data = x * mask 36 | forward_data = torch.fft.fft2(masked_data, norm='ortho') 37 | return forward_data 38 | 39 | 40 | def cdp_backward(data, mask): 41 | """ 42 | Compute the backward model of cdp (the inverse operator of forward model). 43 | 44 | Args: 45 | data (torch.Tensor): Field_data (batch_size*sampling_rate*hight*weight). 46 | mask (torch.Tensor): mask (batch_size*sampling_rate*hight*weight) 47 | 48 | Returns: 49 | backward_data (torch.Tensor): the complex field of backward data (batch_size*1*hight*weight) 50 | """ 51 | Ifft_data = torch.fft.ifft2(data, norm='ortho') 52 | backward_data = Ifft_data * torch.conj(mask) 53 | return backward_data.mean(1, keepdim=True) 54 | -------------------------------------------------------------------------------- /dprox/proxfn/fast/spi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..sum_square import ext_sum_squares 4 | 5 | 6 | class spi(ext_sum_squares): 7 | def __init__(self, linop, K, y): 8 | super().__init__(linop, y) 9 | self.K = self.to_parameter(K) 10 | self.x0 = self.to_parameter(y) 11 | 12 | def _prox(self, v, lam): 13 | assert self.I == torch.ones_like(self.I), \ 14 | 'spi ext_sum_square only support I=1' 15 | 16 | K = self.K.value * 10 17 | K1 = self.x0.value * (K ** 2) 18 | 19 | out = spi_inverse(v, K1, K, lam) 20 | return out 21 | 22 | 23 | def kron(a, b): 24 | """ 25 | Kronecker product of matrices a and b with leading batch dimensions. 26 | Batch dimensions are broadcast. The number of them mush 27 | :type a: torch.Tensor 28 | :type b: torch.Tensor 29 | :rtype: torch.Tensor 30 | """ 31 | siz1 = torch.Size(torch.tensor(a.shape[-2:]) * torch.tensor(b.shape[-2:])) 32 | res = a.unsqueeze(-1).unsqueeze(-3) * b.unsqueeze(-2).unsqueeze(-4) 33 | siz0 = res.shape[:-4] 34 | 35 | return res.reshape(siz0 + siz1) 36 | 37 | 38 | def spi_forward(x, K, alpha, q): 39 | ones = torch.ones(1, 1, K, K).to(x.device) 40 | theta = alpha * kron(x, ones) / (K**2) 41 | y = torch.poisson(theta) 42 | ob = (y >= torch.ones_like(y) * q).float() 43 | 44 | return ob 45 | 46 | 47 | def spi_inverse(ztilde, K1, K, mu): 48 | """ 49 | Proximal operator "Prox\_{\frac{1}{\mu} D}" for single photon imaging 50 | assert alpha == K and q == 1 51 | """ 52 | z = torch.zeros_like(ztilde) 53 | 54 | K0 = K**2 - K1 55 | indices_0 = (K1 == 0) 56 | 57 | z[indices_0] = ztilde[indices_0] - (K0 / mu)[indices_0] 58 | def func(y): return K1 / (torch.exp(y) - 1) - mu * y - K0 + mu * ztilde 59 | 60 | indices_1 = torch.logical_not(indices_0) 61 | 62 | # differentiable binary search 63 | bmin = 1e-5 * torch.ones_like(ztilde) 64 | bmax = 1.1 * torch.ones_like(ztilde) 65 | 66 | bave = (bmin + bmax) / 2.0 67 | 68 | for i in range(10): 69 | tmp = func(bave) 70 | indices_pos = torch.logical_and(tmp > 0, indices_1) 71 | indices_neg = torch.logical_and(tmp < 0, indices_1) 72 | indices_zero = torch.logical_and(tmp == 0, indices_1) 73 | indices_0 = torch.logical_or(indices_0, indices_zero) 74 | indices_1 = torch.logical_not(indices_0) 75 | 76 | bmin[indices_pos] = bave[indices_pos] 77 | bmax[indices_neg] = bave[indices_neg] 78 | bave[indices_1] = (bmin[indices_1] + bmax[indices_1]) / 2.0 79 | 80 | z[K1 != 0] = bave[K1 != 0] 81 | return torch.clamp(z, 0.0, 1.0) 82 | -------------------------------------------------------------------------------- /dprox/proxfn/nlm/__init__.py: -------------------------------------------------------------------------------- 1 | from .patch_nlm import patch_nlm -------------------------------------------------------------------------------- /dprox/proxfn/nlm/patch_nlm.py: -------------------------------------------------------------------------------- 1 | from .nlm import NonLocalMeansFast 2 | from .. import ProxFn 3 | 4 | 5 | class patch_nlm(ProxFn): 6 | def __init__(self, linop): 7 | super().__init__(linop) 8 | self.denoiser = NonLocalMeansFast() 9 | 10 | def _prox(self, v, lam): 11 | lam = lam.sqrt() 12 | out = self.denoiser(v, lam) 13 | return out 14 | -------------------------------------------------------------------------------- /dprox/proxfn/nonneg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from . import ProxFn 4 | 5 | 6 | class nonneg(ProxFn): 7 | def __init__(self, linop=None): 8 | super().__init__(linop) 9 | 10 | def _prox(self, v, lam): 11 | return torch.maximum(v, torch.zeros_like(v)) 12 | -------------------------------------------------------------------------------- /dprox/proxfn/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from . import ProxFn 4 | 5 | 6 | def soft_threshold(v, lam): 7 | """ ref: https://www.tensorflow.org/probability/api_docs/python/tfp/math/soft_threshold 8 | 9 | argmin_x lam * |x|_1 + 0.5 * (x-v)^2 10 | """ 11 | return torch.sign(v) * torch.maximum(torch.abs(v) - lam, torch.zeros_like(v)) 12 | 13 | 14 | class norm1(ProxFn): 15 | def __init__(self, linop=None): 16 | super().__init__(linop) 17 | 18 | def _prox(self, v, lam): 19 | return soft_threshold(v, lam) 20 | 21 | 22 | class norm2(ProxFn): 23 | def __init__(self, linop=None): 24 | super().__init__(linop) 25 | 26 | def _prox(self, v, lam): 27 | return v / (1 + 2 * lam) 28 | -------------------------------------------------------------------------------- /dprox/proxfn/pnp/__init__.py: -------------------------------------------------------------------------------- 1 | from .prior import deep_prior -------------------------------------------------------------------------------- /dprox/proxfn/pnp/denoisers/__init__.py: -------------------------------------------------------------------------------- 1 | from .wrapper import (TVDenoiser, 2 | FFDNet3DDenoiser, FFDNetDenoiser, FFDNetColorDenoiser, 3 | IRCNNDenoiser, DRUNetDenoiser, 4 | QRNN3DDenoiser, GRUNetDenoiser, GRUNetTVDenoiser, 5 | UNetDenoiser) 6 | 7 | from .composite import Augment, DeepTVDenoiser -------------------------------------------------------------------------------- /dprox/proxfn/pnp/denoisers/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | import torch.nn as nn 4 | 5 | class Denoiser(nn.Module): 6 | def denoise(self, input: torch.Tensor, sigma: torch.Tensor): 7 | """ x: [NCHW] , sigma: a single number tensor""" 8 | sigma = sigma.view(-1, 1, 1, 1) 9 | output = self._denoise(input, sigma) 10 | return output 11 | 12 | @abc.abstractmethod 13 | def _denoise(self, x, sigma): 14 | raise NotImplementedError 15 | 16 | 17 | class Denoiser2D(Denoiser): 18 | def denoise(self, input: torch.Tensor, sigma: torch.Tensor): 19 | """ x: [NCHW] , sigma: a single number tensor""" 20 | sigma = sigma.view(-1, 1, 1, 1) 21 | outs = [] 22 | for band in input.split(1, dim=1): 23 | band = self._denoise(band, sigma) 24 | outs.append(band) 25 | return torch.cat(outs, dim=1) 26 | -------------------------------------------------------------------------------- /dprox/proxfn/pnp/denoisers/composite.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class Augment(nn.Module): 7 | def __init__(self, base_denoiser): 8 | super().__init__() 9 | self.base_denoiser = base_denoiser 10 | self.iter = 0 11 | 12 | def denoise(self, x: torch.Tensor, sigma: torch.Tensor): 13 | iter = self.iter 14 | 15 | x = self.augment(x, iter % 8) 16 | 17 | x = self.base_denoiser.denoise(x, sigma) 18 | 19 | if iter % 8 == 3 or iter % 8 == 5: 20 | x = self.augment(x, 8 - iter % 8) 21 | else: 22 | x = self.augment(x, iter % 8) 23 | 24 | self.iter += 1 25 | return x 26 | 27 | def reset(self): 28 | self.iter = 0 29 | 30 | @staticmethod 31 | def augment(img, mode=0): 32 | if mode == 0: 33 | return img 34 | elif mode == 1: 35 | return img.rot90(1, [2, 3]).flip([2]) 36 | elif mode == 2: 37 | return img.flip([2]) 38 | elif mode == 3: 39 | return img.rot90(3, [2, 3]) 40 | elif mode == 4: 41 | return img.rot90(2, [2, 3]).flip([2]) 42 | elif mode == 5: 43 | return img.rot90(1, [2, 3]) 44 | elif mode == 6: 45 | return img.rot90(2, [2, 3]) 46 | elif mode == 7: 47 | return img.rot90(3, [2, 3]).flip([2]) 48 | 49 | 50 | class DeepTVDenoiser: 51 | def __init__(self, deep_denoise, tv_denoising, 52 | deep_hypara_list=[40., 20., 10., 5.], tv_hypara_list=[10, 0.01]): 53 | self.deep_hypara_list = deep_hypara_list 54 | self.tv_hypara_list = tv_hypara_list 55 | self.tv_denoising = tv_denoising 56 | self.deep_denoise = deep_denoise 57 | 58 | def denoise(self, x): 59 | import cvxpy as cp 60 | # x: 1,31,512,512 61 | deep_num = len(self.deep_hypara_list) 62 | tv_num = len(self.tv_hypara_list) 63 | deep_list = [self.deep_denoise(x, torch.tensor(level/255.).to(x.device)) for level in self.deep_hypara_list] 64 | deep_list = [tmp.squeeze().permute(1, 2, 0) for tmp in deep_list] 65 | 66 | tv_list = [self.tv_denoising(x.squeeze().permute(1, 2, 0), level, 5).clamp(0, 1) for level in self.tv_hypara_list] 67 | 68 | ffdnet_mat = np.stack( 69 | [x_ele[:, :, :].cpu().numpy().reshape(-1).astype(np.float64) for x_ele in deep_list], 70 | axis=0) 71 | tv_mat = np.stack( 72 | [x_ele[:, :, :].cpu().numpy().reshape(-1).astype(np.float64) for x_ele in tv_list], 73 | axis=0) 74 | w = cp.Variable(deep_num + tv_num) 75 | P = np.zeros((deep_num + tv_num, deep_num + tv_num)) 76 | P[:deep_num, :deep_num] = ffdnet_mat @ ffdnet_mat.T 77 | P[:deep_num, deep_num:] = -ffdnet_mat @ tv_mat.T 78 | P[deep_num:, :deep_num] = -tv_mat @ ffdnet_mat.T 79 | P[deep_num:, deep_num:] = tv_mat @ tv_mat.T 80 | one_vector_ffdnet = np.ones((1, deep_num)) 81 | one_vector_tv = np.ones((1, tv_num)) 82 | objective = cp.quad_form(w, P) 83 | problem = cp.Problem( 84 | cp.Minimize(objective), 85 | [one_vector_ffdnet @ w[:deep_num] == 1, 86 | one_vector_tv @ w[deep_num:] == 1, 87 | w >= 0]) 88 | problem.solve() 89 | w_value = w.value 90 | x_ffdnet, x_tv = 0, 0 91 | 92 | for idx in range(deep_num): 93 | x_ffdnet += w_value[idx] * deep_list[idx] 94 | for idx in range(tv_num): 95 | x_tv += w_value[idx + deep_num] * tv_list[idx] 96 | v = 0.5 * (x_ffdnet + x_tv) 97 | v = v.permute(2, 0, 1).unsqueeze(0) 98 | return v 99 | -------------------------------------------------------------------------------- /dprox/proxfn/pnp/denoisers/models/TV_denoising.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def TV_denoising(y0, lamda, iteration=100): 5 | device = y0.device 6 | w, h, b = y0.shape 7 | zh = torch.zeros([w, h-1, b], device=device, dtype=torch.float32) 8 | zv = torch.zeros([w-1, h, b], device=device, dtype=torch.float32) 9 | alpha = 5 10 | for it in range(iteration): 11 | x0h = y0 - dht_3d(zh) 12 | x0v = y0 - dvt_3d(zv) 13 | x0 = (x0h + x0v) / 2 14 | zh = clip(zh + 1/alpha*dh(x0), lamda/2) 15 | zv = clip(zv + 1/alpha*dv(x0), lamda/2) 16 | return x0 17 | 18 | def TV_denoising3d(y0, lamda, iteration=100): 19 | device = y0.device 20 | # z = torch.zeros(y0.shape - [1, 1, 1], device=device, dtype=torch.float32) 21 | w, h, b = y0.shape 22 | zh = torch.zeros([w, h-1, b], device=device, dtype=torch.float32) 23 | zv = torch.zeros([w-1, h, b], device=device, dtype=torch.float32) 24 | zt = torch.zeros([w, h, b-1], device=device, dtype=torch.float32) 25 | alpha = 5 26 | for it in range(iteration): 27 | x0h = y0 - dht_3d(zh) 28 | x0v = y0 - dvt_3d(zv) 29 | x0t = y0 - dtt_3d(zt) 30 | x0 = (x0h + x0v + x0t) / 3 31 | zh = clip(zh + 1/alpha*dh(x0), lamda/2) 32 | zv = clip(zv + 1/alpha*dv(x0), lamda/2) 33 | zt = clip(zt + 1/alpha*dt(x0), lamda/2) 34 | return x0 35 | 36 | def clip(x, thres): 37 | return torch.clamp(x, min=-thres, max=thres) 38 | 39 | def dht_3d(x): 40 | return torch.cat([-x[:,0:1,:], x[:,:-1,:]-x[:,1:,:], x[:,-1:,:]], 1) 41 | 42 | def dvt_3d(x): 43 | return torch.cat([-x[0:1,:,:], x[:-1,:,:]-x[1:,:,:], x[-1:,:,:]], 0) 44 | 45 | def dtt_3d(x): 46 | return torch.cat([-x[:,:,0:1], x[:,:,:-1]-x[:,:,1:], x[:,:,-1:]], 2) 47 | 48 | def dh(x): 49 | return x[:,1:,:]-x[:,:-1,:] 50 | 51 | def dv(x): 52 | return x[1:,:,:] - x[:-1,:,:] 53 | 54 | def dt(x): 55 | return x[:,:,1:] - x[:,:,:-1] -------------------------------------------------------------------------------- /dprox/proxfn/pnp/denoisers/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Delta-Prox/15fc489e99c4de2676dc538d9c0552338f6fb894/dprox/proxfn/pnp/denoisers/models/__init__.py -------------------------------------------------------------------------------- /dprox/proxfn/pnp/denoisers/models/qrnn/__init__.py: -------------------------------------------------------------------------------- 1 | from .qrnn3d import QRNNREDC3D 2 | from .grunet import GRUnet 3 | 4 | """Define commonly used architecture""" 5 | 6 | 7 | def qrnn3d(): 8 | net = QRNNREDC3D(1, 16, 5, [1, 3], has_ad=True, bn=True) 9 | net.use_2dconv = False 10 | net.bandwise = False 11 | return net 12 | 13 | 14 | def qrnn2d(): 15 | net = QRNNREDC3D(1, 16, 5, [1, 3], has_ad=True, is_2d=True) 16 | net.use_2dconv = False 17 | net.bandwise = False 18 | return net 19 | 20 | 21 | def qrnn3d_masked(): 22 | net = QRNNREDC3D(2, 16, 5, [1, 3], has_ad=True) 23 | net.use_2dconv = False 24 | net.bandwise = False 25 | return net 26 | 27 | 28 | def grunet_masked(): 29 | return GRUnet(in_ch=2, out_ch=1, use_noise_map=True) 30 | 31 | 32 | def grunet_masked_nobn(): 33 | return GRUnet(in_ch=2, out_ch=1, use_noise_map=True, bn=False) 34 | 35 | 36 | def grunet(): 37 | return GRUnet(in_ch=1, out_ch=1, use_noise_map=False) 38 | 39 | 40 | def grunet_nobn(): 41 | return GRUnet(in_ch=1, out_ch=1, use_noise_map=False, bn=False) 42 | -------------------------------------------------------------------------------- /dprox/proxfn/pnp/denoisers/models/qrnn/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional 4 | from .sync_batchnorm import SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 5 | 6 | BatchNorm3d = SynchronizedBatchNorm3d 7 | 8 | 9 | class BNReLUConv3d(nn.Sequential): 10 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False): 11 | super(BNReLUConv3d, self).__init__() 12 | self.add_module('bn', BatchNorm3d(in_channels)) 13 | self.add_module('relu', nn.ReLU(inplace=inplace)) 14 | self.add_module('conv', nn.Conv3d(in_channels, channels, k, s, p, bias=False)) 15 | 16 | 17 | class BNReLUDeConv3d(nn.Sequential): 18 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False): 19 | super(BNReLUDeConv3d, self).__init__() 20 | self.add_module('bn', BatchNorm3d(in_channels)) 21 | self.add_module('relu', nn.ReLU(inplace=inplace)) 22 | self.add_module('deconv', nn.ConvTranspose3d(in_channels, channels, k, s, p, bias=False)) 23 | 24 | 25 | class BNReLUUpsampleConv3d(nn.Sequential): 26 | def __init__(self, in_channels, channels, k=3, s=1, p=1, upsample=(1,2,2), inplace=False): 27 | super(BNReLUUpsampleConv3d, self).__init__() 28 | self.add_module('bn', BatchNorm3d(in_channels)) 29 | self.add_module('relu', nn.ReLU(inplace=inplace)) 30 | self.add_module('upsample_conv', UpsampleConv3d(in_channels, channels, k, s, p, bias=False, upsample=upsample)) 31 | 32 | 33 | class UpsampleConv3d(torch.nn.Module): 34 | """UpsampleConvLayer 35 | Upsamples the input and then does a convolution. This method gives better results 36 | compared to ConvTranspose2d. 37 | ref: http://distill.pub/2016/deconv-checkerboard/ 38 | """ 39 | 40 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=True, upsample=None): 41 | super(UpsampleConv3d, self).__init__() 42 | self.upsample = upsample 43 | if upsample: 44 | self.upsample_layer = torch.nn.Upsample(scale_factor=upsample, mode='trilinear', align_corners=True) 45 | 46 | self.conv3d = torch.nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) 47 | 48 | def forward(self, x): 49 | x_in = x 50 | if self.upsample: 51 | x_in = self.upsample_layer(x_in) 52 | out = self.conv3d(x_in) 53 | return out 54 | 55 | 56 | class BasicConv3d(nn.Sequential): 57 | def __init__(self, in_channels, channels, k=3, s=1, p=1, bias=False, bn=True): 58 | super(BasicConv3d, self).__init__() 59 | if bn: 60 | self.add_module('bn', BatchNorm3d(in_channels)) 61 | self.add_module('conv', nn.Conv3d(in_channels, channels, k, s, p, bias=bias)) 62 | 63 | 64 | class BasicDeConv3d(nn.Sequential): 65 | def __init__(self, in_channels, channels, k=3, s=1, p=1, bias=False, bn=True): 66 | super(BasicDeConv3d, self).__init__() 67 | if bn: 68 | self.add_module('bn', BatchNorm3d(in_channels)) 69 | self.add_module('deconv', nn.ConvTranspose3d(in_channels, channels, k, s, p, bias=bias)) 70 | 71 | 72 | class BasicUpsampleConv3d(nn.Sequential): 73 | def __init__(self, in_channels, channels, k=3, s=1, p=1, upsample=(1,2,2), bn=True): 74 | super(BasicUpsampleConv3d, self).__init__() 75 | if bn: 76 | self.add_module('bn', BatchNorm3d(in_channels)) 77 | self.add_module('upsample_conv', UpsampleConv3d(in_channels, channels, k, s, p, bias=False, upsample=upsample)) 78 | -------------------------------------------------------------------------------- /dprox/proxfn/pnp/denoisers/models/qrnn/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /dprox/proxfn/pnp/denoisers/models/qrnn/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /dprox/proxfn/pnp/denoisers/models/qrnn/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /dprox/proxfn/pnp/denoisers/models/unet/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet import UNet -------------------------------------------------------------------------------- /dprox/proxfn/pnp/prior.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from dprox.utils import hf, safe_sqrt 7 | 8 | from ..base import ProxFn 9 | from .denoisers import (DRUNetDenoiser, FFDNetColorDenoiser, FFDNetDenoiser, 10 | GRUNetDenoiser, IRCNNDenoiser, UNetDenoiser) 11 | from .denoisers.composite import Augment 12 | 13 | 14 | def get_denoiser(type): 15 | if type == 'ffdnet': 16 | model_path = hf.load_path('pnp_denoisers/ffdnet_gray.pth', repo_type='models') 17 | return FFDNetDenoiser(model_path) 18 | if type == 'ffdnet_color': 19 | model_path = hf.load_path('pnp_denoisers/ffdnet_color.pth', repo_type='models') 20 | return FFDNetColorDenoiser(model_path) 21 | if type == 'drunet_color': 22 | model_path = hf.load_path('pnp_denoisers/drunet_color.pth', repo_type='models') 23 | return DRUNetDenoiser(3, model_path) 24 | if type == 'drunet': 25 | model_path = hf.load_path('pnp_denoisers/drunet_gray.pth', repo_type='models') 26 | return DRUNetDenoiser(1, model_path) 27 | if type == 'ircnn': 28 | model_path = hf.load_path('pnp_denoisers/ircnn_gray.pth', repo_type='models') 29 | return IRCNNDenoiser(1, model_path) 30 | if type == 'grunet': 31 | model_path = hf.load_path('pnp_denoisers/unet_qrnn3d.pth', repo_type='models') 32 | return GRUNetDenoiser(model_path) 33 | if type == 'unet': 34 | model_path = hf.load_path('pnp_denoisers/unet-nm.pt', repo_type='models') 35 | return UNetDenoiser(model_path) 36 | 37 | 38 | def clone(x, nums, shared): 39 | return [x if shared else copy.deepcopy(x) for _ in range(nums)] 40 | 41 | 42 | class deep_prior(ProxFn): 43 | def __init__(self, linop, denoiser='ffdnet', x8=False, clamp=False, trainable=False, unroll_step=None, sqrt=False): 44 | super().__init__(linop) 45 | self.name = denoiser 46 | 47 | if isinstance(denoiser, str): 48 | self.denoiser = get_denoiser(denoiser) 49 | else: 50 | self.denoiser = denoiser 51 | 52 | self.x8 = x8 53 | self.clamp = clamp 54 | self.sqrt = sqrt 55 | if x8: 56 | self.denoiser = Augment(self.denoiser) 57 | 58 | if not trainable: 59 | self.denoiser.eval() 60 | self.denoiser.requires_grad_(False) 61 | 62 | self.unroll = unroll_step is not None 63 | if unroll_step is not None: 64 | self.denoisers = nn.ModuleList(clone(self.denoiser, unroll_step, shared=False)) 65 | 66 | def _reload(self, shape=None): 67 | if self.x8: 68 | self.denoiser.reset() 69 | 70 | def eval(self, v): 71 | raise NotImplementedError('deep prior cannot be explictly evaluated') 72 | 73 | def _prox(self, v: torch.Tensor, lam: torch.Tensor): 74 | """ v: [N, C, H, W] or [N, H, W] 75 | lam: [1] 76 | """ 77 | sigma = safe_sqrt(lam) if self.sqrt else lam 78 | if self.clamp: v = v.clamp(0, 1) 79 | if torch.is_complex(v): v = v.real 80 | if len(v.shape) == 3: input = v.unsqueeze(1) 81 | else: input = v 82 | if self.unroll: out = self.denoisers[self.step].denoise(input, sigma) # sigma = lam.sqrt() 83 | else: out = self.denoiser.denoise(input, sigma) 84 | out = out.type_as(v) 85 | out = out.reshape(*v.shape) 86 | return out 87 | 88 | def __repr__(self): 89 | return f'deep_prior(denoiser="{self.name}", unroll={self.unroll})' 90 | -------------------------------------------------------------------------------- /dprox/proxfn/unrolling/__init__.py: -------------------------------------------------------------------------------- 1 | from .prior import unrolled_prior -------------------------------------------------------------------------------- /dprox/proxfn/unrolling/prior.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .. import ProxFn 4 | from .dgu import Denoiser 5 | 6 | 7 | class unrolled_prior(ProxFn): 8 | def __init__(self, linop, denoiser=None): 9 | super().__init__(linop) 10 | if denoiser is not None: 11 | self.denoiser = denoiser() 12 | else: 13 | self.denoiser = Denoiser() 14 | 15 | def eval(self, v): 16 | raise NotImplementedError('deep prior cannot be explictly evaluated') 17 | 18 | def _prox(self, v: torch.Tensor, lam: torch.Tensor=None): 19 | """ v: [N, C, H, W] 20 | lam: [1] 21 | """ 22 | out = self.denoiser(v, self.step) 23 | return out -------------------------------------------------------------------------------- /dprox/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .misc import to_ndarray, to_torch_tensor, fft2, ifft2, outlier_correct, crop_center_region, safe_sqrt 2 | from .io import imshow, imread_rgb, imread, filter_ckpt, is_image_file, list_image_files 3 | from .metrics import psnr, sam, ssim, mpsnr, mpsnr_max, mssim 4 | from . import huggingface as hf -------------------------------------------------------------------------------- /dprox/utils/containar.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def is_dp_array(x): 6 | """ 7 | check if an object has an attribute 'is_dp_array' with a value of True. 8 | 9 | :param x: The input parameter to the function is x, which is expected to be an object 10 | :return: a boolean value. It checks if the input `x` has an 11 | attribute called `is_dp_array` and if its value is `True`. 12 | """ 13 | return hasattr(x, 'is_dp_array') and x.is_dp_array == True 14 | 15 | 16 | def array(*args, **kwargs): 17 | """ 18 | create a numpy array and sets an attribute to indicate that it is a dp_array. 19 | :return: a numpy array with an additional attribute `is_dp_array` 20 | set to `True`. 21 | """ 22 | out = np.array(*args, **kwargs) 23 | out.is_dp_array = True 24 | return out 25 | 26 | 27 | def is_dp_tensor(x): 28 | """ 29 | check if the input object has an attribute 'is_dp_tensor' set to True, indicating that 30 | it is a delta-prox tensor. 31 | 32 | :param x: The input variable that is being checked for whether it is a delta-prox tensor 33 | or not 34 | :return: a boolean value `True` if the input `x` has an 35 | attribute `is_dp_tensor` that is also `True`, and `False` otherwise. 36 | """ 37 | return hasattr(x, 'is_dp_tensor') and x.is_dp_tensor == True 38 | 39 | 40 | def tensor(*args, **kwargs): 41 | """ 42 | create a PyTorch tensor and set a flag to indicate that it is a 43 | delta-prox tensor. 44 | :return: A PyTorch tensor with an additional attribute `is_dp_tensor` set to `True`. 45 | """ 46 | out = torch.tensor(*args, **kwargs) 47 | out.is_dp_tensor = True 48 | return out 49 | -------------------------------------------------------------------------------- /dprox/utils/huggingface.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib.request 3 | import huggingface_hub 4 | 5 | from tqdm import tqdm 6 | 7 | CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache/dprox") 8 | 9 | 10 | class DownloadProgressBar(tqdm): 11 | # The DownloadProgressBar class is a subclass of the tqdm class in Python used to 12 | # display progress bars for downloading files. 13 | def update_to(self, b=1, bsize=1, tsize=None): 14 | if tsize is not None: 15 | self.total = tsize 16 | self.update(b * bsize - self.n) 17 | 18 | 19 | def download_url(url: str, output_path: str) -> None: 20 | """ 21 | download a file from a given URL and save it to a specified output path while 22 | displaying a progress bar. 23 | 24 | Args: 25 | url (str): The URL of the file to be downloaded 26 | output_path (str): output_path is a string representing the file path where the downloaded file 27 | will be saved. It should include the file name and extension 28 | """ 29 | with DownloadProgressBar( 30 | unit="B", unit_scale=True, miniters=1, desc=url.split("/")[-1] 31 | ) as t: 32 | urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to) 33 | 34 | 35 | def load_path(base_path: str, repo_type="datasets", user_id="delta-prox") -> str: 36 | """ 37 | check if a file exists in a specific directory and download it from a URL if it 38 | doesn't exist. 39 | 40 | Args: 41 | base_path (str): The base path is a string that represents the path to a file or directory that the 42 | function is trying to locate or download. It is used to construct the full path to the file or 43 | directory by appending it to the DPROX_DIR path 44 | 45 | Return: 46 | a string which is the path to the file specified by the input parameter `base_path`. 47 | """ 48 | if os.path.exists(base_path): 49 | return base_path 50 | 51 | save_path = os.path.join(CACHE_DIR, base_path) 52 | if not os.path.exists(save_path): 53 | base_url = "https://huggingface.co" 54 | if repo_type == "datasets": 55 | base_url += "/" + repo_type 56 | repo_id = base_path.split("/")[0] 57 | path = os.path.join(*(base_path.split("/")[1:])) 58 | url = f"{base_url}/{user_id}/{repo_id}/resolve/main/{path}" 59 | print(f"{base_path} not found") 60 | print("Try to download from huggingface: ", url) 61 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 62 | if "\\" in url: ## fix url for windows 63 | url = url.replace("\\", "/") 64 | download_url(url, save_path) 65 | print("Downloaded to ", save_path) 66 | return save_path 67 | 68 | 69 | def load_image(path, user_id="delta-prox"): 70 | import imageio 71 | import numpy as np 72 | 73 | path = load_path(path, user_id=user_id) 74 | img = imageio.imread(path) 75 | return np.float32(img) / 255 76 | 77 | 78 | def load_checkpoint(path, user_id="delta-prox"): 79 | import torch 80 | 81 | ckpt_path = load_path(path, repo_type="models", user_id=user_id) 82 | return torch.load(ckpt_path) 83 | 84 | 85 | def download_dataset(path, user_id="delta-prox", local_dir=None, force_download=False): 86 | if local_dir is None: 87 | local_dir = os.path.join(CACHE_DIR, path) 88 | if os.path.exists(local_dir) and not force_download: 89 | return local_dir 90 | huggingface_hub.snapshot_download( 91 | repo_id=f"{user_id}/{path}", local_dir=local_dir, repo_type="dataset" 92 | ) 93 | return local_dir 94 | -------------------------------------------------------------------------------- /dprox/utils/init/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Delta-Prox/15fc489e99c4de2676dc538d9c0552338f6fb894/dprox/utils/init/__init__.py -------------------------------------------------------------------------------- /dprox/utils/init/sr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def shift_pixel(x, sf, upper_left=True): 5 | """shift pixel for super-resolution with different scale factors 6 | Args: 7 | x: WxHxC or WxH 8 | sf: scale factor 9 | upper_left: shift direction 10 | """ 11 | from scipy.interpolate import interp2d 12 | h, w = x.shape[:2] 13 | shift = (sf-1)*0.5 14 | xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) 15 | if upper_left: 16 | x1 = xv + shift 17 | y1 = yv + shift 18 | else: 19 | x1 = xv - shift 20 | y1 = yv - shift 21 | 22 | x1 = np.clip(x1, 0, w-1) 23 | y1 = np.clip(y1, 0, h-1) 24 | 25 | if x.ndim == 2: 26 | x = interp2d(xv, yv, x)(x1, y1) 27 | if x.ndim == 3: 28 | for i in range(x.shape[-1]): 29 | x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) 30 | 31 | return x -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Welcome to the $\nabla$-Prox examples 2 | 3 | The examples of using $\nabla$-Prox for various applications. 4 | 5 | | Application | 6 | | -- | 7 | | [Computational Optics](papers/deltaprox_siggraph_2023/computional_optics) | 8 | | [Compressive-MRI](papers/deltaprox_siggraph_2023/csmri) | 9 | | [Image Deraining](papers/deltaprox_siggraph_2023/deraining) | 10 | | [Image Deconvolution](applications/deconv.py) | 11 | | [Image Demosaicing](applications/demosaic.py) | 12 | | [Image Super-Resolution](applications/super_resolution.py) | 13 | | [Joint Image Deconvolution and Demosaicing](applications/joint_demosaic_deconv.py) | 14 | 15 | 16 | 17 | 18 | > Contributions are welcome. -------------------------------------------------------------------------------- /examples/applications/csmri.py: -------------------------------------------------------------------------------- 1 | from dprox import * 2 | from dprox.utils import * 3 | from dprox import contrib 4 | 5 | x0, y0, gt, mask = contrib.csmri.sample() 6 | 7 | x = Variable() 8 | y = Placeholder() 9 | data_term = csmri(x, mask, y) 10 | reg_term = deep_prior(x, denoiser='unet') 11 | prob = Problem(data_term + reg_term) 12 | 13 | y.value = y0 14 | max_iter = 24 15 | rhos, sigmas = log_descent(30, 20, max_iter) 16 | prob.solve( 17 | method='admm', 18 | device='cuda', 19 | x0=x0, rhos=rhos, lams={reg_term: sigmas}, max_iter=max_iter, pbar=True 20 | ) 21 | out = x.value.real 22 | 23 | print(psnr(out, gt)) # 43 24 | imshow(out) -------------------------------------------------------------------------------- /examples/applications/deconv.py: -------------------------------------------------------------------------------- 1 | from dprox import * 2 | from dprox.utils import * 3 | from dprox.contrib import * 4 | 5 | img = sample() 6 | psf = point_spread_function(15, 5) 7 | b = blurring(img, psf) 8 | 9 | x = Variable() 10 | data_term = sum_squares(conv(x, psf) - b) 11 | reg_term = deep_prior(x, denoiser='ffdnet_color') 12 | prob = Problem(data_term + reg_term) 13 | prob.solve(method='admm', x0=b) 14 | 15 | print(psnr(x.value, img)) # 35 16 | imshow(x.value) -------------------------------------------------------------------------------- /examples/applications/demosaic.py: -------------------------------------------------------------------------------- 1 | from dprox import * 2 | from dprox.utils import * 3 | from dprox.contrib import * 4 | 5 | img = sample('face') 6 | offset = mosaicing(img) 7 | 8 | x = Variable() 9 | data_term = sum_squares(mosaic(x), offset) 10 | reg_term = deep_prior(x, denoiser='ffdnet_color') 11 | prob = Problem(data_term + reg_term, merge=False) 12 | 13 | max_iter = 30 14 | rhos, sigmas = log_descent(35, 5, max_iter, sqrt=True) 15 | out = prob.solve(method='admm', x0=offset, pbar=True, 16 | rhos=rhos, lams={reg_term: sigmas}) 17 | 18 | print(psnr(out, img)) # 39 19 | imshow(out) -------------------------------------------------------------------------------- /examples/applications/joint_demosaic_deconv.py: -------------------------------------------------------------------------------- 1 | from dprox import * 2 | from dprox.utils import * 3 | from dprox.contrib import * 4 | 5 | img = sample() 6 | img = to_torch_tensor(img, batch=True).float() 7 | psf = point_spread_function(15, 5) 8 | offset = blurring(img, psf) 9 | offset = mosaicing(offset) 10 | 11 | x = Variable() 12 | data_term = sum_squares(mosaic(conv(x, psf)), offset) 13 | reg_term = deep_prior(x, denoiser='ffdnet_color') 14 | prob = Problem(data_term + reg_term, absorb=True) 15 | 16 | out = prob.solve(method='admm', x0=offset, max_iter=24, pbar=True) 17 | 18 | 19 | imshow(out) 20 | print(psnr(out, img)) # 25.9 21 | -------------------------------------------------------------------------------- /examples/applications/super_resolution.py: -------------------------------------------------------------------------------- 1 | from dprox import * 2 | from dprox.utils import * 3 | from dprox.contrib import * 4 | 5 | img = sample('face') 6 | psf = point_spread_function(5, 3) 7 | y, x0 = downsampling(img, psf, sf=2) 8 | 9 | x = Variable() 10 | data_term = sisr(x, y, kernel=psf, sf=2) 11 | reg_term = deep_prior(x, denoiser='ffdnet_color') 12 | prob = Problem(data_term + reg_term) 13 | 14 | max_iter = 24 15 | rhos, sigmas = log_descent(35, 35, max_iter) 16 | 17 | out = prob.solve(method='admm', x0=x0, rhos=rhos, lams={reg_term: sigmas}, max_iter=24, pbar=True) 18 | 19 | print(psnr(out, img)) # 32.9 20 | imshow(out) -------------------------------------------------------------------------------- /examples/papers/README.md: -------------------------------------------------------------------------------- 1 | # Papers with $\nabla$-Prox 2 | 3 | A list of papers that are implemented / replicated in $\nabla$-Prox. 4 | 5 | > If you want to add your project to the list, contact the maintainers or file a pull request. 6 | 7 | | Paper | Code | 8 | | ---- | --- | 9 | |[∇-Prox: Differentiable Proximal Algorithm Modeling for Large-Scale Optimization](https://dl.acm.org/doi/abs/10.1145/3592144). SIGGRAPH'23 | [Link](deltaprox_siggraph_2023/) | 10 | | [Deep Generalized Unfolding Networks for Image Restoration](https://arxiv.org/abs/2204.13348). CVPR'21 | [Link](deltaprox_siggraph_2023/) | 11 | | [Deep Plug-and-Play Prior for Hyperspectral Image Restoration](http://arxiv.org/abs/2209.08240). Neurocomputing'22 | [Link](dphsir_neurcomputing_2022/) | 12 | | [Plug-and-Play Image Restoration with Deep Denoiser Prior](http://arxiv.org/abs/2209.08240). TPAMI'20 | [Link](dpir_tpami_2020/) | 13 | | [Tuning-free Plug-and-Play Proximal Algorithm for Inverse Imaging Problems](http://arxiv.org/abs/2209.08240). ICML'20 | [Link](tfpnp_icml_2020//) | -------------------------------------------------------------------------------- /examples/papers/deltaprox_siggraph_2023/computional_optics/README.md: -------------------------------------------------------------------------------- 1 | # End-to-End Computational Optics 2 | 3 | The application of Delta-Prox for end-to-end computational optics. 4 | 5 | ## Prepare Data and Checkpoints 6 | 7 | - Download [BSD500](https://huggingface.co/datasets/delta-prox/BSD500) for training, and [CBSD68](https://huggingface.co/datasets/delta-prox/CBSD68) and [Urban100](https://huggingface.co/datasets/delta-prox/Urban100) for evaluation. 8 | - The pretrained models are hosted at [Huggingface](https://huggingface.co/delta-prox/computational_optics). 9 | 10 | ## Training 11 | 12 | - Training Delta-Prox 13 | 14 | ```bash 15 | python e2e_optics_dprox.py train 16 | ``` 17 | 18 | - Training DeepOpticsUnet 19 | 20 | ```bash 21 | python e2e_optics_unet.py train 22 | ``` 23 | 24 | ## Evaluation 25 | 26 | - Evaluate Delta-Prox 27 | 28 | ```bash 29 | python e2e_optics_dprox.py test 30 | ``` 31 | 32 | - Evaluate DeepOpticsUnet 33 | 34 | ```bash 35 | python e2e_optics_unet.py test 36 | ``` 37 | 38 | - Evaluate DPIR 39 | 40 | ```bash 41 | python pnp_optics.py 42 | ``` -------------------------------------------------------------------------------- /examples/papers/deltaprox_siggraph_2023/computional_optics/pnp_optics.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torchlight as tl 5 | 6 | from dprox import * 7 | from dprox import Variable 8 | from dprox.contrib.optic import (DOEModelConfig, build_baseline_profile, 9 | build_doe_model, img_psf_conv, 10 | load_sample_img) 11 | from dprox.linop.conv import conv_doe 12 | from dprox.utils import * 13 | 14 | # -------------------- Define Model --------------------- # 15 | 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | config = DOEModelConfig() 18 | rgb_collim_model = build_doe_model(config).to(device) 19 | circular = config.circular 20 | fresnel_phase_c = build_baseline_profile(rgb_collim_model) 21 | 22 | # -------------------- Define Solver --------------------- # 23 | 24 | x = Variable() 25 | y = Placeholder() 26 | PSF = Placeholder() 27 | data_term = sum_squares(conv_doe(x, PSF, circular=circular), y) 28 | reg_term = deep_prior(x, denoiser='ffdnet_color') 29 | solver = compile(data_term + reg_term, method='admm') 30 | solver.eval() 31 | 32 | # -------------------- Define Forward --------------------- # 33 | 34 | sigma = 7.65 / 255 35 | max_iter = 10 36 | rhos, sigmas = log_descent(49, 7.65, max_iter, sigma=max(0.255 / 255, sigma)) 37 | 38 | 39 | def step_fn(gt): 40 | gt = gt.to(device).float() 41 | psf = rgb_collim_model.get_psf(fresnel_phase_c) 42 | inp = img_psf_conv(gt, psf, circular=circular) 43 | inp = inp + torch.randn(*inp.shape, device=inp.device) * sigma 44 | y.value = inp 45 | PSF.value = psf 46 | out = solver.solve(x0=inp, 47 | rhos=rhos, 48 | lams={reg_term: sigmas}, 49 | max_iter=max_iter) 50 | return gt, inp, out 51 | 52 | # -------------------- Evalution --------------------- # 53 | 54 | 55 | tl.metrics.set_data_format('chw') 56 | 57 | # datasets = ['McMaster', 'Kodak24'] 58 | datasets = ['Urban100'] 59 | 60 | for dataset in datasets: 61 | root = 'data/test/' + dataset 62 | logger = tl.logging.Logger('saved/pnp/results' + dataset, name=dataset) 63 | tracker = tl.trainer.util.MetricTracker() 64 | 65 | timer = tl.utils.Timer() 66 | timer.tic() 67 | for idx, name in enumerate(os.listdir(root)): 68 | gt = load_sample_img(os.path.join(root, name)) 69 | 70 | torch.manual_seed(idx) 71 | torch.cuda.manual_seed(idx) 72 | gt, inp, pred = step_fn(gt) 73 | pred = pred.clamp(0, 1) 74 | 75 | psnr = tl.metrics.psnr(pred, gt) 76 | ssim = tl.metrics.ssim(pred, gt) 77 | 78 | tracker.update('psnr', psnr) 79 | tracker.update('ssim', ssim) 80 | 81 | logger.info('{} PSNR {} SSIM {}'.format(name, psnr, ssim)) 82 | logger.save_img(name, pred) 83 | 84 | logger.info('averge results') 85 | logger.info(tracker.summary()) 86 | print(timer.toc() / len(os.listdir(root))) 87 | -------------------------------------------------------------------------------- /examples/papers/deltaprox_siggraph_2023/csmri/README.md: -------------------------------------------------------------------------------- 1 | # Compressed-MRI 2 | 3 | The application of Delta-Prox for end-to-end compressed-MRI. 4 | 5 | ## Prepare Data and Checkpoints 6 | 7 | You download the testing datasests from huggingface, [MICCAI_2020](https://huggingface.co/datasets/delta-prox/MICCAI_2020), [Medical_7_2020](https://huggingface.co/datasets/delta-prox/Medical_7_2020), [examples](https://huggingface.co/datasets/delta-prox/examples) 8 | 9 | To train the rl solver, you need to download the training dataset, [Image128](https://huggingface.co/datasets/delta-prox/Image128), as well. 10 | 11 | The checkpoints is hosted at [aaronb/CSMRI](https://huggingface.co/aaronb/CSMRI). 12 | 13 | ## Evaluation & training 14 | 15 | All scripts can be directly executed, e.g., 16 | 17 | ```bash 18 | python deq_tfpnp.py 19 | ``` 20 | 21 | For more configuration args. Use 22 | ```bash 23 | python deq_tfpnp.py --help 24 | ``` 25 | 26 | -------------------------------------------------------------------------------- /examples/papers/deltaprox_siggraph_2023/csmri/download_datasets.py: -------------------------------------------------------------------------------- 1 | import dprox.utils.huggingface as hf 2 | 3 | hf.download_dataset('Medical_7_2020', local_dir='data/Medical7_2020') 4 | hf.download_dataset('MICCAI_2020', local_dir='data/MICCAI_2020') 5 | -------------------------------------------------------------------------------- /examples/papers/deltaprox_siggraph_2023/csmri/pnp_drunet.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from tfpnp.utils.metric import psnr_qrnn3d 5 | from torch.utils.data import DataLoader 6 | from torchlight.logging import Logger 7 | 8 | from dprox import * 9 | from dprox.algo.tune import * 10 | from dprox.utils import * 11 | from dprox.contrib.csmri import CustomADMM, EvalDataset 12 | 13 | 14 | def main(): 15 | 16 | x = Variable() 17 | y = Placeholder() 18 | mask = Placeholder() 19 | data_term = csmri(x, mask, y) 20 | reg_term = deep_prior(x, denoiser='drunet') 21 | solver = CustomADMM([reg_term], [data_term]).cuda() 22 | 23 | def step_fn(batch): 24 | y.value = batch['y0'].cuda() 25 | mask.value = batch['mask'].cuda() 26 | target = batch['gt'].cuda() 27 | x0 = batch['x0'].cuda() 28 | max_iter = 24 29 | 30 | # Medical7_2020/radial_128_2/5 31 | # rhos, sigmas = log_descent(50, 40, max_iter) 32 | # rhos, _ = log_descent(125, 0.1, max_iter) 33 | 34 | # Medical7_2020/radial_128_2/10 35 | # rhos, sigmas = log_descent(40, 40, max_iter) 36 | # rhos, _ = log_descent(0.1, 0.1, max_iter) 37 | 38 | # Medical7_2020/radial_128_4/5 39 | rhos, sigmas = log_descent(80, 40, max_iter) 40 | rhos, _ = log_descent(10, 0.1, max_iter) 41 | 42 | # Medical7_2020/radial_128_4/15 43 | # rhos, sigmas = log_descent(80, 55, max_iter) 44 | # rhos, _ = log_descent(1, 0.1, max_iter) 45 | 46 | # MICCAI_2020/radial_128_4/5 47 | # rhos, sigmas = log_descent(60, 40, max_iter) 48 | # rhos, _ = log_descent(4, 0.1, max_iter) 49 | 50 | # MICCAI_2020/radial_128_8/5 51 | # rhos, sigmas = log_descent(60, 40, max_iter) 52 | # rhos, _ = log_descent(1, 0.1, max_iter) 53 | 54 | # rhos, sigmas = log_descent(50, 1, max_iter) 55 | pred = solver.solve(x0=x0, rhos=rhos, lams={reg_term: sigmas}, max_iter=max_iter).real 56 | return target, pred 57 | 58 | valid_datasets = { 59 | # 'Medical7_2020/radial_128_2/5': EvalDataset('data/csmri/Medical7_2020/radial_128_2/5'), 60 | 'Medical7_2020/radial_128_4/5': EvalDataset('data/csmri/Medical7_2020/radial_128_4/5'), 61 | # 'Medical7_2020/radial_128_2/10': EvalDataset('data/csmri/Medical7_2020/radial_128_2/10'), 62 | # 'Medical7_2020/radial_128_8/15': EvalDataset('data/csmri/Medical7_2020/radial_128_8/15'), 63 | # 'Medical7_2020/radial_128_4/15': EvalDataset('data/csmri/Medical7_2020/radial_128_4/15'), 64 | # 'MICCAI_2020/radial_128_4/5': EvalDataset('data/csmri/MICCAI_2020/radial_128_4/5'), 65 | # 'MICCAI_2020/radial_128_8/15': EvalDataset('data/csmri/MICCAI_2020/radial_128_8/15'), 66 | } 67 | 68 | save_root = 'abc/pnp_drunet' 69 | 70 | for name, valid_dataset in valid_datasets.items(): 71 | total_psnr = 0 72 | test_loader = DataLoader(valid_dataset) 73 | 74 | save_dir = os.path.join(save_root, name) 75 | os.makedirs(save_dir, exist_ok=True) 76 | logger = Logger(save_dir, name=name) 77 | 78 | for idx, batch in enumerate(test_loader): 79 | with torch.no_grad(): 80 | target, pred = step_fn(batch) 81 | psnr = psnr_qrnn3d(target.squeeze(0).cpu().numpy(), 82 | pred.squeeze(0).cpu().numpy(), 83 | data_range=1) 84 | total_psnr += psnr 85 | logger.save_img(f'{idx}.png', pred) 86 | 87 | logger.info('{} avg psnr= {}'.format(name, total_psnr / len(test_loader))) 88 | 89 | 90 | main() 91 | -------------------------------------------------------------------------------- /examples/papers/deltaprox_siggraph_2023/csmri/pnp_unet.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from tfpnp.utils.metric import psnr_qrnn3d 5 | from torch.utils.data import DataLoader 6 | from torchlight.logging import Logger 7 | 8 | from dprox import * 9 | from dprox.algo.tune import * 10 | from dprox.utils import * 11 | from dprox.contrib.csmri import CustomADMM, EvalDataset 12 | 13 | 14 | def main(): 15 | x = Variable() 16 | y = Placeholder() 17 | mask = Placeholder() 18 | data_term = csmri(x, mask, y) 19 | reg_term = deep_prior(x, denoiser='unet') 20 | solver = CustomADMM([reg_term], [data_term]).cuda() 21 | 22 | def step_fn(batch): 23 | y.value = batch['y0'].cuda() 24 | mask.value = batch['mask'].cuda() 25 | target = batch['gt'].cuda() 26 | x0 = batch['x0'].cuda() 27 | max_iter = 24 28 | 29 | # Medical7_2020/radial_128_2/5 30 | rhos, sigmas = log_descent(50, 40, max_iter) 31 | rhos, _ = log_descent(125, 0.1, max_iter) 32 | 33 | # Medical7_2020/radial_128_2/10 34 | rhos, sigmas = log_descent(40, 40, max_iter) 35 | rhos, _ = log_descent(0.1, 0.1, max_iter) 36 | 37 | # Medical7_2020/radial_128_4/5 38 | rhos, sigmas = log_descent(70, 40, max_iter) 39 | rhos, _ = log_descent(120, 0.1, max_iter) 40 | 41 | # Medical7_2020/radial_128_4/15 42 | # rhos, sigmas = log_descent(50, 40, max_iter) 43 | # rhos, _ = log_descent(0.1, 0.1, max_iter) 44 | 45 | # # MICCAI_2020/radial_128_4/5 46 | # rhos, sigmas = log_descent(60, 40, max_iter) 47 | # rhos, _ = log_descent(4, 0.1, max_iter) 48 | 49 | # # MICCAI_2020/radial_128_8/5 50 | # rhos, sigmas = log_descent(60, 40, max_iter) 51 | # rhos, _ = log_descent(1, 0.1, max_iter) 52 | 53 | # rhos, sigmas = log_descent(50, 1, max_iter) 54 | pred = solver.solve(x0=x0, rhos=rhos, lams={reg_term: sigmas}, max_iter=max_iter).real 55 | return target, pred 56 | 57 | valid_datasets = { 58 | 'Medical7_2020/radial_128_4/5': EvalDataset('data/Medical7_2020/radial_128_4/5'), 59 | # 'Medical7_2020/radial_128_4/15': EvalDataset('data/Medical7_2020/radial_128_4/15'), 60 | # 'MICCAI_2020/radial_128_4/5': EvalDataset('data/MICCAI_2020/radial_128_4/5'), 61 | # 'MICCAI_2020/radial_128_8/5': EvalDataset('data/MICCAI_2020/radial_128_8/5'), 62 | } 63 | 64 | save_root = 'abc/pnp_unet' 65 | 66 | for name, valid_dataset in valid_datasets.items(): 67 | total_psnr = 0 68 | test_loader = DataLoader(valid_dataset) 69 | 70 | save_dir = os.path.join(save_root, name) 71 | os.makedirs(save_dir, exist_ok=True) 72 | logger = Logger(save_dir, name=name) 73 | 74 | for batch in test_loader: 75 | with torch.no_grad(): 76 | target, pred = step_fn(batch) 77 | psnr = psnr_qrnn3d(target.squeeze(0).cpu().numpy(), 78 | pred.squeeze(0).cpu().numpy(), 79 | data_range=1) 80 | total_psnr += psnr 81 | logger.save_img(f'{psnr:0.2f}.png', pred) 82 | 83 | logger.info('{} avg psnr= {}'.format(name, total_psnr / len(test_loader))) 84 | 85 | 86 | main() 87 | -------------------------------------------------------------------------------- /examples/papers/deltaprox_siggraph_2023/csmri/rl_unet_train.py: -------------------------------------------------------------------------------- 1 | from scipy.io import loadmat 2 | from pathlib import Path 3 | from tfpnp.utils.noise import GaussianModelD 4 | 5 | from dprox import * 6 | from dprox.algo.tune import * 7 | from dprox.utils import * 8 | from dprox.contrib.csmri import (CustomADMM, CustomEnv, 9 | EvalDataset, TrainDataset) 10 | 11 | 12 | def build_solver(): 13 | x = Variable() 14 | y = Placeholder() 15 | mask = Placeholder() 16 | data_term = csmri(x, mask, y) 17 | reg_term = deep_prior(x, denoiser='unet') 18 | solver = CustomADMM([reg_term], [data_term]) 19 | return solver, {'y': y, 'mask': mask} 20 | 21 | 22 | def main(): 23 | solver, placeholders = build_solver() 24 | 25 | # dataset 26 | examples_root = Path(hf.download_dataset('examples', force_download=False)) 27 | train_root = Path(hf.download_dataset('Images128', force_download=False)) 28 | val_root = Path(hf.download_dataset('Medical_7_2020', force_download=False)) 29 | 30 | mask_dir = examples_root / 'csmri/masks' 31 | sigma_ns = [5, 10, 15] 32 | sampling_masks = ['radial_128_2', 'radial_128_4', 'radial_128_8'] 33 | 34 | noise_model = GaussianModelD(sigma_ns) 35 | masks = [loadmat(str(mask_dir / f'{sampling_mask}.mat')).get('mask') 36 | for sampling_mask in sampling_masks] 37 | dataset = TrainDataset(train_root, fns=None, masks=masks, noise_model=noise_model) 38 | 39 | valid_datasets = { 40 | 'Medical_7_2020/radial_128_2/15': EvalDataset(val_root / 'radial_128_2/15'), 41 | 'Medical_7_2020/radial_128_4/15': EvalDataset(val_root / 'radial_128_4/15'), 42 | 'Medical_7_2020/radial_128_8/15': EvalDataset(val_root / 'radial_128_8/15'), 43 | } 44 | 45 | # train 46 | 47 | tf_solver = AutoTuneSolver(solver, policy='resnet') 48 | 49 | training_cfg = dict( 50 | rmsize=480, 51 | max_episode_step=30, 52 | train_steps=15000, 53 | warmup=20, 54 | save_freq=1000, 55 | validate_interval=10, 56 | episode_train_times=10, 57 | env_batch=48, 58 | loop_penalty=0.05, 59 | discount=0.99, 60 | lambda_e=0.2, 61 | tau=0.001, 62 | action_pack=1, 63 | log_dir='rl_unet', 64 | custom_env=CustomEnv, 65 | ) 66 | tf_solver.train(dataset, valid_datasets, placeholders, **training_cfg) 67 | ckpt_path = 'ckpt/tfpnp_unet/actor_best.pkl' 68 | tf_solver.eval(ckpt_path, valid_datasets, placeholders, custom_env=CustomEnv) 69 | 70 | 71 | if __name__ == '__main__': 72 | main() 73 | -------------------------------------------------------------------------------- /examples/papers/deltaprox_siggraph_2023/deraining/README.md: -------------------------------------------------------------------------------- 1 | # Image Deraining 2 | 3 | The application of Delta-Prox for image-deraining. 4 | 5 | ## Prepare Data and Checkpoints 6 | 7 | - Download test data from [Google drive](https://drive.google.com/file/d/1P_-RAvltEoEhfT-9GrWRdpEi6NSswTs8/view?usp=sharing 8 | ). Put it into the `datasets/test` folder. For example, use `datasets/test/Rain100H`. 9 | 10 | - Download the checkpoints from [Huggingface](https://huggingface.co/delta-prox/image_deraining). 11 | 12 | ## Evaluation 13 | 14 | - Unrolled Proximal Gradient Descent with shared parameters and restormer as initializer. 15 | 16 | ```bash 17 | python test_unroll_share.py 18 | python evaluate_PSNR_SSIM.py 19 | ``` 20 | 21 | - Unrolled Proximal Gradient Descent with unshared parameters 22 | 23 | ```bash 24 | python test_unroll.py 25 | python evaluate_PSNR_SSIM.py 26 | ``` 27 | 28 | > To obtain the paper results, please use the matlab script `evaluate_PSNR_SSIM.m`. 29 | 30 | ## Acknowledgement 31 | 32 | - [Restormer](https://github.com/swz30/Restormer) 33 | - [DGUNet](https://github.com/MC-E/Deep-Generalized-Unfolding-Networks-for-Image-Restoration) -------------------------------------------------------------------------------- /examples/papers/deltaprox_siggraph_2023/deraining/derain/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Delta-Prox/15fc489e99c4de2676dc538d9c0552338f6fb894/examples/papers/deltaprox_siggraph_2023/deraining/derain/__init__.py -------------------------------------------------------------------------------- /examples/papers/deltaprox_siggraph_2023/deraining/derain/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .dataset import DataLoaderTrain, DataLoaderVal, DataLoaderTest, DataLoaderTest_fine 3 | 4 | def get_training_data(rgb_dir, img_options): 5 | assert os.path.exists(rgb_dir) 6 | return DataLoaderTrain(rgb_dir, img_options) 7 | 8 | def get_validation_data(rgb_dir, img_options): 9 | assert os.path.exists(rgb_dir) 10 | return DataLoaderVal(rgb_dir, img_options) 11 | 12 | def get_test_data(rgb_dir, img_options): 13 | assert os.path.exists(rgb_dir) 14 | return DataLoaderTest(rgb_dir, img_options) 15 | 16 | def get_test_data_fine(rgb_dir, img_options): 17 | assert os.path.exists(rgb_dir) 18 | return DataLoaderTest_fine(rgb_dir, img_options) -------------------------------------------------------------------------------- /examples/papers/deltaprox_siggraph_2023/deraining/test_unroll.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import fire 4 | 5 | import imageio 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from tqdm import tqdm 9 | 10 | from dprox import * 11 | from dprox.utils import * 12 | 13 | from derain.data import get_test_data 14 | from derain.unroll import LearnableDegOp, unrolled_prior 15 | 16 | 17 | def build_solver(): 18 | # custom linop 19 | A = LearnableDegOp().cuda() 20 | def forward_fn(input, step): return A.forward(input, step) 21 | def adjoint_fn(input, step): return A.adjoint(input, step) 22 | raining = LinOpFactory(forward_fn, adjoint_fn) 23 | 24 | # build solver 25 | x = Variable() 26 | b = Placeholder() 27 | data_term = sum_squares(raining(x), b) 28 | reg_term = unrolled_prior(x) 29 | obj = data_term + reg_term 30 | solver = compile(obj, method='pgd', device='cuda') 31 | 32 | # load parameters 33 | ckpt = torch.load('derain_pdg_unroll.pth') 34 | A.load_state_dict(ckpt['linop']) 35 | reg_term.load_state_dict(ckpt['prior']) 36 | rhos = ckpt['rhos'] 37 | 38 | return solver, rhos, b 39 | 40 | 41 | @torch.no_grad() 42 | def main( 43 | dataset='Rain100H', 44 | result_dir='results/dprox_pdg_unroll', 45 | data_dir='datasets/test/', 46 | ): 47 | solver, rhos, b = build_solver() 48 | 49 | rgb_dir_test = os.path.join(data_dir, dataset, 'input') 50 | test_dataset = get_test_data(rgb_dir_test, img_options={}) 51 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, 52 | shuffle=False, num_workers=4, 53 | drop_last=False, pin_memory=True) 54 | 55 | result_dir = os.path.join(result_dir, dataset) 56 | os.makedirs(result_dir, exist_ok=True) 57 | 58 | for data in tqdm(test_loader): 59 | input = data[0].cuda() 60 | filenames = data[1] 61 | 62 | b.value = input 63 | output = solver.solve(x0=input, rhos=rhos, max_iter=7) 64 | output = output + input 65 | 66 | output = torch.clamp(output, 0, 1) 67 | output = output.permute(0, 2, 3, 1).cpu().detach().numpy() * 255 68 | 69 | for batch in range(len(output)): 70 | restored_img = output[batch].astype('uint8') 71 | imageio.imsave( 72 | os.path.join(result_dir, filenames[batch] + '.png'), 73 | restored_img 74 | ) 75 | 76 | 77 | if __name__ == '__main__': 78 | fire.Fire(main) 79 | -------------------------------------------------------------------------------- /examples/papers/deltaprox_siggraph_2023/deraining/test_unroll_share.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import fire 4 | import imageio 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from tqdm import tqdm 8 | 9 | from dprox import * 10 | from dprox.utils import * 11 | 12 | from derain.data import get_test_data 13 | from derain.restormer import Restormer 14 | from derain.unroll_share import LearnableDegOp, unrolled_prior 15 | 16 | 17 | def build_solver(): 18 | # custom linop 19 | A = LearnableDegOp().cuda() 20 | def forward_fn(input, step): return A.forward(input, step) 21 | def adjoint_fn(input, step): return A.adjoint(input, step) 22 | raining = LinOpFactory(forward_fn, adjoint_fn) 23 | 24 | # build solver 25 | x = Variable() 26 | b = Placeholder() 27 | data_term = sum_squares(raining(x), b) 28 | reg_term = unrolled_prior(x) 29 | obj = data_term + reg_term 30 | solver = compile(obj, method='pgd', device='cuda') 31 | 32 | # load parameters 33 | ckpt = torch.load('derain_pdg_unroll_share.pth') 34 | A.load_state_dict(ckpt['linop']) 35 | reg_term.load_state_dict(ckpt['prior']) 36 | rhos = ckpt['rhos'] 37 | 38 | return solver, rhos, b 39 | 40 | 41 | @torch.no_grad() 42 | def main( 43 | dataset='Rain100H', 44 | result_dir='results/dprox_unroll_share', 45 | data_dir='datasets/test/', 46 | ): 47 | solver, rhos, b = build_solver() 48 | 49 | rgb_dir_test = os.path.join(data_dir, dataset, 'input') 50 | test_dataset = get_test_data(rgb_dir_test, img_options={}) 51 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, 52 | shuffle=False, num_workers=4, 53 | drop_last=False, pin_memory=True) 54 | 55 | result_dir = os.path.join(result_dir, dataset) 56 | os.makedirs(result_dir, exist_ok=True) 57 | 58 | restormer = Restormer().cuda() 59 | restormer.load_state_dict(torch.load('restormer.pth')['params']) 60 | 61 | for data in tqdm(test_loader): 62 | input = data[0].cuda() 63 | filenames = data[1] 64 | 65 | def init_hook(iter, state, rho, lam): 66 | if iter == 5: 67 | state[0] = restormer(input) 68 | 69 | b.value = input 70 | output = solver.solve(x0=input, rhos=rhos, max_iter=7, callback=init_hook) 71 | output = output + input 72 | 73 | output = torch.clamp(output, 0, 1) 74 | output = output.permute(0, 2, 3, 1).cpu().detach().numpy() * 255 75 | 76 | for batch in range(len(output)): 77 | restored_img = output[batch].astype('uint8') 78 | imageio.imsave( 79 | os.path.join(result_dir, filenames[batch] + '.png'), 80 | restored_img 81 | ) 82 | 83 | 84 | if __name__ == '__main__': 85 | fire.Fire(main) 86 | -------------------------------------------------------------------------------- /examples/papers/dgunet_cvpr_2021/main.py: -------------------------------------------------------------------------------- 1 | from dprox import * 2 | from dprox.utils import * 3 | 4 | from ops import dgu_linop, dgu_prior 5 | 6 | 7 | def test(): 8 | b = hf.load_image('examples/derain/derain_input.png') 9 | gt = hf.load_image('examples/derain/derain_target.png') 10 | imshow(gt, b) 11 | 12 | # custom linop 13 | A = dgu_linop().cuda() 14 | def forward_fn(input, step): return A.forward(input, step) 15 | def adjoint_fn(input, step): return A.adjoint(input, step) 16 | raining = LinOpFactory(forward_fn, adjoint_fn) 17 | 18 | # build solver 19 | x = Variable() 20 | data_term = sum_squares(raining(x), b) 21 | reg_term = dgu_prior(x) 22 | obj = data_term + reg_term 23 | solver = compile(obj, method='pgd') 24 | 25 | # load parameters 26 | ckpt = hf.load_checkpoint('image_deraining/derain_pdg.pth') 27 | A.load_state_dict(ckpt['linop']) 28 | reg_term.load_state_dict(ckpt['prior']) 29 | rhos = ckpt['rhos'] 30 | 31 | out = solver.solve(x0=b, rhos=rhos, max_iter=7) 32 | out = to_ndarray(out, debatch=True) + b 33 | print(psnr(out, gt)) # 35.92 34 | imshow(gt, out) 35 | assert psnr(out, gt) - 35.92 < 0.1 36 | 37 | 38 | if __name__ == '__main__': 39 | test() 40 | -------------------------------------------------------------------------------- /examples/papers/dgunet_cvpr_2021/ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from dprox import ProxFn 5 | from network import Denoiser, ResBlock, default_conv 6 | 7 | 8 | class dgu_prior(ProxFn): 9 | def __init__(self, linop, denoiser=None): 10 | super().__init__(linop) 11 | if denoiser is not None: 12 | self.denoiser = denoiser() 13 | else: 14 | self.denoiser = Denoiser() 15 | 16 | def eval(self, v): 17 | raise NotImplementedError('deep prior cannot be explictly evaluated') 18 | 19 | def _prox(self, v: torch.Tensor, lam: torch.Tensor = None): 20 | """ v: [N, C, H, W] 21 | lam: [1] 22 | """ 23 | out = self.denoiser(v, self.step) 24 | return out 25 | 26 | 27 | class dgu_linop(nn.Module): 28 | def __init__(self, diag=False): 29 | super().__init__() 30 | self.phi_0 = ResBlock(default_conv, 3, 3) 31 | self.phi_1 = ResBlock(default_conv, 3, 3) 32 | self.phi_6 = ResBlock(default_conv, 3, 3) 33 | self.phit_0 = ResBlock(default_conv, 3, 3) 34 | self.phit_1 = ResBlock(default_conv, 3, 3) 35 | self.phit_6 = ResBlock(default_conv, 3, 3) 36 | 37 | if diag: 38 | self.phid_0 = ResBlock(default_conv, 3, 3) 39 | self.phid_1 = ResBlock(default_conv, 3, 3) 40 | self.phid_6 = ResBlock(default_conv, 3, 3) 41 | 42 | self.max_step = 5 43 | self.step = 0 44 | 45 | def forward(self, x, step=None): 46 | if step is None: step = self.step 47 | if step == 0: 48 | return self.phi_0(x) 49 | elif step == self.max_step + 1: 50 | return self.phi_6(x) 51 | else: 52 | return self.phi_1(x) 53 | 54 | def adjoint(self, x, step=None): 55 | if step is None: step = self.step 56 | if step == 0: 57 | return self.phit_0(x) 58 | elif step == self.max_step + 1: 59 | return self.phit_6(x) 60 | else: 61 | return self.phit_1(x) 62 | 63 | def diag(self, x, step=None): 64 | if step is None: step = self.step 65 | if step == 0: 66 | return self.phid_0(x) 67 | elif step == self.max_step + 1: 68 | return self.phid_6(x) 69 | else: 70 | return self.phid_1(x) 71 | -------------------------------------------------------------------------------- /examples/papers/dphsir_neurcomputing_2022/degrades/__init__.py: -------------------------------------------------------------------------------- 1 | from .general import AffineTransform, PerspectiveTransform, HSI2RGB 2 | from .blur import GaussianBlur, UniformBlur 3 | from .sr import GaussianDownsample, BiCubicDownsample, UniformDownsample 4 | from . import cs 5 | from .noise import GaussianNoise 6 | -------------------------------------------------------------------------------- /examples/papers/dphsir_neurcomputing_2022/degrades/blur.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import ndimage 3 | 4 | from .utils import fspecial_gaussian 5 | 6 | 7 | class AbstractBlur: 8 | def __init__(self, kernel): 9 | self.kernel = kernel 10 | 11 | def __call__(self, img): 12 | # img_L = np.fft.ifftn(np.fft.fftn(img) * np.fft.fftn(np.expand_dims(self.k, axis=2), img.shape)).real 13 | img_L = ndimage.filters.convolve(img, np.expand_dims(self.kernel, axis=2), mode='wrap') 14 | return img_L 15 | 16 | 17 | class GaussianBlur(AbstractBlur): 18 | def __init__(self, ksize=8, sigma=3): 19 | k = fspecial_gaussian(ksize, sigma) 20 | super().__init__(k) 21 | 22 | 23 | class UniformBlur(AbstractBlur): 24 | def __init__(self, ksize): 25 | k = np.ones((ksize, ksize)) / (ksize*ksize) 26 | super().__init__(k) 27 | -------------------------------------------------------------------------------- /examples/papers/dphsir_neurcomputing_2022/degrades/cs.py: -------------------------------------------------------------------------------- 1 | import hdf5storage 2 | import os 3 | import numpy as np 4 | 5 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | 8 | class CASSI(object): 9 | """ Only work when img size = [512, 512, 31] 10 | """ 11 | cache = {} 12 | 13 | def __init__(self, size=None): 14 | if size: 15 | size = tuple(size) 16 | if size not in CASSI.cache: 17 | H, W, C = size 18 | m = np.random.choice([0, 1], size=(H,W), p=[1./2, 1./2]) 19 | ms = [np.roll(m, i, axis=0) for i in range(C)] 20 | mask = np.stack(ms, axis=2) 21 | CASSI.cache[size] = mask.astype('float32') 22 | self.mask = CASSI.cache[size] 23 | else: 24 | self.mask = hdf5storage.loadmat(os.path.join(CURRENT_DIR, 'kernels', 'cs_mask_cassi.mat'))['mask'] 25 | 26 | def __call__(self, img): 27 | return np.sum(img * self.mask, axis=2) 28 | -------------------------------------------------------------------------------- /examples/papers/dphsir_neurcomputing_2022/degrades/general.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | from scipy.io import loadmat 6 | 7 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) 8 | 9 | 10 | class AffineTransform(object): 11 | def __call__(self, x): 12 | srcTri = np.array([[0, 0], [x.shape[1] - 1, 0], [0, x.shape[0] - 1]]).astype(np.float32) 13 | dstTri = np.array([[0, x.shape[1]*0.05], [x.shape[1]*0.99, x.shape[0]*0], [x.shape[1]*0.05, x.shape[0]*0.99]]).astype(np.float32) 14 | warp_mat = cv2.getAffineTransform(srcTri, dstTri) 15 | warp_dst = cv2.warpAffine(x, warp_mat, (x.shape[1], x.shape[0])) 16 | return warp_dst 17 | 18 | 19 | class PerspectiveTransform(object): 20 | def __init__(self, shift): 21 | self.shift = shift 22 | 23 | def __call__(self, img): 24 | rows, cols, _ = img.shape 25 | pts1 = np.float32([[0, 0], [rows, 0], [self.shift, cols], [rows-self.shift, cols]]) 26 | pts2 = np.float32([[0, 0], [rows, 0], [0, cols], [rows, cols]]) 27 | 28 | M = cv2.getPerspectiveTransform(pts1, pts2) 29 | dst = cv2.warpPerspective(img, M, (rows, cols)) 30 | return dst 31 | 32 | 33 | class HSI2RGB(object): 34 | def __init__(self, srf=None): 35 | if srf is None: 36 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) 37 | self.srf = loadmat(os.path.join(CURRENT_DIR, 'kernels', 'misr_spe_p.mat'))['P'] # (3,31) 38 | else: 39 | self.srf = srf 40 | 41 | def __call__(self, img): 42 | return img @ self.srf.transpose() 43 | -------------------------------------------------------------------------------- /examples/papers/dphsir_neurcomputing_2022/degrades/inpaint.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class RandomMask: 5 | def __init__(self, ratio=0.2): 6 | """ ratio: ratio to keep 7 | """ 8 | self.ratio = ratio 9 | 10 | def __call__(self, img): 11 | mask = (np.random.rand(*img.shape) > (1-self.ratio)).astype('float') 12 | img_L = img * mask 13 | return img_L, mask 14 | 15 | 16 | class StripeMask: 17 | """ Input: [W,H,B] """ 18 | 19 | def __init__(self, bandwise=False): 20 | self.bandwise = bandwise 21 | 22 | def __call__(self, img): 23 | mask = np.ones_like(img) 24 | for i in range(0, img.shape[1]-5, 10): 25 | mask[:, i+3] = 0 26 | mask[:, i+4] = 0 27 | mask[:, i+5] = 0 28 | img_L = img * mask 29 | return img_L, mask 30 | 31 | 32 | class RandomStripe: 33 | """ Input: [W,H,B] """ 34 | 35 | def __init__(self, num_bands=4, bandwise=True, ratio=0.2): 36 | self.num_bands = num_bands # how many bands will be added stripe noise 37 | self.bandwise = bandwise # if the location of stripe noise are the same 38 | self.ratio = ratio 39 | 40 | def __call__(self, img): 41 | mask = np.ones_like(img) 42 | w, h, b = img.shape 43 | # random select 4 band to add stripe (actually dead line) 44 | # start_band = np.random.choice(b-self.num_bands, replace=False) 45 | for i in range(b): 46 | stripes = np.random.choice(h, int(h*self.ratio), replace=False) 47 | for j in stripes: 48 | mask[:, j:j+4, i] = 0 49 | img_L = img * mask 50 | return img_L, mask 51 | 52 | 53 | class FastHyStripe: 54 | """ Input: [W,H,B] """ 55 | 56 | def __init__(self, num_bands=15, bandwise=False): 57 | self.num_bands = num_bands # how many bands will be added stripe noise 58 | self.bandwise = bandwise # if the location of stripe noise are the same 59 | 60 | def __call__(self, img): 61 | import time 62 | np.random.seed(int(time.time())) 63 | mask = np.ones_like(img) 64 | w, h, b = img.shape 65 | # random select 4 band to add stripe (actually dead line) 66 | 67 | start_band = 10 68 | # start_band = 0 69 | for i in range(start_band, start_band+self.num_bands): 70 | stripes = np.random.choice(h, 20, replace=False) 71 | for k, j in enumerate(stripes): 72 | t = np.random.rand() 73 | if k == 4: 74 | mask[:, j:j+30, i] = 0 75 | elif k == 10: 76 | mask[:, j:j+15, i] = 0 77 | elif t > 0.6: 78 | mask[:, j:j+4, i] = 0 79 | else: 80 | mask[:, j:j+2, i] = 0 81 | if self.bandwise: 82 | break 83 | if self.bandwise: 84 | mask[:, :, start_band:start_band+self.num_bands] = np.expand_dims(mask[:, :, start_band], axis=-1) 85 | img_L = img * mask 86 | return img_L, mask 87 | -------------------------------------------------------------------------------- /examples/papers/dphsir_neurcomputing_2022/degrades/noise.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class GaussianNoise: 4 | def __init__(self, sigma): 5 | np.random.seed(seed=0) # for reproducibility 6 | self.sigma = sigma 7 | 8 | def __call__(self, img): 9 | img_L = img + np.random.normal(0, self.sigma, img.shape) 10 | return img_L -------------------------------------------------------------------------------- /examples/papers/dphsir_neurcomputing_2022/degrades/sr.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import abc 4 | 5 | import hdf5storage 6 | import numpy as np 7 | 8 | from .utils import imresize_np 9 | from .blur import AbstractBlur, GaussianBlur, UniformBlur 10 | 11 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) 12 | 13 | 14 | class AbstractDownsample(abc.ABC): 15 | def __init__(self, sf, kernel): 16 | self.sf = sf 17 | self.kernel = kernel 18 | 19 | 20 | class ClassicalDownsample(AbstractDownsample): 21 | def __init__(self, sf, blur: AbstractBlur): 22 | super().__init__(sf, blur.kernel) 23 | self.blur = blur 24 | 25 | def __call__(self, img): 26 | """ input: [w,h,c] 27 | data range: both (0,255), (0,1) are ok 28 | """ 29 | img = self.blur(img) 30 | img = img[0::self.sf, 0::self.sf, ...] 31 | return img 32 | 33 | 34 | class GaussianDownsample(ClassicalDownsample): 35 | def __init__(self, sf, ksize=8, sigma=3): 36 | blur = GaussianBlur(ksize, sigma) 37 | super().__init__(sf, blur) 38 | 39 | 40 | class UniformDownsample(ClassicalDownsample): 41 | def __init__(self, sf): 42 | blur = UniformBlur(sf) 43 | super().__init__(sf, blur) 44 | 45 | 46 | class BiCubicDownsample(AbstractDownsample): 47 | kernel_path = os.path.join(CURRENT_DIR, 'kernels', 'kernels_bicubicx234.mat') 48 | valid_sfs = [2, 3, 4] 49 | 50 | def __init__(self, sf): 51 | if sf not in self.valid_sfs: 52 | raise ValueError(f'Invalid scale factor, choose from {self.valid_sfs}') 53 | self.sf = sf 54 | self.kernels = hdf5storage.loadmat(self.kernel_path)['kernels'] 55 | self.kernel = self.kernels[0, sf-2].astype(np.float64) 56 | 57 | def __call__(self, img): 58 | """ input: [w,h,c] 59 | data range: both (0,255), (0,1) are ok 60 | """ 61 | img_L = imresize_np(img, 1/self.sf) 62 | return img_L 63 | -------------------------------------------------------------------------------- /examples/papers/dphsir_neurcomputing_2022/hsi_compress_sensing.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import numpy as np 3 | from scipy.io import loadmat 4 | import torch 5 | 6 | from dprox import * 7 | from dprox.utils import * 8 | 9 | from degrades.cs import CASSI 10 | 11 | # ------------------------------------- # 12 | # Prepare Data # 13 | # ------------------------------------- # 14 | 15 | I = loadmat('Lehavim_0910-1717.mat')['gt'] 16 | 17 | down = CASSI() 18 | mask = down.mask.astype('float32') 19 | y = down(I).astype('float32') 20 | 21 | x0 = np.expand_dims(y, axis=-1) * mask 22 | 23 | print(I.shape, y.shape, x0.shape) 24 | imshow(I[:, :, 20], y, x0[:, :, 20]) 25 | 26 | # % 27 | # ----------------------------------- # 28 | # Define and Solve # 29 | # ----------------------------------- # 30 | 31 | x = Variable() 32 | data_term = compress_sensing(x, mask, y) 33 | reg_term = deep_prior(x, denoiser='grunet') 34 | 35 | solver = compile(data_term+reg_term) 36 | solver.to(torch.device('cuda')) 37 | 38 | iter_num = 24 39 | rhos, sigmas = log_descent(50, 45, iter_num) 40 | x_pred = solver.solve(x0, 41 | rhos=rhos, 42 | weights={reg_term: sigmas}, 43 | max_iter=iter_num, 44 | pbar=True) 45 | 46 | out = to_ndarray(x_pred, debatch=True) 47 | 48 | print(mpsnr(out, I)) # mpsnr: 39.18 49 | imshow(out[:, :, 20]) 50 | # %% 51 | -------------------------------------------------------------------------------- /examples/papers/dphsir_neurcomputing_2022/hsi_deblur.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import torch 3 | from scipy.io import loadmat 4 | 5 | from dprox import * 6 | from dprox.utils import * 7 | 8 | from degrades.blur import GaussianBlur 9 | 10 | # ------------------------------------- # 11 | # Prepare Data # 12 | # ------------------------------------- # 13 | 14 | 15 | I = loadmat('Lehavim_0910-1717.mat')['gt'] 16 | 17 | blur = GaussianBlur() 18 | b = blur(I) 19 | 20 | print(I.shape, b.shape) 21 | imshow(I[:, :, 20], b[:, :, 20]) 22 | 23 | # %% 24 | # ----------------------------------- # 25 | # Define and Solve # 26 | # ----------------------------------- # 27 | 28 | x = Variable() 29 | data_term = sum_squares(conv(x, blur.kernel), b) 30 | reg_term = deep_prior(x, denoiser='grunet') 31 | 32 | device = torch.device('cuda') 33 | solver = compile(data_term + reg_term) 34 | 35 | iter_num = 24 36 | rhos, sigmas = log_descent(35, 10, iter_num) 37 | x_pred = solver.solve(to_torch_tensor(b, batch=True).to(device), 38 | rhos=1e-10, 39 | weights={reg_term: sigmas}, # 54.97 if reg_term: 0.23 40 | max_iter=iter_num, 41 | eps=0, 42 | pbar=True) 43 | 44 | out = to_ndarray(x_pred, debatch=True) 45 | print(mpsnr(out, I)) # 54.22/55.10, 78.31 if rho = 1e-10 46 | imshow(out[:, :, 20]) 47 | 48 | # two deep prior: 53.00 49 | 50 | # %% 51 | -------------------------------------------------------------------------------- /examples/papers/dphsir_neurcomputing_2022/hsi_inpainting.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from scipy.io import loadmat 3 | 4 | from dprox import * 5 | from dprox.utils import * 6 | 7 | from degrades.inpaint import FastHyStripe 8 | 9 | # ------------------------------------- # 10 | # Prepare Data # 11 | # ------------------------------------- # 12 | 13 | I = loadmat('Lehavim_0910-1717.mat')['gt'] 14 | 15 | degrade = FastHyStripe() 16 | b, mask = degrade(I) 17 | mask = mask.astype('float') 18 | 19 | imshow(I[:,:,20], b[:,:,20]) 20 | 21 | #%% 22 | # ----------------------------------- # 23 | # Define and Solve # 24 | # ----------------------------------- # 25 | 26 | S = MulElementwise(mask) 27 | 28 | x = Variable() 29 | data_term = sum_squares(S, b) 30 | deep_reg = deep_prior(x, denoiser='grunet') 31 | problem = Problem(data_term+deep_reg) 32 | 33 | rhos, sigmas = log_descent(5, 4, iter=24, lam=0.6) 34 | x_pred = problem.solve(solver='admm', 35 | x0=b, 36 | weights={deep_reg: sigmas}, 37 | rhos=rhos, 38 | max_iter=24, 39 | pbar=True) 40 | 41 | out = to_ndarray(x_pred, debatch=True) 42 | 43 | print(mpsnr(out, I)) # 74.92/74.88 44 | imshow(out[:,:,20]) 45 | 46 | # %% 47 | -------------------------------------------------------------------------------- /examples/papers/dphsir_neurcomputing_2022/hsi_misr.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import cv2 3 | from scipy.io import loadmat 4 | from dprox import * 5 | from dprox.utils import * 6 | 7 | from degrades.sr import GaussianDownsample 8 | from degrades.general import HSI2RGB 9 | 10 | # ------------------------------------- # 11 | # Prepare Data # 12 | # ------------------------------------- # 13 | 14 | I = loadmat('Lehavim_0910-1717.mat')['gt'].astype('float32') 15 | 16 | sf = 2 17 | down = GaussianDownsample(sf=sf) 18 | y = down(I) 19 | 20 | x0 = cv2.resize(y, (y.shape[1] * sf, y.shape[0] * sf), interpolation=cv2.INTER_CUBIC) 21 | 22 | srf = HSI2RGB().srf.T # 31*3 23 | 24 | T = mul_color(srf) 25 | z = to_ndarray(T(to_torch_tensor(I, batch=True)), debatch=True) 26 | 27 | print(I.shape, y.shape, x0.shape, z.shape) 28 | imshow(I[:, :, 20], y[:, :, 20], x0[:, :, 20], z) 29 | 30 | # %% 31 | # ----------------------------------- # 32 | # Define and Solve # 33 | # ----------------------------------- # 34 | 35 | 36 | x = Variable() 37 | data_term = sisr(x, down.kernel, sf, y) 38 | rgb_fidelity = misr(x, z, srf) 39 | reg = deep_prior(x, denoiser='grunet') 40 | problem = Problem(rgb_fidelity + data_term + reg) 41 | 42 | rhos, sigmas = log_descent(35, 10, 1) 43 | x_pred = problem.solve(solver='admm', 44 | x0=x0, 45 | rhos=rhos, 46 | weights={reg: sigmas}, 47 | max_iter=1, 48 | pbar=True) 49 | 50 | out = to_ndarray(x_pred, debatch=True) 51 | print(mpsnr(out, I)) # 59.11, 59.74 52 | imshow(out[:, :, 20]) 53 | # %% 54 | -------------------------------------------------------------------------------- /examples/papers/dphsir_neurcomputing_2022/hsi_sisr.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import cv2 3 | import torch 4 | from scipy.io import loadmat 5 | 6 | from dprox import * 7 | from dprox.utils import * 8 | 9 | from degrades import GaussianDownsample 10 | 11 | # ------------------------------------- # 12 | # Prepare Data # 13 | # ------------------------------------- # 14 | 15 | I = loadmat('Lehavim_0910-1717.mat')['gt'].astype('float32') 16 | 17 | sf = 2 18 | down = GaussianDownsample(sf=sf) 19 | y = down(I) 20 | 21 | x0 = cv2.resize(y, (y.shape[1]*sf, y.shape[0]*sf), interpolation=cv2.INTER_CUBIC) 22 | 23 | print(I.shape, y.shape, x0.shape) 24 | imshow(I[:, :, 20], y[:, :, 20], x0[:, :, 20]) 25 | 26 | 27 | #%% 28 | # ----------------------------------- # 29 | # Define and Solve # 30 | # ----------------------------------- # 31 | 32 | x = Variable() 33 | data_term = sisr(x, down.kernel, sf, y) 34 | deep_reg = deep_prior(x, denoiser='grunet') 35 | 36 | solver = ADMM([data_term, deep_reg], partition=True) 37 | solver.to(torch.device('cuda')) 38 | 39 | rhos, sigmas = log_descent(35, 10, 24) 40 | x_pred = solver.solve(x0, 41 | rhos=rhos, 42 | weights={deep_reg: sigmas}, 43 | max_iter=24, 44 | pbar=True) 45 | 46 | out = to_ndarray(x_pred, debatch=True) 47 | print(mpsnr(out, I)) # 47.9153/47.5494 48 | imshow(out[:, :, 20]) 49 | # %% 50 | -------------------------------------------------------------------------------- /examples/papers/dpir_tpami_2020/rgb_demosaic.py: -------------------------------------------------------------------------------- 1 | # Demosaic Ref 2 | # - https://sporco.readthedocs.io/en/latest/examples/ppp/ppp_admm_dmsc.html 3 | # - https://github.com/cszn/DPIR/blob/master/main_dpir_demosaick.py 4 | 5 | import cv2 6 | import scipy.misc 7 | 8 | from dprox import * 9 | from dprox.utils import * 10 | 11 | from utils import dm_matlab, mosaic_CFA_Bayer, tensor2uint, uint2tensor4 12 | 13 | # Prepare GT and Input 14 | # I = imread_rgb('/media/exthdd/laizeqiang/lzq/projects/hyper-pnp/OpenProx/examples/experiments/rgb/data/Kodak24/kodim02.png') 15 | I = scipy.misc.face() 16 | CFA, CFA4, b, mask = mosaic_CFA_Bayer(I) 17 | 18 | x0 = dm_matlab(uint2tensor4(CFA4)) 19 | x0 = tensor2uint(x0) 20 | x0 = cv2.cvtColor(CFA, cv2.COLOR_BAYER_BG2RGB_EA) # essential for drunet, wo 14, w 41.72 21 | # x0 = b 22 | 23 | imshow(b) 24 | I = I.astype('float32') / 255 25 | x0 = x0.astype('float32') / 255 26 | b = b.astype('float32') / 255 27 | 28 | print(b.shape, mask.shape) 29 | 30 | # Define and Solve the Problem 31 | x = Variable() 32 | 33 | data_term = sum_squares(mosaic(x), b) 34 | deep_reg = deep_prior(x, denoiser='ffdnet_color', x8=True) 35 | problem = Problem(data_term + deep_reg) 36 | 37 | 38 | hi = 32 39 | low = 2 40 | iter_num = 40 41 | rhos, sigmas = log_descent(hi, low, iter_num) 42 | x_pred = problem.solve(solver='hqs', 43 | x0=x0, 44 | weights={deep_reg: sigmas}, 45 | rhos=rhos, 46 | max_iter=iter_num, 47 | pbar=False) 48 | out = to_ndarray(x_pred, debatch=True).clip(0, 1) 49 | out[mask] = b[mask] 50 | psnr_ = psnr(out, I) 51 | 52 | 53 | # best: 32 2 44.766043261756685 54 | print(psnr_) 55 | imshow(out) 56 | -------------------------------------------------------------------------------- /notebooks/README.md: -------------------------------------------------------------------------------- 1 | # Welcome to the $\nabla$-Prox tutorials 2 | 3 | Get familiar with the concepts and modules of $\nabla$-Prox. We provide several step-by-step tutorials for the framework and its applications. 4 | 5 | To locally run the notebooks, please follow the [installation tutorial](https://github.com/princeton-computational-imaging/Delta-Prox/tree/main#installation) to install the necessary dependencies. You can also run the tutorial notebooks in Colab. 6 | 7 | ### Framework 8 | 9 | | Tutorial | Colab | 10 | | -- | -- | 11 | | [Learning the Basics](learn_the_basic.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/princeton-computational-imaging/Delta-Prox/blob/main/notebooks/learn_the_basic.ipynb)| 12 | | [Differentiable Linear Solver](differentiable_linear_solver.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/princeton-computational-imaging/Delta-Prox/blob/main/notebooks/differentiable_linear_solver.ipynb)| 13 | | [Linear Operator](linear_operator.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/princeton-computational-imaging/Delta-Prox/blob/main/notebooks/linear_operator.ipynb)| 14 | | [Primitive](primitive.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/princeton-computational-imaging/Delta-Prox/blob/main/notebooks/primitive.ipynb)| 15 | 16 | 17 | 18 | ### Applications 19 | 20 | | Tutorial | Colab | 21 | | -- | -- | 22 | | [Compressive-MRI](csmri.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/princeton-computational-imaging/Delta-Prox/blob/main/notebooks/csmri.ipynb) | 23 | | [Computational Optics ](computational_optics.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/princeton-computational-imaging/Delta-Prox/blob/main/notebooks/computational_optics.ipynb) | 24 | | [Image Deraining](deraining.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/princeton-computational-imaging/Delta-Prox/blob/main/notebooks/deraining.ipynb) | 25 | | [Energy System Planning](energy_system_planning.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/princeton-computational-imaging/Delta-Prox/blob/main/notebooks/energy_system_planning.ipynb) | 26 | | [Image Restoration](image_restoration.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/princeton-computational-imaging/Delta-Prox/blob/main/notebooks/image_restoration.ipynb) | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | cvxpy 3 | torchlights 4 | termcolor 5 | tfpnp 6 | matplotlib 7 | tensorboardX -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | deps = [ 4 | 'imageio', 5 | 'scikit_image', 6 | 'matplotlib', 7 | 'munch', 8 | 'tfpnp', 9 | 'cvxpy', 10 | 'torchlights', 11 | 'tensorboardX', 12 | 'termcolor', 13 | 'proximal', 14 | 'opencv-python', 15 | 'huggingface_hub', 16 | ] 17 | 18 | setup( 19 | name='dprox', 20 | description='A domain-specific language and compiler that transforms optimization problems into differentiable proximal solvers.', 21 | url='https://github.com/princeton-computational-imaging/Delta-Prox', 22 | author='Zeqiang Lai', 23 | author_email='laizeqiang@outlook.com', 24 | packages=find_packages(), 25 | version='0.2.1', 26 | include_package_data=True, 27 | install_requires=deps, 28 | ) 29 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # Unit Tests 2 | 3 | ```bash 4 | pip install pytest 5 | ``` 6 | 7 | Use `-s` to enable printing to screen during the tests. 8 | 9 | ```bash 10 | # test all 11 | pytest -s tests 12 | 13 | # test one file 14 | pytest -s tests/test_algorithms.py 15 | 16 | # test one function 17 | pytest -s tests/test_algorithms.py::test_admm 18 | ``` -------------------------------------------------------------------------------- /tests/linalg/test_linear_solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import dprox as dp 3 | import numpy as np 4 | import dprox.utils 5 | 6 | dprox.utils.misc.seed_everything(2023) 7 | 8 | P = np.random.rand(5,5) 9 | I = np.eye(5) 10 | mu = 0.01 11 | A = P.T @ P + mu * I 12 | x = np.random.rand(5) 13 | offset = A @ x 14 | rtol = 1e-8 15 | 16 | 17 | class MatrixLinOp(dp.LinOp): 18 | def __init__(self, A): 19 | super().__init__() 20 | self.A = torch.nn.parameter.Parameter(A) 21 | 22 | def forward(self, x): 23 | return self.A @ x 24 | 25 | def adjoint(self, x): 26 | return self.A.T @ x 27 | 28 | 29 | def test_gmres_scipy(): 30 | from scipy.sparse.linalg import gmres, LinearOperator 31 | A_op = LinearOperator(shape=(5,5), matvec=lambda b: A@b) 32 | 33 | xhat, _ = gmres(A_op, offset) 34 | 35 | print('gmres scipy') 36 | print(x) 37 | print(xhat) 38 | 39 | print(np.mean(np.abs(xhat-x))) 40 | assert np.allclose(xhat, x, rtol=rtol) 41 | 42 | 43 | def test_cg_scipy(): 44 | from scipy.sparse.linalg import cg, LinearOperator 45 | A_op = LinearOperator(shape=(5,5), matvec=lambda b: A@b) 46 | 47 | xhat, _ = cg(A_op, offset) 48 | 49 | print('cg scipy') 50 | print(x) 51 | print(xhat) 52 | 53 | print(np.mean(np.abs(xhat-x))) 54 | assert np.allclose(xhat, x, rtol=rtol) 55 | 56 | 57 | def test_cg(): 58 | A2 = torch.from_numpy(A) 59 | K = lambda x: A2@x 60 | x2 = torch.from_numpy(x) 61 | b2 = torch.from_numpy(offset) 62 | 63 | xhat1 = dp.linalg.solve.cg(K, b2) 64 | xhat2 = dp.linalg.solve.cg2(K, b2) 65 | xhat3 = dp.linalg.solve.pcg(K, b2) 66 | 67 | print('conjugate_gradient') 68 | print(torch.mean(torch.abs(xhat1-x2)).item()) 69 | print(xhat1.numpy()) 70 | assert torch.allclose(xhat1, x2, rtol=rtol) 71 | 72 | print('conjugate_gradient2') 73 | print(torch.mean(torch.abs(xhat2-x2)).item()) 74 | print(xhat2.numpy()) 75 | # assert torch.allclose(xhat2, x2, rtol=rtol) 76 | 77 | print('PCG') 78 | print(torch.mean(torch.abs(xhat3-x2)).item()) 79 | print(xhat3.numpy()) 80 | assert torch.allclose(xhat3, x2, rtol=rtol) 81 | 82 | 83 | def test_plss(): 84 | A2 = torch.from_numpy(A) 85 | # K = lambda x: A2@x 86 | K = MatrixLinOp(A2) 87 | x2 = torch.from_numpy(x) 88 | b2 = torch.from_numpy(offset) 89 | 90 | xhat1 = dp.linalg.solve.plss(K, b2) 91 | 92 | print('PLSS') 93 | print(torch.mean(torch.abs(xhat1-x2)).item()) 94 | print(xhat1.detach().numpy()) 95 | assert torch.allclose(xhat1, x2, rtol=rtol) 96 | 97 | 98 | def test_minres(): 99 | A2 = torch.from_numpy(A) 100 | # K = lambda x: A2@x 101 | K = MatrixLinOp(A2) 102 | x2 = torch.from_numpy(x) 103 | b2 = torch.from_numpy(offset) 104 | 105 | with torch.no_grad(): 106 | xhat1 = dp.linalg.solve.minres(K, b2) 107 | 108 | print('MINRES') 109 | print(torch.mean(torch.abs(xhat1-x2)).item()) 110 | print(xhat1.detach().numpy()) 111 | assert torch.allclose(xhat1, x2, rtol=rtol) 112 | -------------------------------------------------------------------------------- /tests/linalg/test_linear_solver_batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import dprox as dp 3 | import numpy as np 4 | 5 | 6 | P = np.random.rand(5, 5) 7 | I = np.eye(5) 8 | mu = 0.01 9 | A = P.T @ P + mu * I 10 | x = np.random.rand(5) 11 | offset = A @ x 12 | rtol = 1e-8 13 | 14 | 15 | class MatrixLinOp(dp.LinOp): 16 | def __init__(self, A): 17 | super().__init__() 18 | self.A = torch.nn.parameter.Parameter(A) 19 | 20 | def forward(self, x): 21 | return self.A @ x 22 | 23 | def adjoint(self, x): 24 | return self.A.T @ x 25 | 26 | 27 | def test_cg(): 28 | A2 = torch.from_numpy(A) 29 | def K(x): return A2 @ x 30 | x2 = torch.from_numpy(x) 31 | b2 = torch.from_numpy(offset) 32 | 33 | xhat1 = dp.linalg.solve.cg(K, b2) 34 | xhat3 = dp.linalg.solve.pcg(K, b2) 35 | 36 | print('conjugate_gradient') 37 | print(torch.mean(torch.abs(xhat1 - x2)).item()) 38 | print(xhat1.numpy()) 39 | assert torch.allclose(xhat1, x2, rtol=rtol) 40 | 41 | # print('conjugate_gradient2') 42 | # print(torch.mean(torch.abs(xhat2-x2)).item()) 43 | # print(xhat2.numpy()) 44 | # assert torch.allclose(xhat2, x2, rtol=rtol) 45 | 46 | print('PCG') 47 | print(torch.mean(torch.abs(xhat3 - x2)).item()) 48 | print(xhat3.numpy()) 49 | assert torch.allclose(xhat3, x2, rtol=rtol) 50 | -------------------------------------------------------------------------------- /tests/linalg/test_linear_solver_grad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import partial 3 | 4 | 5 | atol, rtol = 1e-6, 1e-3 6 | allclose = partial(torch.allclose, atol=atol, rtol=rtol) 7 | 8 | 9 | def tol(input, other): 10 | tmp = torch.abs(input-other) - atol - rtol * torch.abs(other) 11 | return torch.max(tmp).item() 12 | 13 | 14 | def auto_diff(seed): 15 | torch.manual_seed(seed) 16 | theta = torch.randn((32,32), requires_grad=True) 17 | K = theta * 2 18 | x = torch.randn((32)) 19 | b = K @ x 20 | b = b.clone().detach().requires_grad_(True) 21 | 22 | xhat = torch.linalg.solve(K, b) 23 | 24 | loss = xhat.mean() 25 | loss.backward() 26 | 27 | return theta.grad, b.grad 28 | 29 | 30 | def differentiate_dloss_db(seed): 31 | torch.manual_seed(seed) 32 | theta = torch.randn((32,32), requires_grad=True) 33 | K = theta * 2 34 | x = torch.randn((32)) 35 | b = K @ x 36 | b = b.clone().detach().requires_grad_(True) 37 | 38 | xhat = torch.linalg.solve(K, b) 39 | xhat.retain_grad() 40 | 41 | loss = xhat.mean() 42 | loss.backward() 43 | 44 | db_matrix = torch.inverse(K.T) @ xhat.grad 45 | db_matrix_free = torch.linalg.solve(K.T, xhat.grad) 46 | 47 | return db_matrix, db_matrix_free 48 | 49 | 50 | def matrix_differentiate_dloss_dtheta(seed): 51 | torch.manual_seed(seed) 52 | theta = torch.randn((32,32), requires_grad=True) 53 | K = theta * 2 54 | x = torch.randn((32)) 55 | b = K @ x 56 | b = b.clone().detach().requires_grad_(True) 57 | 58 | xhat = torch.linalg.solve(K, b) 59 | xhat.retain_grad() 60 | 61 | loss = xhat.mean() 62 | loss.backward() 63 | 64 | 65 | def Kmat(theta): 66 | return theta * 2 67 | 68 | dK_dtheta = torch.autograd.functional.jacobian(Kmat, theta) 69 | dxhat_dtheta = -torch.inverse(K) @ dK_dtheta @ xhat 70 | 71 | dloss_dtheta = dxhat_dtheta @ xhat.grad 72 | 73 | return dloss_dtheta 74 | 75 | 76 | def matrix_free_differentiate_dloss_dtheta(seed): 77 | torch.manual_seed(seed) 78 | theta = torch.randn((32,32), requires_grad=True) 79 | K = theta * 2 80 | x = torch.randn((32)) 81 | b = K @ x 82 | b = b.clone().detach().requires_grad_(True) 83 | 84 | xhat = torch.linalg.solve(K, b) 85 | xhat.retain_grad() 86 | 87 | loss = xhat.mean() 88 | loss.backward() 89 | 90 | def linop(theta): 91 | return theta*2 @ xhat 92 | 93 | db_dtheta = -torch.autograd.functional.jacobian(linop, theta).permute(1,2,0).unsqueeze(-1) 94 | dxhat_dtheta = torch.linalg.solve(K, db_dtheta).squeeze(-1) 95 | 96 | dloss_dtheta = dxhat_dtheta @ xhat.grad 97 | 98 | return dloss_dtheta 99 | 100 | 101 | def test_db(): 102 | for seed in range(20): 103 | _, db_ref = auto_diff(seed) 104 | db_matrix, db_matrix_free = differentiate_dloss_db(seed) 105 | print('b', seed, torch.mean(torch.abs(db_ref- db_matrix))) 106 | print('b free', seed, torch.mean(torch.abs(db_ref- db_matrix_free))) 107 | assert allclose(db_matrix, db_ref) 108 | assert allclose(db_matrix_free, db_ref) 109 | 110 | 111 | def test_theta(): 112 | for seed in range(20): 113 | dtheta_ref, _ = auto_diff(seed) 114 | dtheta_matrix = matrix_differentiate_dloss_dtheta(seed) 115 | dtheta_matrix_free = matrix_free_differentiate_dloss_dtheta(seed) 116 | 117 | print('theta', seed, tol(dtheta_matrix, dtheta_ref)) 118 | print('theta free', seed, tol(dtheta_matrix_free, dtheta_ref), 119 | dtheta_ref.abs().max().item(), 120 | (dtheta_matrix_free-dtheta_ref).abs().max().item()) 121 | 122 | assert allclose(dtheta_matrix, dtheta_ref) 123 | assert allclose(dtheta_matrix_free, dtheta_ref) 124 | 125 | -------------------------------------------------------------------------------- /tests/linalg/test_pcg.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Delta-Prox/15fc489e99c4de2676dc538d9c0552338f6fb894/tests/linalg/test_pcg.py -------------------------------------------------------------------------------- /tests/paper/test_derain.py: -------------------------------------------------------------------------------- 1 | # def run_derain(): 2 | # pass 3 | 4 | 5 | # def test_Rain100H(): 6 | # psnr, ssim = run_derain() 7 | # assert abs(psnr - 31.08) < 0.01 8 | # assert abs(ssim - 0.897) < 0.001 9 | 10 | 11 | # def test_Rain100H_with_init(): 12 | # psnr, ssim = run_derain() 13 | # assert abs(psnr - 31.62) < 0.01 14 | # assert abs(ssim - 0.905) < 0.001 15 | 16 | 17 | # def test_Test1200(): 18 | # psnr, ssim = run_derain() 19 | # assert abs(psnr - 32.95) < 0.01 20 | # assert abs(ssim - 0.913) < 0.001 21 | 22 | 23 | # def test_Test1200_with_init(): 24 | # psnr, ssim = run_derain() 25 | # assert abs(psnr - 33.25) < 0.01 26 | # assert abs(ssim - 0.926) < 0.001 27 | -------------------------------------------------------------------------------- /tests/paper/test_energy.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Delta-Prox/15fc489e99c4de2676dc538d9c0552338f6fb894/tests/paper/test_energy.py -------------------------------------------------------------------------------- /tests/paper/test_optics.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchlight as tl 6 | import torchlight.nn as tlnn 7 | 8 | from dprox import * 9 | from dprox import Variable 10 | from dprox.linop.conv import conv_doe 11 | from dprox.utils import * 12 | from dprox.contrib.optic.utils import load_sample_img 13 | from dprox.contrib.optic.doe_model import img_psf_conv, build_doe_model, DOEModelConfig 14 | 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | 17 | 18 | @torch.no_grad() 19 | def run_optics(dataset): 20 | config = DOEModelConfig() 21 | rgb_collim_model = build_doe_model().to(device) 22 | circular = config.circular 23 | 24 | # -------------------- Define Model --------------------- # 25 | x = Variable() 26 | y = Placeholder() 27 | PSF = Placeholder() 28 | data_term = sum_squares(conv_doe(x, PSF, circular=circular), y) 29 | reg_term = deep_prior(x, denoiser='ffdnet_color') 30 | solver = compile(data_term + reg_term, method='admm') 31 | solver.eval() 32 | 33 | # ---------------- Setup Hyperparameter ----------------- # 34 | max_iter = 10 35 | sigma = 7.65 / 255. 36 | rhos, sigmas = log_descent(49, 7.65, max_iter, sigma=max(0.255 / 255, sigma)) 37 | rhos = torch.tensor(rhos, device=device).float() 38 | sigmas = torch.tensor(sigmas, device=device).float() 39 | rgb_collim_model.rhos = nn.parameter.Parameter(rhos) 40 | rgb_collim_model.sigmas = nn.parameter.Parameter(sigmas) 41 | 42 | # ---------------- Forward Model ------------------------ # 43 | def step_fn(gt): 44 | gt = gt.to(device).float() 45 | psf = rgb_collim_model.get_psf() 46 | inp = img_psf_conv(gt, psf, circular=circular) 47 | inp = inp + torch.randn(*inp.shape, device=inp.device) * sigma 48 | y.value = inp 49 | PSF.value = psf 50 | 51 | timer = tl.utils.Timer() 52 | timer.tic() 53 | 54 | out = solver.solve(x0=inp, 55 | rhos=rgb_collim_model.rhos, 56 | lams={reg_term: rgb_collim_model.sigmas.sqrt()}, 57 | max_iter=max_iter) 58 | 59 | return gt, inp, out 60 | 61 | print('Model size (M)', tlnn.benchmark.model_size(rgb_collim_model)) 62 | print('Trainable model size (M)', tlnn.benchmark.trainable_model_size(rgb_collim_model)) 63 | 64 | # --------------- Load and Run ------------------ # 65 | ckpt = hf.load_checkpoint('computational_optics/joint_dprox.pth') 66 | rgb_collim_model.load_state_dict(ckpt['model']) 67 | 68 | tl.metrics.set_data_format('chw') 69 | 70 | root = dataset 71 | root = hf.download_dataset(root) 72 | tracker = tl.trainer.util.MetricTracker() 73 | 74 | timer = tl.utils.Timer() 75 | timer.tic() 76 | for idx, name in enumerate(list_image_files(root)): 77 | gt = load_sample_img(os.path.join(root, name)) 78 | 79 | torch.manual_seed(idx) 80 | torch.cuda.manual_seed(idx) 81 | gt, inp, pred = step_fn(gt) 82 | pred = pred.clamp(0, 1) 83 | 84 | psnr = tl.metrics.psnr(pred, gt) 85 | ssim = tl.metrics.ssim(pred, gt) 86 | 87 | tracker.update('psnr', psnr) 88 | tracker.update('ssim', ssim) 89 | 90 | print('{} PSNR {} SSIM {}'.format(name, psnr, ssim)) 91 | 92 | print('averge results') 93 | print(tracker.summary()) 94 | print(timer.toc() / len(list_image_files(root))) 95 | return tracker['psnr'], tracker['ssim'] 96 | 97 | 98 | def test_ubran100(): 99 | psnr, ssim = run_optics('Urban100') 100 | assert abs(psnr - 30.83) < 0.01 101 | assert abs(ssim - 0.944) < 0.001 102 | 103 | 104 | def test_cbsd68(): 105 | psnr, ssim = run_optics('CBSD68') 106 | assert abs(psnr - 32.01) < 0.01 107 | assert abs(ssim - 0.942) < 0.001 108 | -------------------------------------------------------------------------------- /tests/problem/test_deraining.py: -------------------------------------------------------------------------------- 1 | from dprox import * 2 | from dprox.utils import * 3 | from dprox.contrib.derain import LearnableDegOp 4 | 5 | def test(): 6 | b = hf.load_image('examples/derain/derain_input.png') 7 | gt = hf.load_image('examples/derain/derain_target.png') 8 | imshow(gt, b) 9 | 10 | # custom linop 11 | A = LearnableDegOp().cuda() 12 | def forward_fn(input, step): return A.forward(input, step) 13 | def adjoint_fn(input, step): return A.adjoint(input, step) 14 | raining = LinOpFactory(forward_fn, adjoint_fn) 15 | 16 | # build solver 17 | x = Variable() 18 | data_term = sum_squares(raining(x), b) 19 | reg_term = unrolled_prior(x) 20 | obj = data_term + reg_term 21 | solver = compile(obj, method='pgd') 22 | 23 | # load parameters 24 | ckpt = hf.load_checkpoint('image_deraining/derain_pdg.pth') 25 | A.load_state_dict(ckpt['linop']) 26 | reg_term.load_state_dict(ckpt['prior']) 27 | rhos = ckpt['rhos'] 28 | 29 | out = solver.solve(x0=b, rhos=rhos, max_iter=7) 30 | out = to_ndarray(out, debatch=True) + b 31 | print(psnr(out, gt)) # 35.92 32 | imshow(gt, out) 33 | assert psnr(out, gt) - 35.92 < 0.1 34 | -------------------------------------------------------------------------------- /tests/problem/test_energy_system.py: -------------------------------------------------------------------------------- 1 | import scipy.io 2 | 3 | import dprox as dp 4 | from dprox.utils.huggingface import load_path 5 | from dprox.contrib.energy_system import load_simple_cep_model 6 | 7 | 8 | def test(): 9 | c, A_ub, A_eq, b_ub, b_eq = load_simple_cep_model() 10 | 11 | x = dp.Variable() 12 | prob = dp.Problem(c @ x, [A_ub @ x <= b_ub, A_eq @ x == b_eq]) 13 | out = prob.solve(method='admm', adapt_params=True) 14 | print(out) 15 | 16 | # reference solution 17 | solution = scipy.io.loadmat(load_path("energy_system/simple_cep_model_20220916/esm_instance_solution.mat")) 18 | 19 | # pcg with custom backward 20 | # 83434.40576028604 170.29735095665677 0.0041508882261022985 200.8266268913533 0.569342163225932 29.714674732313096 21 | # tensor(-170.2974, device='cuda:0', dtype=torch.float64) 22 | 23 | # raw pcg 24 | # 83433.69636112433 174.7402594703876 0.004143616818155279 200.82662689427482 0.5693421632258844 29.691228047449876 25 | # tensor(-174.7403, device='cuda:0', dtype=torch.float64) 26 | 27 | # Obj: 8.35e+04, res_z: 0.00e+00, res_primal: 1.95e+02, reƒs_dual: 9.19e-04, eps_primal: 2.00e+02, eps_dual: 1.00e-03, rho: 1.45e+01 28 | # tensor(-99.6752, device='cuda:0', dtype=torch.float64) 29 | -------------------------------------------------------------------------------- /tests/problem/test_inverse_problems.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from dprox import * 4 | from dprox.utils import * 5 | from dprox import contrib 6 | 7 | 8 | def test_csmri(): 9 | x0, y0, gt, mask = contrib.csmri.sample() 10 | 11 | x = Variable() 12 | y = Placeholder() 13 | data_term = csmri(x, mask, y) 14 | reg_term = deep_prior(x, denoiser='unet') 15 | prob = Problem(data_term + reg_term) 16 | 17 | y.value = y0 18 | max_iter = 24 19 | rhos, sigmas = log_descent(30, 20, max_iter) 20 | prob.solve( 21 | method='admm', 22 | device='cuda', 23 | x0=x0, rhos=rhos, lams={reg_term: sigmas}, max_iter=max_iter, pbar=True 24 | ) 25 | out = x.value.real 26 | 27 | print(psnr(out, gt)) 28 | assert abs(psnr(out, gt) - 43.1) < 0.1 29 | 30 | 31 | def test_deconv(): 32 | img = contrib.sample('face') 33 | psf = contrib.point_spread_function(15, 5) 34 | b = contrib.blurring(img, psf) 35 | 36 | x = Variable() 37 | data_term = sum_squares(conv(x, psf) - b) 38 | reg_term = deep_prior(x, denoiser='ffdnet_color') 39 | reg2 = nonneg(x) 40 | prob = Problem(data_term + reg_term + reg2) 41 | 42 | out = prob.solve(method='admm', x0=b, pbar=True) 43 | 44 | print(psnr(out, img)) 45 | assert abs(psnr(out, img) - 34.5) < 0.1 46 | 47 | 48 | def test_deconv2(): 49 | img = contrib.sample('face') 50 | psf = contrib.point_spread_function(ksize=15, sigma=5) 51 | # TODO: this still has bug 52 | y = contrib.blurring(img, psf) + np.random.randn(*img.shape).astype('float32') * 5 / 255.0 53 | y.squeeze(0) 54 | print(img.shape, y.shape) 55 | 56 | x = Variable() 57 | data_term = sum_squares(conv(x, psf) - y) 58 | prior_term = deep_prior(x, 'ffdnet_color') 59 | reg_term = nonneg(x) 60 | objective = data_term + prior_term + reg_term 61 | p = Problem(objective) 62 | out = p.solve(method='admm', x0=y, pbar=True) 63 | print(psnr(out, img)) 64 | -------------------------------------------------------------------------------- /tests/problem/test_jd23.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from dprox import * 4 | from dprox.utils import * 5 | from dprox.linalg import LinearSolveConfig 6 | from dprox.contrib import * 7 | 8 | 9 | def test_jd2(): 10 | img = sample('face') 11 | img = to_torch_tensor(img, batch=True).float() 12 | psf = point_spread_function(15, 5) 13 | b = blurring(img, psf) 14 | b = mosaicing(b) 15 | 16 | x = Variable() 17 | data_term = sum_squares(mosaic(conv(x, psf)) - b) 18 | reg_term = deep_prior(x, denoiser='ffdnet_color') 19 | prob = Problem(data_term + reg_term, linear_solve_config=LinearSolveConfig(max_iters=50)) 20 | 21 | max_iter = 5 22 | rhos, sigmas = log_descent(35, 30, max_iter) 23 | out = prob.solve(method='admm', x0=b, rhos=rhos, lams={reg_term: sigmas}, max_iter=max_iter, pbar=True) 24 | 25 | print(psnr(out, img)) # 25.25 26 | 27 | assert abs(psnr(out, img) - 25.10) < 0.1 28 | 29 | 30 | def load(name='face', return_tensor=True): 31 | import imageio 32 | s = imageio.imread(name).copy().astype('float32') / 255 33 | s = s[:768,:,:] 34 | if return_tensor: 35 | s = to_torch_tensor(s, batch=True).float() 36 | return s 37 | 38 | 39 | def test_jd2_batched(): 40 | psf = point_spread_function(15, 5) 41 | 42 | # img = load('tests/face2.png') 43 | img = sample('face') 44 | img1 = to_torch_tensor(img, batch=True).float() 45 | b = blurring(img1, psf) 46 | b1 = mosaicing(b) 47 | 48 | img = sample('face') 49 | img2 = to_torch_tensor(img, batch=True).float() 50 | b = blurring(img2, psf) 51 | b2 = mosaicing(b) 52 | b = torch.cat([b1, b2], dim=0) 53 | print(b.shape) 54 | 55 | x = Variable() 56 | data_term = sum_squares(mosaic(conv(x, psf)) - b) 57 | reg_term = deep_prior(x, denoiser='ffdnet_color') 58 | prob = Problem(data_term + reg_term, linear_solve_config=LinearSolveConfig(max_iters=50)) 59 | 60 | max_iter = 5 61 | rhos, sigmas = log_descent(35, 30, max_iter) 62 | out = prob.solve(method='admm', x0=b, rhos=rhos, lams={reg_term: sigmas}, max_iter=max_iter, pbar=True) 63 | 64 | print(psnr(out[0:1], img1)) # 29.689 65 | print(psnr(out[1].unsqueeze(0), img2)) # 29.689 66 | 67 | # assert abs(psnr(out[0:1], img1) - 29.92) < 0.1 68 | assert abs(psnr(out[0:1], img1) - 25.10) < 0.1 69 | assert abs(psnr(out[1].unsqueeze(0), img2) - 25.10) < 0.1 -------------------------------------------------------------------------------- /tests/problem/test_ml_problems.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import dprox as dp 3 | 4 | 5 | def test_lsq(): 6 | x = dp.Variable((3,3)) 7 | rhs = np.array([[1, 2, 3],[4,5,6],[7,8,9]]) 8 | prob = dp.Problem(dp.sum_squares(2*x - rhs)) 9 | prob.solve('admm', x0=np.zeros((3,3))) 10 | print(x.value) 11 | 12 | assert (x.value.cpu().numpy() == rhs / 2).all() 13 | 14 | 15 | def test_lsq1(): 16 | x = dp.Variable((3,3)) 17 | rhs = np.array([[1, 2, 3],[4,5,6],[7,8,9]]) 18 | prob = dp.Problem(dp.sum_squares(2*x, rhs)) 19 | prob.solve('admm', x0=np.zeros((3,3))) 20 | print(x.value) 21 | 22 | assert (x.value.cpu().numpy() == rhs / 2).all() 23 | 24 | 25 | def test_lsq2(): 26 | x = dp.Variable((3,3,1)) 27 | rhs = np.array([[[1, 2, 3],[4,5,6],[7,8,9]]]) 28 | kernel = np.array([[1,1],[1,1]]) / 4 29 | prob = dp.Problem(dp.sum_squares(dp.conv(x, kernel) - rhs)) 30 | prob.solve('admm', x0=np.zeros((3,3,1))) 31 | out = dp.eval(dp.conv(x, kernel)-rhs, x.value, zero_out_constant=False) 32 | print(x.value) 33 | print(out) 34 | assert (out < 1e-5).all() 35 | 36 | 37 | def test_lsq3(): 38 | x = dp.Variable((3)) 39 | rhs = np.array([1, 2, 3]) 40 | prob = dp.Problem(dp.sum_squares(2*x - rhs)) 41 | prob.solve('admm', x0=np.zeros(3)) 42 | print(x.value) 43 | 44 | assert (x.value.cpu().numpy() == rhs / 2).all() 45 | -------------------------------------------------------------------------------- /tests/test_algorithms.py: -------------------------------------------------------------------------------- 1 | from dprox import * 2 | from dprox.utils import * 3 | from dprox import contrib 4 | 5 | 6 | def test_admm(): 7 | img = contrib.sample('face') 8 | psf = contrib.point_spread_function(15, 5) 9 | b = contrib.blurring(img, psf) 10 | 11 | x = Variable() 12 | data_term = sum_squares(conv(x, psf) - b) 13 | reg_term = deep_prior(x, denoiser='ffdnet_color') 14 | reg2 = nonneg(x) 15 | prob = Problem(data_term + reg_term + reg2) 16 | 17 | out = prob.solve(method='admm', x0=b, pbar=True) 18 | 19 | print(psnr(out, img)) 20 | assert abs(psnr(out, img) - 34.51) < 0.1 21 | 22 | 23 | def test_ladmm(): 24 | img = contrib.sample('face') 25 | psf = contrib.point_spread_function(15, 5) 26 | b = contrib.blurring(img, psf) 27 | 28 | x = Variable() 29 | data_term = sum_squares(conv(x, psf) - b) 30 | reg_term = deep_prior(x, denoiser='ffdnet_color') 31 | reg2 = nonneg(x) 32 | prob = Problem(data_term + reg_term + reg2) 33 | 34 | out = prob.solve(method='ladmm', x0=b, pbar=True) 35 | 36 | print(psnr(out, img)) 37 | assert abs(psnr(out, img) - 34.51) < 0.1 38 | 39 | 40 | def test_admm_vxu(): 41 | img = contrib.sample('face') 42 | psf = contrib.point_spread_function(15, 5) 43 | b = contrib.blurring(img, psf) 44 | 45 | x = Variable() 46 | data_term = sum_squares(conv(x, psf) - b) 47 | reg_term = deep_prior(x, denoiser='ffdnet_color') 48 | reg2 = nonneg(x) 49 | prob = Problem(data_term + reg_term + reg2) 50 | 51 | out = prob.solve(method='admm_vxu', x0=b, pbar=True) 52 | 53 | print(psnr(out, img)) 54 | assert abs(psnr(out, img) - 34.50) < 0.1 55 | 56 | 57 | def test_hqs(): 58 | img = contrib.sample('face') 59 | psf = contrib.point_spread_function(15, 5) 60 | b = contrib.blurring(img, psf) 61 | 62 | x = Variable() 63 | data_term = sum_squares(conv(x, psf) - b) 64 | reg_term = deep_prior(x, denoiser='ffdnet_color') 65 | reg2 = nonneg(x) 66 | prob = Problem(data_term + reg_term + reg2) 67 | 68 | out = prob.solve(method='hqs', x0=b, pbar=True) 69 | 70 | print(psnr(out, img)) 71 | assert abs(psnr(out, img) - 34.08) < 0.1 72 | 73 | 74 | def test_pc(): 75 | img = contrib.sample('face') 76 | psf = contrib.point_spread_function(15, 5) 77 | b = contrib.blurring(img, psf) 78 | 79 | x = Variable() 80 | data_term = sum_squares(conv(x, psf) - b) 81 | reg_term = deep_prior(x, denoiser='ffdnet_color') 82 | reg2 = nonneg(x) 83 | prob = Problem(data_term + reg_term + reg2) 84 | 85 | out = prob.solve(method='pc', x0=b, pbar=True) 86 | 87 | print(psnr(out, img)) 88 | assert abs(psnr(out, img) - 29.87) < 0.1 89 | 90 | 91 | def test_pgd(): 92 | img = contrib.sample('face') 93 | psf = contrib.point_spread_function(15, 5) 94 | b = contrib.blurring(img, psf) 95 | 96 | x = Variable() 97 | data_term = sum_squares(conv(x, psf) - b) 98 | reg_term = deep_prior(x, denoiser='ffdnet_color') 99 | prob = Problem(data_term + reg_term) 100 | 101 | out = prob.solve(method='pgd', x0=b, pbar=True) 102 | 103 | print(psnr(out, img)) 104 | assert abs(psnr(out, img) - 21.44) < 0.1 105 | -------------------------------------------------------------------------------- /tests/test_grad.py: -------------------------------------------------------------------------------- 1 | from dprox import * 2 | from dprox.contrib.optic import * 3 | from dprox.utils import * 4 | 5 | 6 | def test_deep_prior(): 7 | device = torch.device("cuda") 8 | x = Variable() 9 | reg_term = deep_prior(x, denoiser="ffdnet_color").to(device) 10 | 11 | sigma = torch.nn.Parameter(torch.tensor(0.1).to(device)) 12 | inp = torch.randn((1, 3, 128, 128), device=device) 13 | x.value = inp 14 | y = reg_term.prox(inp, sigma) 15 | 16 | loss = torch.nn.functional.mse_loss(inp, y) 17 | loss.backward() 18 | print(sigma.grad) 19 | -------------------------------------------------------------------------------- /tests/test_linop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import dprox as dp 3 | 4 | from dprox.contrib import fspecial_gaussian 5 | from dprox.utils import to_torch_tensor 6 | 7 | from scipy.misc import face 8 | 9 | 10 | def test_conv(): 11 | x = dp.Variable() 12 | psf = fspecial_gaussian(15, 5) 13 | op = dp.conv(x, psf) 14 | 15 | K = dp.CompGraph(op) 16 | assert K.sanity_check() 17 | 18 | img = to_torch_tensor(face(), batch=True) 19 | out = K.forward(img) 20 | 21 | print(img.shape) 22 | print(out.shape) 23 | assert img.shape == out.shape 24 | 25 | 26 | def test_constant(): 27 | x = dp.Variable() 28 | y = 3 * (x - torch.tensor([2, 2, 2])) 29 | 30 | input = torch.tensor([1, 2, 3]) 31 | print(dp.eval(y, input, zero_out_constant=False)) 32 | 33 | x.value = torch.tensor([1, 2, 3]) 34 | 35 | print(y.variables) 36 | print(y.constants) 37 | print(y.value) 38 | print(y.offset) 39 | assert torch.allclose(y.value, 3 * (torch.tensor([1, 2, 3]) - torch.tensor([2, 2, 2]))) 40 | assert torch.allclose(y.offset, - 3 * torch.tensor([2, 2, 2])) 41 | 42 | 43 | def test_grad(): 44 | x = dp.Variable() 45 | K = dp.CompGraph(dp.grad(x, dim=1) + dp.grad(x, dim=2)) 46 | assert K.sanity_check() 47 | img = to_torch_tensor(face(), batch=True) 48 | print(img.shape) 49 | outputs = K.forward(img) 50 | 51 | 52 | def test_mosaic(): 53 | x = dp.Variable() 54 | K = dp.CompGraph(dp.mosaic(x)) 55 | assert K.sanity_check() 56 | img = to_torch_tensor(face(), batch=True) 57 | print(img.shape) 58 | outputs = K.forward(img) 59 | 60 | 61 | def test_sum(): 62 | x1 = dp.Variable() 63 | x2 = dp.Variable() 64 | 65 | K = dp.CompGraph(x1 + x2) 66 | 67 | v1 = torch.randn((4, 4), requires_grad=True) 68 | v2 = torch.randn((4, 4)) 69 | 70 | outputs = K.forward(v1, v2) 71 | 72 | print(outputs) 73 | assert torch.allclose(outputs, v1 + v2) 74 | 75 | loss = torch.mean(outputs) 76 | loss.backward() 77 | print(v1.grad) 78 | assert torch.allclose(v1.grad, torch.full_like(v1.grad, 1 / 16)) 79 | 80 | 81 | def test_variable(): 82 | x = dp.Variable() 83 | 84 | print(x.forward(torch.tensor([2, 2, 2]))) 85 | print(x.adjoint(torch.tensor([2, 2, 2]))) 86 | 87 | K = dp.CompGraph(x) 88 | 89 | out = K.adjoint(torch.tensor([2, 2, 2])) 90 | print(out) 91 | 92 | 93 | def test_vstack(): 94 | x = dp.Variable() 95 | K = dp.CompGraph(dp.vstack([dp.mosaic(x), dp.grad(x)])) 96 | 97 | img = to_torch_tensor(face(), batch=True) 98 | print(img.shape) 99 | 100 | outputs = K.forward(img) 101 | inputs = K.adjoint(outputs) 102 | print(inputs.shape) 103 | K.sanity_check() 104 | 105 | 106 | def test_complex(): 107 | from dprox import contrib 108 | 109 | img = contrib.sample('face') 110 | psf = contrib.point_spread_function(15, 5) 111 | b = contrib.blurring(img, psf) 112 | 113 | x = dp.Variable() 114 | data_term = dp.sum_squares(dp.conv(x, psf) - b) 115 | reg_term = dp.deep_prior(x, denoiser='ffdnet_color') 116 | reg2 = dp.nonneg(x) 117 | K = dp.CompGraph(dp.vstack([fn.linop for fn in [data_term, reg_term, reg2]])) 118 | K.forward(b) 119 | -------------------------------------------------------------------------------- /tests/test_linop_primitive.py: -------------------------------------------------------------------------------- 1 | import dprox as dp 2 | import torch 3 | 4 | def test_eval(): 5 | op = dp.Variable() 6 | out = dp.eval(op, torch.zeros(64,64)) 7 | print(out) 8 | assert torch.allclose(out, torch.zeros(64,64)) -------------------------------------------------------------------------------- /tests/test_primitive.py: -------------------------------------------------------------------------------- 1 | import dprox as dp 2 | 3 | 4 | # def test(): 5 | # dp.compile() 6 | # dp.specialize() 7 | # dp.optimize() 8 | # dp.visualize() 9 | # dp.validate() --------------------------------------------------------------------------------