├── torch_submod ├── __init__.py └── graph_cuts.py ├── docs ├── source │ ├── torch │ │ ├── __init__.py │ │ └── autograd.py │ ├── refs.bib │ ├── index.rst │ └── conf.py ├── requirements_docs.txt └── Makefile ├── setup.cfg ├── .gitignore ├── src ├── blocks.h └── blocks.cpp ├── .travis.yml ├── readme.md ├── tests ├── test_blocks.py └── test_grad.py ├── license.txt └── setup.py /torch_submod/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/source/torch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | test=pytest 3 | 4 | [bdist_wheel] 5 | universal=1 6 | -------------------------------------------------------------------------------- /docs/requirements_docs.txt: -------------------------------------------------------------------------------- 1 | sphinxcontrib-bibtex 2 | sphinx_rtd_theme 3 | mock 4 | -------------------------------------------------------------------------------- /docs/source/torch/autograd.py: -------------------------------------------------------------------------------- 1 | # Mock when building the doc without pytorch. 2 | class Function(object): 3 | def __init__(self, *args, **kwargs): 4 | pass 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .eggs 2 | tmp 3 | *.egg-info 4 | build 5 | */__pycache__ 6 | torch_submod/blocks*.so 7 | *.swp 8 | .cache 9 | *.pyc 10 | .hypothesis 11 | .idea 12 | .ipynb_checkpoints 13 | -------------------------------------------------------------------------------- /src/blocks.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | 6 | pybind11::array_t blockwise_means(const pybind11::array_t& blocks, 7 | const pybind11::array_t& input); 8 | pybind11::array_t blocks_2d(const pybind11::array_t& matrix); 9 | -------------------------------------------------------------------------------- /docs/source/refs.bib: -------------------------------------------------------------------------------- 1 | @article{niculae2017regularized, 2 | title={A Regularized Framework for Sparse and Structured Neural Attention}, 3 | author={Niculae, Vlad and Blondel, Mathieu}, 4 | journal={arXiv preprint arXiv:1705.07704}, 5 | year={2017} 6 | } 7 | 8 | @inproceedings{djolonga17learning, 9 | author={Djolonga, Josip and Krause, Andreas}, 10 | booktitle={Neural Information Processing Systems (NIPS)}, 11 | title={Differentiable Learning of Submodular Models}, 12 | year={2017} 13 | } 14 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | # The has been adapted from 2 | # https://conda.io/docs/user-guide/tasks/use-conda-with-travis-ci.html 3 | language: python 4 | python: 5 | # We don't actually use the Travis Python, but this keeps it organized. 6 | - "2.7" 7 | - "3.5" 8 | addons: 9 | apt: 10 | packages: 11 | - liblapack-dev 12 | - liblapacke-dev 13 | - libopenblas-dev 14 | - libboost-dev 15 | install: 16 | - sudo apt-get update 17 | # We do this conditionally because it saves us some downloading if the 18 | # version is the same. 19 | - if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then 20 | wget https://repo.continuum.io/miniconda/Miniconda2-latest-Linux-x86_64.sh -O miniconda.sh; 21 | else 22 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; 23 | fi 24 | - bash miniconda.sh -b -p $HOME/miniconda 25 | - export PATH="$HOME/miniconda/bin:$PATH" 26 | - hash -r 27 | - conda config --set always_yes yes --set changeps1 no 28 | - conda update -q conda 29 | # Useful for debugging any issues with conda 30 | - conda info -a 31 | 32 | - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION 33 | - source activate test-environment 34 | - conda install pytorch -c soumith 35 | - conda install numpy 36 | - conda install scipy 37 | - conda install scikit-learn 38 | - pip install pybind11 39 | - python setup.py install 40 | script: 41 | python setup.py test 42 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # `torch-submod` 2 | [![Documentation Status](https://readthedocs.org/projects/torch-submod/badge/?version=latest)](http://torch-submod.readthedocs.io/en/latest/?badge=latest) 3 | [![Build Status](https://travis-ci.org/josipd/torch-submod.svg?branch=master)](https://travis-ci.org/josipd/torch-submod) 4 | 5 | A PyTorch library for differentiable submodular minimization. 6 | 7 | At the moment only one- and two-dimensional graph cut functions have been 8 | implemented, so that this package provides differentiable (with respect to the 9 | input signal *and* the weights) total variation solvers. 10 | 11 | Please refer to the [documentation](https://torch-submod.readthedocs.io) 12 | for more information about the project. 13 | You can also have a look at the following [notebook](notebooks/denoising.ipynb) 14 | that showcases how to learn weights for image denoising. 15 | 16 | ### Installation 17 | 18 | After installing PyTorch, you can install the package with: 19 | 20 | ``` 21 | python setup.py install 22 | ``` 23 | 24 | ### Testing 25 | 26 | To run the tests you simply have to run: 27 | 28 | ``` 29 | python setup.py test 30 | ``` 31 | 32 | ### Bibliography 33 | 34 | * *[DK17]* J. Djolonga and A. Krause. Differentiable learning of submodular models. In Advances in Neural Information Processing Systems (NIPS), 2017. 35 | * *[NB17]* V. Niculae and M. Blondel. A regularized framework for sparse and structured neural attention. arXiv preprint arXiv:1705.07704, 2017. 36 | -------------------------------------------------------------------------------- /tests/test_blocks.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | import numpy as np 3 | from torch_submod.blocks import blockwise_means, blocks_2d 4 | 5 | 6 | def test_2d_blocks(): 7 | matrix = np.asarray([ 8 | [-1, -3, 4, 5, 1, 1, 1], 9 | [-1, -1, 4, 4, 4, 1, 1]], dtype=np.float32) 10 | blocks = np.asarray([ 11 | [0, 1, 2, 3, 4, 4, 4], 12 | [0, 0, 2, 2, 2, 4, 4]], dtype=np.int32) 13 | assert np.all(blocks == blocks_2d(matrix)) 14 | 15 | # Test row and column matrices. 16 | matrix = np.asarray([[ 17 | .5, .1, .1, .1, .5, .5, 3, -3, 4, 5, 6, 6, 6, 7, 8]], dtype=np.float32) 18 | blocks = np.asarray([[ 19 | 0, 1, 1, 1, 2, 2, 3, 4, 5, 6, 7, 7, 7, 8, 9]], dtype=np.int32) 20 | assert np.all(blocks == blocks_2d(matrix)) 21 | # Also with transpose. 22 | assert np.all(blocks.T == blocks_2d(matrix.T)) 23 | 24 | # TODO(josipd): Try with non-2d, check that an exception is thrown. 25 | 26 | 27 | def test_blockwise_means(): 28 | blocks = np.asarray([ 29 | 0, 1, 0, 0, 1, 1, 0, 2], dtype=np.int32) 30 | vector = np.asarray([ 31 | 0, 1, 2, 3, 4, 5, 6, 7], dtype=np.float32) 32 | b0 = np.mean([0, 2, 3, 6]) 33 | b1 = np.mean([1, 4, 5]) 34 | b2 = np.mean([7]) 35 | expected = np.asarray([ 36 | b0, b1, b0, b0, b1, b1, b0, b2]) 37 | assert np.allclose(expected, blockwise_means(blocks, vector)) 38 | assert np.allclose(expected, blockwise_means(blocks + 10, vector)) 39 | -------------------------------------------------------------------------------- /license.txt: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, Josip Djolonga 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | * Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | * Redistributions in binary form must reproduce the above copyright 11 | notice, this list of conditions and the following disclaimer in the 12 | documentation and/or other materials provided with the distribution. 13 | * Neither the name of the ETH Zurich nor the 14 | names of its contributors may be used to endorse or promote products 15 | derived from this software without specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 18 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 19 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 | DISCLAIMED. IN NO EVENT SHALL Josip Djolonga BE LIABLE FOR ANY 21 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 22 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 23 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 24 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | -------------------------------------------------------------------------------- /tests/test_grad.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | import torch 3 | from torch.autograd import Variable 4 | from torch_submod.graph_cuts import ( 5 | TotalVariation2d, TotalVariation2dWeighted, TotalVariation1d) 6 | from torch.autograd.gradcheck import gradcheck 7 | from hypothesis import given, settings 8 | import hypothesis.strategies as st 9 | 10 | 11 | @given(st.integers(10, 100), st.floats(0.1, 10)) 12 | def test_1d(n, w): 13 | x = Variable(torch.randn(n), requires_grad=True) 14 | w = Variable(torch.Tensor([w]), requires_grad=True) 15 | tv_args = {'method': 'condattautstring'} 16 | assert gradcheck(TotalVariation1d(tv_args=tv_args), (x, w), 17 | eps=1e-5, atol=1e-2, rtol=1e-3) 18 | 19 | 20 | @given(st.integers(10, 100), st.floats(0.1, 10)) 21 | def test_1dw(n, w): 22 | x = Variable(10 * torch.randn(n), requires_grad=True) 23 | w = Variable(0.1 + w * torch.rand(n - 1), requires_grad=True) 24 | tv_args = {'method': 'tautstring'} 25 | assert gradcheck(TotalVariation1d(tv_args=tv_args), (x, w), 26 | eps=5e-5, atol=5e-2, rtol=1e-2) 27 | 28 | 29 | @settings(deadline=30000, max_examples=30, timeout=120) 30 | @given(st.integers(5, 20), st.integers(5, 20), st.floats(0.1, 10)) 31 | def test_2d(n, m, w): 32 | x = Variable(torch.randn(n, m), requires_grad=True) 33 | w = Variable(0.1 + torch.Tensor([w]), requires_grad=True) 34 | tv_args = {'method': 'dr', 'max_iters': 1000, 'n_threads': 6} 35 | assert gradcheck(TotalVariation2d(tv_args=tv_args), (x, w), 36 | eps=1e-5, atol=1e-2, rtol=1e-3) 37 | 38 | 39 | @settings(deadline=30000, max_examples=30, timeout=120) 40 | @given(st.integers(5, 10), st.integers(5, 10), st.floats(0.1, 10)) 41 | def test_2dw(n, m, w): 42 | x = Variable(torch.randn(n, m), requires_grad=True) 43 | w_r = Variable(0.1 + w * torch.rand(n, m-1), requires_grad=True) 44 | w_c = Variable(0.1 + w * torch.rand(n-1, m), requires_grad=True) 45 | tv_args = {'max_iters': 1000, 'n_threads': 6} 46 | assert gradcheck(TotalVariation2dWeighted(tv_args=tv_args), (x, w_r, w_c), 47 | eps=1e-5, atol=5e-2, rtol=1e-3) 48 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # The script is adapted from 2 | # https://github.com/pybind/python_example/blob/master/setup.py 3 | 4 | from setuptools import setup, Extension 5 | from setuptools.command.build_ext import build_ext 6 | import sys 7 | import setuptools 8 | 9 | __version__ = '0.1' 10 | 11 | 12 | class get_pybind_include(object): 13 | """Helper class to determine the pybind11 include path 14 | The purpose of this class is to postpone importing pybind11 15 | until it is actually installed, so that the ``get_include()`` 16 | method can be invoked. """ 17 | 18 | def __init__(self, user=False): 19 | self.user = user 20 | 21 | def __str__(self): 22 | import pybind11 23 | return pybind11.get_include(self.user) 24 | 25 | 26 | ext_modules = [ 27 | Extension( 28 | 'torch_submod.blocks', 29 | ['src/blocks.cpp'], 30 | include_dirs=[ 31 | # Path to pybind11 headers 32 | get_pybind_include(), 33 | get_pybind_include(user=True), 34 | ], 35 | language='c++' 36 | ), 37 | ] 38 | 39 | 40 | # As of Python 3.6, CCompiler has a `has_flag` method. 41 | # cf http://bugs.python.org/issue26689 42 | def has_flag(compiler, flagname): 43 | """Return a boolean indicating whether a flag name is supported on 44 | the specified compiler. 45 | """ 46 | import tempfile 47 | with tempfile.NamedTemporaryFile('w', suffix='.cpp') as f: 48 | f.write('int main (int argc, char **argv) { return 0; }') 49 | try: 50 | compiler.compile([f.name], extra_postargs=[flagname]) 51 | except setuptools.distutils.errors.CompileError: 52 | return False 53 | return True 54 | 55 | 56 | def cpp_flag(compiler): 57 | """Return the -std=c++[11/14] compiler flag. 58 | The c++14 is prefered over c++11 (when it is available). 59 | """ 60 | if has_flag(compiler, '-std=c++14'): 61 | return '-std=c++14' 62 | elif has_flag(compiler, '-std=c++11'): 63 | return '-std=c++11' 64 | else: 65 | raise RuntimeError('Unsupported compiler -- at least C++11 support ' 66 | 'is needed!') 67 | 68 | 69 | class BuildExt(build_ext): 70 | """A custom build extension for adding compiler-specific options.""" 71 | c_opts = { 72 | 'msvc': ['/EHsc'], 73 | 'unix': [], 74 | } 75 | 76 | if sys.platform == 'darwin': 77 | c_opts['unix'] += ['-stdlib=libc++', '-mmacosx-version-min=10.7'] 78 | 79 | def build_extensions(self): 80 | ct = self.compiler.compiler_type 81 | opts = self.c_opts.get(ct, []) 82 | if ct == 'unix': 83 | opts.append('-DVERSION_INFO="%s"' % 84 | self.distribution.get_version()) 85 | opts.append(cpp_flag(self.compiler)) 86 | if has_flag(self.compiler, '-fvisibility=hidden'): 87 | opts.append('-fvisibility=hidden') 88 | elif ct == 'msvc': 89 | opts.append('/DVERSION_INFO=\\"%s\\"' % 90 | self.distribution.get_version()) 91 | for ext in self.extensions: 92 | ext.extra_compile_args = opts 93 | build_ext.build_extensions(self) 94 | 95 | setup( 96 | name='torch_submod', 97 | version=__version__, 98 | author='Josip Djolonga', 99 | author_email='josipd@inf.ethz.ch', 100 | url='https://github.com/josipd/torch-submod', 101 | description='A PyTorch library for differentiable submodular minimization', 102 | long_description='', 103 | packages=['torch_submod'], 104 | ext_modules=ext_modules, 105 | install_requires=[ 106 | 'pybind11>=2.2' 107 | 'numpy', 108 | ], 109 | setup_requires=[ 110 | 'pytest-runner', 111 | 'prox_tv', 112 | 'scikit-learn', 113 | ], 114 | tests_require=[ 115 | 'pytest', 116 | 'hypothesis', 117 | ], 118 | cmdclass={'build_ext': BuildExt}, 119 | zip_safe=False, 120 | license='license.txt', 121 | ) 122 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | ######################################## 2 | Welcome to torch-submod's documentation! 3 | ######################################## 4 | 5 | ************ 6 | Introduction 7 | ************ 8 | 9 | A library implementing layers that solve the min-norm problem for submodular 10 | functions. 11 | The computation of the Jacobian (i.e., backpropagation) is done using the 12 | methods from :cite:`djolonga17learning`. At the moment only graph-cuts on two 13 | dimensional grids are implemented, in which case the min-norm problem is also 14 | known as a total variation problem. 15 | 16 | At the moment only one- and two-dimensional graph cut functions have been 17 | implemented, so that this package provides differentiable (with respect to the 18 | input signal *and* the weights) total variation solvers. 19 | 20 | ************ 21 | Installation 22 | ************ 23 | 24 | Once you install PyTorch (following 25 | `these instructions `_) , you can install 26 | the package as:: 27 | 28 | python setup.py install 29 | 30 | ***** 31 | Usage 32 | ***** 33 | 34 | For example, let us try to learn row- and column- weights that will denoise 35 | a simple image. 36 | Let us create an image that is zero everywhere, except its left-right corner 37 | that is filled with ones. Then, we will corrupt it with normal noise, and try 38 | to recover it using a total-variation solver with learned weights. 39 | 40 | Note that an extended version of the example below, together with visualization 41 | is provided in the repository as a 42 | `jupyter notebook `_). 43 | 44 | 45 | >>> from __future__ import division, print_function 46 | >>> import torch 47 | >>> from torch.autograd import Variable 48 | >>> from torch_submod.graph_cuts import TotalVariation2dWeighted as tv2d 49 | >>> 50 | >>> torch.manual_seed(0) 51 | >>> m, n = 50, 100 # The image dimensions. 52 | >>> std = 1e-1 # The standard deviation of noise. 53 | >>> x = torch.zeros((m, n)) 54 | >>> x[:m//2, :n//2] += 1 55 | >>> x_noisy = x + torch.normal(torch.zeros(x.size())) 56 | >>> 57 | >>> x = Variable(x, requires_grad=False) 58 | >>> x_noisy = Variable(x_noisy, requires_grad=False) 59 | >>> 60 | >>> # The learnable parameters. 61 | >>> log_w_row = Variable(- 3 * torch.ones(1), requires_grad=True) 62 | >>> log_w_col = Variable(- 3 * torch.ones(1), requires_grad=True) 63 | >>> scale = Variable(torch.ones(1), requires_grad=True) 64 | >>> 65 | >>> optimizer = torch.optim.SGD([log_w_row, log_w_col, scale], lr=.5) 66 | >>> losses = [] 67 | >>> for iter_no in range(1000): 68 | >>> w_row = torch.exp(log_w_row) 69 | >>> w_col = torch.exp(log_w_col) 70 | >>> y = tv2d()(scale * x_noisy, 71 | >>> w_row.expand((m, n-1)), w_col.expand((m - 1, n))) 72 | >>> optimizer.zero_grad() 73 | >>> loss = torch.mean((y - x)**2) 74 | >>> loss.backward() 75 | >>> if iter_no % 100 == 0: 76 | >>> losses.append(loss.data[0]) 77 | >>> optimizer.step() 78 | >>> print('\n'.join(map(str, losses))) 79 | 0.809337258339 80 | 0.100806325674 81 | 0.0123300831765 82 | 0.00451330607757 83 | 0.00304582691751 84 | 0.00262771383859 85 | 0.00248298258521 86 | 0.00242520542815 87 | 0.00239872303791 88 | 0.00239089410752 89 | 90 | 91 | 92 | ================ 93 | Function classes 94 | ================ 95 | 96 | ^^^^^^^^^^ 97 | Graph cuts 98 | ^^^^^^^^^^ 99 | 100 | To solve the total-variation problem we are using the 101 | `prox_tv `_ library. Please refer to the 102 | documentation accompanying that package to find out more about the set of 103 | available methods. Namely, each function accepts a ``tv_args`` dictionary 104 | argument, which is passed onto the solver. The idea to average within the 105 | connected components, enabled when ``average_connected=True``, first appeared 106 | for the one-dimensional case in :cite:`niculae2017regularized`. 107 | 108 | *Note*: At the moment the total variation problems can be solved only 109 | *on the CPU*, so please make sure that all variables are placed on the CPU. 110 | 111 | 112 | .. autoclass:: torch_submod.graph_cuts.TotalVariation2dWeighted 113 | :members: __init__, forward 114 | 115 | .. autoclass:: torch_submod.graph_cuts.TotalVariation2d 116 | :members: __init__, forward 117 | 118 | .. autoclass:: torch_submod.graph_cuts.TotalVariation1d 119 | :members: __init__, forward 120 | 121 | ************ 122 | Bibliography 123 | ************ 124 | 125 | .. bibliography:: refs.bib 126 | 127 | 128 | ================== 129 | Indices and tables 130 | ================== 131 | 132 | * :ref:`genindex` 133 | * :ref:`modindex` 134 | * :ref:`search` 135 | -------------------------------------------------------------------------------- /src/blocks.cpp: -------------------------------------------------------------------------------- 1 | #include "blocks.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace py = pybind11; 9 | 10 | namespace { 11 | 12 | const char* blockwise_means_doc = R"(blockwise_means(blocks, input) 13 | 14 | Average the elements of the given vector within each block. 15 | 16 | Specifically, the coordinate ``i`` of the returned vector will contain the mean 17 | of all entries ``j`` in ``input`` that have ``blocks[j]=b``. 18 | 19 | Arguments 20 | --------- 21 | blocks : numpy.ndarray 22 | A vector of ints that denote the block memberships of each position. 23 | 24 | The set of numbers in this vector should start at zero and be consecutive. 25 | 26 | input : numpy.ndarray 27 | A vector of same size as ``block`` containing the data to be averaged. 28 | 29 | Returns 30 | -------- 31 | numpy.ndarray 32 | A vector of same size as the inputs that contains the block-wise averages. 33 | )"; 34 | 35 | } // namespace 36 | 37 | py::array_t blockwise_means(const py::array_t& blocks_, 38 | const py::array_t& input_) { 39 | if (blocks_.ndim() != 1 || input_.ndim() != 1) { 40 | throw std::runtime_error("the given arrays must be one dimensional"); 41 | } 42 | if (blocks_.shape(0) != input_.shape(0)) { 43 | throw std::runtime_error("the number of elements in the must match"); 44 | } 45 | 46 | auto blocks = blocks_.unchecked<1>(); 47 | auto input = input_.unchecked<1>(); 48 | 49 | int n_blocks = 0; 50 | for (int i = 0; i < blocks.shape(0); i++) { 51 | if (blocks[i] < 0) { 52 | throw std::runtime_error("the block ids must be non-negative"); 53 | } 54 | n_blocks = std::max(n_blocks, 1 + blocks(i)); 55 | } 56 | 57 | std::vector blocks_sums(n_blocks); 58 | std::vector blocks_total(n_blocks); 59 | 60 | for (int i = 0; i < blocks.shape(0); i++) { 61 | blocks_sums[blocks(i)] += input[i]; 62 | ++blocks_total[blocks[i]]; 63 | } 64 | for (int i = 0; i < n_blocks; i++) { 65 | blocks_sums[i] /= static_cast(blocks_total[i]); 66 | } 67 | 68 | py::array_t output_ = py::array_t(blocks.shape(0)); 69 | auto output = output_.mutable_unchecked<1>(); 70 | for (int i = 0; i < output_.shape(0); i++) { 71 | output(i) = blocks_sums[blocks[i]]; 72 | } 73 | return output_; 74 | } 75 | 76 | 77 | namespace { 78 | 79 | const char* blocks_2d_doc = R"(blocks_2d(matrix) 80 | 81 | Return the connected components of the matrix. 82 | 83 | Two positions are connected iff they hold the same value, and they differ in 84 | one coordinate by one (i.e., it is a 4-connected grid). 85 | 86 | Arguments 87 | --------- 88 | matrix : numpy.ndarray 89 | The two-dimensional matrix. 90 | 91 | Returns 92 | -------- 93 | numpy.ndarray 94 | A matrix of ints, same size as the input. 95 | 96 | The positions corresponding to the same connected component have the 97 | same label. The labels are consecutive integers starting at zero. 98 | )"; 99 | 100 | } // namespace 101 | 102 | 103 | py::array_t blocks_2d(const py::array_t& matrix_) { 104 | if (matrix_.ndim() != 2) { 105 | throw std::runtime_error("the given matrix must be two dimensional"); 106 | } 107 | auto matrix = matrix_.unchecked<2>(); 108 | 109 | std::vector ranks(matrix.size()); 110 | std::vector parents(matrix.size()); 111 | boost::disjoint_sets union_find(&ranks[0], &parents[0]); 112 | 113 | #define IDX(i, j) ((i) * static_cast(matrix.shape(1)) + (j)) 114 | 115 | for (int i = 0; i < matrix_.shape(0); i++) { 116 | for (int j = 0; j < matrix_.shape(1); j++) { 117 | int idx = IDX(i, j); 118 | union_find.make_set(idx); 119 | if (i > 0 && matrix(i, j) == matrix(i - 1, j)) { 120 | union_find.union_set(idx, IDX(i - 1, j)); 121 | } 122 | if (j > 0 && matrix(i, j) == matrix(i, j - 1)) { 123 | union_find.union_set(idx, IDX(i, j - 1)); 124 | } 125 | } 126 | } 127 | 128 | std::unordered_map root_to_idx; 129 | py::array_t output_ = py::array_t({matrix_.shape(0), 130 | matrix_.shape(1)}); 131 | auto output = output_.mutable_unchecked<2>(); 132 | int next_id = 0; 133 | for (int i = 0; i < matrix.shape(0); i++) { 134 | for (int j = 0; j < matrix.shape(1); j++) { 135 | int idx = IDX(i, j); 136 | int root = union_find.find_set(idx); 137 | auto iter = root_to_idx.find(root); 138 | if (iter == root_to_idx.end()) { 139 | output(i, j) = next_id; 140 | root_to_idx[root] = next_id++; 141 | } else { 142 | output(i, j) = iter->second; 143 | } 144 | } 145 | } 146 | 147 | return output_; 148 | } 149 | 150 | 151 | PYBIND11_MODULE(blocks, m) { 152 | py::options options; 153 | options.disable_function_signatures(); 154 | m.def("blockwise_means", blockwise_means, blockwise_means_doc); 155 | m.def("blocks_2d", blocks_2d, blocks_2d_doc); 156 | } 157 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = build 9 | 10 | # User-friendly check for sphinx-build 11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 13 | endif 14 | 15 | # Internal variables. 16 | PAPEROPT_a4 = -D latex_paper_size=a4 17 | PAPEROPT_letter = -D latex_paper_size=letter 18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source 19 | # the i18n builder cannot share the environment and doctrees with the others 20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source 21 | 22 | .PHONY: help 23 | help: 24 | @echo "Please use \`make ' where is one of" 25 | @echo " html to make standalone HTML files" 26 | @echo " dirhtml to make HTML files named index.html in directories" 27 | @echo " singlehtml to make a single large HTML file" 28 | @echo " pickle to make pickle files" 29 | @echo " json to make JSON files" 30 | @echo " htmlhelp to make HTML files and a HTML help project" 31 | @echo " qthelp to make HTML files and a qthelp project" 32 | @echo " applehelp to make an Apple Help Book" 33 | @echo " devhelp to make HTML files and a Devhelp project" 34 | @echo " epub to make an epub" 35 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 36 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 37 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 38 | @echo " text to make text files" 39 | @echo " man to make manual pages" 40 | @echo " texinfo to make Texinfo files" 41 | @echo " info to make Texinfo files and run them through makeinfo" 42 | @echo " gettext to make PO message catalogs" 43 | @echo " changes to make an overview of all changed/added/deprecated items" 44 | @echo " xml to make Docutils-native XML files" 45 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 46 | @echo " linkcheck to check all external links for integrity" 47 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 48 | @echo " coverage to run coverage check of the documentation (if enabled)" 49 | 50 | .PHONY: clean 51 | clean: 52 | rm -rf $(BUILDDIR)/* 53 | 54 | .PHONY: html 55 | html: 56 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 57 | @echo 58 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 59 | 60 | .PHONY: dirhtml 61 | dirhtml: 62 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 63 | @echo 64 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 65 | 66 | .PHONY: singlehtml 67 | singlehtml: 68 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 69 | @echo 70 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 71 | 72 | .PHONY: pickle 73 | pickle: 74 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 75 | @echo 76 | @echo "Build finished; now you can process the pickle files." 77 | 78 | .PHONY: json 79 | json: 80 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 81 | @echo 82 | @echo "Build finished; now you can process the JSON files." 83 | 84 | .PHONY: htmlhelp 85 | htmlhelp: 86 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 87 | @echo 88 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 89 | ".hhp project file in $(BUILDDIR)/htmlhelp." 90 | 91 | .PHONY: qthelp 92 | qthelp: 93 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 94 | @echo 95 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 96 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 97 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/torch-submod.qhcp" 98 | @echo "To view the help file:" 99 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/torch-submod.qhc" 100 | 101 | .PHONY: applehelp 102 | applehelp: 103 | $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp 104 | @echo 105 | @echo "Build finished. The help book is in $(BUILDDIR)/applehelp." 106 | @echo "N.B. You won't be able to view it unless you put it in" \ 107 | "~/Library/Documentation/Help or install it in your application" \ 108 | "bundle." 109 | 110 | .PHONY: devhelp 111 | devhelp: 112 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 113 | @echo 114 | @echo "Build finished." 115 | @echo "To view the help file:" 116 | @echo "# mkdir -p $$HOME/.local/share/devhelp/torch-submod" 117 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/torch-submod" 118 | @echo "# devhelp" 119 | 120 | .PHONY: epub 121 | epub: 122 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 123 | @echo 124 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 125 | 126 | .PHONY: latex 127 | latex: 128 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 129 | @echo 130 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 131 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 132 | "(use \`make latexpdf' here to do that automatically)." 133 | 134 | .PHONY: latexpdf 135 | latexpdf: 136 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 137 | @echo "Running LaTeX files through pdflatex..." 138 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 139 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 140 | 141 | .PHONY: latexpdfja 142 | latexpdfja: 143 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 144 | @echo "Running LaTeX files through platex and dvipdfmx..." 145 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 146 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 147 | 148 | .PHONY: text 149 | text: 150 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 151 | @echo 152 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 153 | 154 | .PHONY: man 155 | man: 156 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 157 | @echo 158 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 159 | 160 | .PHONY: texinfo 161 | texinfo: 162 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 163 | @echo 164 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 165 | @echo "Run \`make' in that directory to run these through makeinfo" \ 166 | "(use \`make info' here to do that automatically)." 167 | 168 | .PHONY: info 169 | info: 170 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 171 | @echo "Running Texinfo files through makeinfo..." 172 | make -C $(BUILDDIR)/texinfo info 173 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 174 | 175 | .PHONY: gettext 176 | gettext: 177 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 178 | @echo 179 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 180 | 181 | .PHONY: changes 182 | changes: 183 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 184 | @echo 185 | @echo "The overview file is in $(BUILDDIR)/changes." 186 | 187 | .PHONY: linkcheck 188 | linkcheck: 189 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 190 | @echo 191 | @echo "Link check complete; look for any errors in the above output " \ 192 | "or in $(BUILDDIR)/linkcheck/output.txt." 193 | 194 | .PHONY: doctest 195 | doctest: 196 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 197 | @echo "Testing of doctests in the sources finished, look at the " \ 198 | "results in $(BUILDDIR)/doctest/output.txt." 199 | 200 | .PHONY: coverage 201 | coverage: 202 | $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage 203 | @echo "Testing of coverage in the sources finished, look at the " \ 204 | "results in $(BUILDDIR)/coverage/python.txt." 205 | 206 | .PHONY: xml 207 | xml: 208 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 209 | @echo 210 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 211 | 212 | .PHONY: pseudoxml 213 | pseudoxml: 214 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 215 | @echo 216 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 217 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # torch-submod documentation build configuration file, created by 4 | # sphinx-quickstart on Wed Nov 1 18:16:41 2017. 5 | # 6 | # This file is execfile()d with the current directory set to its 7 | # containing dir. 8 | # 9 | # Note that not all possible configuration values are present in this 10 | # autogenerated file. 11 | # 12 | # All configuration values have a default; values that are commented out 13 | # serve to show the default. 14 | 15 | import sys 16 | import os 17 | from mock import Mock 18 | import sphinx_rtd_theme 19 | 20 | # If extensions (or modules to document with autodoc) are in another directory, 21 | # add these directories to sys.path here. If the directory is relative to the 22 | # documentation root, use os.path.abspath to make it absolute, like shown here. 23 | sys.path.insert(0, os.path.abspath('../..')) 24 | try: 25 | import torch 26 | except ImportError: 27 | # Mock torch.autograd.Function 28 | sys.path.insert(0, os.path.abspath('.')) 29 | 30 | class MockIgnoringArgs(object): 31 | def __init__(self, *args, **kwargs): 32 | pass 33 | 34 | # The following packages have to be mocked as we use autodoc. 35 | MOCK_MODULES = ['prox_tv', 36 | 'prox_tv.tv1_2d', 37 | 'prox_tv.tv1_1d', 38 | 'prox_tv.tv1w_2d', 39 | 'prox_tv.tv1w_1d', 40 | 'numpy', 41 | 'sklearn', 42 | 'sklearn.isotonic', 43 | 'sklearn.isotonic.isotonic_regression', 44 | 'torch_submod.blocks', 45 | 'torch_submod.blocks.blockwise_means', 46 | 'torch_submod.blocks.blocks_2d'] 47 | sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) 48 | 49 | 50 | 51 | # -- General configuration ------------------------------------------------ 52 | 53 | # If your documentation needs a minimal Sphinx version, state it here. 54 | #needs_sphinx = '1.0' 55 | 56 | # Add any Sphinx extension module names here, as strings. They can be 57 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 58 | # ones. 59 | extensions = [ 60 | 'sphinx.ext.autodoc', 61 | 'sphinx.ext.doctest', 62 | 'sphinx.ext.intersphinx', 63 | 'sphinx.ext.mathjax', 64 | 'sphinx.ext.napoleon', 65 | 'sphinxcontrib.bibtex', 66 | ] 67 | 68 | # Add any paths that contain templates here, relative to this directory. 69 | templates_path = ['_templates'] 70 | 71 | # The suffix(es) of source filenames. 72 | # You can specify multiple suffix as a list of string: 73 | # source_suffix = ['.rst', '.md'] 74 | source_suffix = '.rst' 75 | 76 | # The encoding of source files. 77 | #source_encoding = 'utf-8-sig' 78 | 79 | # The master toctree document. 80 | master_doc = 'index' 81 | 82 | # General information about the project. 83 | project = u'torch-submod' 84 | copyright = u'2017, Josip Djolonga' 85 | author = u'Josip Djolonga' 86 | 87 | # The version info for the project you're documenting, acts as replacement for 88 | # |version| and |release|, also used in various other places throughout the 89 | # built documents. 90 | # 91 | # The short X.Y version. 92 | version = u'0.1' 93 | # The full version, including alpha/beta/rc tags. 94 | release = u'0.1' 95 | 96 | # The language for content autogenerated by Sphinx. Refer to documentation 97 | # for a list of supported languages. 98 | # 99 | # This is also used if you do content translation via gettext catalogs. 100 | # Usually you set "language" from the command line for these cases. 101 | language = None 102 | 103 | # There are two options for replacing |today|: either, you set today to some 104 | # non-false value, then it is used: 105 | #today = '' 106 | # Else, today_fmt is used as the format for a strftime call. 107 | #today_fmt = '%B %d, %Y' 108 | 109 | # List of patterns, relative to source directory, that match files and 110 | # directories to ignore when looking for source files. 111 | exclude_patterns = [] 112 | 113 | # The reST default role (used for this markup: `text`) to use for all 114 | # documents. 115 | #default_role = None 116 | 117 | # If true, '()' will be appended to :func: etc. cross-reference text. 118 | #add_function_parentheses = True 119 | 120 | # If true, the current module name will be prepended to all description 121 | # unit titles (such as .. function::). 122 | #add_module_names = True 123 | 124 | # If true, sectionauthor and moduleauthor directives will be shown in the 125 | # output. They are ignored by default. 126 | #show_authors = False 127 | 128 | # The name of the Pygments (syntax highlighting) style to use. 129 | pygments_style = 'sphinx' 130 | 131 | # A list of ignored prefixes for module index sorting. 132 | #modindex_common_prefix = [] 133 | 134 | # If true, keep warnings as "system message" paragraphs in the built documents. 135 | #keep_warnings = False 136 | 137 | # If true, `todo` and `todoList` produce output, else they produce nothing. 138 | todo_include_todos = False 139 | 140 | 141 | # -- Options for HTML output ---------------------------------------------- 142 | 143 | # The theme to use for HTML and HTML Help pages. See the documentation for 144 | # a list of builtin themes. 145 | html_theme = 'sphinx_rtd_theme' 146 | 147 | # Theme options are theme-specific and customize the look and feel of a theme 148 | # further. For a list of options available for each theme, see the 149 | # documentation. 150 | #html_theme_options = {} 151 | 152 | # Add any paths that contain custom themes here, relative to this directory. 153 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 154 | 155 | # The name for this set of Sphinx documents. If None, it defaults to 156 | # " v documentation". 157 | #html_title = None 158 | 159 | # A shorter title for the navigation bar. Default is the same as html_title. 160 | #html_short_title = None 161 | 162 | # The name of an image file (relative to this directory) to place at the top 163 | # of the sidebar. 164 | #html_logo = None 165 | 166 | # The name of an image file (relative to this directory) to use as a favicon of 167 | # the docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 168 | # pixels large. 169 | #html_favicon = None 170 | 171 | # Add any paths that contain custom static files (such as style sheets) here, 172 | # relative to this directory. They are copied after the builtin static files, 173 | # so a file named "default.css" will overwrite the builtin "default.css". 174 | html_static_path = ['_static'] 175 | 176 | # Add any extra paths that contain custom files (such as robots.txt or 177 | # .htaccess) here, relative to this directory. These files are copied 178 | # directly to the root of the documentation. 179 | #html_extra_path = [] 180 | 181 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 182 | # using the given strftime format. 183 | #html_last_updated_fmt = '%b %d, %Y' 184 | 185 | # If true, SmartyPants will be used to convert quotes and dashes to 186 | # typographically correct entities. 187 | #html_use_smartypants = True 188 | 189 | # Custom sidebar templates, maps document names to template names. 190 | #html_sidebars = {} 191 | 192 | # Additional templates that should be rendered to pages, maps page names to 193 | # template names. 194 | #html_additional_pages = {} 195 | 196 | # If false, no module index is generated. 197 | #html_domain_indices = True 198 | 199 | # If false, no index is generated. 200 | #html_use_index = True 201 | 202 | # If true, the index is split into individual pages for each letter. 203 | #html_split_index = False 204 | 205 | # If true, links to the reST sources are added to the pages. 206 | #html_show_sourcelink = True 207 | 208 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 209 | #html_show_sphinx = True 210 | 211 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 212 | #html_show_copyright = True 213 | 214 | # If true, an OpenSearch description file will be output, and all pages will 215 | # contain a tag referring to it. The value of this option must be the 216 | # base URL from which the finished HTML is served. 217 | #html_use_opensearch = '' 218 | 219 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 220 | #html_file_suffix = None 221 | 222 | # Language to be used for generating the HTML full-text search index. 223 | # Sphinx supports the following languages: 224 | # 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja' 225 | # 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr' 226 | #html_search_language = 'en' 227 | 228 | # A dictionary with options for the search language support, empty by default. 229 | # Now only 'ja' uses this config value 230 | #html_search_options = {'type': 'default'} 231 | 232 | # The name of a javascript file (relative to the configuration directory) that 233 | # implements a search results scorer. If empty, the default will be used. 234 | #html_search_scorer = 'scorer.js' 235 | 236 | # Output file base name for HTML help builder. 237 | htmlhelp_basename = 'torch-submoddoc' 238 | 239 | # -- Options for LaTeX output --------------------------------------------- 240 | 241 | latex_elements = { 242 | # The paper size ('letterpaper' or 'a4paper'). 243 | #'papersize': 'letterpaper', 244 | 245 | # The font size ('10pt', '11pt' or '12pt'). 246 | #'pointsize': '10pt', 247 | 248 | # Additional stuff for the LaTeX preamble. 249 | #'preamble': '', 250 | 251 | # Latex figure (float) alignment 252 | #'figure_align': 'htbp', 253 | } 254 | 255 | # Grouping the document tree into LaTeX files. List of tuples 256 | # (source start file, target name, title, 257 | # author, documentclass [howto, manual, or own class]). 258 | latex_documents = [ 259 | (master_doc, 'torch-submod.tex', u'torch-submod Documentation', 260 | u'Josip Djolonga', 'manual'), 261 | ] 262 | 263 | # The name of an image file (relative to this directory) to place at the top of 264 | # the title page. 265 | #latex_logo = None 266 | 267 | # For "manual" documents, if this is true, then toplevel headings are parts, 268 | # not chapters. 269 | #latex_use_parts = False 270 | 271 | # If true, show page references after internal links. 272 | #latex_show_pagerefs = False 273 | 274 | # If true, show URL addresses after external links. 275 | #latex_show_urls = False 276 | 277 | # Documents to append as an appendix to all manuals. 278 | #latex_appendices = [] 279 | 280 | # If false, no module index is generated. 281 | #latex_domain_indices = True 282 | 283 | 284 | # -- Options for manual page output --------------------------------------- 285 | 286 | # One entry per manual page. List of tuples 287 | # (source start file, name, description, authors, manual section). 288 | man_pages = [ 289 | (master_doc, 'torch-submod', u'torch-submod Documentation', 290 | [author], 1) 291 | ] 292 | 293 | # If true, show URL addresses after external links. 294 | #man_show_urls = False 295 | 296 | 297 | # -- Options for Texinfo output ------------------------------------------- 298 | 299 | # Grouping the document tree into Texinfo files. List of tuples 300 | # (source start file, target name, title, author, 301 | # dir menu entry, description, category) 302 | texinfo_documents = [ 303 | (master_doc, 'torch-submod', u'torch-submod Documentation', 304 | author, 'torch-submod', 'One line description of project.', 305 | 'Miscellaneous'), 306 | ] 307 | 308 | # Documents to append as an appendix to all manuals. 309 | #texinfo_appendices = [] 310 | 311 | # If false, no module index is generated. 312 | #texinfo_domain_indices = True 313 | 314 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 315 | #texinfo_show_urls = 'footnote' 316 | 317 | # If true, do not generate a @detailmenu in the "Top" node's menu. 318 | #texinfo_no_detailmenu = False 319 | 320 | 321 | # Example configuration for intersphinx: refer to the Python standard library. 322 | intersphinx_mapping = {'python': ('https://docs.python.org/3.4', None), 323 | 'torch': ('http://pytorch.org/docs/master', None), 324 | 'numpy': ('http://docs.scipy.org/doc/numpy/', None)} 325 | -------------------------------------------------------------------------------- /torch_submod/graph_cuts.py: -------------------------------------------------------------------------------- 1 | """Argmin-differentiable total variation functions.""" 2 | from __future__ import division, print_function 3 | 4 | # Must import these first, gomp issues with pytorch. 5 | from prox_tv import tv1w_2d, tv1_2d, tv1w_1d, tv1_1d 6 | 7 | import numpy as np 8 | from sklearn.isotonic import isotonic_regression as isotonic 9 | 10 | import torch 11 | from torch.autograd import Function 12 | from .blocks import blockwise_means, blocks_2d 13 | 14 | __all__ = ("TotalVariation2d", "TotalVariation2dWeighted", "TotalVariation1d") 15 | 16 | 17 | class TotalVariation2dWeighted(Function): 18 | r"""A two dimensional total variation function. 19 | 20 | Specifically, given as input the unaries `x`, positive row weights 21 | :math:`\mathbf{r}` and positive column weights :math:`\mathbf{c}`, the 22 | output is computed as 23 | 24 | .. math:: 25 | 26 | \textrm{argmin}_{\mathbf z} 27 | \frac{1}{2} \|\mathbf{x}-\mathbf{z}\|^2 + 28 | \sum_{i, j} r_{i,j} |z_{i, j} - z_{i, j + 1}| + 29 | \sum_{i, j} c_{i,j} |z_{i, j} - z_{i + 1, j}|. 30 | 31 | Arguments 32 | --------- 33 | refine: bool 34 | If set the solution will be refined with isotonic regression. 35 | avearge_2d: bool 36 | How to compute the approximate derivative. 37 | 38 | If ``True``, will average within each connected component. 39 | If ``False``, it will average within each block of equal values. 40 | Typically, you want this set to true. 41 | tv_args: dict 42 | The dictionary of arguments passed to the total variation solver. 43 | """ 44 | def __init__(self, refine=True, average_connected=True, tv_args=None): 45 | self.tv_args = tv_args if tv_args is not None else {} 46 | self.refine = refine 47 | self.average_connected = average_connected 48 | 49 | def forward(self, x, weights_row, weights_col): 50 | r"""Solve the total variation problem and return the solution. 51 | 52 | Arguments 53 | --------- 54 | x: :class:`torch:torch.Tensor` 55 | A tensor with shape ``(m, n)`` holding the input signal. 56 | weights_row: :class:`torch:torch.Tensor` 57 | The horizontal edge weights. 58 | 59 | Tensor of shape ``(m, n - 1)``, or ``(1,)`` if all weights 60 | are equal. 61 | weights_col: :class:`torch:torch.Tensor` 62 | The vertical edge weights. 63 | 64 | Tensor of shape ``(m - 1, n)``, or ``(1,)`` if all weights 65 | are equal. 66 | 67 | Returns 68 | ------- 69 | :class:`torch:torch.Tensor` 70 | The solution to the total variation problem, of shape ``(m, n)``. 71 | """ 72 | opt = tv1w_2d(x.numpy(), weights_col.numpy(), weights_row.numpy(), 73 | **self.tv_args) 74 | if self.refine: 75 | opt = self._refine(opt, x, weights_row, weights_col) 76 | opt = torch.Tensor(opt).view_as(x) 77 | self.save_for_backward(opt) 78 | return opt 79 | 80 | def _grad_x(self, opt, grad_output): 81 | if self.average_connected: 82 | blocks = blocks_2d(opt.numpy()) 83 | else: 84 | _, blocks = np.unique(opt.numpy().ravel(), return_inverse=True) 85 | grad_x = blockwise_means(blocks.ravel(), grad_output.numpy().ravel()) 86 | # We need the clone as there seems to e a double-free error in py27, 87 | # namely, torch free()s the array after numpy has already free()d it. 88 | return torch.from_numpy(grad_x).view(opt.size()).clone() 89 | 90 | def _grad_w_row(self, opt, grad_x): 91 | """Compute the derivative with respect to the row weights.""" 92 | diffs_row = torch.sign(opt[:, :-1] - opt[:, 1:]) 93 | return - diffs_row * (grad_x[:, :-1] - grad_x[:, 1:]) 94 | 95 | def _grad_w_col(self, opt, grad_x): 96 | """Compute the derivative with respect to the column weights.""" 97 | diffs_col = torch.sign(opt[:-1, :] - opt[1:, :]) 98 | return - diffs_col * (grad_x[:-1, :] - grad_x[1:, :]) 99 | 100 | def backward(self, grad_output): 101 | opt, = self.saved_tensors 102 | grad_weights_row, grad_weights_col = None, None 103 | grad_x = self._grad_x(opt, grad_output) 104 | 105 | if self.needs_input_grad[1]: 106 | grad_weights_row = self._grad_w_row(opt, grad_x) 107 | 108 | if self.needs_input_grad[2]: 109 | grad_weights_col = self._grad_w_col(opt, grad_x) 110 | 111 | return grad_x, grad_weights_row, grad_weights_col 112 | 113 | def _refine(self, opt, x, weights_row, weights_col): 114 | """Refine the solution by solving an isotonic regression. 115 | 116 | The weights can either be two-dimensional tensors, or of shape (1,).""" 117 | idx = np.argsort(opt.ravel()) # Will pick an arbitrary order cone. 118 | ordered_vec = np.zeros_like(idx, dtype=np.float) 119 | ordered_vec[idx] = np.arange(np.size(opt)) 120 | f = self._linearize(ordered_vec.reshape(opt.shape), 121 | weights_row.numpy(), weights_col.numpy()) 122 | opt_idx = isotonic((x.view(-1).numpy() - f.ravel())[idx]) 123 | opt = np.zeros_like(opt_idx) 124 | opt[idx] = opt_idx 125 | return opt 126 | 127 | def _linearize(self, y, weights_row, weights_col): 128 | """Compute a linearization of the graph-cut function at the given point. 129 | 130 | Arguments 131 | --------- 132 | y : numpy.ndarray 133 | The point where the linearization is computed, shape ``(m, n)``. 134 | weights_row : numpy.ndarray 135 | The non-negative row weights, with shape ``(m, n - 1)``. 136 | y : numpy.ndarray 137 | The non-negative column weights, with shape ``(m - 1, n)``. 138 | 139 | Returns 140 | ------- 141 | numpy.ndarray 142 | The linearization of the graph-cut function at ``y``.""" 143 | diffs_col = np.sign(y[1:, :] - y[:-1, :]) 144 | diffs_row = np.sign(y[:, 1:] - y[:, :-1]) 145 | 146 | f = np.zeros_like(y) # The linearization. 147 | f[:, 1:] += diffs_row * weights_row 148 | f[:, :-1] -= diffs_row * weights_row 149 | f[1:, :] += diffs_col * weights_col 150 | f[:-1, :] -= diffs_col * weights_col 151 | 152 | return f 153 | 154 | 155 | class TotalVariation2d(TotalVariation2dWeighted): 156 | r"""A two dimensional total variation function with tied edge weights. 157 | 158 | Specifically, given as input the unaries `x` and edge weight ``w``, the 159 | returned value is given by: 160 | 161 | .. math:: 162 | 163 | \textrm{argmin}_{\mathbf z} 164 | \frac{1}{2} \|\mathbf{x}-\mathbf{z}\|^2 + 165 | \sum_{i, j} w |z_{i, j} - z_{i, j + 1}| + 166 | \sum_{i, j} w |z_{i, j} - z_{i + 1, j}|. 167 | 168 | Arguments 169 | --------- 170 | refine: bool 171 | If set the solution will be refined with isotonic regression. 172 | avearge_2d: bool 173 | How to compute the approximate derivative. 174 | 175 | If ``True``, will average within each connected component. 176 | If ``False``, it will average within each block of equal values. 177 | Typically, you want this set to true. 178 | tv_args: dict 179 | The dictionary of arguments passed to the total variation solver. 180 | """ 181 | def __init__(self, refine=True, average_connected=True, tv_args=None): 182 | super(TotalVariation2d, self).__init__( 183 | refine=refine, 184 | average_connected=average_connected, 185 | tv_args=tv_args) 186 | 187 | def forward(self, x, w): 188 | r"""Solve the total variation problem and return the solution. 189 | 190 | Arguments 191 | --------- 192 | x: :class:`torch:torch.Tensor` 193 | A tensor with shape ``(m, n)`` holding the input signal. 194 | weights_row: :class:`torch:torch.Tensor` 195 | The horizontal edge weights. 196 | 197 | Tensor of shape ``(m, n - 1)``, or ``(1,)`` if all weights 198 | are equal. 199 | weights_col: :class:`torch:torch.Tensor` 200 | The vertical edge weights. 201 | 202 | Tensor of shape ``(m - 1, n)``, or ``(1,)`` if all weights 203 | are equal. 204 | 205 | Returns 206 | ------- 207 | :class:`torch:torch.Tensor` 208 | The solution to the total variation problem, of shape ``(m, n)``. 209 | """ 210 | assert w.size() == (1,) 211 | opt = tv1_2d(x.numpy(), w.numpy()[0], **self.tv_args) 212 | 213 | if self.refine: # Should we improve it with isotonic regression. 214 | opt = self._refine(opt, x, w, w) 215 | 216 | opt = torch.Tensor(opt).view_as(x) 217 | self.save_for_backward(opt) 218 | return opt 219 | 220 | def backward(self, grad_output): 221 | opt, = self.saved_tensors 222 | grad_x = self._grad_x(opt, grad_output) 223 | grad_w = None 224 | 225 | if self.needs_input_grad[1]: 226 | grad_w = (torch.sum(self._grad_w_row(opt, grad_x)) + 227 | torch.sum(self._grad_w_col(opt, grad_x))) 228 | grad_w = torch.Tensor([grad_w]) 229 | 230 | return grad_x, grad_w 231 | 232 | 233 | class TotalVariation1d(TotalVariation2dWeighted): 234 | r"""A one dimensional total variation function. 235 | 236 | Specifically, given as input the signal `x` and weights :math:`\mathbf{w}`, 237 | the output is computed as 238 | 239 | .. math:: 240 | 241 | \textrm{argmin}_{\mathbf z} 242 | \frac{1}{2} \|\mathbf{x}-\mathbf{z}\|^2 + 243 | \sum_{i=1}^{n-1} w_i |z_i - z_{i+1}|. 244 | 245 | Arguments 246 | --------- 247 | average_connected: bool 248 | How to compute the approximate derivative. 249 | 250 | If ``True``, will average within each connected component. 251 | If ``False``, it will average within each block of equal values. 252 | Typically, you want this set to true. 253 | tv_args: dict 254 | The dictionary of arguments passed to the total variation solver. 255 | """ 256 | def __init__(self, average_connected=True, tv_args=None): 257 | if tv_args is None: 258 | self.tv_args = {} 259 | else: 260 | self.tv_args = tv_args 261 | self.average_connected = average_connected 262 | 263 | def forward(self, x, weights): 264 | r"""Solve the total variation problem and return the solution. 265 | 266 | Arguments 267 | --------- 268 | x: :class:`torch:torch.Tensor` 269 | A tensor with shape ``(n,)`` holding the input signal. 270 | weights: :class:`torch:torch.Tensor` 271 | The edge weights. 272 | 273 | Shape ``(n-1,)``, or ``(1,)`` if all weights are equal. 274 | 275 | Returns 276 | ------- 277 | :class:`torch:torch.Tensor` 278 | The solution to the total variation problem, of shape ``(m, n)``. 279 | """ 280 | self.equal_weights = weights.size() == (1,) 281 | if self.equal_weights: 282 | opt = tv1_1d(x.numpy().ravel(), weights.numpy()[0], 283 | **self.tv_args) 284 | else: 285 | opt = tv1w_1d(x.numpy().ravel(), weights.numpy().ravel(), 286 | **self.tv_args) 287 | opt = torch.Tensor(opt).view_as(x) 288 | 289 | self.save_for_backward(opt) 290 | return opt 291 | 292 | def backward(self, grad_output): 293 | opt, = self.saved_tensors 294 | grad_weights = None 295 | 296 | opt = opt.view((1, -1)) 297 | grad_x = self._grad_x(opt, grad_output) 298 | 299 | if self.needs_input_grad[1]: 300 | grad_weights = self._grad_w_row(opt, grad_x).view(-1) 301 | if self.equal_weights: 302 | grad_weights = torch.Tensor([torch.sum(grad_weights)]) 303 | 304 | return grad_x.view(-1), grad_weights 305 | --------------------------------------------------------------------------------