├── .github
└── workflows
│ ├── install_and_test.yml
│ └── release.yml
├── LICENSE
├── MANIFEST.in
├── README.md
├── docs
└── source
│ ├── _static
│ └── logo.svg
│ ├── _templates
│ └── logo.html
│ ├── api
│ ├── density.rst
│ ├── drnn.rst
│ ├── normalizer.rst
│ ├── qrnn.rst
│ ├── quantiles.rst
│ └── quantnn.rst
│ ├── api_reference.rst
│ ├── cdf_qrnn.png
│ ├── conf.py
│ ├── drnn.rst
│ ├── examples.rst
│ ├── index.rst
│ ├── notebooks
│ ├── cloud_top_pressure_retrieval.ipynb
│ ├── rain_rate_retrieval.ipynb
│ └── simple_example.ipynb
│ ├── pdf_drnn.png
│ ├── qrnn.rst
│ ├── qrnn.svg
│ ├── quantiles.svg
│ └── user_guide.rst
├── notebooks
├── cloud_top_pressure_retrieval.ipynb
├── comparison_xgboost.ipynb
├── convolutional_rain_rate_retrieval.ipynb
├── migrating_from_mse_regression.ipynb
├── rain_rate_retrieval.ipynb
├── simple_example.ipynb
└── tensor_board_logging.ipynb
├── quantnn.yml
├── quantnn
├── __init__.py
├── a_priori.py
├── backends
│ ├── __init__.py
│ ├── pytorch.py
│ └── tensor.py
├── common.py
├── data.py
├── data
│ └── matplotlib_style.rc
├── density.py
├── drnn.py
├── examples
│ ├── __init__.py
│ ├── gprof_conv.py
│ ├── gprof_simple.py
│ ├── modis_ctp.py
│ └── simple.py
├── files
│ ├── __init__.py
│ └── sftp.py
├── generic
│ └── __init__.py
├── logging
│ ├── __init__.py
│ └── multiprocessing.py
├── metrics.py
├── models
│ ├── __init__.py
│ ├── keras
│ │ ├── __init__.py
│ │ ├── padding.py
│ │ ├── unet.py
│ │ └── xception.py
│ └── pytorch
│ │ ├── __init__.py
│ │ ├── aggregators.py
│ │ ├── base.py
│ │ ├── blocks.py
│ │ ├── common.py
│ │ ├── decoders.py
│ │ ├── downsampling.py
│ │ ├── encoders.py
│ │ ├── factories.py
│ │ ├── fully_connected.py
│ │ ├── generative.py
│ │ ├── lightning.py
│ │ ├── logging.py
│ │ ├── normalization.py
│ │ ├── resnet.py
│ │ ├── stages.py
│ │ ├── torchvision.py
│ │ ├── unet.py
│ │ ├── upsampling.py
│ │ └── xception.py
├── mrnn.py
├── neural_network_model.py
├── normalizer.py
├── packed_tensor.py
├── plotting.py
├── qrnn.py
├── quantiles.py
├── transformations.py
└── utils.py
├── setup.py
└── test
├── conftest.py
├── files
├── test_files.py
└── test_sftp.py
├── models
├── pytorch
│ ├── test_aggregators.py
│ ├── test_base.py
│ ├── test_blocks.py
│ ├── test_decoders.py
│ ├── test_downsampling.py
│ ├── test_encoders.py
│ ├── test_fully_connected.py
│ ├── test_normalization.py
│ ├── test_torchvision.py
│ └── test_upsampling.py
├── test_keras.py
└── test_pytorch.py
├── test_data.py
├── test_data
├── x_train.npy
└── y_train.npy
├── test_density.py
├── test_drnn.py
├── test_generic.py
├── test_normalizer.py
├── test_packed_tensor.py
├── test_qrnn.py
├── test_quantiles.py
├── test_tensor_backends.py
└── test_utils.py
/.github/workflows/install_and_test.yml:
--------------------------------------------------------------------------------
1 | name: install_and_test
2 | on: [push]
3 | jobs:
4 | install_and_test_job:
5 | strategy:
6 | matrix:
7 | os: [ubuntu-latest]
8 | python: [3.8, 3.9]
9 | runs-on: ${{ matrix.os }}
10 | steps:
11 | - uses: actions/checkout@v2
12 | - uses: actions/setup-python@v2
13 | with:
14 | python-version: ${{ matrix.python }}
15 | - run: pip install torch torchvision pytest tensorflow
16 | - run: pip install .
17 | - run: pytest test
18 | env:
19 | QUANTNN_SFTP_USER: ${{ secrets.QUANTNN_SFTP_USER }}
20 | QUANTNN_SFTP_PASSWORD: ${{ secrets.QUANTNN_SFTP_PASSWORD }}
21 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: release
2 | on:
3 | push:
4 | tags:
5 | - '*'
6 | jobs:
7 | release_job:
8 | runs-on: ubuntu-latest
9 | steps:
10 | - uses: actions/checkout@v2
11 | with:
12 | ref: 'main'
13 | - uses: actions/setup-python@v2
14 | with:
15 | python-version: '3.8'
16 | - run: pip install .
17 | - run: pip install wheel twine
18 | - run: python setup.py sdist bdist_wheel
19 | - run: python -m twine upload -u __token__ -p ${{ secrets.TWINE_TOKEN }} dist/*
20 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2020 Simon Pfreundschuh
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4 |
5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6 |
7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
8 |
9 | End license text.
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include quantnn/data/matplotlib_style.rc
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # quantnn
2 |
3 | The ``quantnn`` package provides an implementation of quantile regression neural
4 | networks on top of Keras and Pytorch.
5 |
--------------------------------------------------------------------------------
/docs/source/_templates/logo.html:
--------------------------------------------------------------------------------
1 |
6 |
7 |
15 |

quantnn
16 |
17 |
--------------------------------------------------------------------------------
/docs/source/api/density.rst:
--------------------------------------------------------------------------------
1 | .. automodule:: quantnn.density
2 | :members:
3 |
--------------------------------------------------------------------------------
/docs/source/api/drnn.rst:
--------------------------------------------------------------------------------
1 | .. automodule:: quantnn.drnn
2 | :members:
3 |
--------------------------------------------------------------------------------
/docs/source/api/normalizer.rst:
--------------------------------------------------------------------------------
1 | .. automodule:: quantnn.normalizer
2 | :members:
3 |
--------------------------------------------------------------------------------
/docs/source/api/qrnn.rst:
--------------------------------------------------------------------------------
1 | .. automodule:: quantnn.qrnn
2 | :members:
3 | :inherited-members:
4 |
--------------------------------------------------------------------------------
/docs/source/api/quantiles.rst:
--------------------------------------------------------------------------------
1 | .. automodule:: quantnn.quantiles
2 | :members:
3 |
--------------------------------------------------------------------------------
/docs/source/api/quantnn.rst:
--------------------------------------------------------------------------------
1 | .. automodule:: quantnn
2 | :members:
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 | :caption: Submodules
7 |
8 | backends
9 | qrnn
10 | quantiles
11 | drnn
12 | density
13 | data
14 | normalizer
15 |
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/docs/source/api_reference.rst:
--------------------------------------------------------------------------------
1 | API Reference
2 | =============
3 |
4 | This part of the documentation contains the source code documentation of
5 | the **quantnn**. It provides detailed information on the usage of specific
6 | components of the package.
7 |
8 | .. toctree::
9 | :maxdepth: 2
10 |
11 | api/quantnn
12 | api/data
13 |
--------------------------------------------------------------------------------
/docs/source/cdf_qrnn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/simonpf/quantnn/c4d650cf0c6da5b4a704905b6c267d1ca996466f/docs/source/cdf_qrnn.png
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | # import os
14 | # import sys
15 | # sys.path.insert(0, os.path.abspath('.'))
16 |
17 | # -- Project information -----------------------------------------------------
18 |
19 | project = 'quantnn'
20 | copyright = '2020, Simon Pfreundschuh'
21 | author = 'Simon Pfreundschuh'
22 |
23 | # The full version, including alpha/beta/rc tags
24 | release = '0.0.1'
25 |
26 |
27 | # -- General configuration ---------------------------------------------------
28 |
29 | # Add any Sphinx extension module names here, as strings. They can be
30 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
31 | # ones.
32 | extensions = ["nbsphinx", "sphinx.ext.autodoc", "sphinx.ext.napoleon",
33 | "sphinx.ext.autosummary"]
34 |
35 | # Add any paths that contain templates here, relative to this directory.
36 | templates_path = ['_templates']
37 |
38 | # List of patterns, relative to source directory, that match files and
39 | # directories to ignore when looking for source files.
40 | # This pattern also affects html_static_path and html_extra_path.
41 | exclude_patterns = []
42 |
43 |
44 | # -- Options for HTML output -------------------------------------------------
45 |
46 | # The theme to use for HTML and HTML Help pages. See the documentation for
47 | # a list of builtin themes.
48 | #
49 | html_theme = 'smpl'
50 | html_theme_options = {
51 | "navigation_bar_minimum_height": "15vh",
52 | "navigation_bar_targets": ["index.html",
53 | "user_guide.html",
54 | "examples.html",
55 | "api_reference.html"],
56 | "navigation_bar_names": ["Home", "User guide", "Examples", "API Reference"],
57 | "navigation_bar_element_padding": "40px",
58 | "navigation_bar_background_color": "#333333",
59 | "navigation_bar_element_hover_color": "#ff5050",
60 | "navigation_bar_border_color": "#ff5050",
61 | "navigation_bar_border_style": "solid",
62 | "navigation_bar_border_width": "0px 0px 0px 0px",
63 |
64 | "link_color": "#ff5050",
65 | "link_visited_color": "#ff5050",
66 | "link_hover_color": "#990000",
67 |
68 | "sidebars_right": [],
69 | "sidebars_left":["localtoc.html", "globaltoc.html"],
70 | "globaltoc_maxdepth": 1,
71 | "inline_code_border_radius": "2px",
72 | "sidebar_left_border_color": "#ff5050",
73 |
74 | "highlight_border_color": "#ff5050",
75 | }
76 |
77 | # Add any paths that contain custom static files (such as style sheets) here,
78 | # relative to this directory. They are copied after the builtin static files,
79 | # so a file named "default.css" will overwrite the builtin "default.css".
80 | html_static_path = ['_static']
81 |
--------------------------------------------------------------------------------
/docs/source/drnn.rst:
--------------------------------------------------------------------------------
1 | ==========================================
2 | Density regression neural networks (DRNNs)
3 | ==========================================
4 |
5 | How it works
6 | ------------
7 |
8 | Defining a model
9 | ----------------
10 |
11 | Training
12 | --------
13 |
14 | Evaluation
15 | ----------
16 |
17 | Handling output
18 | ---------------
19 |
20 | Loading and saving models
21 | -------------------------
22 |
--------------------------------------------------------------------------------
/docs/source/examples.rst:
--------------------------------------------------------------------------------
1 | Examples
2 | ========
3 |
4 | This section contains examples of applications of the **quantnn** package.
5 |
6 | .. toctree::
7 | :maxdepth: 2
8 | :caption: Contents:
9 |
10 | notebooks/simple_example
11 | notebooks/cloud_top_pressure_retrieval
12 | notebooks/rain_rate_retrieval
13 |
14 |
15 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | .. quantnn documentation master file, created by
2 | sphinx-quickstart on Sun Dec 6 10:44:55 2020.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | quantnn
7 | =======
8 |
9 | **quantnn** is a Python package for solving probabilistic regression problems using
10 | (deterministic) deep neural networks, i.e. to predict a conditional distribution
11 | :math:`P(y|x)` for given input :math:`x`.
12 |
13 | It currently provides implementations of two neural-network based methods to solve
14 | these type of problems:
15 |
16 | 1. **Quantile regression neural networks (QRNNs)**: QRNNs learn to predict the quantiles
17 | of the conditional distribution :math:`P(y|x)`, which can be used to estimate its
18 | cumulative distribution function (CDF).
19 |
20 | .. figure:: cdf_qrnn.png
21 | :width: 800
22 | :alt: A conditional Cumulative distribution function predicted using a QRNN.
23 |
24 | Example of a QRNN applied to predict the quantiles of a function with
25 | heteroscedastic noise.
26 |
27 | 2. **Density regression neural networks (DRNNs)**: DRNNs learn predict a binned
28 | version of the probability density function (PDF) of :math:`P(y|x)`.
29 |
30 | .. figure:: pdf_drnn.png
31 | :width: 800
32 | :alt: A conditional probability density function predicted using a DRNN.
33 |
34 | Example of a DRNN applied to predict the quantiles of a function with
35 | heteroscedastic noise.
36 |
37 |
38 |
39 | Features
40 | --------
41 |
42 | - A flexible, high-level implementation of QRNN and DRNNs supporting both PyTorch and Keras
43 | as backends.
44 | - Generic functions to manipulate and process QRNN and DRNN predictions such as computing the
45 | posterior mean or classifying inputs.
46 |
47 | Installation
48 | ------------
49 |
50 | The currently recommended way of installing the **quantnn** package is to checkout the source from
51 | `GitHub `_ and install in editable mode using ``pip``:
52 |
53 | .. code-block:: bash
54 |
55 | pip install -e .
56 |
57 | Content
58 | -------
59 |
60 | .. toctree::
61 | :maxdepth: 2
62 |
63 | user_guide
64 | examples
65 | api_reference
66 |
67 |
68 |
69 |
--------------------------------------------------------------------------------
/docs/source/notebooks/cloud_top_pressure_retrieval.ipynb:
--------------------------------------------------------------------------------
1 | ../../../notebooks/cloud_top_pressure_retrieval.ipynb
--------------------------------------------------------------------------------
/docs/source/notebooks/rain_rate_retrieval.ipynb:
--------------------------------------------------------------------------------
1 | ../../../notebooks/rain_rate_retrieval.ipynb
--------------------------------------------------------------------------------
/docs/source/notebooks/simple_example.ipynb:
--------------------------------------------------------------------------------
1 | ../../../notebooks/simple_example.ipynb
--------------------------------------------------------------------------------
/docs/source/pdf_drnn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/simonpf/quantnn/c4d650cf0c6da5b4a704905b6c267d1ca996466f/docs/source/pdf_drnn.png
--------------------------------------------------------------------------------
/docs/source/qrnn.rst:
--------------------------------------------------------------------------------
1 | ===========================================
2 | Quantile regression neural networks (QRNNs)
3 | ===========================================
4 |
5 | Consider first the case of a simple, one-dimensional input vector :math:`x` with
6 | input features :math:`x_1, \ldots, x_n` and a corresponding scalar output value
7 | :math:`y`. If the problem of mapping :math:`x` to :math:`y` does not admit a unique
8 | solution a more suitable approach than predicting a single value for :math:`y`
9 | is to instead predict its conditional distribution :math:`P(y|x)` for given
10 | input :math:`x`.
11 |
12 | QRNNs do this by learning to predict a sequence of quantiles :math:`y_\tau` of
13 | the distribution :math:`P(y | \mathbf{x})`. The quantile :math:`y_\tau` for
14 | :math:`\tau \in [0, 1]` is defined as the value for which :math:`P(y \leq y_\tau
15 | | x) = \tau`. Since the quantiles :math:`y_{\tau_0}, y_{\tau_1}, \ldots`
16 | correspond to the values of the inverse :math:`F^{-1}` of the cumulative
17 | distribution function :math:`F(x) = P(X \leq x | Y = y)`, the QRNN output can be
18 | interpreted as a piece-wise linear approximation of the CDF of :math:`P(y|x)`.
19 |
20 | .. figure:: qrnn.svg
21 | :width: 600
22 | :alt: Illustration of a QRNN
23 |
24 | A QRNN is a neural network which predicts a sequence of quantiles of
25 | the posterior distribution :math:`P(y|x)`.
26 |
27 | How it works
28 | ------------
29 |
30 | QRNNs make use of quantile regression to learn to predict the quantiles of the
31 | distribution :math:`P(y|x)`. This works because the quantile :math:`y_\tau` for
32 | a given quantile fraction :math:`\tau \in [0, 1]` corresponds to a minimum of the
33 | expected value :math:`\mathbf{E}_y {\mathcal{L}_\tau(y, y_\tau)}` of the quantile
34 | loss function
35 |
36 | .. math::
37 |
38 | \mathcal{L}_\tau(y, y_\tau) =
39 | \begin{cases}
40 | \tau (y - y_\tau) & \text{if } y > y_\tau \\
41 | (1 - \tau) (y_\tau - y) & \text{otherwise}.
42 | \end{cases}
43 |
44 | A proof of this can be found on `wikipedia `_.
45 |
46 | Because of this property, training a neural network using the quantile loss
47 | function :math:`\mathcal{L}_\tau` will teach the network to predict the
48 | corresponding quantile :math:`y_\tau`. QRNNs extend this principle to a
49 | sequence of quantiles corresponding to an arbitrary selection of quantile
50 | fractions :math:`\tau_1, \tau_2, \ldots` which are optimized simultaneously.
51 |
52 | Defining a model
53 | ----------------
54 |
55 | To define a QRNN model you need to specify the quantiles to predict as well as the
56 | architecture of the underlying network to use. The code snippet below shows how
57 | to create a QRNN model to predict the first until 99th percentiles.
58 |
59 | As neural network model the QRNN uses fully-connected neural network with four hidden
60 | layers with 128 neurons each and ReLU activation functions. Since this is taken
61 | from the simple example notebook, the number of input features is set to 1.
62 |
63 | .. code ::
64 |
65 | from quantnn import QRNN
66 | quantiles = np.linspace(0.01, 0.99, 99)
67 |
68 | layers = 4
69 | neurons = 128
70 | activation = "relu"
71 | model = (layers, neurons, activation)
72 |
73 | qrnn = q.QRNN(quantiles, input_dimensions=1, model=model)
74 |
75 | QRNN provides a simplified interface to create QRNNs with simple fully-connected
76 | architectures. The QRNN class, however, doesn't restrict you to use a fully-connected
77 | network but supports any other suitable architecture such as, for example,
78 | fully-convolutional DenseNet-type architecture:
79 |
80 | .. code ::
81 |
82 | model = ... # Define model as Pytorch Module or Keras model object.
83 | qrnn = q.QRNN(quantiles, model)
84 |
85 | Training
86 | --------
87 |
88 | Evaluation
89 | ----------
90 |
91 | Saving and loading
92 | ------------------
93 |
94 |
--------------------------------------------------------------------------------
/docs/source/user_guide.rst:
--------------------------------------------------------------------------------
1 | User guide
2 | ==========
3 |
4 | This section describes the basic usage of the **quantnn** package.
5 |
6 | Overview
7 | --------
8 |
9 | The main functionality of quantnn is implemented by two model classes:
10 | :py:class:`quantnn.qrnn.QRNN` and :py:class:`quantnn.drnn.DRNN`. The
11 | :py:class:`~quantnn.qrnn.QRNN` class provides an implementation of quantile
12 | regression neural networks (QRNNs), whereas the :py:class:`~quantnn.drnn.DRNN`
13 | implements density regression neural networks (DRNNs).
14 |
15 | Basic workflow
16 | --------------
17 |
18 | The basic usage of both the :py:mod:`~quantnn.QRNN` and :py:mod:`~quantnn.DRNN`
19 | classes is similar and follows the generic machine learning workflow:
20 |
21 | 1. **Defining the model**: A model is defined by instantiating the
22 | corresponding model class. For this, you need to define the architecture of
23 | the underlying neural network and specify which quantiles to predict (QRNN)
24 | or the binning of the PDF (DRNN).
25 |
26 | 2. **Training the model**: The training phase is similar as for any other deep
27 | neural network. The neural network backend (PyTorch or Keras) takes care of
28 | the heavy lifting, you only have to choose the training parameters.
29 |
30 | 3. **Evaluating the model**: Finally, you will of course want to make
31 | predictions with you model. This is done using the model
32 | :py:meth:`~quantnn.QRNN.predict` method, which will produce a tensor of
33 | either quantiles (QRNN) or the binned PDF (DRNN) of the posterior
34 | distribution. To further process these prediction the
35 | :py:mod:`quantnn.quantiles` and :py:mod:`quantnn.density` module provide
36 | function that can be used to derive statistics of the probabilistic
37 | results.
38 |
39 | 4. **Loading and saving the model**: To reuse your train model you can save
40 | and load it using the corresponding class methods of the
41 | :py:class:`~quantnn.qrnn.QRNN` :py:class:`~quantnn.qrnn.DRNN` classes.
42 |
43 | .. note ::
44 |
45 | Care has been taken to design the interfaces of the :py:class:`~quantnn.qrnn.QRNN`
46 | and :py:class:`~quantnn.drnn.DRNN` classes as consistently as possible so that
47 | both classes can be used interchangeably to the largest extent possible.
48 |
49 | Content
50 | -------
51 |
52 | .. toctree::
53 | :maxdepth: 2
54 |
55 | qrnn
56 | drnn
57 |
58 |
59 |
60 |
61 |
62 |
--------------------------------------------------------------------------------
/quantnn.yml:
--------------------------------------------------------------------------------
1 | name: quantnn
2 | channels:
3 | - conda-forge
4 | - defaults
5 | dependencies:
6 | - python=3.9
7 | - torchvision
8 | - torchaudio
9 | - cudatoolkit=10.2
10 | - pytorch
11 | - tensorflow
12 | - tensorboard
13 | prefix: /home/simonpf/miniconda3/envs/quantnn
14 |
--------------------------------------------------------------------------------
/quantnn/__init__.py:
--------------------------------------------------------------------------------
1 | r"""
2 | =======
3 | quantnn
4 | =======
5 |
6 | The quantnn package provides functionality for probabilistic modeling and prediction
7 | using deep neural networks.
8 |
9 | The two main features of the quantnn package are implemented by the
10 | :py:class:`~quantnn.qrnn.QRNN` and :py:class:`~quantnn.qrnn.DRNN` classes, which implement
11 | quantile regression neural networks (QRNNs) and density regression neural networks (DRNNs),
12 | respectively.
13 |
14 | The modules :py:mod:`quantnn.quantiles` and :py:mod:`quantnn.density` provide generic
15 | (backend agnostic) functions to manipulate probabilistic predictions.
16 | """
17 | import logging as _logging
18 | import os
19 |
20 | from rich.logging import RichHandler
21 | from quantnn.neural_network_model import set_default_backend, get_default_backend
22 | from quantnn.qrnn import QRNN
23 | from quantnn.drnn import DRNN
24 | from quantnn.quantiles import (
25 | cdf,
26 | pdf,
27 | posterior_mean,
28 | probability_less_than,
29 | probability_larger_than,
30 | sample_posterior,
31 | sample_posterior_gaussian,
32 | quantile_loss,
33 | )
34 |
35 | _LOG_LEVEL = os.environ.get("QUANTNN_LOG_LEVEL", "WARNING").upper()
36 | _logging.basicConfig(
37 | level=_LOG_LEVEL, format="%(message)s", datefmt="[%X]", handlers=[RichHandler()]
38 | )
39 |
--------------------------------------------------------------------------------
/quantnn/a_priori.py:
--------------------------------------------------------------------------------
1 | """
2 | ================
3 | quantnn.a_priori
4 | ================
5 |
6 | Defines classes to represent a priori distributions.
7 | """
8 | from quantnn.generic import (
9 | get_array_module,
10 | expand_dims,
11 | as_type,
12 | concatenate,
13 | tensordot,
14 | exp,
15 | )
16 |
17 |
18 | class LookupTable:
19 | """
20 | Tensor formulation of a simple piece-wise linear lookup
21 | table.
22 |
23 | The a priori here is described as univariate function represented
24 | by its values at a sequence of nodes and the corresponding probabilities.
25 | """
26 |
27 | def __init__(self, x, y):
28 | """
29 | Create a new lookup table a priori instance.
30 |
31 | Args:
32 | x: The x values at which the value of the a priori is known.
33 | y: The corresponding non-normalized density.
34 | """
35 | self.x = x
36 | self.y = y
37 |
38 | def __call__(self, x, dist_axis=1):
39 | """
40 | Evaluate the a priori.
41 |
42 | Args:
43 | x: Tensor containing the values at which to evaluate the a priori.
44 | dist_axis: The axis along which the tensor x is sorted.
45 |
46 | Returns;
47 | Tensor with the same size as 'x' containing the values of the a priori
48 | at 'x' obtained by linear interpolation.
49 | """
50 | if len(x.shape) == 1:
51 | dist_axis = 0
52 | xp = get_array_module(x)
53 | n_dims = len(x.shape)
54 |
55 | n = x.shape[dist_axis]
56 | x_index = [slice(0, None)] * n_dims
57 | x_index[dist_axis] = 0
58 |
59 | selection_l = [slice(0, None)] * n_dims
60 | selection_l[dist_axis] = slice(0, -1)
61 | selection_l = tuple(selection_l)
62 | selection_r = [slice(0, None)] * n_dims
63 | selection_r[dist_axis] = slice(1, None)
64 | selection_r = tuple(selection_r)
65 |
66 | r_shape = [1] * n_dims
67 | r_shape[dist_axis] = -1
68 | r_x = self.x.reshape(r_shape)
69 | r_y = self.y.reshape(r_shape)
70 |
71 | r_x_l = r_x[selection_l]
72 | r_x_r = r_x[selection_r]
73 | r_y_l = r_y[selection_l]
74 | r_y_r = r_y[selection_r]
75 |
76 | rs = []
77 |
78 | for i in range(0, n):
79 | x_index[dist_axis] = slice(i, i + 1)
80 | index = tuple(x_index)
81 | x_i = x[index]
82 |
83 | mask = as_type(xp, (r_x_l < x_i) * (r_x_r >= x_i), x_i)
84 | r = r_y_l * (r_x_r - x_i) * mask
85 | r += r_y_r * (x_i - r_x_l) * mask
86 | r /= mask * (r_x_r - r_x_l) + (1.0 - mask)
87 | r = expand_dims(xp, r.sum(dist_axis), dist_axis)
88 | rs.append(r)
89 |
90 | r = concatenate(xp, rs, dist_axis)
91 | return r
92 |
93 |
94 | class Gaussian:
95 | def __init__(self, x_a, s, dist_axis=-1):
96 | self.x_a = x_a
97 | self.s = s
98 | self.dist_axis = -1
99 |
100 | def __call__(self, x, dist_axis=1):
101 | xp = get_array_module(x)
102 | n_dims = len(x.shape)
103 | shape = [1] * n_dims
104 | shape[self.dist_axis] = -1
105 | x_a = self.x_a.reshape(shape)
106 |
107 | dx = x - x_a
108 |
109 | sdx = tensordot(xp, dx, self.s, ((self.dist_axis,), (-1,)))
110 | l = -0.5 * (dx * sdx).sum(self.dist_axis)
111 |
112 | return exp(xp, l)
113 |
--------------------------------------------------------------------------------
/quantnn/backends/__init__.py:
--------------------------------------------------------------------------------
1 | from quantnn.backends.pytorch import PyTorch
2 | from quantnn.common import UnsupportedTensorType
3 |
4 | _TENSOR_BACKEND_CLASSES = [PyTorch]
5 |
6 | TENSOR_BACKENDS = [b for b in _TENSOR_BACKEND_CLASSES if b.available()]
7 |
8 |
9 | def get_tensor_backend(tensor):
10 | """
11 | Determine the tensor backend for a given tensor.
12 |
13 | Args:
14 | tensor: A tensor type of any of the supported backends.
15 |
16 | Return:
17 | The backend class which providing the interface to the tensor library
18 | corresponding to ``tensor``.
19 |
20 | Raises:
21 | :py:class:`~quantnn.common.UnsupportedTensorType` when the tensor type is not
22 | supported by quantnn.
23 | """
24 | for backend in TENSOR_BACKENDS:
25 | if backend.matches_tensor(tensor):
26 | return backend
27 | raise UnsupportedTensorType(
28 | f"The provided tensor of type {type(tensor)} is not supported by quantnn."
29 | )
30 |
--------------------------------------------------------------------------------
/quantnn/backends/pytorch.py:
--------------------------------------------------------------------------------
1 | from quantnn.backends.tensor import TensorBackend
2 |
3 |
4 | class PyTorch(TensorBackend):
5 | """
6 | TensorBackend implementation using torch tensors.
7 | """
8 |
9 | @classmethod
10 | def available(cls):
11 | try:
12 | import torch
13 | except ImportError:
14 | return False
15 | return True
16 |
17 | @classmethod
18 | def matches_tensor(cls, tensor):
19 | import torch
20 |
21 | return isinstance(tensor, torch.Tensor)
22 |
23 | @classmethod
24 | def from_numpy(cls, array, like=None):
25 | import torch
26 |
27 | tensor = torch.from_numpy(array)
28 | if like is not None:
29 | tensor = tensor.type(like.dtype).to(like.device)
30 | return tensor
31 |
32 | @classmethod
33 | def to_numpy(cls, array):
34 | import torch
35 | if array.dtype in [torch.bfloat16]:
36 | array = array.float()
37 | return array.cpu().detach().numpy()
38 |
39 | @classmethod
40 | def as_type(cls, tensor, like):
41 | return tensor.type_as(like)
42 |
43 | @classmethod
44 | def sample_uniform(cls, shape=None, like=None):
45 | import torch
46 |
47 | if shape is None and like is None:
48 | raise ValueError(
49 | "'sample_uniform' requires at least one of the arguments "
50 | "'shape' and 'like'. "
51 | )
52 | dtype = None
53 | device = None
54 | if like is not None:
55 | dtype = like.dtype
56 | device = like.device
57 | if shape is None:
58 | shape = like.shape
59 |
60 | return torch.rand(shape, dtype=dtype, device=device)
61 |
62 | @classmethod
63 | def sample_gaussian(cls, shape=None, like=None):
64 | import torch
65 |
66 | if shape is None and like is None:
67 | raise ValueError(
68 | "'sample_uniform' requires at least one of the arguments "
69 | "'shape' and 'like'. "
70 | )
71 | dtype = None
72 | device = None
73 | if like is not None:
74 | dtype = like.dtype
75 | device = like.device
76 | if shape is None:
77 | shape = like.shape
78 |
79 | return torch.normal(0, 1, shape, dtype=dtype, device=device)
80 |
81 | @classmethod
82 | def size(cls, tensor):
83 | return tensor.numel()
84 |
85 | @classmethod
86 | def concatenate(cls, tensors, dimension):
87 | import torch
88 |
89 | return torch.cat(tensors, dimension)
90 |
91 | @classmethod
92 | def expand_dims(cls, tensor, dimension_index):
93 | return tensor.unsqueeze(dimension_index)
94 |
95 | @classmethod
96 | def exp(cls, tensor):
97 | return tensor.exp()
98 |
99 | @classmethod
100 | def log(cls, tensor):
101 | return tensor.log()
102 |
103 | @classmethod
104 | def pad_zeros(cls, tensor, n, dimension_index):
105 | import torch
106 |
107 | n_dims = len(tensor.shape)
108 | dimension_index = dimension_index % n_dims
109 | pad = [0] * 2 * n_dims
110 | pad[2 * n_dims - 2 - 2 * dimension_index] = n
111 | pad[2 * n_dims - 1 - 2 * dimension_index] = n
112 | return torch.nn.functional.pad(tensor, pad, "constant", 0.0)
113 |
114 | @classmethod
115 | def pad_zeros_left(cls, tensor, n, dimension_index):
116 | import torch
117 |
118 | n_dims = len(tensor.shape)
119 | dimension_index = dimension_index % n_dims
120 | pad = [0] * 2 * n_dims
121 | pad[2 * n_dims - 2 - 2 * dimension_index] = n
122 | pad[2 * n_dims - 1 - 2 * dimension_index] = 0
123 | return torch.nn.functional.pad(tensor, pad, "constant", 0.0)
124 |
125 | @classmethod
126 | def arange(cls, start, end, step, like=None):
127 | import torch
128 |
129 | device = None
130 | dtype = torch.float32
131 | if like is not None:
132 | dtype = like.dtype
133 | device = like.device
134 | return torch.arange(start, end, step, dtype=dtype, device=device)
135 |
136 | @classmethod
137 | def reshape(cls, tensor, shape):
138 | return tensor.reshape(shape)
139 |
140 | @classmethod
141 | def trapz(cls, y, x, dimension):
142 | import torch
143 |
144 | return torch.trapz(y, x, dim=dimension)
145 |
146 | @classmethod
147 | def cumsum(cls, y, dimension):
148 | import torch
149 |
150 | return torch.cumsum(y, dimension)
151 |
152 | @classmethod
153 | def zeros(cls, shape=None, like=None):
154 | import torch
155 |
156 | if shape is None and like is None:
157 | raise ValueError(
158 | "'zeros' requires at least one of the arguments " "'shape' and 'like'. "
159 | )
160 | dtype = None
161 | device = None
162 | if like is not None:
163 | if shape is None:
164 | return torch.zeros_like(like)
165 | dtype = like.dtype
166 | device = like.device
167 |
168 | return torch.zeros(shape, device=device, dtype=dtype)
169 |
170 | @classmethod
171 | def ones(cls, shape=None, like=None):
172 | import torch
173 |
174 | if shape is None and like is None:
175 | raise ValueError(
176 | "'ones' requires at least one of the arguments " "'shape' and 'like'. "
177 | )
178 | dtype = None
179 | device = None
180 | if like is not None:
181 | if shape is None:
182 | return torch.ones_like(like)
183 | dtype = like.dtype
184 | device = like.dtype
185 |
186 | return torch.ones(shape, device=device, dtype=dtype)
187 |
188 | @classmethod
189 | def softmax(cls, x, axis=None):
190 | import torch
191 |
192 | return torch.nn.functional.softmax(x, dim=axis)
193 |
194 | @classmethod
195 | def where(cls, condition, x, y):
196 | import torch
197 |
198 | return torch.where(condition, x, y)
199 |
--------------------------------------------------------------------------------
/quantnn/common.py:
--------------------------------------------------------------------------------
1 | """
2 | quantnn.common
3 | ==============
4 |
5 | Implements common features used by the other submodules of the ``quantnn``
6 | package.
7 | """
8 |
9 |
10 | class QuantnnException(Exception):
11 | """ Base exception for exception from the quantnn package."""
12 |
13 |
14 | class UnknownArrayTypeException(QuantnnException):
15 | """Thrown when a function is called with an unsupported array type."""
16 |
17 |
18 | class UnsupportedTensorType(QuantnnException):
19 | """
20 | Thrown when quantnn is asked to handle a tensor type it doesn't support.
21 | """
22 |
23 |
24 | class UnknownModuleException(QuantnnException):
25 | """
26 | Thrown when an unsupported backend is passed to a generic array
27 | operation.
28 | """
29 |
30 |
31 | class UnsupportedBackendException(QuantnnException):
32 | """
33 | Thrown when quantnn is requested to load a backend that is not supported.
34 | """
35 |
36 |
37 | class MissingBackendException(QuantnnException):
38 | """
39 | Thrown when a requested backend could not be imported.
40 | """
41 |
42 |
43 | class InvalidDimensionException(QuantnnException):
44 | """Thrown when an input array doesn't match expected shape."""
45 |
46 |
47 | class ModelNotSupported(QuantnnException):
48 | """Thrown when a provided model isn't supported by the chosen backend."""
49 |
50 |
51 | class MissingAuthenticationInfo(QuantnnException):
52 | """Thrown when required authentication information is not available."""
53 |
54 |
55 | class DatasetError(QuantnnException):
56 | """
57 | Thrown when a given dataset object does not provide the expected interface.
58 | """
59 |
60 |
61 | class InvalidURL(QuantnnException):
62 | """
63 | Thrown when a provided file URL is invalid.
64 | """
65 |
66 |
67 | class InputDataError(QuantnnException):
68 | """
69 | Thrown when the training data does not match the expected format.
70 | """
71 |
72 | class ModelLoadError(QuantnnException):
73 | """
74 | Thrown when an error occurs while a model is loaded.
75 | """
76 |
--------------------------------------------------------------------------------
/quantnn/examples/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/simonpf/quantnn/c4d650cf0c6da5b4a704905b6c267d1ca996466f/quantnn/examples/__init__.py
--------------------------------------------------------------------------------
/quantnn/examples/gprof_conv.py:
--------------------------------------------------------------------------------
1 | """
2 | =======================
3 | quantnn.gprof_conv.py
4 | =======================
5 |
6 | This module provides download functions and dataset classes for the
7 | convolutional GPROF retrieval example.
8 | """
9 | from pathlib import Path
10 | from urllib.request import urlretrieve
11 |
12 |
13 | def download_data(destination="data"):
14 | """
15 | Downloads training and evaluation data for the CTP retrieval.
16 |
17 | Args:
18 | destination: Where to store the downloaded data.
19 | """
20 | datasets = [
21 | "gprof_conv.npz",
22 | ]
23 |
24 | Path(destination).mkdir(exist_ok=True)
25 | for file in datasets:
26 | file_path = Path("data") / file
27 | if not file_path.exists():
28 | print(f"Downloading file {file}.")
29 | url = f"http://spfrnd.de/data/gprof/{file}"
30 | urlretrieve(url, file_path)
31 |
--------------------------------------------------------------------------------
/quantnn/examples/modis_ctp.py:
--------------------------------------------------------------------------------
1 | """
2 | quantnn.examples.ctp
3 | ====================
4 |
5 | This module implements helper functions for the MODIS cloud-top pressure
6 | example.
7 | """
8 | from pathlib import Path
9 | from urllib.request import urlretrieve
10 |
11 | _DATA_PATH = "/home/simonpf/src/pansat/notebooks/products/data/"
12 | _MODIS_FILES = [
13 | _DATA_PATH + "MODIS/MYD021KM.A2016286.1750.061.2018062214718.hdf",
14 | _DATA_PATH + "MODIS/MYD03.A2016286.1750.061.2018062032022.hdf",
15 | ]
16 |
17 | modis_files = _MODIS_FILES
18 | pad_along_orbit = 200
19 | pad_across_orbit = 300
20 |
21 |
22 | def prepare_input_data(modis_files):
23 | """
24 | Prepares validation data for the MODIS CTP retrieval.
25 |
26 | Args:
27 | modis_file: List of filenames containing the MODIS input data to use
28 | as input.
29 |
30 | Returns:
31 | Dictionary containing the different files required to run the retrieval
32 | on the CALIOP data.
33 | """
34 | from datetime import datetime
35 | from satpy import Scene
36 | import scipy as sp
37 | from scipy.interpolate import RegularGridInterpolator
38 | from scipy.ndimage import maximum_filter, minimum_filter
39 | from scipy.signal import convolve
40 | import numpy as np
41 | import xarray
42 | from pansat.products.reanalysis.era5 import ERA5Product
43 | from pansat.products.satellite.calipso import clay01km
44 | from pykdtree.kdtree import KDTree
45 | from PIL import Image
46 |
47 | #
48 | # Prepare MODIS data.
49 | #
50 |
51 | scene = Scene(filenames=modis_files, reader="modis_l1b")
52 | scene.load(["true_color", "31", "32", "latitude", "longitude"], resolution=1000)
53 |
54 | scene["true_color_small"] = scene["true_color"][:, ::4, ::4]
55 | scene["bt_11_small"] = scene["31"][::4, ::4]
56 | scene["bt_12_small"] = scene["32"][::4, ::4]
57 | scene["latitude_small"] = scene["latitude"][::4, ::4]
58 | scene["longitude_small"] = scene["longitude"][::4, ::4]
59 | scene.save_dataset("true_color_small", "modis_true_color.png")
60 | image = Image.open("modis_true_color.png")
61 | modis_rgb = np.array(image)
62 | bt_11_rgb = scene["bt_11_small"].compute()
63 | bt_12_rgb = scene["bt_12_small"].compute()
64 | lats_rgb = scene["latitude_small"].compute()
65 | lons_rgb = scene["longitude_small"].compute()
66 |
67 | # MODIS data input features.
68 |
69 | lats_r = scene["latitude"].compute()
70 | lons_r = scene["longitude"].compute()
71 | bt_11 = scene["31"].compute()
72 | bt_12 = scene["32"].compute()
73 |
74 | def mean_filter(img):
75 | k = np.ones((5, 5)) / 25.0
76 | return convolve(img, k, mode="same")
77 |
78 | def std_filter(img):
79 | mu = mean_filter(img ** 2)
80 | mu2 = mean_filter(img) ** 2
81 | return np.sqrt(mu - mu2)
82 |
83 | bt_11_w = maximum_filter(bt_11, [5, 5])
84 | bt_11_c = minimum_filter(bt_11, [5, 5])
85 | bt_12_w = maximum_filter(bt_12, [5, 5])
86 | bt_12_c = minimum_filter(bt_12, [5, 5])
87 | bt_11_s = std_filter(bt_11)
88 | bt_1112_s = std_filter(bt_11 - bt_12)
89 |
90 | #
91 | # Calipso data
92 | #
93 |
94 | t_0 = datetime(2016, 10, 12, 17, 00)
95 | t_1 = datetime(2016, 10, 12, 17, 50)
96 | calipso_files = clay01km.download(t_0, t_1)
97 |
98 | lat_min = lats_r.data.min()
99 | lat_max = lats_r.data.max()
100 | lon_min = lons_r.data.min()
101 | lon_max = lons_r.data.max()
102 |
103 | dataset = clay01km.open(calipso_files[0])
104 | lats_c = dataset["latitude"].data
105 | lons_c = dataset["longitude"].data
106 | ctp_c = dataset["layer_top_pressure"]
107 | cth_c = dataset["layer_top_altitude"]
108 |
109 | indices = np.where(
110 | (lats_c > lat_min)
111 | * (lats_c <= lat_max)
112 | * (lons_c > lon_min)
113 | * (lons_c <= lon_max)
114 | )
115 | points = np.hstack([lats_r.data.reshape(-1, 1), lons_r.data.reshape(-1, 1)])
116 | kd_tree = KDTree(points)
117 |
118 | points_c = np.hstack([lats_c.reshape(-1, 1), lons_c.reshape(-1, 1)])
119 | d, indices = kd_tree.query(points_c)
120 | valid = d < 0.01
121 | indices = indices[valid]
122 | lats_c = lats_c[valid]
123 | lons_c = lons_c[valid]
124 | ctp_c = ctp_c[valid]
125 | cth_c = cth_c[valid]
126 | bt_11 = bt_11.data.ravel()[indices]
127 | bt_12 = bt_12.data.ravel()[indices]
128 | bt_11_w = bt_11_w.ravel()[indices]
129 | bt_12_w = bt_12_w.ravel()[indices]
130 | bt_11_c = bt_11_c.ravel()[indices]
131 | bt_12_c = bt_12_c.ravel()[indices]
132 | bt_11_s = bt_11_s.ravel()[indices]
133 | bt_1112_s = bt_1112_s.ravel()[indices]
134 | lats_r = lats_r.data.ravel()[indices]
135 | lons_r = lons_r.data.ravel()[indices]
136 |
137 | #
138 | # ERA 5 data.
139 | #
140 |
141 | t_0 = datetime(2016, 10, 12, 17, 45)
142 | t_1 = datetime(2016, 10, 12, 17, 50)
143 |
144 | surface_variables = ["surface_pressure", "2m_temperature", "tcwv"]
145 | domain = [lat_min - 2, lat_max + 2, lon_min - 2, lon_max + 2]
146 | surface_product = ERA5Product("hourly", "surface", surface_variables, domain)
147 | era_surface_files = surface_product.download(t_0, t_1)
148 |
149 | pressure_variables = ["temperature"]
150 | pressure_product = ERA5Product("hourly", "pressure", pressure_variables, domain)
151 | era_pressure_files = pressure_product.download(t_0, t_1)
152 |
153 | # interpolate pressure data.
154 |
155 | era5_data = xarray.open_dataset(era_pressure_files[0])
156 | lats_era = era5_data["latitude"][::-1]
157 | lons_era = era5_data["longitude"]
158 | p_era = era5_data["level"]
159 | p_inds = [np.where(p_era == p)[0] for p in [950, 850, 700, 500, 250]]
160 | pressures = []
161 | for ind in p_inds:
162 | p_interp = RegularGridInterpolator(
163 | [lats_era, lons_era], era5_data["t"].data[0, ind[0], ::-1, :]
164 | )
165 | pressures.append(p_interp((lats_r, lons_r)))
166 |
167 | era5_data = xarray.open_dataset(era_surface_files[0])
168 | lats_era = era5_data["latitude"][::-1]
169 | lons_era = era5_data["longitude"]
170 | t_interp = RegularGridInterpolator(
171 | [lats_era, lons_era], era5_data["t2m"].data[0, ::-1, :]
172 | )
173 | t_surf = t_interp((lats_r, lons_r))
174 | p_interp = RegularGridInterpolator(
175 | [lats_era, lons_era], era5_data["sp"].data[0, ::-1, :]
176 | )
177 | p_surf = p_interp((lats_r, lons_r))
178 | tcwv_interp = RegularGridInterpolator(
179 | [lats_era, lons_era], era5_data["tcwv"].data[0, ::-1, :]
180 | )
181 | tcwv = tcwv_interp((lats_r, lons_r))
182 |
183 | #
184 | # Assemble input data
185 | #
186 |
187 | x = np.zeros((lats_r.size, 16))
188 | x[:, 0] = p_surf
189 | x[:, 1] = t_surf
190 | for i, p in enumerate(pressures):
191 | x[:, 2 + i] = p
192 | x[:, 7] = tcwv
193 |
194 | x[:, 8] = bt_12
195 | x[:, 9] = bt_11 - bt_12
196 | x[:, 10] = bt_11_w - bt_12_w
197 | x[:, 11] = bt_11_c - bt_12_c
198 | x[:, 12] = bt_12_w - bt_12
199 | x[:, 13] = bt_12_c - bt_12
200 |
201 | x[:, 14] = bt_11_s
202 | x[:, 15] = bt_1112_s
203 |
204 | output_data = {
205 | "input_data": x,
206 | "ctp": ctp_c,
207 | "latitude": lats_r,
208 | "longitude": lons_r,
209 | "latitude_rgb": lats_rgb,
210 | "longitude_rgb": lons_rgb,
211 | "modis_rgb": modis_rgb,
212 | "bt_11_rgb": bt_11_rgb,
213 | "bt_12_rgb": bt_12_rgb,
214 | }
215 | return output_data
216 |
217 |
218 | def download_data(destination="data"):
219 | """
220 | Downloads training and evaluation data for the CTP retrieval.
221 |
222 | Args:
223 | destination: Where to store the downloaded data.
224 | """
225 | datasets = ["ctp_training_data.npz", "ctp_validation_data.npz"]
226 |
227 | Path(destination).mkdir(exist_ok=True)
228 |
229 | for file in datasets:
230 | file_path = Path("data") / file
231 | if not file_path.exists():
232 | print(f"Downloading file {file}.")
233 | url = f"http://spfrnd.de/data/ctp/{file}"
234 | urlretrieve(url, file_path)
235 |
--------------------------------------------------------------------------------
/quantnn/examples/simple.py:
--------------------------------------------------------------------------------
1 | """
2 | quantnn.examples.simple
3 | =======================
4 |
5 | This module provides a simple toy example to illustrate the basic
6 | functionality of quantile regression neural networks. The task is a simple
7 | 1-dimensional regression problem of a signal with heteroscedastic noise:
8 |
9 | .. math::
10 |
11 | y = \sin(x) + \cdot \cos(x) \cdot \mathcal{N}(0, 1)
12 |
13 | """
14 | import numpy as np
15 | import scipy as sp
16 | import matplotlib.pyplot as plt
17 | from matplotlib.cm import magma
18 | from matplotlib.colors import Normalize
19 |
20 |
21 | def create_training_data(n=1_000_000):
22 | """
23 | Create training data by randomly sampling the range :math:`[-\pi, \pi]`
24 | and computing y.
25 |
26 | Args:
27 | n(int): How many sample to compute.
28 |
29 | Return:
30 | Tuple ``(x, y)`` containing the input samples ``x`` given as 2D array
31 | with samples along first and input features along second dimension
32 | and the corresponding :math:`y` values in ``y``.
33 | """
34 | x = 2.0 * np.pi * np.random.random(size=n) - np.pi
35 | y = np.sin(x) + 1.0 * np.cos(x) * np.random.randn(n)
36 | return x, y
37 |
38 |
39 | def create_validation_data(x):
40 | """
41 | Creates validation data for the toy example.
42 |
43 | In contrast to the generation of the training data this function allows
44 | specifying the x value of the data which allows plotting the predicted
45 | result over an arbitrary domain.
46 |
47 | Args:
48 | x: Arbitrary array containing the x values for which to compute
49 | corresponding y values.
50 | Return:
51 | Numpy array containing the y values corresponding to the given x
52 | values.
53 | """
54 | y = np.sin(x) + 1.0 * np.cos(x) * np.random.randn(*x.shape)
55 | return y
56 |
57 |
58 | def plot_histogram(x, y):
59 | """
60 | Plot 2D histogram of data.
61 | """
62 | # Calculate histogram
63 | bins_x = np.linspace(-np.pi, np.pi, 201)
64 | bins_y = np.linspace(-4, 4, 201)
65 | x_img, y_img = np.meshgrid(bins_x, bins_y)
66 | img, _, _ = np.histogram2d(x, y, bins=(bins_x, bins_y), density=True)
67 |
68 | # Plot results
69 | f, ax = plt.subplots(1, 1, figsize=(10, 6))
70 | m = ax.pcolormesh(x_img, y_img, img.T, vmin=0, vmax=0.3, cmap="magma")
71 | x_sin = np.linspace(-np.pi, np.pi, 1001)
72 | y_sin = np.sin(x_sin)
73 | ax.plot(x_sin, y_sin, c="grey", label="$y=\sin(x)$", lw=3)
74 | ax.set_ylim([-2, 2])
75 | ax.set_xlabel("x")
76 | ax.set_ylabel("y")
77 | plt.colorbar(m, label="Normalized frequency")
78 | plt.legend()
79 |
80 |
81 | def plot_results(x_train, y_train, x_val, y_pred, y_mean, quantiles):
82 | """
83 | Plots the predicted quantiles against empirical quantiles.
84 | """
85 | # Calculate histogram and empirical quantiles.
86 | bins_x = np.linspace(-np.pi, np.pi, 201)
87 | bins_y = np.linspace(-4, 4, 201)
88 | x_img, y_img = np.meshgrid(bins_x, bins_y)
89 | img, _, _ = np.histogram2d(x_train, y_train, bins=(bins_x, bins_y), density=True)
90 | norm = np.trapz(img, x=0.5 * (bins_y[1:] + bins_y[:-1]), axis=1)
91 | img_normed = img / norm.reshape(-1, 1)
92 | img_cdf = sp.integrate.cumtrapz(
93 | img_normed, x=0.5 * (bins_y[1:] + bins_y[:-1]), axis=1
94 | )
95 |
96 | x_centers = 0.5 * (bins_x[1:] + bins_x[:-1])
97 | y_centers = 0.5 * (bins_y[2:] + bins_y[:-2])
98 |
99 | norm = Normalize(0, 1)
100 | plt.figure(figsize=(10, 6))
101 | img = plt.contourf(
102 | x_centers,
103 | y_centers,
104 | img_cdf.T,
105 | levels=quantiles,
106 | norm=norm,
107 | cmap="magma",
108 | )
109 | for i in range(0, 13, 1):
110 | l_q = plt.plot(x_val, y_pred[:, i], lw=2, ls="--", color="grey")
111 | handles = l_q
112 | handles += plt.plot(x_val, y_mean, c="k", ls="--", lw=2)
113 | labels = ["Predicted quantiles", "Predicted mean"]
114 | plt.legend(handles=handles, labels=labels)
115 |
116 | plt.xlim([-np.pi, np.pi])
117 | plt.ylim([-3, 3])
118 | plt.xlabel("x")
119 | plt.ylabel("y")
120 | plt.grid(False)
121 | plt.colorbar(img, label=r"Empirical quantiles")
122 |
--------------------------------------------------------------------------------
/quantnn/files/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | =============
3 | quantnn.files
4 | =============
5 |
6 | The :py:mod:`quantnn.files` module provides an abstraction layer to open files
7 | locally or via SFTP.
8 |
9 | Refer to the documentation of the :py:mod:`quantnn.files.sftp`` module for
10 | information on how to set the username and password for the SFTP connection.
11 |
12 | Example
13 | -------
14 |
15 | .. code-block::
16 |
17 | # To load a local file:
18 | with open_file("my_file.txt") as file:
19 | content = file.read()
20 |
21 | # To open a remote file via SFTP:
22 | with open_file("sftp://129.16.35.202/my_file.txt")
23 |
24 |
25 | """
26 | from contextlib import contextmanager
27 | from pathlib import PurePath, Path
28 | from urllib.parse import urlparse
29 |
30 | from quantnn.files import sftp
31 | from quantnn.common import InvalidURL
32 |
33 |
34 | @contextmanager
35 | def read_file(path, *args, **kwargs):
36 | """
37 | Generic function to open files. Currently supports opening files
38 | on the local system as well as on a remote machine via SFTP.
39 | """
40 | if isinstance(path, PurePath):
41 | yield open(path, *args, **kwargs)
42 | return
43 |
44 | url = urlparse(path)
45 | if url.netloc == "":
46 | yield open(path, *args, **kwargs)
47 | return
48 |
49 | if url.scheme == "sftp":
50 | host = url.netloc
51 | if host == "":
52 | raise InvalidURL(
53 | f"No host in SFTP URL."
54 | f"To load a file using SFTP, the URL must be of the form "
55 | f"'sftp:///'."
56 | )
57 |
58 | with sftp.download_file(host, url.path) as file:
59 | yield open(file, *args, **kwargs)
60 | return
61 |
62 | raise InvalidURL(f"The provided protocol '{url.scheme}' is not supported.")
63 |
64 |
65 | class _DummyCache:
66 | """
67 | A dummy cache for local files, which are not cached
68 | at all.
69 |
70 | """
71 |
72 | def __init__(self):
73 | """Create dummy cache."""
74 | pass
75 |
76 | def download_files(self, host, files, pool):
77 | pass
78 |
79 | def get(self, host, path):
80 | """Get file from cache."""
81 | return path
82 |
83 | def cleanup(self):
84 | pass
85 |
86 |
87 | class CachedDataFolder:
88 | """
89 | This class provides an interface to a generic folder containing
90 | dataset files. This folder can be accessed via the local file
91 | system or SFTP. If the folder is located on a remote SFTP server,
92 | the files are cached to avoid having to retransfer the files.
93 |
94 | Attributes:
95 | files: List of available files in the folder.
96 | host: The name of the host where the folder is located or
97 | "" if the folder is local.
98 | cache: Cache object used to cache data accesses.
99 | """
100 |
101 | def __init__(self, path, pattern="*", n_files=None):
102 | """
103 | Create a CachedDataFolder.
104 |
105 | Args:
106 | path: Path to the folder to load.
107 | pattern: Glob pattern to select the files.
108 | n_files: If given only the first ``n_files`` matching files will
109 | be loaded.
110 | """
111 | if isinstance(path, PurePath):
112 | files = path.iterdir()
113 | self.host = ""
114 | self.cache = _DummyCache()
115 | else:
116 | url = urlparse(path)
117 | if url.netloc == "":
118 | files = Path(path).iterdir()
119 | self.host = ""
120 | self.cache = _DummyCache()
121 | else:
122 | if url.scheme == "sftp":
123 | self.host = url.netloc
124 | if self.host == "":
125 | raise InvalidURL(
126 | f"No host in SFTP URL."
127 | f"To load a file using SFTP, the URL must be of the "
128 | f" form 'sftp:///'."
129 | )
130 | files = sftp.list_files(self.host, url.path)
131 | self.cache = sftp.SFTPCache()
132 | else:
133 | raise InvalidURL(
134 | f"The provided protocol '{url.scheme}' " f" is not supported."
135 | )
136 | self.files = list(filter(lambda f: f.match(pattern), files))
137 | if n_files:
138 | self.files = self.files[:n_files]
139 |
140 | def download(self, pool):
141 | """
142 | This method downloads all files in the folder to populate the
143 | cache.
144 |
145 | Args:
146 | The PoolExecutor to use for the conurrent download.
147 | """
148 | self.cache.download_files(self.host, self.files, pool)
149 |
150 | def get(self, path):
151 | """
152 | Retrieve file from folder.
153 |
154 | Args:
155 | path: The path of the file to retrieve.
156 |
157 | Return:
158 | If it is a local file, the filename of the file is returned.
159 | If the file is remote a cached temporary file object with
160 | the data is returned.
161 | """
162 | return self.cache.get(self.host, path)
163 |
164 | def open(self, path, *args, **kwargs):
165 | """
166 | Retrieve file from cache and open.
167 |
168 | Args:
169 | path: The path of the file to retrieve.
170 | *args: Passed to open call if file is local.
171 | **kwargs: Passed to open call if file is local.
172 |
173 | """
174 | file = self.get(path)
175 | if isinstance(file, PurePath):
176 | return open(file, *args, **kwargs)
177 | return file
178 |
--------------------------------------------------------------------------------
/quantnn/files/sftp.py:
--------------------------------------------------------------------------------
1 | """
2 | ==================
3 | quantnn.files.sftp
4 | ==================
5 |
6 | This module provides high-level functions to access file via
7 | SFTP.
8 | """
9 | from contextlib import contextmanager
10 | from concurrent.futures import Future
11 | from copy import copy
12 | import logging
13 | import os
14 | from pathlib import Path
15 | import tempfile
16 | import warnings
17 |
18 | with warnings.catch_warnings():
19 | warnings.filterwarnings("ignore")
20 | import paramiko
21 |
22 | from quantnn.common import MissingAuthenticationInfo
23 |
24 | _LOGGER = logging.getLogger("quantnn.files.sftp")
25 |
26 |
27 | def get_login_info():
28 | """
29 | Retrieves SFTP login info from the 'QUANTNN_SFTP_USER' AND
30 | 'QUANTNN_SFTP_PASSWORD' environment variables.
31 |
32 | Returns:
33 |
34 | Tuple ``(user_name, password)`` containing the SFTP user name and
35 | password retrieved from the environment variables.
36 |
37 | Raises:
38 |
39 | MissingAuthenticationInfo exception when required information is
40 | not provided as environment variable.
41 | """
42 | user_name = os.environ.get("QUANTNN_SFTP_USER")
43 | password = os.environ.get("QUANTNN_SFTP_PASSWORD")
44 | if user_name is None or password is None:
45 | raise MissingAuthenticationInfo(
46 | "SFTPStream dataset requires the 'QUANTNN_SFTP_USER' and "
47 | "'QUANTNN_SFTP_PASSWORD' to be set."
48 | )
49 | return user_name, password
50 |
51 |
52 | @contextmanager
53 | def get_sftp_connection(host):
54 | """
55 | Contextmanager to open and close an SFTP connection to
56 | a given host.
57 |
58 | Login credentials for the SFTP server are retrieved from the
59 | 'QUANTNN_SFTP_USER' and 'QUANTNN_SFTP_PASSWORD' environment variables.
60 |
61 | Args:
62 | host: IP address of the host.
63 |
64 | Returns:
65 | ``paramiko.SFTP`` object providing access to the open SFTP connection.
66 | """
67 | user_name, password = get_login_info()
68 | transport = None
69 | sftp = None
70 | try:
71 | transport = paramiko.Transport(host)
72 | transport.connect(username=user_name, password=password)
73 | sftp = paramiko.SFTPClient.from_transport(transport)
74 | yield sftp
75 | finally:
76 | if sftp:
77 | sftp.close()
78 | if transport:
79 | transport.close()
80 |
81 |
82 | def list_files(host, path):
83 | """
84 | List files in SFTP folder.
85 |
86 | Args:
87 | host: IP address of the host.
88 | path: The path for which to list the files
89 |
90 |
91 | Returns:
92 | List of absolute paths to the files discovered under
93 | the given path.
94 | """
95 | with get_sftp_connection(host) as sftp:
96 | files = sftp.listdir(path)
97 | return [Path(path) / f for f in files]
98 |
99 |
100 | @contextmanager
101 | def download_file(host, path):
102 | """
103 | Downloads file from host to a temporary directory and
104 | return the path of this file.
105 |
106 | Args:
107 | host: IP address of the host from which to download the file.
108 | path: Path of the file on the host.
109 |
110 | Return:
111 | pathlib.Path object pointing to the downloaded file.
112 | """
113 | path = Path(path)
114 | with tempfile.TemporaryDirectory() as directory:
115 | destination = Path(directory) / path.name
116 | with get_sftp_connection(host) as sftp:
117 | _LOGGER.info("Downloading file %s to %s.", path, destination)
118 | sftp.get(str(path), str(destination))
119 | yield destination
120 |
121 |
122 | def _download_file(host, path):
123 | """
124 | Wrapper file to concurrently download files via SFTP.
125 | """
126 | _, file = tempfile.mkstemp()
127 | with get_sftp_connection(host) as sftp:
128 | _LOGGER.info("Downloading file %s to %s.", path, file)
129 | try:
130 | sftp.get(str(path), file)
131 | except Exception:
132 | os.remove(file)
133 | return file
134 |
135 |
136 | class SFTPCache:
137 | """
138 | Cache for SFTP files.
139 |
140 | Attributes:
141 | files: Dictionary mapping tuples ``(host, path)`` to temporary
142 | file object.
143 | """
144 |
145 | def __init__(self):
146 | self._owner = True
147 | self.files = {}
148 |
149 | def __del__(self):
150 | """Make sure temporary data is cleaned up."""
151 | if self._owner:
152 | self._cleanup()
153 |
154 | def _cleanup(self):
155 | """ Clean up temporary files. """
156 | _LOGGER.info("Cleaning up SFTP cache.")
157 | for file in self.files.values():
158 | if isinstance(file, Future):
159 | file = file.result()
160 | if isinstance(file, str):
161 | os.remove(file)
162 | else:
163 | os.remove(file.name)
164 |
165 | def download_files(self, host, paths, pool):
166 | """
167 | Download list of file concurrently.
168 |
169 | Args:
170 | host: The SFTP host from which to download the data.
171 | paths: List of paths to download from the host.
172 | pool: A PoolExecutor to use for parallelizing the download.
173 | """
174 | tasks = {}
175 | for path in paths:
176 | if (host, path) not in self.files:
177 | task = pool.submit(_download_file, host, path)
178 | tasks[path] = task
179 | for path in paths:
180 | if (host, path) not in self.files:
181 | self.files[(host, path)] = tasks[path].result()
182 |
183 | def get(self, host, path):
184 | """
185 | Retrieve file from cache. If file is not found in cache it is
186 | retrieved via SFTP and stored in the cache.
187 |
188 | Args:
189 | host: The SFTP host from which to retrieve the file.
190 | path: The path of the file on the host.
191 |
192 | Return:
193 | The temporary file object containing the requested file.
194 | """
195 | key = (host, path)
196 | if key not in self.files:
197 | _, file = tempfile.mkstemp()
198 | with get_sftp_connection(host) as sftp:
199 | sftp.get(str(path), file)
200 | self.files[key] = file
201 |
202 | value = self.files[key]
203 | if isinstance(value, Future):
204 | return value.result()
205 | return self.files[key]
206 |
207 | def __getstate__(self):
208 | """Set owner attribute to false when object is pickled. """
209 | dct = copy(self.__dict__)
210 | dct["_owner"] = False
211 | return dct
212 |
--------------------------------------------------------------------------------
/quantnn/logging/multiprocessing.py:
--------------------------------------------------------------------------------
1 | """
2 | ===============================
3 | quantnn.logging.multiprocessing
4 | ===============================
5 |
6 | This module defines utility function to handle logging from sub-processes.
7 | """
8 | import logging
9 | from logging import handlers
10 | import multiprocessing
11 | import threading
12 |
13 | _LOG_QUEUE = None
14 |
15 |
16 | def get_log_queue():
17 | """
18 | Return global logging queue.
19 | """
20 | global _LOG_QUEUE
21 | if _LOG_QUEUE is None:
22 | _LOG_QUEUE = multiprocessing.Queue()
23 | return _LOG_QUEUE
24 |
25 |
26 | class SubprocessLogging(multiprocessing.Process):
27 | """
28 | Base class to handle logging from subprocesses. Subprocesses should
29 | inherit from this class in order to have their log messages displayed
30 | cleanly.
31 | """
32 | def __init__(self):
33 | super().__init__()
34 | self.log_queue = get_log_queue()
35 |
36 | def run(self):
37 | import quantnn.logging
38 | root = logging.getLogger()
39 | root.handlers = [handlers.QueueHandler(self.log_queue)]
40 |
41 |
42 | class LoggingThread(threading.Thread):
43 | """
44 | Thread to log messages from working processes.
45 | """
46 | def __init__(self,
47 | log_queue):
48 | """
49 | Args:
50 | log_queue: The queue to which other processes are logging their
51 | messages.
52 |
53 | """
54 | super().__init__()
55 | self.log_queue = log_queue
56 |
57 | def run(self):
58 | """
59 | Listen on queue and print incoming messages.
60 | """
61 | while True:
62 | record = self.log_queue.get()
63 | if record is None:
64 | break
65 | logger = logging.getLogger(record.name)
66 | logger.handle(record)
67 |
68 |
69 | _USERS = 0
70 | _LOGGING_THREAD = None
71 |
72 |
73 | def start_logging():
74 | """
75 | Starts the listener thread that prints messages from the subprocesses.
76 | """
77 | global _USERS, _LOGGING_THREAD
78 | _USERS += 1
79 | if _LOGGING_THREAD is None:
80 | _LOGGING_THREAD = LoggingThread(get_log_queue())
81 | _LOGGING_THREAD.start()
82 |
83 |
84 | def stop_logging():
85 | """
86 | Signal that no more messages are expected from subprocesses.
87 | """
88 | global _USERS, _LOGGING_THREAD
89 | _USERS -= 1
90 | if _USERS <= 0:
91 | if _LOGGING_THREAD is not None:
92 | get_log_queue().put(None)
93 | _LOGGING_THREAD.join()
94 | _LOGGING_THREAD = None
95 |
--------------------------------------------------------------------------------
/quantnn/models/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | quantnn.models.pytorch
3 | ======================
4 |
5 | This moudles provides Pytorch neural network models that can be used a backend
6 | for the :py:class:`quantnn.QRNN` class.
7 | """
8 | from quantnn.models.pytorch.common import (
9 | BatchedDataset,
10 | save_model,
11 | load_model,
12 | )
13 | from torch.nn import CrossEntropyLoss
14 | from quantnn.models.pytorch.fully_connected import FullyConnected
15 | from quantnn.models.pytorch.unet import UNet
16 |
--------------------------------------------------------------------------------
/quantnn/models/keras/padding.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow import keras
3 | from tensorflow.keras import layers
4 |
5 |
6 | class SymmetricPadding(layers.Layer):
7 | def __init__(self, amount):
8 | super().__init__()
9 | self.paddings = tf.constant(
10 | [[0, 0], [amount, amount], [amount, amount], [0, 0]]
11 | )
12 |
13 | def call(self, input):
14 | return tf.pad(input, self.paddings, "SYMMETRIC")
15 |
--------------------------------------------------------------------------------
/quantnn/models/keras/unet.py:
--------------------------------------------------------------------------------
1 | """
2 | =========================
3 | quantnn.models.keras.unet
4 | =========================
5 |
6 | This module provides an implementation of the UNet [unet]_
7 | architecture.
8 |
9 | .. [unet] O. Ronneberger, P. Fischer and T. Brox, "U-net: Convolutional networks
10 | for biomedical image segmentation", Proc. Int. Conf. Med. Image Comput.
11 | Comput.-Assist. Intervent. (MICCAI), pp. 234-241, 2015.
12 | """
13 | import tensorflow as tf
14 | from tensorflow import keras
15 | from tensorflow.keras import layers
16 | from tensorflow.keras import activations
17 | from tensorflow.keras import Input
18 |
19 | from quantnn.models.keras.padding import SymmetricPadding
20 |
21 |
22 | class ConvolutionBlock(layers.Layer):
23 | """
24 | A convolution block consisting of a pair of 3x3 convolutions followed by
25 | batch normalization and ReLU activations.
26 | """
27 |
28 | def __init__(self, channels_in, channels_out):
29 | """
30 | Create new convolution block.
31 |
32 | Args:
33 | channels_in: The number of input channels.
34 | channels_out: The number of output channels.
35 | """
36 | super().__init__()
37 | input_shape = (None, None, channels_in)
38 | self.block = keras.Sequential()
39 | self.block.add(SymmetricPadding(1))
40 | self.block.add(
41 | layers.Conv2D(channels_out, 3, padding="valid", input_shape=input_shape)
42 | )
43 | self.block.add(layers.BatchNormalization())
44 | self.block.add(layers.ReLU())
45 | self.block.add(SymmetricPadding(1))
46 | self.block.add(layers.Conv2D(channels_out, 3, padding="valid"))
47 | self.block.add(layers.BatchNormalization())
48 | self.block.add(layers.ReLU())
49 |
50 | def call(self, input):
51 | x = input
52 | return self.block(x)
53 |
54 |
55 | class DownsamplingBlock(keras.Sequential):
56 | """
57 | A downsampling block consisting of a max pooling layer and a
58 | convolution block.
59 | """
60 |
61 | def __init__(self, channels_in, channels_out):
62 | """
63 | Create new convolution block.
64 |
65 | Args:
66 | channels_in: The number of input channels.
67 | channels_out: The number of output channels.
68 | """
69 | super().__init__()
70 | input_shape = (None, None, channels_in)
71 | self.add(layers.MaxPooling2D(strides=(2, 2)))
72 | self.add(ConvolutionBlock(channels_in, channels_out))
73 |
74 |
75 | class UpsamplingBlock(layers.Layer):
76 | """
77 | An upsampling block which which uses bilinear interpolation
78 | to increase the input size. This is followed by a 1x1 convolution to
79 | reduce the number of channels, concatenation of the skip inputs
80 | from the corresponding downsampling layer and a convolution block.
81 |
82 | """
83 |
84 | def __init__(self, channels_in, channels_out):
85 | """
86 | Create new convolution block.
87 |
88 | Args:
89 | channels_in: The number of input channels.
90 | channels_out: The number of output channels.
91 | """
92 | super().__init__()
93 | self.upsample = layers.UpSampling2D(size=(2, 2), interpolation="bilinear")
94 | input_shape = (None, None, channels_in)
95 | self.reduce = layers.Conv2D(
96 | channels_in // 2, 1, padding="same", input_shape=input_shape
97 | )
98 | self.concat = layers.Concatenate()
99 | self.conv_block = ConvolutionBlock(channels_in, channels_out)
100 |
101 | def call(self, inputs):
102 | x, x_skip = inputs
103 | x_up = self.reduce(self.upsample(x))
104 | x = self.concat([x_up, x_skip])
105 | return self.conv_block(x)
106 |
107 |
108 | class UNet(keras.Model):
109 | """
110 | Keras implementation of the UNet architecture, an input block followed
111 | by 4 encoder blocks and 4 decoder blocks.
112 |
113 |
114 |
115 |
116 | """
117 |
118 | def __init__(self, n_inputs, n_outputs):
119 | super().__init__()
120 |
121 | self.in_block = keras.Sequential(
122 | [
123 | SymmetricPadding(1),
124 | layers.Conv2D(64, 3, padding="valid", input_shape=(None, None, 128)),
125 | layers.BatchNormalization(),
126 | layers.ReLU(),
127 | SymmetricPadding(1),
128 | layers.Conv2D(64, 3, padding="valid"),
129 | layers.BatchNormalization(),
130 | layers.ReLU(),
131 | ]
132 | )
133 |
134 | self.down_block_1 = DownsamplingBlock(64, 128)
135 | self.down_block_2 = DownsamplingBlock(128, 256)
136 | self.down_block_3 = DownsamplingBlock(256, 512)
137 | self.down_block_4 = DownsamplingBlock(512, 1024)
138 |
139 | self.up_block_1 = UpsamplingBlock(1024, 512)
140 | self.up_block_2 = UpsamplingBlock(512, 256)
141 | self.up_block_3 = UpsamplingBlock(256, 128)
142 | self.up_block_4 = UpsamplingBlock(128, 64)
143 |
144 | self.out_block = layers.Conv2D(
145 | n_outputs, 1, padding="same", input_shape=(None, None, 64)
146 | )
147 |
148 | def call(self, inputs):
149 |
150 | d_64 = self.in_block(inputs)
151 |
152 | d_128 = self.down_block_1(d_64)
153 | d_256 = self.down_block_2(d_128)
154 | d_512 = self.down_block_3(d_256)
155 | d_1024 = self.down_block_4(d_512)
156 |
157 | u_512 = self.up_block_1([d_1024, d_512])
158 | u_256 = self.up_block_2([u_512, d_256])
159 | u_128 = self.up_block_3([u_256, d_128])
160 | u_64 = self.up_block_4([u_128, d_64])
161 |
162 | return self.out_block(u_64)
163 |
--------------------------------------------------------------------------------
/quantnn/models/pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | qrnn.models.pytorch
3 | ===================
4 |
5 | This model provides Pytorch neural network models that can be used a backend
6 | models for the :py:class:`quantnn.QRNN` class.
7 | """
8 | from quantnn.models.pytorch.common import (
9 | CrossEntropyLoss,
10 | QuantileLoss,
11 | MSELoss,
12 | BatchedDataset,
13 | save_model,
14 | load_model,
15 | )
16 | from quantnn.models.pytorch.common import PytorchModel as Model
17 | from quantnn.models.pytorch.fully_connected import FullyConnected
18 | from quantnn.models.pytorch.unet import UNet
19 |
--------------------------------------------------------------------------------
/quantnn/models/pytorch/base.py:
--------------------------------------------------------------------------------
1 | """
2 | quantnn.models.pytorch.base
3 | ===========================
4 |
5 | Helper classes for pytorch models.
6 | """
7 |
8 |
9 | class ParamCount:
10 | """
11 | Mixin class for pytorch modules that add a 'n_params' attribute
12 | to the class.
13 | """
14 |
15 | @property
16 | def n_params(self):
17 | return sum(p.numel() for p in self.parameters() if p.requires_grad)
18 |
--------------------------------------------------------------------------------
/quantnn/models/pytorch/downsampling.py:
--------------------------------------------------------------------------------
1 | """
2 | quantnn.models.pytorch.downsampling
3 | ===================================
4 |
5 | This module provides factory classes for downsampling modules.
6 | """
7 | from typing import Union, Tuple, Callable
8 |
9 | import torch
10 | from torch import nn
11 |
12 | from quantnn.models.pytorch.normalization import LayerNormFirst
13 |
14 |
15 | class ConvNextDownsamplerFactory:
16 | """
17 | Downsampler factory consisting of layer normalization followed
18 | by strided convolution.
19 | """
20 |
21 | def __call__(
22 | self, channels_in: int, channels_out: int, f_dwn: Union[int, Tuple[int, int]]
23 | ):
24 | """
25 | Args:
26 | channels_in: The number of channels in the input.
27 | channels_out: The number of channels in the output.
28 | f_dwn: The downsampling factor. Can be a tuple if different
29 | downsampling factors should be applied along height and
30 | width of the image.
31 | """
32 | return nn.Sequential(
33 | LayerNormFirst(channels_in),
34 | nn.Conv2d(channels_in, channels_out, kernel_size=f_dwn, stride=f_dwn),
35 | )
36 |
37 |
38 | class PatchMergingBlock(nn.Module):
39 | """
40 | Implements patch merging as employed in the Swin architecture.
41 | """
42 |
43 | def __init__(
44 | self,
45 | channels_in: int,
46 | channels_out: int,
47 | f_dwn: Union[int, Tuple[int, int]],
48 | norm_factory: Callable[[int], nn.Module] = None,
49 | ):
50 | super().__init__()
51 | if isinstance(f_dwn, tuple):
52 | if f_dwn[0] != f_dwn[1]:
53 | raise ValueError(
54 | "Downsampling by patch merging only supports homogeneous "
55 | "downsampling factors."
56 | )
57 | f_dwn = f_dwn[0]
58 | self.f_dwn = f_dwn
59 | channels_d = channels_in * f_dwn**2
60 |
61 | if norm_factory is None:
62 | norm_factory = LayerNormFirst
63 |
64 | self.norm = norm_factory(channels_d)
65 |
66 | if channels_d != channels_out:
67 | self.proj = nn.Conv2d(channels_d, channels_out, kernel_size=1)
68 | else:
69 | self.proj = nn.Identity()
70 |
71 | def forward(self, x: torch.Tensor) -> torch.Tensor:
72 | """
73 | Downsample tensor.
74 | """
75 | x_d = nn.functional.pixel_unshuffle(x, self.f_dwn)
76 | return self.proj(self.norm(x_d))
77 |
78 |
79 | class PatchMergingFactory:
80 | """
81 | A factory class to create patch merging downsamplers as employed
82 | by the swin architecture.
83 | """
84 |
85 | def __call__(
86 | self, channels_in: int, channels_out: int, f_dwn: Union[int, Tuple[int, int]]
87 | ):
88 | return PatchMergingBlock(channels_in, channels_out, f_dwn)
89 |
--------------------------------------------------------------------------------
/quantnn/models/pytorch/factories.py:
--------------------------------------------------------------------------------
1 | """
2 | quantnn.models.pytorch.factories
3 | ================================
4 |
5 | Factory objects to build neural networks.
6 | """
7 | from torch import nn
8 |
9 | class MaxPooling:
10 | """
11 | Factory for creating max-pooling downsampling layers.
12 | """
13 | def __init__(
14 | self,
15 | kernel_size=(2, 2),
16 | project_first=True
17 | ):
18 | """
19 | Args:
20 | project_first: If the channel during the downsampling
21 | are increased and project_first is True, channel
22 | numbers are increased prior to the downsampling
23 | operation.
24 |
25 | """
26 | self.kernel_size = kernel_size
27 | self.project_first = project_first
28 |
29 | def __call__(
30 | self,
31 | channels_in,
32 | channels_out,
33 | f_down
34 | ):
35 | """
36 | Args:
37 | channels_in: The number of input channels.
38 | channels_out: The number of input channels.
39 | f_down: The number of output channels.
40 | """
41 | pool = nn.MaxPool2d(
42 | kernel_size=self.kernel_size,
43 | stride=f_down
44 | )
45 |
46 | if channels_in == channels_out:
47 | return pool
48 |
49 | project = nn.Conv2d(channels_in, channels_out, kernel_size=1)
50 |
51 | if self.project_first:
52 | return nn.Sequential(
53 | project,
54 | pool
55 | )
56 | return nn.Sequential(
57 | pool,
58 | project
59 | )
60 |
--------------------------------------------------------------------------------
/quantnn/models/pytorch/generative.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn.functional import relu
4 |
5 | from quantnn.models.pytorch.xception import SymmetricPadding
6 |
7 |
8 | class DownBlock(nn.Module):
9 | def __init__(self, channels_in, channels_out):
10 | super().__init__()
11 | self.norm = nn.BatchNorm2d(channels_in)
12 | self.body = nn.Sequential(
13 | nn.ReLU(),
14 | SymmetricPadding(1),
15 | nn.Conv2d(channels_in, channels_in, 3),
16 | nn.BatchNorm2d(channels_in),
17 | SymmetricPadding(1),
18 | nn.Conv2d(channels_in, channels_out, 3),
19 | nn.ReLU(),
20 | SymmetricPadding(1),
21 | nn.MaxPool2d(3, stride=2)
22 | )
23 | self.skip = nn.Sequential(
24 | SymmetricPadding(1),
25 | nn.Conv2d(channels_in, channels_out, 1),
26 | nn.MaxPool2d(3, stride=2)
27 | )
28 |
29 | def forward(self, x):
30 | x_n = self.norm(x)
31 | y = self.body(x_n)
32 | return y + self.skip(x_n)
33 |
34 |
35 | class DownBlockSpectral(nn.Module):
36 | def __init__(self, channels_in, channels_out, relu_in=True):
37 | super().__init__()
38 | if relu_in:
39 | self.body = nn.Sequential(
40 | nn.ReLU(),
41 | SymmetricPadding([1, 2, 1, 2]),
42 | nn.utils.spectral_norm(nn.Conv2d(channels_in, channels_in, 4)),
43 | #nn.BatchNorm2d(channels_out),
44 | nn.ReLU(),
45 | SymmetricPadding([1, 2, 1, 2]),
46 | nn.utils.spectral_norm(nn.Conv2d(channels_in, channels_out, 4)),
47 | nn.AvgPool2d(2)
48 | )
49 | else:
50 | self.body = nn.Sequential(
51 | SymmetricPadding([1, 2, 1, 2]),
52 | nn.utils.spectral_norm(nn.Conv2d(channels_in, channels_in, 4)),
53 | #nn.BatchNorm2d(channels_out),
54 | nn.ReLU(),
55 | SymmetricPadding([1, 2, 1, 2]),
56 | nn.utils.spectral_norm(nn.Conv2d(channels_in, channels_out, 4)),
57 | nn.AvgPool2d(2)
58 | )
59 |
60 | self.skip = nn.Sequential(
61 | SymmetricPadding([0, 1, 0, 1]),
62 | nn.utils.spectral_norm(nn.Conv2d(channels_in, channels_out, 2)),
63 | nn.AvgPool2d(2)
64 | )
65 |
66 | def forward(self, x):
67 | y = self.body(x)
68 | return y + self.skip(x)
69 |
70 |
71 | class GeneratorBlock(nn.Module):
72 | def __init__(self, channels_in, channels_out):
73 | super().__init__()
74 | self.body = nn.Sequential(
75 | nn.BatchNorm2d(channels_in),
76 | nn.LeakyReLU(0.1),
77 | SymmetricPadding([1, 2, 1, 2]),
78 | nn.utils.spectral_norm(nn.Conv2d(channels_in, channels_out, 4)),
79 | nn.BatchNorm2d(channels_out),
80 | nn.LeakyReLU(0.1),
81 | SymmetricPadding([1, 2, 1, 2]),
82 | nn.utils.spectral_norm(nn.Conv2d(channels_out, channels_out, 4)),
83 | )
84 | self.skip = nn.Sequential(
85 | nn.Conv2d(channels_in, channels_out, 1),
86 | )
87 |
88 | def forward(self, x):
89 | return self.body(x) + self.skip(x)
90 |
91 | class GeneratorBlockUp(nn.Module):
92 | def __init__(self, channels_in, channels_out):
93 | super().__init__()
94 | self.body = nn.Sequential(
95 | nn.BatchNorm2d(channels_in),
96 | nn.LeakyReLU(0.1),
97 | nn.Upsample(scale_factor=2, mode="nearest"),
98 | SymmetricPadding([1, 2, 1, 2]),
99 | nn.utils.spectral_norm(nn.Conv2d(channels_in, channels_out, 4)),
100 | nn.BatchNorm2d(channels_out),
101 | nn.LeakyReLU(0.1),
102 | SymmetricPadding([1, 2, 1, 2]),
103 | nn.utils.spectral_norm(nn.Conv2d(channels_out, channels_out, 4)),
104 | )
105 | self.skip = nn.Sequential(
106 | nn.Upsample(scale_factor=2, mode="nearest"),
107 | nn.Conv2d(channels_in, channels_out, 1),
108 | )
109 |
110 | def forward(self, x):
111 | return self.body(x) + self.skip(x)
112 |
113 |
114 | class ConditionalGenerator(nn.Module):
115 | def __init__(self, channels_in, channels, n_lat=8, output_range=1):
116 | super().__init__()
117 |
118 | self.n_lat = n_lat
119 | self.output_range = output_range
120 |
121 | self.conditioner = nn.ModuleList([
122 | DownBlock(channels_in, channels),
123 | DownBlock(channels, channels),
124 | DownBlock(channels, channels),
125 | DownBlock(channels, channels)
126 | ])
127 |
128 | self.generator = nn.ModuleList([
129 | nn.Sequential(
130 | GeneratorBlock(channels + n_lat, channels),
131 | GeneratorBlockUp(channels, channels)
132 | ),
133 | nn.Sequential(
134 | GeneratorBlock(channels, channels),
135 | GeneratorBlockUp(channels, channels)
136 | ),
137 | nn.Sequential(
138 | GeneratorBlock(channels, channels),
139 | GeneratorBlockUp(channels, channels)
140 | ),
141 | nn.Sequential(
142 | GeneratorBlock(channels, channels),
143 | GeneratorBlockUp(channels, channels)
144 | ),
145 | ])
146 |
147 | self.output = nn.Sequential(
148 | nn.BatchNorm2d(channels),
149 | nn.LeakyReLU(0.1),
150 | nn.utils.spectral_norm(nn.Conv2d(channels, 1, 1))
151 | )
152 |
153 | def forward(self, x, z=None):
154 |
155 | y = x
156 | for layer in self.conditioner:
157 | y = layer(y)
158 |
159 | n = x.shape[0]
160 |
161 | if z is None:
162 | z = torch.normal(0, 1, size=(n, self.n_lat, 4, 4))
163 | else:
164 | z = z.reshape((n, self.n_lat, 4, 4))
165 |
166 | input = None
167 | for layer in self.generator:
168 | if input is None:
169 | input = layer(torch.cat([y, z], 1))
170 | else:
171 | input = layer(input)
172 |
173 | return self.output_range * torch.tanh(self.output(input))
174 |
175 |
176 | class Discriminator(nn.Module):
177 | def __init__(self, channels_in, channels=64):
178 | super().__init__()
179 |
180 | self.body = nn.Sequential(
181 | DownBlockSpectral(channels_in + 1, channels, relu_in=False),
182 | DownBlockSpectral(channels, channels),
183 | DownBlockSpectral(channels, channels),
184 | DownBlockSpectral(channels, channels),
185 | DownBlockSpectral(channels, channels),
186 | nn.ReLU(0.1)
187 | )
188 |
189 | self.classifier = nn.Sequential(
190 | nn.utils.spectral_norm(nn.Linear(channels, 1)),
191 | )
192 |
193 | def forward(self, x, y):
194 | if y.ndim < x.ndim:
195 | y = y.unsqueeze(1)
196 | x_in = torch.cat([x, y], 1)
197 | y = self.body(x_in).sum((-2, -1))
198 | return self.classifier(y)
199 |
200 |
201 | def hinge_loss(d_real, d_fake):
202 | return relu(1.0 - d_real).mean() + relu(1.0 + d_fake).mean()
203 |
204 | def training_step(x, y, g, d, opt_g, opt_d):
205 | y = y.unsqueeze(1)
206 |
207 | batch_size = x.size(0)
208 | z = torch.randn(batch_size, g.n_lat, 4, 4).cuda()
209 |
210 | # Train discriminator
211 | opt_d.zero_grad()
212 | opt_g.zero_grad()
213 | d_real = d(x, y)
214 | d_fake = d(x, g(x, z))
215 | d_loss = hinge_loss(d_real, d_fake)
216 | d_loss.backward()
217 | opt_d.step()
218 |
219 | # Train generator
220 | opt_d.zero_grad()
221 | opt_g.zero_grad()
222 | z = torch.randn(batch_size, g.n_lat, 4, 4).cuda()
223 | fake = g(x, z)
224 | g_loss = -d(x, fake).mean()
225 | mse_loss = ((fake - y) ** 2).mean()
226 | g_loss = g_loss #+ mse_loss
227 |
228 | g_loss.backward()
229 | opt_g.step()
230 |
231 | return g_loss.item(), d_loss.item(), mse_loss.item()
232 |
--------------------------------------------------------------------------------
/quantnn/models/pytorch/logging.py:
--------------------------------------------------------------------------------
1 | """
2 | ==============================
3 | quantnn.models.pytorch.logging
4 | ==============================
5 |
6 | This module contains training logger that are specific for the
7 | PyTorch backend.
8 | """
9 | import torch
10 | from torch.utils.tensorboard.writer import SummaryWriter
11 | from torch.utils.tensorboard.summary import hparams
12 | import xarray as xr
13 |
14 |
15 | from quantnn.logging import TrainingLogger
16 |
17 |
18 | class SummaryWriter(SummaryWriter):
19 | """
20 | Specialization of torch original SummaryWriter that overrides 'add_params'
21 | to avoid creating a new directory to store the hyperparameters.
22 |
23 | Source: https://github.com/pytorch/pytorch/issues/32651
24 | """
25 |
26 | def add_hparams(self, hparam_dict, metric_dict, epoch):
27 | torch._C._log_api_usage_once("tensorboard.logging.add_hparams")
28 | if type(hparam_dict) is not dict or type(metric_dict) is not dict:
29 | raise TypeError("hparam_dict and metric_dict should be dictionary.")
30 | exp, ssi, sei = hparams(hparam_dict, metric_dict)
31 |
32 | logdir = self._get_file_writer().get_logdir()
33 |
34 | with SummaryWriter(log_dir=logdir) as w_hp:
35 | w_hp.file_writer.add_summary(exp)
36 | w_hp.file_writer.add_summary(ssi)
37 | w_hp.file_writer.add_summary(sei)
38 | for k, v in metric_dict.items():
39 | w_hp.add_scalar(k, v, epoch)
40 |
41 |
42 | class TensorBoardLogger(TrainingLogger):
43 | """
44 | Logger that also logs information to tensor board.
45 | """
46 |
47 | def __init__(
48 | self, n_epochs, log_rate=100, log_directory=None, epoch_begin_callback=None
49 | ):
50 | """
51 | Create a new logger instance.
52 |
53 | Args:
54 | n_epochs: The number of epochs for which the training will last.
55 | log_rate: The message rate for output to standard out.
56 | log_directory: The directory to use for tensorboard output.
57 | epoch_begin_callback: Callback function the will be called with
58 | arguments ``writer, model``, where ``writer`` is the current
59 | ``torch.utils.tensorboard.writer.SummaryWriter`` object used
60 | used to write output and ``model`` is the model that is being
61 | in its current state.
62 | """
63 | super().__init__(n_epochs, log_rate)
64 | self.writer = SummaryWriter(log_dir=log_directory)
65 | self.epoch_begin_callback = epoch_begin_callback
66 | self.attributes = None
67 |
68 | def set_attributes(self, attributes):
69 | """
70 | Stores attributes that describe the training in the logger.
71 | These will be stored in the logger history.
72 |
73 | Args:
74 | Dictionary of attributes to store in the history of the
75 | logger.
76 | """
77 | super().set_attributes(attributes)
78 |
79 | def epoch_begin(self, model):
80 | """
81 | Called at the beginning of each epoch.
82 |
83 | Args:
84 | The model that is trained in its current state.
85 | """
86 | TrainingLogger.epoch_begin(self, model)
87 | if self.epoch_begin_callback:
88 | self.epoch_begin_callback(self.writer, model, self.i_epoch)
89 |
90 | def training_step(self, loss, n_samples, of=None, losses=None):
91 | """
92 | Log processing of a training batch. This method should be called
93 | after each batch is processed so that the logger can keep track
94 | of training progress.
95 |
96 | Args:
97 | loss: The loss of the current batch.
98 | n_samples: The number of samples in the batch.
99 | of: If available the number of batches in the epoch.
100 | """
101 | super().training_step(loss, n_samples, of=of, losses=losses)
102 |
103 | def validation_step(self, loss, n_samples, of=None, losses=None):
104 | """
105 | Log processing of a validation batch.
106 |
107 | Args:
108 | i: The index of the current batch.
109 | loss: The loss of the current batch.
110 | n_samples: The number of samples in the batch.
111 | of: If available the number of batches in the epoch.
112 | """
113 | super().validation_step(loss, n_samples, of=of, losses=losses)
114 |
115 | def epoch(self, learning_rate=None, metrics=None):
116 | """
117 | Log processing of epoch.
118 |
119 | Args:
120 | learning_rate: If available the learning rate of the optimizer.
121 | """
122 | TrainingLogger.epoch(self, learning_rate, metrics=metrics)
123 | self.writer.add_scalar("Learning rate", learning_rate, self.i_epoch)
124 |
125 | for name, v in self.history.variables.items():
126 | if name == "epochs":
127 | continue
128 | if len(v.dims) == 1:
129 | value = v.data[-1]
130 | self.writer.add_scalar(name, value, self.i_epoch)
131 |
132 | if metrics is not None:
133 | for m in metrics:
134 | if hasattr(m, "get_figures"):
135 | figures = m.get_figures()
136 | if isinstance(figures, dict):
137 | for target in figures.keys():
138 | f = figures[target]
139 | self.writer.add_figure(
140 | f"{m.name} ({target})", f, self.i_epoch
141 | )
142 | else:
143 | self.writer.add_figure(f"{m.name}", figures, self.i_epoch)
144 |
145 | def training_end(self):
146 | """
147 | Called to signal the end of the training to the logger.
148 | """
149 | if self.attributes is not None:
150 | if self.i_epoch >= self.n_epochs:
151 | metrics = {}
152 | for name, v in self.history.variables.items():
153 | if name == "epochs":
154 | continue
155 | if len(v.dims) == 1:
156 | metrics[name + "_final"] = v.data[-1]
157 | self.writer.add_hparams(self.attributes, {}, self.i_epoch)
158 | self.writer.flush()
159 |
160 | def __del__(self):
161 | # Extract metric values for hyper parameters.
162 | super().__del__()
163 |
--------------------------------------------------------------------------------
/quantnn/models/pytorch/normalization.py:
--------------------------------------------------------------------------------
1 | """
2 | quantnn.models.pytorch.normalization
3 | ====================================
4 |
5 | This module implements normalization layers.
6 | """
7 | import torch
8 | from torch import nn
9 |
10 |
11 | class LayerNormFirst(nn.Module):
12 | """
13 | Layer norm that normalizes the first dimension
14 | """
15 |
16 | def __init__(
17 | self,
18 | n_channels,
19 | eps=1e-6
20 | ):
21 | """
22 | Args:
23 | n_channels: The number of channels in the input.
24 | eps: Epsilon added to variance to avoid numerical issues.
25 | """
26 | super().__init__()
27 | self.scaling = nn.Parameter(torch.ones(n_channels), requires_grad=True)
28 | self.bias = nn.Parameter(torch.zeros(n_channels), requires_grad=True)
29 | self.eps = eps
30 |
31 | def forward(self, x):
32 | """
33 | Apply normalization to x.
34 | """
35 | mu = x.mean(1, keepdim=True)
36 | var = (x - mu).pow(2).mean(1, keepdim=True)
37 | x_n = (x - mu) / torch.sqrt(var + self.eps)
38 | x = self.scaling[..., None, None] * x_n + self.bias[..., None, None]
39 | return x
40 |
41 |
42 | class GRN(nn.Module):
43 | """
44 | Global Response Normalization (GRN) as proposed in https://openaccess.thecvf.com/content/CVPR2023/html/Woo_ConvNeXt_V2_Co-Designing_and_Scaling_ConvNets_With_Masked_Autoencoders_CVPR_2023_paper.html
45 | """
46 | def __init__(self, n_channels, eps=1e-6):
47 | """
48 | n_channels: Number of channels over which to normalize.
49 | eps: Epsilon added to mean to avoid numerical issues.
50 | """
51 | super().__init__()
52 | self.scaling = nn.Parameter(torch.zeros(1, n_channels, 1, 1), requires_grad=True)
53 | self.bias = nn.Parameter(torch.zeros(1, n_channels, 1, 1), requires_grad=True)
54 |
55 | def forward(self, x):
56 | """
57 | Apply normalization to x.
58 | """
59 | x_l2 = torch.norm(x, p=2, dim=(-2, -1), keepdim=True)
60 | rel_imp = x_l2 / (x_l2.mean(dim=1, keepdim=True) + 1e-6)
61 | return self.scaling * (x * rel_imp) + self.bias + x
62 |
--------------------------------------------------------------------------------
/quantnn/models/pytorch/resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | quantnn.models.pytorch.resnet
3 | =============================
4 |
5 | This module provides an implementation of a fully-covolutional
6 | [ResNet]_-based decoder-encoder architecture.
7 |
8 |
9 | .. [ResNet] Deep Residual Learning for Image Recognition
10 | """
11 | import torch
12 | import torch.nn as nn
13 |
14 |
15 | def _conv2(channels_in, channels_out, kernel_size):
16 | """
17 | Convolution with reflective padding to keep image size constant.
18 | """
19 | return nn.Conv2d(
20 | channels_in,
21 | channels_out,
22 | kernel_size=kernel_size,
23 | padding=kernel_size // 2,
24 | padding_mode="reflect",
25 | )
26 |
27 |
28 | def _conv2_down(channels_in, channels_out, kernel_size):
29 | """
30 | Convolution combined with downsampling and reflective padding to
31 | decrease input size by a factor of 2.
32 | """
33 | return nn.Conv2d(
34 | channels_in,
35 | channels_out,
36 | kernel_size=kernel_size,
37 | padding=kernel_size // 2,
38 | stride=2,
39 | padding_mode="reflect",
40 | )
41 |
42 |
43 | class ResidualBlock(nn.Module):
44 | """
45 | A residual block consists of either two or three convolution operations
46 | followed by batch norm and relu activation together with an identity
47 | mapping connecting the input and the activation feeding into the
48 | last ReLU layer.
49 | """
50 |
51 | def __init__(self, channels_in, channels_out, bottleneck=None, downsample=False):
52 | """
53 | Create new convolution block.
54 |
55 | Args:
56 | channels_in: The number of input channels.
57 | channels_out: The number of output channels.
58 | bottleneck: Whether to apply a bottle neck to reduce the
59 | number of parameters.
60 | downsample: If true the image dimensions are reduced by
61 | a factor of two.
62 | """
63 | super().__init__()
64 | self.downsample = downsample
65 | if bottleneck is None:
66 | self.block = nn.Sequential(
67 | _conv2(channels_in, channels_out, 3),
68 | nn.BatchNorm2d(channels_out),
69 | nn.ReLU(inplace=True),
70 | (
71 | _conv2_down(channels_out, channels_out, 3)
72 | if downsample
73 | else _conv2(channels_out, channels_out, 3)
74 | ),
75 | nn.BatchNorm2d(channels_out),
76 | )
77 | else:
78 | self.block = nn.Sequential(
79 | _conv2(channels_in, bottleneck, 1),
80 | nn.BatchNorm2d(bottleneck),
81 | nn.ReLU(inplace=True),
82 | _conv2(bottleneck, bottleneck, 3),
83 | nn.BatchNorm2d(bottleneck),
84 | nn.ReLU(inplace=True),
85 | (
86 | _conv2_down(bottleneck, channels_out, 1)
87 | if downsample
88 | else _conv2(bottleneck, channels_out, 1)
89 | ),
90 | nn.BatchNorm2d(channels_out),
91 | )
92 | self.activation = nn.ReLU(inplace=True)
93 |
94 | self.projection = None
95 | if channels_in != channels_out:
96 | if downsample:
97 | self.projection = _conv2_down(channels_in, channels_out, 1)
98 | else:
99 | self.projection = _conv2(channels_in, channels_out, 1)
100 |
101 | def forward(self, x):
102 | """
103 | Propagate input through block.
104 | """
105 | y = self.block(x)
106 | if self.projection:
107 | x = self.projection(x)
108 | return self.activation(y + x)
109 |
110 |
111 | class DownSamplingBlock(nn.Module):
112 | """
113 | UNet downsampling block consisting of strided convolution followed
114 | by given number of residual blocks.
115 | """
116 |
117 | def __init__(self, channels_in, channels_out, n_blocks, bottleneck=None):
118 | super().__init__()
119 | modules = [
120 | ResidualBlock(
121 | channels_in, channels_out, bottleneck=bottleneck, downsample=True
122 | )
123 | ] * (n_blocks - 1)
124 | modules += [
125 | ResidualBlock(channels_out, channels_out, bottleneck=bottleneck)
126 | ] * (n_blocks - 1)
127 | self.block = nn.Sequential(*modules)
128 |
129 | def forward(self, x):
130 | """Propagate input through block."""
131 | return self.block(x)
132 |
133 |
134 | class UpSamplingBlock(nn.Module):
135 | """
136 | ResNet upsampling block consisting of linear interpolation
137 | followed by given number of residual blocks.
138 | """
139 |
140 | def __init__(
141 | self, channels_in, channels_skip, channels_out, n_blocks, bottleneck=None
142 | ):
143 | super().__init__()
144 | self.upscaling = nn.Upsample(
145 | scale_factor=2, mode="bilinear", align_corners=True
146 | )
147 |
148 | modules = [
149 | ResidualBlock(
150 | channels_in + channels_skip, channels_out, bottleneck=bottleneck
151 | )
152 | ]
153 | modules += [
154 | ResidualBlock(channels_out, channels_out, bottleneck=bottleneck)
155 | ] * (n_blocks - 1)
156 | self.block = nn.Sequential(*modules)
157 |
158 | def forward(self, x, x_skip):
159 | """Propagate input through block."""
160 | x = self.upscaling(x)
161 | x = torch.cat([x, x_skip], dim=1)
162 | return self.block(x)
163 |
164 |
165 | class ResNet(nn.Module):
166 | """
167 | Decoder-encoder network using residual blocks.
168 |
169 | The ResNet class implements a fully-convolutional decoder-encoder
170 | network for point-to-point regression tasks.
171 |
172 | The network consists of 5 downsampling blocks followed by the
173 | same number of upsampling blocks. The first downsampling block
174 | consists of a 7x7 convolution with stride two followed by batch
175 | norm and ReLU activation. All following block are residual blocks
176 | with bottlenecks.
177 | """
178 |
179 | def __init__(self, n_inputs, n_outputs, blocks=2):
180 |
181 | super().__init__()
182 | self.n_inputs = n_inputs
183 | self.n_outputs = n_outputs
184 |
185 | self.in_block = nn.Sequential(
186 | _conv2_down(n_inputs, 128, 7), nn.BatchNorm2d(128), nn.ReLU()
187 | )
188 |
189 | if type(blocks) is int:
190 | blocks = 4 * [blocks]
191 |
192 | self.down_block_1 = DownSamplingBlock(128, 256, blocks[0], bottleneck=256)
193 | self.down_block_2 = DownSamplingBlock(256, 512, blocks[1], bottleneck=256)
194 | self.down_block_3 = DownSamplingBlock(512, 1024, blocks[2], bottleneck=256)
195 | self.down_block_4 = DownSamplingBlock(1024, 2048, blocks[3], bottleneck=512)
196 |
197 | self.up_block_1 = UpSamplingBlock(2048, 1024, 1024, blocks[3], bottleneck=256)
198 | self.up_block_2 = UpSamplingBlock(1024, 512, 512, blocks[2], bottleneck=256)
199 | self.up_block_3 = UpSamplingBlock(512, 256, 256, blocks[1], bottleneck=256)
200 | self.up_block_4 = UpSamplingBlock(256, 128, n_outputs, blocks[0])
201 | self.up_block_5 = UpSamplingBlock(n_outputs, n_inputs, n_outputs, blocks[0])
202 |
203 | self.out_block = nn.Sequential(
204 | _conv2(n_outputs, n_outputs, 1),
205 | nn.BatchNorm2d(n_outputs),
206 | nn.ReLU(),
207 | _conv2(n_outputs, n_outputs, 1),
208 | )
209 |
210 | def forward(self, x):
211 | """Propagate input through resnet."""
212 |
213 | d_0 = self.in_block(x)
214 | d_1 = self.down_block_1(d_0)
215 | d_2 = self.down_block_2(d_1)
216 | d_3 = self.down_block_3(d_2)
217 | d_4 = self.down_block_4(d_3)
218 |
219 | u_4 = self.up_block_1(d_4, d_3)
220 | u_3 = self.up_block_2(u_4, d_2)
221 | u_2 = self.up_block_3(u_3, d_1)
222 | u_1 = self.up_block_4(u_2, d_0)
223 | u_0 = self.up_block_5(u_1, x)
224 |
225 | return self.out_block(u_0)
226 |
--------------------------------------------------------------------------------
/quantnn/models/pytorch/stages.py:
--------------------------------------------------------------------------------
1 | """
2 | quantnn.models.pytorch.stages
3 | =============================
4 |
5 | Implements generic stages used in back-bone models.
6 | """
7 | import numpy as np
8 | import torch
9 | from torch import nn
10 |
11 | from quantnn.models.pytorch import blocks
12 |
13 |
14 | class AggregationTreeNode(nn.Module):
15 | """
16 | Represents a node in an aggregation tree.
17 |
18 | Aggregation in all right child nodes are combined so that
19 | aggregation is in principle only performed in the nodes at
20 | the lowest level.
21 | """
22 |
23 | def __init__(
24 | self,
25 | channels_in,
26 | channels_out,
27 | level,
28 | block_factory,
29 | aggregator_factory,
30 | channels_agg=0,
31 | downsample=1,
32 | block_args=None,
33 | block_kwargs=None,
34 | ):
35 | super().__init__()
36 |
37 | if block_args is None:
38 | block_args = []
39 | if block_kwargs is None:
40 | block_kwargs = {}
41 |
42 | if channels_agg == 0:
43 | channels_agg = 2 * channels_out
44 |
45 | self.aggregator = None
46 | self.level = level
47 | if level <= 0:
48 | self.left = block_factory(
49 | channels_in,
50 | channels_out,
51 | *block_args,
52 | downsample=downsample,
53 | **block_kwargs,
54 | )
55 | self.right = None
56 | elif level == 1:
57 | self.aggregator = aggregator_factory(channels_agg, channels_out)
58 | self.left = block_factory(
59 | channels_in,
60 | channels_out,
61 | *block_args,
62 | downsample=downsample,
63 | **block_kwargs,
64 | )
65 | self.right = block_factory(
66 | channels_out,
67 | channels_out,
68 | *block_args,
69 | downsample=1,
70 | **block_kwargs,
71 | )
72 | else:
73 | self.aggregator = None
74 | self.left = AggregationTreeNode(
75 | channels_in,
76 | channels_out,
77 | level - 1,
78 | block_factory,
79 | aggregator_factory,
80 | downsample=downsample,
81 | block_args=block_args,
82 | block_kwargs=block_kwargs,
83 | )
84 | self.right = AggregationTreeNode(
85 | channels_out,
86 | channels_out,
87 | level - 1,
88 | block_factory,
89 | aggregator_factory,
90 | channels_agg=channels_agg + channels_out,
91 | downsample=1,
92 | block_args=block_args,
93 | block_kwargs=block_kwargs,
94 | )
95 |
96 | def forward(self, x, pass_through=None):
97 | """
98 | Forward input through tree and aggregate results from child nodes.
99 | """
100 | if self.level <= 0:
101 | return self.left(x)
102 |
103 | if pass_through is None:
104 | pass_through = []
105 |
106 | y_1 = self.left(x)
107 | if self.aggregator is not None:
108 | y_2 = self.right(y_1)
109 | pass_through = pass_through + [y_1, y_2]
110 | return self.aggregator(torch.cat(pass_through, 1))
111 |
112 | return self.right(y_1, pass_through + [y_1])
113 |
114 |
115 | class AggregationTreeRoot(AggregationTreeNode):
116 | """
117 | Root of an aggregation tree.
118 | """
119 |
120 | def __init__(
121 | self,
122 | channels_in,
123 | channels_out,
124 | tree_height,
125 | block_factory,
126 | aggregator_factory,
127 | downsample=1,
128 | block_args=None,
129 | block_kwargs=None,
130 | ):
131 | channels_agg = 2 * channels_out + channels_in
132 | super().__init__(
133 | channels_in,
134 | channels_out,
135 | tree_height,
136 | block_factory,
137 | aggregator_factory,
138 | channels_agg=channels_agg,
139 | downsample=1,
140 | )
141 |
142 | self.downsampler = None
143 | if downsample > 1:
144 | self.downsampler = nn.MaxPool2d(kernel_size=downsample)
145 |
146 | def forward(self, x):
147 | """
148 | Forward input through tree and aggregate results from child nodes.
149 | """
150 | if self.level == 0:
151 | return self.left(x)
152 |
153 | if self.downsampler is not None:
154 | x = self.downsampler(x)
155 |
156 | pass_through = []
157 |
158 | y_1 = self.left(x)
159 | if self.aggregator is not None:
160 | y_2 = self.right(y_1)
161 | pass_through = pass_through + [y_1, y_2, x]
162 | return self.aggregator(torch.cat(pass_through, 1))
163 |
164 | return self.right(y_1, pass_through + [y_1, x])
165 |
166 |
167 | class AggregationTreeFactory:
168 | """
169 | An aggregation tree implementing hierarchical aggregation of blocks in a stage.
170 | """
171 |
172 | def __init__(self, aggregator_factory=None):
173 | if aggregator_factory is None:
174 | aggregator_factory = blocks.ConvBlockFactory(
175 | norm_factory=nn.BatchNorm2d, activation_factory=nn.ReLU
176 | )
177 | self.aggregator_factory = aggregator_factory
178 |
179 | def __call__(
180 | self,
181 | channels_in,
182 | channels_out,
183 | n_blocks,
184 | block_factory,
185 | downsample=1,
186 | block_args=None,
187 | block_kwargs=None,
188 | ):
189 | n_levels = np.log2(n_blocks)
190 | return AggregationTreeRoot(
191 | channels_in,
192 | channels_out,
193 | n_levels,
194 | block_factory,
195 | self.aggregator_factory,
196 | downsample=downsample,
197 | block_args=None,
198 | block_kwargs=None,
199 | )
200 |
--------------------------------------------------------------------------------
/quantnn/models/pytorch/unet.py:
--------------------------------------------------------------------------------
1 | """
2 | quantnn.models.pytorch.unet
3 | ===========================
4 |
5 | This module provides an implementation of the UNet [unet]_
6 | architecture.
7 |
8 | .. [unet] O. Ronneberger, P. Fischer and T. Brox, "U-net: Convolutional networks
9 | for biomedical image segmentation", Proc. Int. Conf. Med. Image Comput.
10 | Comput.-Assist. Intervent. (MICCAI), pp. 234-241, 2015.
11 | """
12 | import torch
13 | import torch.nn as nn
14 |
15 |
16 | def _conv2(channels_in, channels_out, kernel_size):
17 | """2D convolution with padding to keep image size constant. """
18 | return nn.Conv2d(
19 | channels_in,
20 | channels_out,
21 | kernel_size=kernel_size,
22 | padding=kernel_size // 2,
23 | padding_mode="reflect",
24 | )
25 |
26 |
27 | class ConvolutionBlock(nn.Module):
28 | """
29 | A convolution block consisting of a pair of 2x2
30 | convolutions followed by a batch normalization layer and
31 | ReLU activaitons.
32 | """
33 |
34 | def __init__(self, channels_in, channels_out):
35 | """
36 | Create new convolution block.
37 |
38 | Args:
39 | channels_in: The number of input channels.
40 | channels_out: The number of output channels.
41 | """
42 | super().__init__()
43 | self.block = nn.Sequential(
44 | _conv2(channels_in, channels_out, 3),
45 | nn.BatchNorm2d(channels_out),
46 | nn.ReLU(inplace=True),
47 | _conv2(channels_out, channels_out, 3),
48 | nn.BatchNorm2d(channels_out),
49 | nn.ReLU(inplace=True),
50 | )
51 |
52 | def forward(self, x):
53 | """Propagate input through layer."""
54 | return self.block(x)
55 |
56 |
57 | class DownsamplingBlock(nn.Module):
58 | """
59 | UNet downsampling block consisting of 2x2 max-pooling followed
60 | by a convolution block.
61 | """
62 |
63 | def __init__(self, channels_in, channels_out):
64 | super().__init__()
65 | self.block = nn.Sequential(
66 | nn.MaxPool2d(2), ConvolutionBlock(channels_in, channels_out)
67 | )
68 |
69 | def forward(self, x):
70 | """Propagate input through block."""
71 | return self.block(x)
72 |
73 |
74 | class UpsamplingBlock(nn.Module):
75 | """
76 | UNet upsampling block consisting bilinear interpolation followed
77 | by a 1x1 convolution to decrease the channel dimensions and followed
78 | by a UNet convolution block.
79 | """
80 |
81 | def __init__(self, channels_in, channels_out):
82 | super().__init__()
83 | self.upscaling = nn.Upsample(
84 | scale_factor=2, mode="bilinear", align_corners=True
85 | )
86 | self.reduce = _conv2(channels_in, channels_in // 2, 3)
87 | self.conv = ConvolutionBlock(channels_in, channels_out)
88 |
89 | def forward(self, x, x_skip):
90 | """Propagate input through block."""
91 | x = self.reduce(self.upscaling(x))
92 | x = torch.cat([x, x_skip], dim=1)
93 | return self.conv(x)
94 |
95 |
96 | class UNet(nn.Module):
97 | """
98 | PyTorch implementation of UNet, consisting of 4 downsampling
99 | blocks followed by 4 upsampling blocks and skip connection between
100 | down- and upsampling blocks of matching output and input size.
101 |
102 | The core of each down and upsampling block consists of two
103 | 2D 3x3 convolution followed by batch norm and ReLU activation
104 | functions.
105 | """
106 |
107 | def __init__(self, n_inputs, n_outputs):
108 | """
109 | Args:
110 | n_input: The number of input channels.
111 | n_outputs: The number of output channels.
112 | """
113 | super().__init__()
114 | self.n_inputs = n_inputs
115 | self.n_outputs = n_outputs
116 |
117 | self.in_block = ConvolutionBlock(n_inputs, 64)
118 |
119 | self.down_block_1 = DownsamplingBlock(64, 128)
120 | self.down_block_2 = DownsamplingBlock(128, 256)
121 | self.down_block_3 = DownsamplingBlock(256, 512)
122 | self.down_block_4 = DownsamplingBlock(512, 1024)
123 |
124 | self.up_block_1 = UpsamplingBlock(1024, 512)
125 | self.up_block_2 = UpsamplingBlock(512, 256)
126 | self.up_block_3 = UpsamplingBlock(256, 128)
127 | self.up_block_4 = UpsamplingBlock(128, n_outputs)
128 |
129 | self.out_block = _conv2(n_outputs, n_outputs, 1)
130 |
131 | def forward(self, x):
132 | """Propagate input through network."""
133 |
134 | d_64 = self.in_block(x)
135 | d_128 = self.down_block_1(d_64)
136 | d_256 = self.down_block_2(d_128)
137 | d_512 = self.down_block_3(d_256)
138 | d_1024 = self.down_block_4(d_512)
139 |
140 | u_512 = self.up_block_1(d_1024, d_512)
141 | u_256 = self.up_block_2(u_512, d_256)
142 | u_128 = self.up_block_3(u_256, d_128)
143 | u_out = self.up_block_4(u_128, d_64)
144 |
145 | return self.out_block(u_out)
146 |
--------------------------------------------------------------------------------
/quantnn/models/pytorch/upsampling.py:
--------------------------------------------------------------------------------
1 | """
2 | quantnn.models.pytorch.upsampling
3 | =================================
4 |
5 | Upsampling factories.
6 | """
7 | from torch import nn
8 |
9 |
10 | class BilinearFactory:
11 | """
12 | A factory for producing bilinear upsampling layers.
13 | """
14 |
15 | def __call__(self, channels_in, channels_out, factor):
16 | if channels_in is not None and channels_out is not None:
17 | if channels_in != channels_out:
18 | return nn.Sequential(
19 | nn.Conv2d(channels_in, channels_out, kernel_size=1),
20 | nn.UpsamplingBilinear2d(scale_factor=factor),
21 | )
22 | return nn.UpsamplingBilinear2d(scale_factor=factor)
23 |
24 |
25 | class UpsampleFactory:
26 | """
27 | A factory for torch's generic upsampling layer.
28 | """
29 |
30 | def __init__(self, **kwargs):
31 | self.kwargs = kwargs
32 |
33 | def __call__(self, channels_in, channels_out, factor):
34 | if channels_in is not None and channels_out is not None:
35 | if channels_in != channels_out:
36 | return nn.Sequential(
37 | nn.Conv2d(channels_in, channels_out, kernel_size=1),
38 | nn.Upsample(scale_factor=factor, **self.kwargs),
39 | )
40 | return nn.Upsample(scale_factor=factor, **self.kwargs)
41 |
42 |
43 | class UpConvolutionFactory:
44 | """
45 | Factory for up-convolution upsampling as used in the original
46 | U-Net architecture.
47 | """
48 |
49 | def __init__(self, **kwargs):
50 | self.kwargs = kwargs
51 |
52 | def __call__(self, channels_in, channels_out, factor):
53 | return nn.Sequential(
54 | nn.Upsample(scale_factor=factor, **self.kwargs),
55 | nn.Conv2d(channels_in, channels_out, kernel_size=factor, padding="same"),
56 | )
57 |
--------------------------------------------------------------------------------
/quantnn/models/pytorch/xception.py:
--------------------------------------------------------------------------------
1 | """
2 | ===============================
3 | quantnn.models.pytorch.xception
4 | ===============================
5 |
6 | PyTorch neural network models based on the Xception architecture.
7 | """
8 | import torch
9 | from torch import nn
10 |
11 |
12 | class SymmetricPadding(nn.Module):
13 | """
14 | Network module implementing symmetric padding.
15 |
16 | This is just a wrapper around torch's ``nn.functional.pad`` with mode
17 | set to 'replicate'.
18 | """
19 |
20 | def __init__(self, amount):
21 | super().__init__()
22 | if isinstance(amount, int):
23 | self.amount = [amount] * 4
24 | else:
25 | self.amount = amount
26 |
27 | def forward(self, x):
28 | return nn.functional.pad(x, self.amount, "replicate")
29 |
30 |
31 | class SeparableConv3x3(nn.Sequential):
32 | """
33 | Depth-wise separable convolution using with kernel size 3x3.
34 | """
35 |
36 | def __init__(self, channels_in, channels_out):
37 | super().__init__(
38 | nn.Conv2d(
39 | channels_in,
40 | channels_in,
41 | kernel_size=3,
42 | groups=channels_in,
43 | padding=1,
44 | padding_mode="replicate",
45 | ),
46 | nn.Conv2d(channels_in, channels_out, kernel_size=1),
47 | )
48 |
49 |
50 | class XceptionBlock(nn.Module):
51 | """
52 | Xception block consisting of two depth-wise separable convolutions
53 | each folowed by batch-norm and GELU activations.
54 | """
55 |
56 | def __init__(self, channels_in, channels_out, downsample=False):
57 | """
58 | Args:
59 | channels_in: The number of incoming channels.
60 | channels_out: The number of outgoing channels.
61 | downsample: Whether or not to insert 3x3 max pooling block
62 | after the first convolution.
63 | """
64 | super().__init__()
65 | if downsample:
66 | self.block_1 = nn.Sequential(
67 | SeparableConv3x3(channels_in, channels_out),
68 | nn.GroupNorm(1, channels_out),
69 | SymmetricPadding(1),
70 | nn.MaxPool2d(kernel_size=3, stride=2),
71 | nn.GELU(),
72 | )
73 | else:
74 | self.block_1 = nn.Sequential(
75 | SeparableConv3x3(channels_in, channels_out),
76 | nn.GroupNorm(1, channels_out),
77 | nn.GELU(),
78 | )
79 |
80 | self.block_2 = nn.Sequential(
81 | SeparableConv3x3(channels_out, channels_out),
82 | nn.GroupNorm(1, channels_out),
83 | nn.GELU(),
84 | )
85 |
86 | if channels_in != channels_out or downsample:
87 | if downsample:
88 | self.projection = nn.Conv2d(channels_in, channels_out, 1, stride=2)
89 | else:
90 | self.projection = nn.Conv2d(channels_in, channels_out, 1)
91 | else:
92 | self.projection = None
93 |
94 | def forward(self, x):
95 | """
96 | Propagate input through block.
97 | """
98 | if self.projection is None:
99 | x_proj = x
100 | else:
101 | x_proj = self.projection(x)
102 | y = self.block_2(self.block_1(x))
103 | return torch.add(x_proj,y )
104 |
105 |
106 | class DownsamplingBlock(nn.Sequential):
107 | """
108 | Xception downsampling block.
109 | """
110 |
111 | def __init__(self, n_channels, n_blocks):
112 | blocks = [XceptionBlock(n_channels, n_channels, downsample=True)]
113 | for i in range(n_blocks):
114 | blocks.append(XceptionBlock(n_channels, n_channels))
115 | super().__init__(*blocks)
116 |
117 |
118 | class UpsamplingBlock(nn.Module):
119 | """
120 | Xception upsampling block.
121 | """
122 |
123 | def __init__(self, n_channels, skip_connections=True):
124 | """
125 | Args:
126 | n_channels: The number of incoming and outgoing channels.
127 | """
128 | super().__init__()
129 | self.upsample = nn.Upsample(mode="bilinear",
130 | scale_factor=2,
131 | align_corners=False)
132 | n_channels_in = n_channels * 2 if skip_connections else n_channels
133 | self.block = nn.Sequential(
134 | SeparableConv3x3(n_channels_in, n_channels),
135 | nn.GroupNorm(1, n_channels),
136 | nn.GELU(),
137 | )
138 | self.projection = nn.Conv2d(n_channels_in, n_channels, 1,)
139 |
140 |
141 | def forward(self, x, x_skip=None):
142 | """
143 | Propagate input through block.
144 | """
145 | x_up = self.upsample(x)
146 | if x_skip is not None:
147 | x_merged = torch.cat([x_up, x_skip], 1)
148 | else:
149 | x_merged = x_up
150 | return torch.add(self.block(x_merged), self.projection(x_merged))
151 |
152 |
153 | class XceptionFpn(nn.Module):
154 | """
155 | Feature pyramid network (FPN) with 5 stages based on xception
156 | architecture.
157 | """
158 |
159 | def __init__(self, n_inputs, n_outputs, n_features=128, blocks=2):
160 | """
161 | Args:
162 | n_inputs: Number of input channels.
163 | n_outputs: The number of output channels,
164 | n_features: The number of features in the xception blocks.
165 | blocks: The number of blocks per stage
166 | """
167 | super().__init__()
168 |
169 | if isinstance(blocks, int):
170 | blocks = [blocks] * 5
171 |
172 | self.in_block = nn.Conv2d(n_inputs, n_features, 1)
173 |
174 | self.down_block_2 = DownsamplingBlock(n_features, blocks[0])
175 | self.down_block_4 = DownsamplingBlock(n_features, blocks[1])
176 | self.down_block_8 = DownsamplingBlock(n_features, blocks[2])
177 | self.down_block_16 = DownsamplingBlock(n_features, blocks[3])
178 | self.down_block_32 = DownsamplingBlock(n_features, blocks[4])
179 |
180 | self.up_block_16 = UpsamplingBlock(n_features)
181 | self.up_block_8 = UpsamplingBlock(n_features)
182 | self.up_block_4 = UpsamplingBlock(n_features)
183 | self.up_block_2 = UpsamplingBlock(n_features)
184 | self.up_block = UpsamplingBlock(n_features)
185 |
186 | self.head = nn.Sequential(
187 | nn.Conv2d(2 * n_features, n_features, 1),
188 | nn.GroupNorm(1, n_features),
189 | nn.GELU(),
190 | nn.Conv2d(n_features, n_features, 1),
191 | nn.GroupNorm(1, n_features),
192 | nn.GELU(),
193 | nn.Conv2d(n_features, n_outputs, 1),
194 | )
195 |
196 | def forward(self, x):
197 | """
198 | Propagate input through block.
199 | """
200 | x_in = self.in_block(x)
201 | x_2 = self.down_block_2(x_in)
202 | x_4 = self.down_block_4(x_2)
203 | x_8 = self.down_block_8(x_4)
204 | x_16 = self.down_block_16(x_8)
205 | x_32 = self.down_block_32(x_16)
206 |
207 | x_16_u = self.up_block_16(x_32, x_16)
208 | x_8_u = self.up_block_8(x_16_u, x_8)
209 | x_4_u = self.up_block_4(x_8_u, x_4)
210 | x_2_u = self.up_block_2(x_4_u, x_2)
211 | x_u = self.up_block(x_2_u, x_in)
212 |
213 | return self.head(torch.cat([x_in, x_u], 1))
214 |
--------------------------------------------------------------------------------
/quantnn/plotting.py:
--------------------------------------------------------------------------------
1 | """
2 | ================
3 | quantnn.plotting
4 | ================
5 |
6 | The plotting module provides some utility function for plotting QRNN results.
7 | """
8 | from copy import copy
9 | import pathlib
10 |
11 | from matplotlib import rc
12 | import matplotlib as mpl
13 | import matplotlib.pyplot as plt
14 | from matplotlib.colors import to_rgba
15 |
16 | _STYLE_FILE = pathlib.Path(__file__).parent / "data" / "matplotlib_style.rc"
17 |
18 |
19 | def set_style(latex=False):
20 | """
21 | Sets matplotlib style to a style file that I find visually more pleasing
22 | then the default settings.
23 |
24 | Args:
25 | latex: Whether or not to use latex to render text.
26 | """
27 | plt.style.use(str(_STYLE_FILE))
28 | rc("text", usetex=latex)
29 |
30 |
31 | def plot_confidence_intervals(ax, x, y_pred, quantiles, color="C0"):
32 | """
33 | Plots symmetric confidence intervals using transparency to display uncertainty.
34 |
35 | This function plots a 1-dimensional sequence of predicts quantiles as confidence
36 | intervals. The intervals are displayed as filled regions with the transparency
37 | set according to the corresponding uncertainty.
38 |
39 | Arguments:
40 | ax: Matplotlib axes instance to use for plotting.
41 | x: The x values corresponding to the prediction in y_pred.
42 | y_pred: 2D array containing the predicted quantiles.
43 | quantiles: The quantiles corresponding to the second axis of y_pred.
44 | color: The color to use for filling.
45 | """
46 | n = y_pred.shape[1]
47 | if n % 2:
48 | c_0 = y_pred[:, n // 2]
49 | else:
50 | c_0 = 0.5 * (y_pred[:, n // 2] + y_pred[:, n // 2 + 1])
51 | alpha_0 = 0.9
52 | alpha_min = 0.1
53 | d_alpha = alpha_0 - alpha_min
54 |
55 | c = c_0
56 | for i in range(n // 2 - 1, -1, -1):
57 | q = quantiles[i]
58 | alpha = alpha_0 - 2.0 * d_alpha * (0.5 - q)
59 | color = to_rgba(color, alpha)
60 | ax.fill_between(x, c, y_pred[:, i], edgecolor=None, facecolor=color)
61 |
62 | c = c_0
63 | for i in range(n // 2 + 1, n):
64 | q = quantiles[i]
65 | alpha = alpha_0 - 2.0 * d_alpha * (q - 0.5)
66 | color = to_rgba(color, alpha)
67 | ax.fill_between(x, c, y_pred[:, i], edgecolor=None, facecolor=color)
68 |
69 |
70 | def plot_quantiles(ax, x, y_pred, quantiles, cmap="magma"):
71 | """
72 | Plots symmetric confidence intervals using transparency to display uncertainty.
73 |
74 | This function plots a 1-dimensional sequence of predicts quantiles as confidence
75 | intervals. The intervals are displayed as filled regions with the transparency
76 | set according to the corresponding uncertainty.
77 |
78 | Arguments:
79 | ax: Matplotlib axes instance to use for plotting.
80 | x: The x values corresponding to the prediction in y_pred.
81 | y_pred: 2D array containing the predicted quantiles.
82 | quantiles: The quantiles corresponding to the second axis of y_pred.
83 | color: The color to use for filling.
84 | """
85 | cmap = copy(mpl.cm.get_cmap(cmap))
86 | norm = mpl.colors.BoundaryNorm(quantiles, cmap.N)
87 | mappable = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
88 | cmap.set_under([0.0] * 4)
89 | cmap.set_over([0.0] * 4)
90 |
91 | n = len(quantiles)
92 | for i in range(n - 1):
93 | q = 0.5 * (quantiles[i] + quantiles[i + 1])
94 | color = mappable.to_rgba(q)
95 | y_low = y_pred[:, i]
96 | y_high = y_pred[:, i + 1]
97 | ax.fill_between(x, y_low, y_high, edgecolor=None, facecolor=color)
98 |
99 | return mappable
100 |
--------------------------------------------------------------------------------
/quantnn/transformations.py:
--------------------------------------------------------------------------------
1 | """
2 | =======================
3 | quantnn.transformations
4 | =======================
5 |
6 | This module defines transformations that can be applied to train
7 | a network in a transformed space but evaluate it in the original
8 | space.
9 | """
10 | import numpy as np
11 |
12 | from quantnn.backends import get_tensor_backend
13 |
14 |
15 | class Log10:
16 | """
17 | Transforms values to log space.
18 | """
19 |
20 | def __init__(self):
21 | self.xp = None
22 |
23 | def __call__(self, x):
24 | """
25 | Transform tensor.
26 |
27 | Args:
28 | x: Tensor containing the values to transform.
29 |
30 | Return:
31 | Tensor containing the transformed values.
32 |
33 | """
34 | if self.xp is None:
35 | xp = get_tensor_backend(x)
36 | self.xp = xp
37 | else:
38 | xp = self.xp
39 | return xp.log(x.double()).float() / np.log(10)
40 |
41 | def invert(self, y):
42 | """
43 | Transform transformed values back to original space.
44 |
45 | Args:
46 | y: Tensor containing the transformed values to transform
47 | back.
48 |
49 | Returns:
50 | Tensor containing the original values.
51 | """
52 | if self.xp is None:
53 | xp = get_tensor_backend(y)
54 | self.xp = xp
55 | else:
56 | xp = self.xp
57 | return xp.exp(np.log(10) * y.double()).float()
58 |
59 |
60 | class Log:
61 | """
62 | Transforms values to log space.
63 | """
64 |
65 | def __init__(self):
66 | self.xp = None
67 |
68 | def __call__(self, x):
69 | """
70 | Transform tensor.
71 |
72 | Args:
73 | x: Tensor containing the values to transform.
74 |
75 | Return:
76 | Tensor containing the transformed values.
77 |
78 | """
79 | if self.xp is None:
80 | xp = get_tensor_backend(x)
81 | self.xp = xp
82 | else:
83 | xp = self.xp
84 | return xp.log(x.double()).float()
85 |
86 | def invert(self, y):
87 | """
88 | Transform transformed values back to original space.
89 |
90 | Args:
91 | y: Tensor containing the transformed values to transform
92 | back.
93 |
94 | Returns:
95 | Tensor containing the original values.
96 | """
97 |
98 | if self.xp is None:
99 | xp = get_tensor_backend(y)
100 | self.xp = xp
101 | else:
102 | xp = self.xp
103 | return xp.exp(y.double()).float()
104 |
105 |
106 | class Softplus:
107 | """
108 | Applies softplus transform to values.
109 | """
110 |
111 | def __init__(self):
112 | self.xp = None
113 |
114 | def __call__(self, x):
115 | """
116 | Transform tensor.
117 |
118 | Args:
119 | x: Tensor containing the values to transform.
120 |
121 | Return:
122 | Tensor containing the transformed values.
123 |
124 | """
125 | if self.xp is None:
126 | xp = get_tensor_backend(x)
127 | self.xp = xp
128 | else:
129 | xp = self.xp
130 |
131 | return xp.where(x > 10, x, xp.log(xp.exp(x) - 1.0 + 1e-30))
132 |
133 | def invert(self, y):
134 | """
135 | Transform transformed values back to original space.
136 |
137 | Args:
138 | y: Tensor containing the transformed values to transform
139 | back.
140 |
141 | Returns:
142 | Tensor containing the original values.
143 | """
144 | if self.xp is None:
145 | xp = get_tensor_backend(y)
146 | self.xp = xp
147 | else:
148 | xp = self.xp
149 | return xp.where(y > 10, y, xp.log(xp.exp(y) + 1.0))
150 |
151 |
152 | class LogLinear:
153 | """
154 | Composition of natural logarithm transformation and for
155 | x > 1 and identity transformation for x <= 1.
156 | """
157 |
158 | def __init__(self):
159 | self.xp = None
160 |
161 | def __call__(self, x):
162 | """
163 | Transform tensor.
164 |
165 | Args:
166 | x: Tensor containing the values to transform.
167 |
168 | Return:
169 | Tensor containing the transformed values.
170 |
171 | """
172 | if self.xp is None:
173 | xp = get_tensor_backend(x)
174 | self.xp = xp
175 | else:
176 | xp = self.xp
177 |
178 | return xp.where(x > 1, x - 1.0, xp.log(x))
179 |
180 | def invert(self, y):
181 | """
182 | Transform transformed values back to original space.
183 |
184 | Args:
185 | y: Tensor containing the transformed values to transform
186 | back.
187 |
188 | Returns:
189 | Tensor containing the original values.
190 | """
191 |
192 | if self.xp is None:
193 | xp = get_tensor_backend(y)
194 | self.xp = xp
195 | else:
196 | xp = self.xp
197 | return xp.where(y > 0, y + 1.0, xp.exp(y))
198 |
199 |
200 | class Id:
201 | """
202 | Identity transform for testing.
203 | """
204 |
205 | def __init__(self):
206 | self.xp = None
207 |
208 | def __call__(self, x):
209 | """
210 | Transform tensor.
211 |
212 | Args:
213 | x: Tensor containing the values to transform.
214 |
215 | Return:
216 | Tensor containing the transformed values.
217 |
218 | """
219 | return x
220 |
221 | def invert(self, y):
222 | """
223 | Transform transformed values back to original space.
224 |
225 | Args:
226 | y: Tensor containing the transformed values to transform
227 | back.
228 |
229 | Returns:
230 | Tensor containing the original values.
231 |
232 | """
233 | return y
234 |
--------------------------------------------------------------------------------
/quantnn/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | =============
3 | quantnn.utils
4 | =============
5 |
6 | This module providers Helper functions that are used in multiple other modules.
7 | """
8 | import io
9 | from pathlib import Path
10 | from tempfile import NamedTemporaryFile
11 |
12 | import xarray as xr
13 |
14 |
15 | def apply(f, *args):
16 | """
17 | Applies a function to sequence values or dicts of values.
18 |
19 | Args:
20 | f: The function to apply to ``x`` or all items in ``x``.
21 | *args: Sequence of arguments to be supplied to ``f``. If all arguments
22 | are dicts, the function ``f`` is applied key-wise to all elements
23 | in the dict. Otherwise the function is applied to all provided
24 | argument.s
25 |
26 | Return:
27 | ``{k: f(x_1[k], x_1[k], ...) for k in x}`` or ``f(x)`` depending on
28 | whether ``x_1, ...`` are a dicts or not.
29 | """
30 | if any(isinstance(x, dict) for x in args):
31 | results = {}
32 | d = [x for x in args if isinstance(x, dict)][0]
33 | for k in d:
34 | args_k = [arg[k] if isinstance(arg, dict) else arg
35 | for arg in args]
36 | results[k] = f(*args_k)
37 | return results
38 | return f(*args)
39 |
40 |
41 | def serialize_dataset(dataset):
42 | """
43 | Writes xarray dataset to a bytestream.
44 |
45 | Args:
46 | dataset: A xarray dataset to seraialize.
47 |
48 | Returns:
49 | Bytes object containing the dataset as netcdf file.
50 | """
51 | tmp = NamedTemporaryFile(delete=False)
52 | tmp.close()
53 | filename = tmp.name
54 | try:
55 | dataset.to_netcdf(filename)
56 | with open(filename, "rb") as file:
57 | buffer = file.read()
58 | finally:
59 | Path(filename).unlink()
60 | return buffer
61 |
62 |
63 | def deserialize_dataset(data):
64 | """
65 | Read xarray dataset from byte stream containing the
66 | dataset in NetCDF format.
67 |
68 | Args:
69 | data: The bytes object containing the binary data of the
70 | NetCDf file.
71 |
72 | Returns:
73 | The deserialized xarray dataset.
74 | """
75 | tmp = NamedTemporaryFile(delete=False)
76 | tmp.close()
77 | filename = tmp.name
78 | try:
79 | with open(filename, "wb") as file:
80 | buffer = file.write(data)
81 | dataset = xr.load_dataset(filename)
82 | finally:
83 | Path(filename).unlink()
84 | return dataset
85 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | with open("README.md", "r") as fh:
4 | long_description = fh.read()
5 |
6 | setuptools.setup(
7 | name="quantnn",
8 | version="0.0.5",
9 | author="Simon Pfreundschuh",
10 | description="Quantile regression neural networks.",
11 | long_description=long_description,
12 | long_description_content_type="text/markdown",
13 | url="https://github.com/simonpf/quantnn",
14 | packages=setuptools.find_packages(),
15 | classifiers=[
16 | "Programming Language :: Python :: 3",
17 | "Operating System :: OS Independent",
18 | ],
19 | install_requires=[
20 | "numpy",
21 | "scipy",
22 | "xarray",
23 | "paramiko",
24 | "einops",
25 | "matplotlib",
26 | "rich"],
27 | python_requires=">=3.6",
28 | include_package_data=True,
29 | )
30 |
--------------------------------------------------------------------------------
/test/conftest.py:
--------------------------------------------------------------------------------
1 | """
2 | Contains fixtures that are automatically available in all test files.
3 | """
4 | import pytest
5 | import numpy
6 |
7 | BACKENDS = [numpy]
8 |
9 | #try:
10 | # import tensorflow as tf
11 | # BACKENDS.append(tf)
12 | #except ModuleNotFoundError:
13 | # pass
14 |
15 | try:
16 | import torch
17 | BACKENDS += [torch]
18 | except ModuleNotFoundError:
19 | pass
20 |
21 | try:
22 | import jax.numpy as jnp
23 | BACKENDS.append(jnp)
24 | except ModuleNotFoundError:
25 | pass
26 |
27 | def pytest_configure():
28 | pytest.backends = BACKENDS
29 |
30 |
--------------------------------------------------------------------------------
/test/files/test_files.py:
--------------------------------------------------------------------------------
1 | """
2 | Test for the generic function in the :py:mod:`quantnn.files module`.
3 | """
4 | import os
5 |
6 | import pytest
7 | import numpy as np
8 | from quantnn.files import read_file, CachedDataFolder
9 | from concurrent.futures import ThreadPoolExecutor
10 |
11 | # Currently no SFTP test data available.
12 | HAS_LOGIN_INFO = False
13 |
14 | def test_local_file(tmp_path):
15 | """
16 | Ensures that opening a local file works.
17 | """
18 | with open(tmp_path / "test.txt", "w") as file:
19 | file.write("test")
20 |
21 | with read_file(tmp_path / "test.txt") as file:
22 | content = file.read()
23 |
24 | assert content == "test"
25 |
26 |
27 | def test_local_folder(tmp_path):
28 | """
29 | Ensures that opening a local file works.
30 | """
31 | with open(tmp_path / "test.txt", "w") as file:
32 | file.write("test")
33 |
34 | data_folder = CachedDataFolder(tmp_path, "*.txt")
35 |
36 | assert len(data_folder.files) == 1
37 |
38 | file = data_folder.open(data_folder.files[0])
39 | assert file.read() == "test"
40 |
41 |
42 | @pytest.mark.skipif(not HAS_LOGIN_INFO, reason="No SFTP login info.")
43 | def test_remote_file():
44 | """
45 | Ensures that opening a local file works.
46 | """
47 | host = "129.16.35.202"
48 | path = "/mnt/array1/share/MLDatasets/test/data_0.npz"
49 | with read_file("sftp://" + host + path, "rb") as file:
50 | data = np.load(file)
51 |
52 | assert np.all(np.isclose(data["x"], 0.0))
53 |
54 |
55 | @pytest.mark.skipif(not HAS_LOGIN_INFO, reason="No SFTP login info.")
56 | def test_remote_folder(tmp_path):
57 | """
58 | Ensures that opening a local file works.
59 | """
60 | host = "129.16.35.202"
61 | path = "/mnt/array1/share/MLDatasets/test/"
62 |
63 | data_folder = CachedDataFolder("sftp://" + host + path, "*.npz")
64 |
65 | assert len(data_folder.files) == 8
66 | data_folder.files.sort()
67 |
68 | file = data_folder.get(data_folder.files[0])
69 | data = np.load(file)
70 | assert np.all(np.isclose(data["x"], 0.0))
71 |
72 | pool = ThreadPoolExecutor(max_workers=4)
73 | data_folder.download(pool)
74 |
75 | file = data_folder.get(data_folder.files[0])
76 | data = np.load(file)
77 | assert np.all(np.isclose(data["x"], 0.0))
78 |
79 |
80 |
--------------------------------------------------------------------------------
/test/files/test_sftp.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import pytest
5 | from quantnn.files import sftp
6 |
7 | # Currently no SFTP test data available.
8 | HAS_LOGIN_INFO = False
9 |
10 | @pytest.mark.skipif(not HAS_LOGIN_INFO, reason="No SFTP login info.")
11 | def test_list_files():
12 | host = "129.16.35.202"
13 | path = "/mnt/array1/share/MLDatasets/test"
14 | files = sftp.list_files(host, path)
15 | assert len(files) == 8
16 |
17 | @pytest.mark.skipif(not HAS_LOGIN_INFO, reason="No SFTP login info.")
18 | def test_download_file():
19 | """
20 | Ensure that downloading of files work and the data is cleaned up after
21 | usage.
22 | """
23 | host = "129.16.35.202"
24 | path = "/mnt/array1/share/MLDatasets/test/data_0.npz"
25 | tmp_file = None
26 | with sftp.download_file(host, path) as file:
27 | tmp_file = file
28 | data = np.load(file)
29 | assert np.all(np.isclose(data["x"], 0.0))
30 |
31 | assert not tmp_file.exists()
32 |
33 | @pytest.mark.skipif(not HAS_LOGIN_INFO, reason="No SFTP login info.")
34 | def test_sftp_cache():
35 | """
36 | Ensure that downloading of files work and the data is cleaned up after
37 | usage.
38 | """
39 | host = "129.16.35.202"
40 | path = "/mnt/array1/share/MLDatasets/test/data_0.npz"
41 |
42 | cache = sftp.SFTPCache()
43 | file = cache.get(host, path)
44 | data = np.load(file, allow_pickle=True)
45 | assert np.all(np.isclose(data["x"], 0.0))
46 |
--------------------------------------------------------------------------------
/test/models/pytorch/test_aggregators.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import torch
4 | from quantnn.packed_tensor import PackedTensor
5 | from quantnn.models.pytorch.aggregators import (
6 | AverageAggregatorFactory,
7 | BlockAggregatorFactory,
8 | SparseAggregator,
9 | SumAggregatorFactory,
10 | LinearAggregatorFactory,
11 | AttentionFusion
12 | )
13 | from quantnn.models.pytorch.torchvision import ResNetBlockFactory
14 |
15 |
16 | def fill_tensor(t, indices):
17 | """
18 | Fills tensor with corresponding sample indices.
19 | """
20 | for i, ind in enumerate(indices):
21 | t[i] = ind
22 | return t
23 |
24 |
25 | def make_random_packed_tensor(batch_size, samples, shape):
26 | """
27 | Create a sparse tensor representing a training batch with
28 | missing samples. Which samples are missing is randomized.
29 | The elements of the tensor correspond to the sample index.
30 |
31 | Args:
32 | batch_size: The nominal batch size of the training batch.
33 | samples: The number of non-missing samples.
34 | shape: The of each sample in the training batch.
35 | """
36 | indices = sorted(
37 | np.random.choice(np.arange(batch_size), size=samples, replace=False)
38 | )
39 | t = np.ones((samples,) + shape, dtype=np.float32)
40 | t = fill_tensor(t, indices)
41 | return PackedTensor(t, batch_size, indices)
42 |
43 |
44 | def test_sparse_aggregator():
45 | """
46 | Tests the sparse aggregator by ensuring that sparse inputs are correctly
47 | combined.
48 | """
49 | aggregator_factory = AverageAggregatorFactory()
50 | aggregator = SparseAggregator((8, 8, 8), 8, aggregator_factory)
51 |
52 | # Ensure full tensor is returned if only one of the provided
53 | # tensors is sparse.
54 | x_1 = torch.ones((100, 8, 32, 32), dtype=torch.float32)
55 | fill_tensor(x_1, range(100))
56 | x_2 = torch.ones((100, 8, 32, 32), dtype=torch.float32)
57 | fill_tensor(x_2, range(100))
58 | x_3 = make_random_packed_tensor(100, 50, (8, 32, 32))
59 | y = aggregator(x_1, x_2, x_3)
60 | assert not isinstance(y, PackedTensor)
61 | assert torch.isclose(y, x_1).all()
62 |
63 | x_2 = torch.ones((100, 8, 32, 32), dtype=torch.float32)
64 | fill_tensor(x_2, range(100))
65 | x_1 = make_random_packed_tensor(100, 50, (8, 32, 32))
66 | y = aggregator(x_1, x_2)
67 | assert not isinstance(y, PackedTensor)
68 | assert torch.isclose(y, x_2).all()
69 |
70 | # Make sure merging works with two packed tensors.
71 | x_1 = make_random_packed_tensor(100, 50, (8, 32, 32))
72 | x_2 = make_random_packed_tensor(100, 50, (8, 32, 32))
73 | y = aggregator(x_1, x_2)
74 | assert isinstance(y, PackedTensor)
75 | batch_indices_union = sorted(list(set(x_1.batch_indices + x_2.batch_indices)))
76 | assert y.batch_indices == batch_indices_union
77 | for ind, batch_ind in enumerate(y.batch_indices):
78 | assert torch.isclose(y._t[ind], batch_ind * torch.ones(1, 1, 1)).all()
79 |
80 | # Test aggregation with linear layer.
81 | aggregator_factory = LinearAggregatorFactory()
82 | aggregator = SparseAggregator((8, 8), 8, aggregator_factory)
83 |
84 | # Ensure full tensor is returned if only one of the provided
85 | # tensors is sparse.
86 | x_1 = torch.ones((100, 8, 32, 32), dtype=torch.float32)
87 | fill_tensor(x_1, range(100))
88 | x_2 = make_random_packed_tensor(100, 50, (8, 32, 32))
89 | y = aggregator(x_1, x_2)
90 | assert not isinstance(y, PackedTensor)
91 |
92 | x_2 = torch.ones((100, 8, 32, 32), dtype=torch.float32)
93 | fill_tensor(x_2, range(100))
94 | x_1 = make_random_packed_tensor(100, 50, (8, 32, 32))
95 | y = aggregator(x_1, x_2)
96 | assert not isinstance(y, PackedTensor)
97 |
98 | # Make sure merging works with two packed tensors.
99 | x_1 = make_random_packed_tensor(100, 50, (8, 32, 32))
100 | x_2 = make_random_packed_tensor(100, 50, (8, 32, 32))
101 | y = aggregator(x_1, x_2)
102 | assert isinstance(y, PackedTensor)
103 | batch_indices_union = sorted(list(set(x_1.batch_indices + x_2.batch_indices)))
104 | assert y.batch_indices == batch_indices_union
105 |
106 | y = aggregator(x_1, None)
107 | assert isinstance(y, PackedTensor)
108 |
109 |
110 | AGGREGATORS = [
111 | SumAggregatorFactory(),
112 | AverageAggregatorFactory(),
113 | LinearAggregatorFactory(),
114 | BlockAggregatorFactory(ResNetBlockFactory()),
115 | ]
116 |
117 |
118 | @pytest.mark.parametrize("aggregator", AGGREGATORS)
119 | def test_aggregators(aggregator):
120 | a = torch.ones(1, 10, 16, 16)
121 | b = torch.ones(1, 10, 16, 16)
122 | agg = aggregator((10, 10), 10)
123 | c = agg(a, b)
124 | assert c.shape == (1, 10, 16, 16)
125 |
126 | c = torch.ones(1, 10, 16, 16)
127 | agg = aggregator((10, 10, 10), 10)
128 | c = agg(a, b, c)
129 | assert c.shape == (1, 10, 16, 16)
130 |
131 |
132 | def test_attention_fusion():
133 |
134 | x = torch.ones(1, 10, 32, 32)
135 | y = torch.ones(1, 20, 32, 32)
136 | z = torch.ones(1, 5, 32, 32)
137 |
138 | mha = AttentionFusion((10, 20, 5), 16, 8)
139 | merged = mha(x, y, z)
140 |
141 | assert merged.shape == (1, 16, 32, 32)
142 |
--------------------------------------------------------------------------------
/test/models/pytorch/test_base.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for the quantnn.models.pytorch.base module.
3 | """
4 | from torch import nn
5 |
6 | from quantnn.models.pytorch.base import ParamCount
7 |
8 |
9 | class ConvModule(nn.Conv2d, ParamCount):
10 | """
11 | Simple wrapper class that adds ParamCount mixin to
12 | a Conv2d module.
13 | """
14 |
15 | def __init__(
16 | self, channels_in: int, channels_out: int, kernel_size: int, bias: bool = False
17 | ):
18 | nn.Conv2d.__init__(
19 | self, channels_in, channels_out, kernel_size=kernel_size, bias=bias
20 | )
21 |
22 |
23 | def test_n_params():
24 | """
25 | Ensure that n_params returns the right number of parameters for
26 | a 2D convolution layer.
27 | """
28 | conv = ConvModule(16, 16, 3, bias=True)
29 | assert conv.n_params == 16 * 16 * 3 * 3 + 16
30 |
31 | conv = ConvModule(16, 16, 3, bias=False)
32 | assert conv.n_params == 16 * 16 * 3 * 3
33 |
--------------------------------------------------------------------------------
/test/models/pytorch/test_blocks.py:
--------------------------------------------------------------------------------
1 | from quantnn.models.pytorch.blocks import (
2 | ConvBlockFactory,
3 | ConvTransposedBlockFactory,
4 | ResNeXtBlockFactory,
5 | ConvNextBlockFactory
6 | )
7 |
8 | import torch
9 | from torch import nn
10 |
11 |
12 | def test_conv_block_factory():
13 | """
14 | Ensures that the basic ConvBlockFactory produces a working
15 | block.
16 | """
17 | block_factory = ConvBlockFactory(kernel_size=3)
18 | block = block_factory(8, 16)
19 | input = torch.ones(8, 8, 8, 8)
20 | output = block(input)
21 | assert output.shape == (8, 16, 8, 8)
22 |
23 | block_factory = ConvBlockFactory(
24 | kernel_size=3,
25 | norm_factory=nn.BatchNorm2d,
26 | activation_factory=nn.GELU
27 | )
28 |
29 | block = block_factory(8, 16)
30 | input = torch.ones(8, 8, 8, 8)
31 | output = block(input)
32 | assert output.shape == (8, 16, 8, 8)
33 |
34 | block = block_factory(8, 16, downsample=(1, 2))
35 | input = torch.ones(8, 8, 8, 8)
36 | output = block(input)
37 | assert output.shape == (8, 16, 8, 4)
38 |
39 |
40 | def test_conv_transposed_block_factory():
41 | """
42 | Ensures that the basic ConvBlockFactory produces a working
43 | block.
44 | """
45 | block_factory = ConvTransposedBlockFactory(kernel_size=3)
46 | block = block_factory(8, 16)
47 | input = torch.ones(8, 8, 8, 8)
48 | output = block(input)
49 | assert output.shape == (8, 16, 8, 8)
50 |
51 | block_factory = ConvTransposedBlockFactory(
52 | kernel_size=3,
53 | norm_factory=nn.BatchNorm2d,
54 | activation_factory=nn.GELU
55 | )
56 | block = block_factory(8, 16)
57 | input = torch.ones(8, 8, 8, 8)
58 | output = block(input)
59 | assert output.shape == (8, 16, 8, 8)
60 |
61 | block = block_factory(8, 16, downsample=(1, 2))
62 | input = torch.ones(8, 8, 8, 8)
63 | output = block(input)
64 | assert output.shape == (8, 16, 8, 15)
65 |
66 | def test_resnext_block():
67 | """
68 | Ensure that the ResNext factory produces an nn.Module and that
69 | the output has the specified number of channels.
70 | """
71 | x = torch.ones((1, 1, 8, 8))
72 |
73 | factory = ResNeXtBlockFactory()
74 | block = factory(1, 64)
75 | y = block(x)
76 | assert y.shape == (1, 64, 8, 8)
77 |
78 | block = factory(1, 64, downsample=2)
79 | y = block(x)
80 | assert y.shape == (1, 64, 4, 4)
81 |
82 | block = factory(8, 64, downsample=(1, 2))
83 | input = torch.ones(8, 8, 8, 8)
84 | output = block(input)
85 | assert output.shape == (8, 64, 8, 4)
86 |
87 |
88 | def test_convnext_block():
89 | """
90 | Ensure that the ConvNext factory produces an nn.Module and that
91 | the output has the specified number of channels.
92 | """
93 | x = torch.ones((1, 1, 8, 8))
94 |
95 | # ConvNext V1
96 | factory = ConvNextBlockFactory(version=1)
97 | block = factory(1, 64)
98 | y = block(x)
99 | assert y.shape == (1, 64, 8, 8)
100 | block = factory(1, 64, downsample=2)
101 | y = block(x)
102 | assert y.shape == (1, 64, 4, 4)
103 |
104 | # ConvNext V2
105 | factory = ConvNextBlockFactory(version=2)
106 | block = factory(1, 64)
107 | y = block(x)
108 | assert y.shape == (1, 64, 8, 8)
109 | block = factory(1, 64, downsample=2)
110 | y = block(x)
111 | assert y.shape == (1, 64, 4, 4)
112 |
--------------------------------------------------------------------------------
/test/models/pytorch/test_decoders.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for quantnn.models.pytorch.decoders.
3 | """
4 | import pytest
5 |
6 | import torch
7 | from quantnn.packed_tensor import PackedTensor
8 | from quantnn.models.pytorch.encoders import (
9 | SpatialEncoder,
10 | MultiInputSpatialEncoder
11 | )
12 | from quantnn.models.pytorch.decoders import (
13 | SpatialDecoder,
14 | SparseSpatialDecoder,
15 | DLADecoderStage,
16 | DLADecoder
17 | )
18 | from quantnn.models.pytorch.torchvision import ResNetBlockFactory
19 | from quantnn.models.pytorch.aggregators import (
20 | LinearAggregatorFactory,
21 | SparseAggregatorFactory,
22 | BlockAggregatorFactory
23 | )
24 | from quantnn.models.pytorch.upsampling import BilinearFactory
25 |
26 |
27 | def test_spatial_decoder():
28 | """
29 | Test that chaining an encoder and corresponding decoder reproduces
30 | output of the same spatial dimensions as the input.
31 | """
32 | block_factory = ResNetBlockFactory()
33 | encoder = SpatialEncoder(
34 | channels=1,
35 | stages=[4] * 4,
36 | channel_scaling=2,
37 | max_channels=8,
38 | block_factory=block_factory,
39 | )
40 | decoder = SpatialDecoder(
41 | channels=1,
42 | stages=[4] * 3,
43 | channel_scaling=2,
44 | max_channels=8,
45 | block_factory=block_factory,
46 | skip_connections=False
47 | )
48 | # Test forward without skip connections.
49 | x = torch.ones((1, 1, 32, 32))
50 | y = decoder(encoder(x))
51 |
52 | # Shape of y should be same as before.
53 | assert y.shape == (1, 1, 32, 32)
54 |
55 | #
56 | # Test asymmetric decoder
57 | #
58 | decoder = SpatialDecoder(
59 | channels=[8, 1, 1],
60 | stages=[4] * 2,
61 | channel_scaling=2,
62 | max_channels=8,
63 | block_factory=block_factory,
64 | skip_connections=encoder.skip_connections,
65 | )
66 | # Test forward width skips returned.
67 | y = decoder(encoder(x, return_skips=True))
68 | # Size should be less than input.
69 | assert y.shape == (1, 1, 16, 16)
70 |
71 | encoder = SpatialEncoder(
72 | channels=1,
73 | stages=[4] * 4,
74 | channel_scaling=2,
75 | max_channels=8,
76 | block_factory=block_factory,
77 | downsampling_factors=[2, 2, 4]
78 | )
79 | decoder = SpatialDecoder(
80 | channels=1,
81 | stages=[4] * 3,
82 | channel_scaling=2,
83 | max_channels=8,
84 | block_factory=block_factory,
85 | skip_connections=False,
86 | upsampling_factors=[4, 2, 2],
87 | )
88 | # Test forward without skip connections.
89 | x = torch.ones((1, 1, 128, 128))
90 | y = decoder(encoder(x))
91 | # Width and height should be reduced by 16.
92 | # Number of channels should be maximum.
93 | assert y.shape == (1, 1, 128, 128)
94 |
95 | # Test decoder with different channel config.
96 | decoder = SpatialDecoder(
97 | channels=[8, 2, 2, 2],
98 | stages=[4] * 3,
99 | channel_scaling=2,
100 | max_channels=8,
101 | block_factory=block_factory,
102 | skip_connections=False,
103 | upsampling_factors=[4, 2, 2],
104 | )
105 | # Test forward without skip connections.
106 | x = torch.ones((1, 1, 128, 128))
107 | y = decoder(encoder(x))
108 | # Width and height should be reduced by 16.
109 | # Number of channels should be maximum.
110 | assert y.shape == (1, 2, 128, 128)
111 |
112 |
113 | def test_encoder_decoder_multi_scale_output():
114 | """
115 | Tests the chaining of a multi-input encoder and a decoder with
116 | potentially missing input.
117 | """
118 | block_factory = ResNetBlockFactory()
119 | aggregator_factory = SparseAggregatorFactory(
120 | LinearAggregatorFactory()
121 | )
122 | inputs = {
123 | "input_1": (1, 16),
124 | "input_2": (2, 32)
125 | }
126 | encoder = MultiInputSpatialEncoder(
127 | inputs=inputs,
128 | channels=4,
129 | stages=[4] * 4,
130 | block_factory=block_factory,
131 | channel_scaling=2,
132 | max_channels=16,
133 | aggregator_factory=aggregator_factory
134 | )
135 |
136 | decoder = SparseSpatialDecoder(
137 | channels=4,
138 | stages=[4] * 3,
139 | channel_scaling=2,
140 | max_channels=16,
141 | block_factory=block_factory,
142 | skip_connections=3,
143 | multi_scale_output=16
144 | )
145 | # Test forward without skip connections.
146 | x = {
147 | "input_1": PackedTensor(
148 | torch.ones((0, 16, 16, 16)),
149 | batch_size=4,
150 | batch_indices=[]
151 | ),
152 | "input_2": PackedTensor(
153 | torch.ones((1, 32, 8, 8)),
154 | batch_size=4,
155 | batch_indices=[0]
156 | )
157 | }
158 | y = encoder(x, return_skips=True)
159 | y = decoder(y)
160 | assert len(y) == 4
161 | for y_i in y:
162 | assert y_i.shape[1] == 16
163 | assert y[-1].shape[2] == 32
164 |
165 | #
166 | # Ensure using different channels than encoder works.
167 | #
168 |
169 | decoder = SparseSpatialDecoder(
170 | channels=[16, 8, 8, 8],
171 | stages=[4] * 3,
172 | channel_scaling=2,
173 | max_channels=16,
174 | block_factory=block_factory,
175 | skip_connections=encoder.skip_connections,
176 | multi_scale_output=16
177 | )
178 | x = {
179 | "input_1": PackedTensor(
180 | torch.ones((0, 16, 16, 16)),
181 | batch_size=4,
182 | batch_indices=[]
183 | ),
184 | "input_2": PackedTensor(
185 | torch.ones((1, 32, 8, 8)),
186 | batch_size=4,
187 | batch_indices=[0]
188 | )
189 | }
190 | y = encoder(x, return_skips=True)
191 | y = decoder(y)
192 | assert y[0].shape[1] == 16
193 | assert y[-1].shape[1] == 16
194 |
195 | #
196 | # Test decoer with more stages than encoder.
197 | #
198 |
199 | decoder = SparseSpatialDecoder(
200 | channels=[16, 8, 8, 8, 2],
201 | stages=[4] * 4,
202 | channel_scaling=2,
203 | max_channels=16,
204 | block_factory=block_factory,
205 | skip_connections=encoder.skip_connections,
206 | base_scale=3
207 | )
208 | x = {
209 | "input_1": PackedTensor(
210 | torch.ones((0, 16, 16, 16)),
211 | batch_size=4,
212 | batch_indices=[]
213 | ),
214 | "input_2": PackedTensor(
215 | torch.ones((1, 32, 8, 8)),
216 | batch_size=4,
217 | batch_indices=[0]
218 | )
219 | }
220 | y = encoder(x, return_skips=True)
221 | y = decoder(y)
222 | assert y.shape[1] == 2
223 | assert y.shape[2] == 64
224 |
225 | @pytest.mark.xfail
226 | def test_dla_decoder():
227 | """
228 | Test implementation of the DLA decoder stages and full decoder.
229 | """
230 | x = [
231 | torch.ones((2, 16, 4, 4)),
232 | torch.ones((2, 8, 16, 16)),
233 | torch.ones((2, 4, 32, 32)),
234 | torch.ones((2, 2, 64, 64)),
235 | ]
236 |
237 | aggregator_factory = BlockAggregatorFactory(
238 | ResNetBlockFactory()
239 | )
240 | upsampler_factory = BilinearFactory()
241 |
242 | #
243 | # Single stage
244 | #
245 |
246 | decoder = DLADecoderStage(
247 | [16, 8, 4, 2],
248 | [16, 8, 4, 2],
249 | [16, 4, 2, 1],
250 | aggregator_factory,
251 | upsampler_factory
252 | )
253 | y = decoder(x)
254 | # Output should contain one less tensor than the input.
255 | assert len(y) == len(x) - 1
256 | y[0].shape == (2, 4, 32, 32)
257 | y[1].shape == (2, 8, 16, 16)
258 | y[2].shape == (2, 16, 4, 4)
259 |
260 | #
261 | # Full decoder
262 | #
263 |
264 | decoder = DLADecoder(
265 | [16, 8, 4, 2],
266 | [16, 4, 2, 1],
267 | aggregator_factory,
268 | upsampler_factory
269 | )
270 | y = decoder(x[::-1])
271 | # Output should be a single tensor
272 | assert isinstance(y, torch.Tensor)
273 | assert y.shape == (2, 2, 64, 64)
274 |
275 | #
276 | # Test sparse input.
277 | #
278 | aggregator_factory = SparseAggregatorFactory(
279 | BlockAggregatorFactory(
280 | ResNetBlockFactory()
281 | )
282 | )
283 | decoder = DLADecoder(
284 | [16, 8, 4, 2],
285 | [16, 4, 2, 1],
286 | aggregator_factory,
287 | upsampler_factory
288 | )
289 |
290 | x = [
291 | PackedTensor(
292 | torch.ones((0, 16, 4, 4)),
293 | batch_size=2,
294 | batch_indices=[]
295 | ),
296 | torch.ones((2, 8, 16, 16)),
297 | torch.ones((2, 4, 32, 32)),
298 | PackedTensor(
299 | torch.ones((0, 16, 64, 64)),
300 | batch_size=2,
301 | batch_indices=[]
302 | )
303 | ]
304 | y = decoder(x[::-1])
305 |
306 | # Output is, again, a full tensor.
307 | assert isinstance(y, torch.Tensor)
308 | assert y.shape == (2, 2, 64, 64)
309 |
310 |
--------------------------------------------------------------------------------
/test/models/pytorch/test_downsampling.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for the quantnn.models.pytorch.downsampling module.
3 | """
4 | import torch
5 | from torch import nn
6 |
7 | from quantnn.models.pytorch.downsampling import (
8 | ConvNextDownsamplerFactory,
9 | PatchMergingFactory
10 | )
11 |
12 |
13 | def test_convnext_downsampling():
14 | """
15 | Test that ConvNext downsampling works with different downsampling factors
16 | along dimensions.
17 | """
18 | down = ConvNextDownsamplerFactory()(16, 16, (2, 4))
19 | x = torch.rand(2, 16, 32, 32)
20 | y = down(x)
21 |
22 | assert y.shape[-2] == 16
23 | assert y.shape[-1] == 8
24 |
25 |
26 | def test_path_merging_downsampling():
27 | """
28 | Test PathMerging block for downsampling.
29 |
30 | """
31 | down = PatchMergingFactory()(16, 16, (2, 2))
32 |
33 | x = torch.rand(2, 16, 32, 32)
34 | y = down(x)
35 | assert y.shape[-2] == 16
36 | assert y.shape[-1] == 16
37 |
--------------------------------------------------------------------------------
/test/models/pytorch/test_fully_connected.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from quantnn.models.pytorch.fully_connected import MLP
4 |
5 |
6 | def test_mlp():
7 | #
8 | # 2D Input
9 | #
10 |
11 | x = torch.rand(128, 8)
12 |
13 | # No residual connections.
14 | mlp = MLP(
15 | features_in=8,
16 | n_features=128,
17 | features_out=16,
18 | n_layers=4,
19 | )
20 | y = mlp(x)
21 | assert y.shape == (128, 16)
22 |
23 | # Standard residual connections.
24 | mlp = MLP(
25 | features_in=8, n_features=128, features_out=16, n_layers=4, residuals="simple"
26 | )
27 | y = mlp(x)
28 | assert y.shape == (128, 16)
29 |
30 | # Hyper residual connections.
31 | mlp = MLP(
32 | features_in=8, n_features=128, features_out=16, n_layers=4, residuals="hyper"
33 | )
34 | y = mlp(x)
35 | assert y.shape == (128, 16)
36 |
37 | #
38 | # 4D input
39 | #
40 | x = torch.rand(128, 8, 8, 8)
41 |
42 | # No residual connections.
43 | mlp = MLP(
44 | features_in=8,
45 | n_features=128,
46 | features_out=16,
47 | n_layers=4,
48 | )
49 | y = mlp(x)
50 | assert y.shape == (128, 16, 8, 8)
51 |
52 | # Standard residual connections.
53 | mlp = MLP(
54 | features_in=8, n_features=128, features_out=16, n_layers=4, residuals="simple"
55 | )
56 | y = mlp(x)
57 | assert y.shape == (128, 16, 8, 8)
58 |
59 | # Hyper residual connections.
60 | mlp = MLP(
61 | features_in=8, n_features=128, features_out=16, n_layers=4, residuals="hyper"
62 | )
63 | y = mlp(x)
64 | assert y.shape == (128, 16, 8, 8)
65 |
66 |
67 | def test_mlp_output_shape():
68 | #
69 | # 2D Input
70 | #
71 |
72 | x = torch.rand(128, 8)
73 |
74 | # No residual connections.
75 | mlp = MLP(
76 | features_in=8, n_features=128, features_out=16, n_layers=4, output_shape=(4, 4)
77 | )
78 | y = mlp(x)
79 | assert y.shape == (128, 4, 4)
80 |
81 | # Standard residual connections.
82 | mlp = MLP(
83 | features_in=8,
84 | n_features=128,
85 | features_out=16,
86 | n_layers=4,
87 | residuals="simple",
88 | output_shape=(4, 4),
89 | )
90 | y = mlp(x)
91 | assert y.shape == (128, 4, 4)
92 |
93 | # Hyper residual connections.
94 | mlp = MLP(
95 | features_in=8,
96 | n_features=128,
97 | features_out=16,
98 | n_layers=4,
99 | residuals="hyper",
100 | output_shape=(4, 4),
101 | )
102 | y = mlp(x)
103 | assert y.shape == (128, 4, 4)
104 |
105 | #
106 | # 4D input
107 | #
108 | x = torch.rand(128, 8, 8, 8)
109 |
110 | # No residual connections.
111 | mlp = MLP(
112 | features_in=8, n_features=128, features_out=16, n_layers=4, output_shape=(4, 4)
113 | )
114 | y = mlp(x)
115 | assert y.shape == (128, 4, 4, 8, 8)
116 |
117 | # Standard residual connections.
118 | mlp = MLP(
119 | features_in=8,
120 | n_features=128,
121 | features_out=16,
122 | n_layers=4,
123 | residuals="simple",
124 | output_shape=(4, 4),
125 | )
126 | y = mlp(x)
127 | assert y.shape == (128, 4, 4, 8, 8)
128 |
129 | # Hyper residual connections.
130 | mlp = MLP(
131 | features_in=8,
132 | n_features=128,
133 | features_out=16,
134 | n_layers=4,
135 | residuals="hyper",
136 | output_shape=(4, 4),
137 | )
138 | y = mlp(x)
139 | assert y.shape == (128, 4, 4, 8, 8)
140 |
--------------------------------------------------------------------------------
/test/models/pytorch/test_normalization.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for the quantnn.models.pytorch.normalization module.
3 | """
4 | import torch
5 | import numpy as np
6 |
7 | from quantnn.models.pytorch.normalization import (
8 | LayerNormFirst,
9 | GRN
10 | )
11 |
12 |
13 | def test_layer_norm_first():
14 | """
15 | Assert that layer norm with channels along first dimensions works.
16 | """
17 | norm = LayerNormFirst(16)
18 | x = torch.rand(10, 16, 24, 24)
19 | y = norm(x)
20 |
21 | mu = y.mean(1).detach().numpy()
22 | assert np.all(np.isclose(mu, 0.0, atol=1e-5))
23 |
24 |
25 | def test_grn():
26 | """
27 | Assert that GRN works.
28 | """
29 | norm = GRN(16)
30 | x = torch.rand(10, 16, 24, 24)
31 | y = norm(x)
32 | x = x.detach().numpy()
33 | y = y.detach().numpy()
34 | assert np.all(np.isclose(x, y))
35 |
--------------------------------------------------------------------------------
/test/models/pytorch/test_torchvision.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for the torchvision wrappers defined in
3 | quantnn.models.pytorch.torchvision.
4 | """
5 | import pytest
6 | import torch
7 |
8 | import quantnn.models.pytorch.torchvision as tv
9 | try:
10 | import quantnn.models.pytorch.torchvision as tv
11 |
12 | HAS_TORCHVISION = True
13 | except ImportError as e:
14 | HAS_TORCHVISION = False
15 |
16 |
17 | @pytest.mark.skipif(not HAS_TORCHVISION, reason="torchvision not available")
18 | def test_resnet_block():
19 | """
20 | Ensure that the ResNet factory produces an nn.Module and that
21 | the output has the specified number of channels.
22 | """
23 | x = torch.ones((1, 1, 8, 8))
24 |
25 | factory = tv.ResNetBlockFactory()
26 | block = factory(1, 2)
27 | y = block(x)
28 | assert y.shape == (1, 2, 8, 8)
29 |
30 | block = factory(1, 2, downsample=2)
31 | y = block(x)
32 | assert y.shape == (1, 2, 4, 4)
33 |
34 | block = factory(1, 2, downsample=(1, 2))
35 | y = block(x)
36 | assert y.shape == (1, 2, 8, 4)
37 |
38 |
39 | @pytest.mark.skipif(not HAS_TORCHVISION, reason="torchvision not available")
40 | def test_convnext_block():
41 | """
42 | Ensure that the ConvNeXt factory produces an nn.Module and that
43 | the output has the specified number of channels.
44 | """
45 | x = torch.ones((1, 1, 8, 8))
46 |
47 | factory = tv.ConvNeXtBlockFactory()
48 | block = factory(1, 2)
49 | y = block(x)
50 | assert y.shape == (1, 2, 8, 8)
51 |
52 | block = factory(1, 2, downsample=2)
53 | y = block(x)
54 | assert y.shape == (1, 2, 4, 4)
55 |
56 | block = factory(1, 2, downsample=(1, 2))
57 | y = block(x)
58 | assert y.shape == (1, 2, 8, 4)
59 |
60 |
61 | @pytest.mark.skipif(not HAS_TORCHVISION, reason="torchvision not available")
62 | def test_swin_transformer_block():
63 | """
64 | Ensure that the ConvNeXt factory produces an nn.Module and that
65 | the output has the specified number of channels.
66 | """
67 | x = torch.ones((1, 16, 32, 32))
68 |
69 | factory = tv.SwinBlockFactory()
70 | block = factory(16, 16)
71 | y = block(x)
72 | assert y.shape == (1, 16, 32, 32)
73 |
74 | block = factory(16, 32, downsample=2)
75 | y = block(x)
76 | assert y.shape == (1, 32, 16, 16)
77 |
78 | block = factory(16, 32, downsample=4)
79 | y = block(x)
80 | assert y.shape == (1, 32, 8, 8)
81 |
--------------------------------------------------------------------------------
/test/models/pytorch/test_upsampling.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for the quantnn.models.pytorch.upsampling module.
3 | """
4 | import torch
5 | from torch import nn
6 |
7 | from quantnn.models.pytorch.upsampling import (
8 | BilinearFactory,
9 | UpsampleFactory,
10 | UpConvolutionFactory,
11 | )
12 |
13 |
14 | def test_bilinear():
15 | """
16 | Ensure bilinear upsampler factory supports upsampling with different
17 | factors along different dimensions.
18 | """
19 | up = BilinearFactory()(channels_in=16, channels_out=8, factor=(2, 4))
20 | x = torch.rand(2, 16, 32, 32)
21 | y = up(x)
22 |
23 | assert y.shape[1] == 8
24 | assert y.shape[-2] == 64
25 | assert y.shape[-1] == 128
26 |
27 |
28 | def test_upsample():
29 | """
30 | Ensure that generic upsampling factory supports upsampling with different
31 | factors along different dimensions.
32 | """
33 | fac = UpsampleFactory(mode="nearest")
34 | up = fac(channels_in=16, channels_out=8, factor=(2, 4))
35 | x = torch.rand(2, 16, 32, 32)
36 | y = up(x)
37 |
38 | assert y.shape[1] == 8
39 | assert y.shape[-2] == 64
40 | assert y.shape[-1] == 128
41 |
42 |
43 | def test_upconvolution():
44 | """
45 | Ensure that the up-convolution factory supports upsampling with different
46 | factors along different dimensions.
47 | """
48 | fac = UpConvolutionFactory(mode="nearest")
49 | up = fac(channels_in=16, channels_out=8, factor=(2, 4))
50 | x = torch.rand(2, 16, 32, 32)
51 | y = up(x)
52 |
53 | assert y.shape[1] == 8
54 | assert y.shape[-2] == 64
55 | assert y.shape[-1] == 128
56 |
--------------------------------------------------------------------------------
/test/models/test_keras.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for the PyTorch NN backend.
3 | """
4 | import tensorflow as tf
5 | from tensorflow import keras
6 |
7 | from quantnn.models.keras import QuantileLoss, CrossEntropyLoss
8 | import numpy as np
9 | from quantnn import (QRNN,
10 | set_default_backend,
11 | get_default_backend)
12 |
13 | def test_quantile_loss():
14 | """
15 | Ensure that quantile loss corresponds to half of absolute error
16 | loss and that masking works as expected.
17 | """
18 | loss = QuantileLoss([0.5], mask=-1e3)
19 |
20 | y_pred = np.random.rand(10, 1, 10)
21 | y = np.random.rand(10, 1, 10)
22 |
23 | l = loss(y, y_pred)
24 |
25 | dy = (y_pred - y)
26 | l_ref = 0.5 * np.mean(np.abs(dy))
27 |
28 | assert np.isclose(l, l_ref)
29 |
30 | y_pred = np.random.rand(20, 1, 10)
31 | y_pred[10:] = -2e3
32 | y = np.random.rand(20, 1, 10)
33 | y[10:] = -2e3
34 |
35 | loss = QuantileLoss([0.5], mask=-1e3)
36 | l = loss(y, y_pred)
37 | l_ref = loss(y[:10], y_pred[:10])
38 | assert np.isclose(l, l_ref)
39 |
40 | def test_cross_entropy_loss():
41 | """
42 | Test masking for cross entropy loss.
43 |
44 | Need to take into account that Keras, by default, expects channels along last axis.
45 | """
46 | y_pred = np.random.rand(10, 10, 10).astype(np.float32)
47 | y = np.ones((10, 10, 1), dtype=np.float32)
48 | bins = np.linspace(0, 1, 11)
49 | y[:, :, 0] = 0.55
50 |
51 | loss = CrossEntropyLoss(bins, mask=-1.0)
52 | ref = -y_pred[:, :, 5] + np.log(np.exp(y_pred).sum(1))
53 | assert np.all(np.isclose(loss(y, y_pred),
54 | ref.mean(),
55 | rtol=1e-3))
56 |
57 | y[5:, :, :] = -1.0
58 | y[:, 5:, :] = -1.0
59 | ref = -y_pred[:5, :5, 5] + np.log(np.exp(y_pred[:5, :5, :]).sum(-1))
60 | assert np.all(np.isclose(loss(y, y_pred),
61 | ref.mean()))
62 |
63 | loss = CrossEntropyLoss(10, mask=-1.0)
64 | y = np.ones((10, 10, 1), dtype=np.int32)
65 | y[:, :, :] = 5
66 | ref = -y_pred[:, :, 5] + np.log(np.exp(y_pred).sum(1))
67 | assert np.all(np.isclose(loss(y, y_pred),
68 | ref.mean(),
69 | rtol=1e-3))
70 |
71 | y[5:, :, :] = -1.0
72 | y[:, 5:, :] = -1.0
73 |
74 | ref = -y_pred[:5, :5, 5] + np.log(np.exp(y_pred[:5, :5, :]).sum(-1))
75 | assert np.all(np.isclose(loss(y, y_pred),
76 | ref.mean()))
77 |
78 | # Test binary case.
79 | loss = CrossEntropyLoss(2, mask=-1.0)
80 | y = np.ones((10, 10, 1), dtype=np.int32)
81 | y_pred = np.random.rand(10, 10, 1).astype(np.float32)
82 | y[:, :, :] = 1
83 | ref = tf.math.log_sigmoid(y_pred)
84 | assert np.all(np.isclose(
85 | loss(y, y_pred),
86 | -tf.math.reduce_mean(ref),
87 | rtol=1e-3
88 | ))
89 |
90 | y[5:, :, :] = -1.0
91 | y[:, 5:, :] = -1.0
92 |
93 | ref = -tf.math.log_sigmoid(y_pred[:5, :5])
94 | assert np.all(np.isclose(
95 | loss(y, y_pred),
96 | tf.math.reduce_mean(ref),
97 | rtol=1e-3
98 | ))
99 |
100 | def test_training_with_dataloader():
101 | """
102 | Ensure that training with a pytorch dataloader works.
103 | """
104 | set_default_backend("keras")
105 | x = np.random.rand(1024, 16)
106 | y = np.random.rand(1024)
107 |
108 | batched_data = [
109 | {
110 | "x": x[i * 128: (i + 1) * 128],
111 | "y": y[i * 128: (i + 1) * 128],
112 | }
113 | for i in range(1024 // 128)
114 | ]
115 |
116 | qrnn = QRNN(np.linspace(0.05, 0.95, 10), n_inputs=x.shape[1])
117 | qrnn.train(batched_data, n_epochs=1)
118 |
119 |
120 | def test_training_with_dict():
121 | """
122 | Ensure that training with batch objects as dicts works.
123 | """
124 | set_default_backend("keras")
125 | x = np.random.rand(1024, 16)
126 | y = np.random.rand(1024)
127 |
128 | batched_data = [
129 | {
130 | "x": x[i * 128: (i + 1) * 128],
131 | "y": y[i * 128: (i + 1) * 128],
132 | }
133 | for i in range(1024 // 128)
134 | ]
135 |
136 | qrnn = QRNN(np.linspace(0.05, 0.95, 10), n_inputs=x.shape[1])
137 |
138 | qrnn.train(batched_data, n_epochs=1)
139 |
140 |
141 | def test_training_with_dict_and_keys():
142 | """
143 | Ensure that training with batch objects as dicts and provided keys
144 | argument works.
145 | """
146 | set_default_backend("keras")
147 | x = np.random.rand(1024, 16)
148 | y = np.random.rand(1024)
149 |
150 | batched_data = [
151 | {
152 | "x": x[i * 128: (i + 1) * 128],
153 | "x_2": x[i * 128: (i + 1) * 128],
154 | "y": y[i * 128: (i + 1) * 128],
155 | }
156 | for i in range(1024 // 128)
157 | ]
158 |
159 | qrnn = QRNN(np.linspace(0.05, 0.95, 10), n_inputs=x.shape[1])
160 | qrnn.train(batched_data, n_epochs=1, keys=("x", "y"))
161 |
162 |
163 | def test_training_multiple_outputs():
164 | """
165 | Ensure that training with batch objects as dicts and provided keys
166 | argument works.
167 | """
168 | set_default_backend("keras")
169 |
170 | class MultipleOutputModel(keras.Model):
171 | def __init__(self):
172 | super().__init__()
173 | self.hidden = keras.layers.Dense(128, "relu", input_shape=(16,))
174 | self.head_1 = keras.layers.Dense(11, None)
175 | self.head_2 = keras.layers.Dense(11, None)
176 |
177 | def call(self, x):
178 | x = self.hidden(x)
179 | y_1 = self.head_1(x)
180 | y_2 = self.head_2(x)
181 | return {
182 | "y_1": y_1,
183 | "y_2": y_2
184 | }
185 |
186 | x = np.random.rand(1024, 16)
187 | y = np.random.rand(1024)
188 |
189 | batched_data = [
190 | {
191 | "x": x[i * 128: (i + 1) * 128],
192 | "y": {
193 | "y_1": y[i * 128: (i + 1) * 128],
194 | "y_2": y[i * 128: (i + 1) * 128]
195 | }
196 | }
197 | for i in range(1024 // 128)
198 | ]
199 |
200 | model = MultipleOutputModel()
201 | qrnn = QRNN(np.linspace(0.05, 0.95, 11), model=model)
202 | qrnn.train(batched_data, n_epochs=10, keys=("x", "y"))
203 |
--------------------------------------------------------------------------------
/test/test_data.py:
--------------------------------------------------------------------------------
1 | """
2 | Test for the quantnn.data module.
3 |
4 | """
5 | import logging
6 | import os
7 |
8 | import numpy as np
9 | import pytest
10 | try:
11 | import torch
12 | TORCH_AVAILABLE = True
13 | except ImportError:
14 | TORCH_AVAILABLE = False
15 |
16 | from quantnn.data import DataFolder, LazyDataFolder
17 |
18 |
19 | # Currently no SFTP test data available.
20 | HAS_LOGIN_INFO = False
21 |
22 |
23 | LOGGER = logging.getLogger(__file__)
24 |
25 |
26 | class Dataset:
27 | """
28 | A test dataset class to test the streaming of data via SFTP.
29 | """
30 | def __init__(self,
31 | filename,
32 | batch_size=1):
33 | """
34 | Create new dataset.
35 |
36 | Args:
37 | filename: Path of the file load the data from.
38 | batch_size: The batch size of the samples to return.
39 | """
40 |
41 | self.batch_size = batch_size
42 | data = np.load(filename)
43 | self.x = data["x"]
44 | self.y = data["y"].reshape(-1, 1)
45 | LOGGER.info("Loaded data from file %s.", filename)
46 |
47 | def _shuffle(self):
48 | """
49 | Shuffles the data order keeping x and y samples consistent.
50 | """
51 | indices = np.random.permutation(self.x.shape[0])
52 | self.x = self.x[indices]
53 | self.y = self.y[indices]
54 |
55 | def __len__(self):
56 | """ Number of samples in dataset. """
57 | return self.x.shape[0] // self.batch_size
58 |
59 | def __getitem__(self, index):
60 | """ Return batch from dataset. """
61 | if index >= len(self):
62 | raise IndexError()
63 |
64 | if index == 0:
65 | self._shuffle()
66 |
67 | start = index * self.batch_size
68 | end = (index + 1) * self.batch_size
69 | return (self.x[start:end], self.y[start:end])
70 |
71 |
72 | @pytest.mark.xfail()
73 | @pytest.mark.skipif(not HAS_LOGIN_INFO, reason="No SFTP login info.")
74 | def test_sftp_stream():
75 | """
76 | Assert that streaming via SFTP yields all data in the given folder
77 | and that kwargs are correctly passed on to dataset class.
78 | """
79 | host = "129.16.35.202"
80 | path = "/mnt/array1/share/MLDatasets/test/"
81 | stream = DataFolder("sftp://" + host + path,
82 | Dataset,
83 | kwargs={"batch_size": 2},
84 | aggregate=None,
85 | n_workers=16)
86 | stream_2 = DataFolder("sftp://" + host + path,
87 | Dataset,
88 | kwargs={"batch_size": 2},
89 | n_workers=16)
90 |
91 | next(iter(stream))
92 | next(iter(stream_2))
93 |
94 | x_sum = 0.0
95 | y_sum = 0.0
96 | for x, y in stream:
97 | x_sum += x.sum()
98 | y_sum += y.sum()
99 | assert x.shape[0] == 2
100 |
101 | x_sum = 0.0
102 | y_sum = 0.0
103 | for x, y in stream_2:
104 | x_sum += x.sum()
105 | y_sum += y.sum()
106 | assert x.shape[0] == 2
107 |
108 | x_sum = 0.0
109 | y_sum = 0.0
110 | for x, y in stream:
111 | x_sum += x.sum()
112 | y_sum += y.sum()
113 | assert x.shape[0] == 2
114 |
115 | x_sum = 0.0
116 | y_sum = 0.0
117 | for i, (x, y) in enumerate(stream_2):
118 | x_sum += x.sum()
119 | y_sum += y.sum()
120 | assert x.shape[0] == 2
121 |
122 | assert np.isclose(x_sum, 7 * 8 / 2 * 10 * 10)
123 | assert np.isclose(y_sum, 7 * 8 / 2 * 10)
124 |
125 |
126 | @pytest.mark.skipif(not HAS_LOGIN_INFO, reason="No SFTP login info.")
127 | def test_lazy_datafolder():
128 | host = "129.16.35.202"
129 | path = "/mnt/array1/share/MLDatasets/test/"
130 | stream = LazyDataFolder("sftp://" + host + path,
131 | Dataset,
132 | kwargs={"batch_size": 2},
133 | n_workers=2,
134 | batch_queue_size=1)
135 |
136 | x_sum = 0.0
137 | y_sum = 0.0
138 | for x, y in stream:
139 | x_sum += x.sum()
140 | y_sum += y.sum()
141 | assert x.shape[0] == 2
142 |
143 | assert np.isclose(x_sum, 7 * 8 / 2 * 10 * 10)
144 | assert np.isclose(y_sum, 7 * 8 / 2 * 10)
145 |
146 | class TensorDataset:
147 | """
148 | A test dataset class to test the streaming of data via SFTP.
149 | """
150 | def __init__(self,
151 | filename,
152 | batch_size=1):
153 | """
154 | Create new dataset.
155 |
156 | Args:
157 | filename: Path of the file load the data from.
158 | batch_size: The batch size of the samples to return.
159 | """
160 |
161 | self.batch_size = batch_size
162 | data = np.load(filename)
163 | self.x = data["x"]
164 | self.y = data["y"].reshape(-1, 1)
165 |
166 | def _shuffle(self):
167 | """
168 | Shuffles the data order keeping x and y samples consistent.
169 | """
170 | indices = np.random.permutation(self.x.shape[0])
171 | self.x = self.x[indices]
172 | self.y = self.y[indices]
173 |
174 | def __len__(self):
175 | """ Number of samples in dataset. """
176 | return self.x.shape[0] // self.batch_size
177 |
178 | def __getitem__(self, index):
179 | """ Return batch from dataset. """
180 | if index >= len(self):
181 | raise IndexError()
182 |
183 | if index == 0:
184 | self._shuffle()
185 |
186 | start = index * self.batch_size
187 | end = (index + 1) * self.batch_size
188 | return (torch.tensor(self.x[start:end]),
189 | torch.tensor(self.y[start:end]))
190 |
191 |
192 | @pytest.mark.skipif(not HAS_LOGIN_INFO, reason="No SFTP login info.")
193 | @pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not installed.")
194 | def test_aggregation():
195 | """
196 | Assert that aggregation of tensor works as expected.
197 | """
198 | host = "129.16.35.202"
199 | path = "/mnt/array1/share/MLDatasets/test/"
200 | stream = DataFolder("sftp://" + host + path,
201 | TensorDataset,
202 | kwargs={"batch_size": 1},
203 | aggregate=2,
204 | n_workers=16)
205 | stream_2 = DataFolder("sftp://" + host + path,
206 | TensorDataset,
207 | kwargs={"batch_size": 1},
208 | aggregate=2,
209 | n_workers=16)
210 |
211 | next(iter(stream))
212 | next(iter(stream_2))
213 |
214 | x_sum = 0.0
215 | y_sum = 0.0
216 | for x, y in stream:
217 | x_sum += x.sum()
218 | y_sum += y.sum()
219 | assert x.shape[0] == 2
220 |
221 | x_sum = 0.0
222 | y_sum = 0.0
223 | for x, y in stream_2:
224 | x_sum += x.sum()
225 | y_sum += y.sum()
226 | assert np.all(np.isclose(x[:, 0].detach().numpy().ravel(),
227 | y.detach().numpy().ravel()
228 | ))
229 | assert x.shape[0] == 2
230 |
231 | x_sum = 0.0
232 | y_sum = 0.0
233 | for x, y in stream:
234 | x_sum += x.sum()
235 | y_sum += y.sum()
236 | assert x.shape[0] == 2
237 |
238 | x_sum = 0.0
239 | y_sum = 0.0
240 |
241 | for x, y in stream_2:
242 | x_sum += x.sum()
243 | y_sum += y.sum()
244 | assert x.shape[0] == 2
245 |
246 | assert np.isclose(x_sum, 7 * 8 / 2 * 10 * 10)
247 | assert np.isclose(y_sum, 7 * 8 / 2 * 10)
248 |
--------------------------------------------------------------------------------
/test/test_data/x_train.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/simonpf/quantnn/c4d650cf0c6da5b4a704905b6c267d1ca996466f/test/test_data/x_train.npy
--------------------------------------------------------------------------------
/test/test_data/y_train.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/simonpf/quantnn/c4d650cf0c6da5b4a704905b6c267d1ca996466f/test/test_data/y_train.npy
--------------------------------------------------------------------------------
/test/test_drnn.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for quantnn.drnn module.
3 | """
4 | import os
5 | import tempfile
6 |
7 | import pytest
8 |
9 | import numpy as np
10 | from quantnn import drnn, set_default_backend, get_default_backend
11 | from quantnn.drnn import DRNN
12 |
13 | #
14 | # Import available backends.
15 | #
16 |
17 | backends = []
18 | try:
19 | import quantnn.models.keras
20 |
21 | backends += ["keras"]
22 | except:
23 | pass
24 |
25 | try:
26 | import quantnn.models.pytorch
27 |
28 | backends += ["pytorch"]
29 | except:
30 | pass
31 |
32 | class TestDrnn:
33 | def setup_method(self):
34 | dir = os.path.dirname(os.path.realpath(__file__))
35 | path = os.path.join(dir, "test_data")
36 | x_train = np.load(os.path.join(path, "x_train.npy"))
37 | x_mean = np.mean(x_train, keepdims=True)
38 | x_sigma = np.std(x_train, keepdims=True)
39 | self.x_train = (x_train - x_mean) / x_sigma
40 | self.bins = np.logspace(0, 3, 21)
41 | y = np.load(os.path.join(path, "y_train.npy"))
42 | self.y_train = y
43 |
44 | def test_to_categorical(self):
45 | """
46 | Assert that converting a continuous target variable to binned
47 | representation works as expected.
48 | """
49 | bins = np.linspace(0, 10, 11)
50 |
51 | y = np.arange(12) - 0.5
52 | y_cat = drnn._to_categorical(y, bins)
53 |
54 | assert y_cat[0] == 0
55 | assert np.all(np.isclose(y_cat[1:-1], np.arange(10)))
56 | assert y_cat[-1] == 9
57 |
58 | @pytest.mark.parametrize("backend", backends)
59 | def test_drnn(self, backend):
60 | """
61 | Test training of DRNN using numpy arrays as input.
62 | """
63 | set_default_backend(backend)
64 | drnn = DRNN(self.bins,
65 | n_inputs=self.x_train.shape[1])
66 | drnn.train((self.x_train, self.y_train),
67 | validation_data=(self.x_train, self.y_train),
68 | n_epochs=2)
69 |
70 | drnn.predict(self.x_train)
71 |
72 | mu = drnn.posterior_mean(self.x_train[:2, :])
73 | assert len(mu.shape) == 1
74 |
75 | r = drnn.sample_posterior(self.x_train[:4, :], n_samples=2)
76 | assert r.shape == (4, 2)
77 |
78 | @pytest.mark.parametrize("backend", backends)
79 | def test_drnn_dict_iterable(self, backend):
80 | """
81 | Test training with dataset object that yields dicts instead of
82 | tuples.
83 | """
84 | set_default_backend(backend)
85 | backend = get_default_backend()
86 |
87 | class DictWrapper:
88 | def __init__(self, data):
89 | self.data = data
90 |
91 | def __iter__(self):
92 | for x, y in self.data:
93 | yield {"x": x, "y": y}
94 |
95 | def __len__(self):
96 | return len(self.data)
97 |
98 | data = backend.BatchedDataset((self.x_train, self.y_train), 256)
99 | drnn = DRNN(self.bins,
100 | n_inputs=self.x_train.shape[1])
101 | drnn.train(DictWrapper(data), n_epochs=2, keys=("x", "y"))
102 |
103 | @pytest.mark.parametrize("backend", backends)
104 | def test_drnn_datasets(self, backend):
105 | """
106 | Provide data as dataset object instead of numpy arrays.
107 | """
108 | set_default_backend(backend)
109 | backend = get_default_backend()
110 | data = backend.BatchedDataset((self.x_train, self.y_train), 256)
111 | drnn = DRNN(self.bins,
112 | n_inputs=self.x_train.shape[1])
113 | drnn.train(data, n_epochs=2)
114 |
115 | @pytest.mark.parametrize("backend", backends)
116 | def test_save_drnn(self, backend):
117 | """
118 | Test saving and loading of DRNNs.
119 | """
120 | set_default_backend(backend)
121 | drnn = DRNN(self.bins,
122 | n_inputs=self.x_train.shape[1])
123 | f = tempfile.NamedTemporaryFile()
124 | drnn.save(f.name)
125 | drnn_loaded = DRNN.load(f.name)
126 |
127 | x_pred = drnn.predict(self.x_train)
128 | x_pred_loaded = drnn.predict(self.x_train)
129 |
130 | if not type(x_pred) == np.ndarray:
131 | x_pred = x_pred.detach()
132 |
133 | assert np.allclose(x_pred, x_pred_loaded)
134 |
--------------------------------------------------------------------------------
/test/test_generic.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for generic array manipulation functions.
3 | """
4 | import numpy as np
5 | import pytest
6 | from quantnn.generic import (get_array_module, to_array, sample_uniform,
7 | sample_gaussian, numel, concatenate, expand_dims,
8 | pad_zeros, pad_zeros_left, as_type, arange,
9 | reshape, trapz, cumsum, cumtrapz, ones, zeros,
10 | softmax, exp, tensordot, argmax, take_along_axis,
11 | digitize, scatter_add)
12 |
13 |
14 | @pytest.mark.parametrize("backend", pytest.backends)
15 | def test_get_array_module(backend):
16 | """
17 | Ensures that get_array_module returns right array object
18 | when given an array created using the arange method of the
19 | corresponding module object.
20 | """
21 | x = backend.ones(10)
22 | module = get_array_module(x)
23 | assert module == backend
24 |
25 |
26 | @pytest.mark.parametrize("backend", pytest.backends)
27 | def test_to_array(backend):
28 | """
29 | Converts numpy array to array of given backend and ensures
30 | that corresponding module object matches the backend.
31 | """
32 | x = np.arange(10)
33 | array = to_array(backend, x)
34 | assert get_array_module(array) == backend
35 |
36 |
37 | @pytest.mark.parametrize("backend", pytest.backends)
38 | def test_sample_uniform(backend):
39 | """
40 | Ensures that array of random samples has array type
41 | corresponding to the right backend module.
42 | """
43 | samples = sample_uniform(backend, (10, ))
44 | assert get_array_module(samples) == backend
45 |
46 |
47 | @pytest.mark.parametrize("backend", pytest.backends)
48 | def test_sample_gaussian(backend):
49 | """
50 | Ensures that array of random samples has array type
51 | corresponding to the right backend module.
52 | """
53 | samples = sample_gaussian(backend, (10, ))
54 | assert get_array_module(samples) == backend
55 |
56 |
57 | @pytest.mark.parametrize("backend", pytest.backends)
58 | def test_numel(backend):
59 | """
60 | Ensures that the numel function returns the right number of elements.
61 | """
62 | array = backend.ones(10)
63 | assert numel(array) == 10
64 |
65 |
66 | @pytest.mark.parametrize("backend", pytest.backends)
67 | def test_concatenate(backend):
68 | """
69 | Ensures that concatenation of array yields tensor with the expected size.
70 | """
71 | array_1 = backend.ones((10, 1))
72 | array_2 = backend.ones((10, 2))
73 | result = concatenate(backend, [array_1, array_2], 1)
74 | assert numel(result) == 30
75 |
76 |
77 | @pytest.mark.parametrize("backend", pytest.backends)
78 | def test_expand_dims(backend):
79 | """
80 | Ensures that expansion of dims yields expected shape.
81 | """
82 | array = backend.ones((10,))
83 | result = expand_dims(backend, array, 1)
84 | assert len(result.shape) == 2
85 | assert result.shape[1] == 1
86 |
87 |
88 | @pytest.mark.parametrize("backend", pytest.backends)
89 | def test_pad_zeros(backend):
90 | """
91 | Ensures that zero padding pads zeros.
92 | """
93 | array = backend.ones((10, 10))
94 | result = pad_zeros(backend, array, 2, 1)
95 | result = pad_zeros(backend, result, 1, 0)
96 | assert result.shape[0] == 12
97 | assert result.shape[1] == 14
98 | assert result[0, 1] == 0.0
99 | assert result[-1, -2] == 0.0
100 |
101 |
102 | @pytest.mark.parametrize("backend", pytest.backends)
103 | def test_pad_zeros_left(backend):
104 | """
105 | Ensures that zero padding pads zeros only on left side.
106 | """
107 | array = backend.ones((10, 10))
108 | result = pad_zeros_left(backend, array, 2, 1)
109 | result = pad_zeros_left(backend, result, 1, 0)
110 | assert result.shape[0] == 11
111 | assert result.shape[1] == 12
112 | assert result[0, 1] == 0.0
113 | assert result[-1, -2] != 0.0
114 |
115 |
116 | @pytest.mark.parametrize("backend", pytest.backends)
117 | def test_as_type(backend):
118 | """
119 | Ensures that conversion of types works.
120 | """
121 | array = backend.ones((10, 10))
122 | mask = array > 0.0
123 | result = as_type(backend, mask, array)
124 | assert array.dtype == result.dtype
125 |
126 |
127 | @pytest.mark.parametrize("backend", pytest.backends)
128 | def test_arange(backend):
129 | """
130 | Ensures that generation of ranges works as expected.
131 | """
132 | array = arange(backend, 0, 10.1, 1)
133 | assert array[0] == 0.0
134 | assert array[-1] == 10.0
135 |
136 |
137 | @pytest.mark.parametrize("backend", pytest.backends)
138 | def test_reshape(backend):
139 | array = arange(backend, 0, 10.1, 1)
140 | result = reshape(backend, array, (1, 11, 1))
141 | assert result.shape[0] == 1
142 | assert result.shape[1] == 11
143 | assert result.shape[2] == 1
144 |
145 |
146 | @pytest.mark.parametrize("backend", pytest.backends)
147 | def test_trapz(backend):
148 | array = arange(backend, 0, 10.1, 1)
149 | result = trapz(backend, array, array, 0)
150 | assert result == 50
151 |
152 |
153 | @pytest.mark.parametrize("backend", pytest.backends)
154 | def test_cumsum(backend):
155 | array = reshape(backend, arange(backend, 0, 10.1, 1), (11, 1))
156 | result = cumsum(backend, array, 0)
157 | assert result[-1, 0] == 55
158 | result = cumsum(backend, array, 1)
159 | assert result[-1, 0] == 10
160 |
161 |
162 | @pytest.mark.parametrize("backend", pytest.backends)
163 | def test_cumtrapz(backend):
164 | y = reshape(backend, arange(backend, 0, 10.1, 1), (11, 1))
165 | x = arange(backend, 0, 10.1, 1)
166 |
167 | result = cumtrapz(backend, y, x, 0)
168 | assert result[0, 0] == 0.0
169 | assert result[-1, 0] == 50.0
170 |
171 | result = cumtrapz(backend, y, 2.0 * x, 0)
172 | assert result[0, 0] == 0.0
173 | assert result[-1, 0] == 100.0
174 |
175 |
176 | @pytest.mark.parametrize("backend", pytest.backends)
177 | def test_zeros(backend):
178 | x = ones(backend, (1, 1))
179 | assert x[0, 0] == 1.0
180 |
181 |
182 | @pytest.mark.parametrize("backend", pytest.backends)
183 | def test_zeros(backend):
184 | x = zeros(backend, (1, 1))
185 | assert x[0, 0] == 0.0
186 |
187 |
188 | @pytest.mark.parametrize("backend", pytest.backends)
189 | def test_softmax(backend):
190 | array = arange(backend, 0, 10.1, 1)
191 | y = softmax(backend, array)
192 |
193 |
194 | @pytest.mark.parametrize("backend", pytest.backends)
195 | def test_tensordot(backend):
196 | x = arange(backend, 0, 10.1, 1)
197 | y = ones(backend, 11)
198 | z = tensordot(backend, x, y, ((0,), (0,)))
199 | assert np.isclose(z, 55)
200 |
201 |
202 | @pytest.mark.parametrize("backend", pytest.backends)
203 | def test_argmax(backend):
204 | x = arange(backend, 0, 10.1, 1).reshape(-1, 1)
205 | i = argmax(backend, x)
206 | assert i == 10
207 |
208 | i = argmax(backend, x, 1)
209 | assert all(i == 0)
210 |
211 |
212 | @pytest.mark.parametrize("backend", pytest.backends)
213 | def test_take_along_axis(backend):
214 | x = arange(backend, 0, 10.1, 1).reshape(-1, 1)
215 | indices = to_array(backend, [[0], [1], [2], [3]])
216 | i = take_along_axis(backend, x, indices, 0)
217 |
218 | assert i[0] == 0
219 | assert i[1] == 1
220 | assert i[2] == 2
221 | assert i[3] == 3
222 |
223 |
224 | @pytest.mark.parametrize("backend", pytest.backends)
225 | def test_digitize(backend):
226 | x = arange(backend, 0, 10.0, 1).reshape(-1, 1) + 0.5
227 | bins = arange(backend, 0, 10.1, 1)
228 |
229 | inds = digitize(backend, x, bins)
230 | assert inds[0] == 1
231 | assert inds[-1] == 10
232 |
233 |
234 | @pytest.mark.parametrize("backend", pytest.backends)
235 | def test_scatter_add(backend):
236 | x = zeros(backend, (3, 3))
237 | y = ones(backend, (2, 3))
238 | indices = to_array(backend, [0, 2])
239 |
240 | z = scatter_add(backend, x, indices, y, 0)
241 | assert np.isclose(z[0, 0], 1.0)
242 | assert np.isclose(z[1, 0], 0.0)
243 | assert np.isclose(z[2, 0], 1.0)
244 |
245 | x = zeros(backend, (3, 3))
246 | y = ones(backend, (3, 2))
247 | z = scatter_add(backend, x, indices, y, 1)
248 | assert np.isclose(z[0, 0], 1.0)
249 | assert np.isclose(z[0, 1], 0.0)
250 | assert np.isclose(z[0, 2], 1.0)
251 |
--------------------------------------------------------------------------------
/test/test_normalizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from quantnn.normalizer import Normalizer, MinMaxNormalizer
3 |
4 | def test_normalizer_2d():
5 | """
6 | Checks that all feature indices that are not excluded have zero
7 | mean and unit std. dev.
8 | """
9 | x = np.random.normal(size=(100000, 10)) + np.arange(10).reshape(1, -1)
10 | normalizer = Normalizer(x,
11 | exclude_indices=range(1, 10, 2))
12 |
13 | x_normed = normalizer(x)
14 |
15 | # Included indices should have zero mean and std. dev. 1.0.
16 | assert np.all(np.isclose(x_normed[:, ::2].mean(axis=0),
17 | 0.0,
18 | atol=1e-1))
19 | assert np.all(np.isclose(x_normed[:, ::2].std(axis=0),
20 | 1.0,
21 | 1e-1))
22 |
23 | # Excluded indices
24 | assert np.all(np.isclose(x_normed[:, 1::2].mean(axis=0),
25 | np.arange(10)[1::2].reshape(1, -1),
26 | 1e-2))
27 | assert np.all(np.isclose(x_normed[:, 1::2].std(axis=0),
28 | 1.0,
29 | 1e-2))
30 |
31 | # Channels without variation should be set to -1.0
32 | x = np.zeros((100, 10))
33 | normalizer = Normalizer(x)
34 | x_normed = normalizer(x)
35 | assert np.all(np.isclose(x_normed, -1.0))
36 |
37 | def test_min_max_normalizer_2d():
38 | """
39 | Checks that all feature indices that are not excluded have zero
40 | mean and unit std. dev.
41 | """
42 | x = np.random.normal(size=(100000, 11)) + np.arange(11).reshape(1, -1)
43 | normalizer = MinMaxNormalizer(x, exclude_indices=range(1, 10, 2))
44 | x[:, 10] = np.nan
45 |
46 | x_normed = normalizer(x)
47 |
48 | # Included indices should have minimum value -0.9 and
49 | # maximum value 1.0.
50 | assert np.all(np.isclose(x_normed[:, :10:2].min(axis=0),
51 | -1.0))
52 | assert np.all(np.isclose(x_normed[:, :10:2].max(axis=0),
53 | 1.0))
54 | # nan values should be set to -1.0.
55 | assert np.all(np.isclose(x_normed[:, -1], -1.5))
56 |
57 | # Channels without variation should be set to -1.0
58 | x = np.zeros((100, 10))
59 | normalizer = MinMaxNormalizer(x)
60 | x_normed = normalizer(x)
61 | assert np.all(np.isclose(x_normed, -1.0))
62 |
63 | def test_invert():
64 | """
65 | Ensure that the inverse function of the Normalizer works as expected.
66 | """
67 | x = np.random.normal(size=(100000, 10)) + np.arange(10).reshape(1, -1)
68 | normalizer = Normalizer(x, exclude_indices=[0, 1, 2])
69 |
70 | x_normed = normalizer(x)
71 | x = normalizer.invert(x_normed)
72 |
73 | assert np.all(np.isclose(np.mean(x, axis=0),
74 | np.arange(10, dtype=np.float32),
75 | atol=1e-2))
76 |
77 | def test_save_and_load(tmp_path):
78 | """
79 | Ensure that saved and loaded normalizer yields same results as original.
80 | """
81 | x = np.random.normal(size=(100000, 10)) + np.arange(10).reshape(1, -1)
82 | normalizer = Normalizer(x,
83 | exclude_indices=range(1, 10, 2))
84 | normalizer.save(tmp_path / "normalizer.pckl")
85 | loaded = Normalizer.load(tmp_path / "normalizer.pckl")
86 |
87 | x_normed = normalizer(x)
88 | x_normed_loaded = loaded(x)
89 |
90 | assert np.all(np.isclose(x_normed,
91 | x_normed_loaded))
92 |
93 |
94 | def test_load_sftp(tmp_path):
95 | """
96 | Ensure that saved and loaded normalizer yields same results as original.
97 | """
98 | x = np.random.normal(size=(100000, 10)) + np.arange(10).reshape(1, -1)
99 | normalizer = Normalizer(x,
100 | exclude_indices=range(1, 10, 2))
101 | normalizer.save(tmp_path / "normalizer.pckl")
102 | loaded = Normalizer.load(tmp_path / "normalizer.pckl")
103 |
104 | x_normed = normalizer(x)
105 | x_normed_loaded = loaded(x)
106 |
107 | assert np.all(np.isclose(x_normed,
108 | x_normed_loaded,
109 | rtol=1e-3))
110 |
--------------------------------------------------------------------------------
/test/test_packed_tensor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | from quantnn.packed_tensor import PackedTensor
5 |
6 | def fill_tensor(t, indices):
7 | """
8 | Fills tensor with corresponding sample indices.
9 | """
10 | for i, ind in enumerate(indices):
11 | t[i] = ind
12 | return t
13 |
14 |
15 | def make_random_packed_tensor(batch_size, samples, shape=(1,)):
16 | """
17 | Create a sparse tensor representing a training batch with
18 | missing samples. Which samples are missing is randomized.
19 | The elements of the tensor correspond to the sample index.
20 |
21 | Args:
22 | batch_size: The nominal batch size of the training batch.
23 | samples: The number of non-missing samples.
24 | shape: The of each sample in the training batch.
25 | """
26 | indices = sorted(
27 | np.random.choice(np.arange(batch_size), size=samples, replace=False)
28 | )
29 | t = np.ones((samples,) + shape, dtype=np.float32)
30 | t = fill_tensor(t, indices)
31 | return PackedTensor(t, batch_size, indices)
32 |
33 |
34 | def test_attributes():
35 | """
36 | Test attributes of packed tensor.
37 | """
38 | t = torch.ones((2, 2))
39 | t_p = PackedTensor(t, 4, [0, 1])
40 |
41 | assert t_p.batch_size == 4
42 | assert t_p.batch_indices == [0, 1]
43 | assert t_p.shape == (2, 2)
44 |
45 |
46 | def test_stack():
47 | """
48 | Test stacking of list of tensor into batch.
49 | """
50 | t = torch.ones((2, 2))
51 | u = 2.0 * torch.ones((2, 2))
52 |
53 | tensors = [None, t, u, None]
54 | b = PackedTensor.stack(tensors)
55 |
56 | assert b.batch_indices == [1, 2]
57 | assert b.batch_size == 4
58 |
59 | b_e = b.expand()
60 | assert (b_e[0] == 0.0).all()
61 | assert (b_e[1] == 1.0).all()
62 | assert (b_e[2] == 2.0).all()
63 | assert (b_e[0] == 0.0).all()
64 |
65 |
66 | def test_set_get_item():
67 | """
68 | Test setting of items.
69 | """
70 | t = torch.ones((2, 2))
71 | t_p = PackedTensor(t, 4, [2, 3])
72 |
73 | # Test setting channel.
74 | t_p[:, 0] = 4.0
75 | t_e = t_p.expand()
76 | assert t_e[0, 0] == 0.0
77 | assert t_e[2, 0] == 4.0
78 | assert t_e[2, 1] == 1.0
79 |
80 | # Test getting data from tensor.
81 | assert (t_p[:, 0].tensor == 4.0).all()
82 | assert (t_p[..., 0].tensor == 4.0).all()
83 |
84 |
85 | def test_expand():
86 | """
87 | Test expansion of packed tensor.
88 | """
89 | indices = [0, 3]
90 | t = fill_tensor(torch.ones((2, 2)), indices)
91 | t_p = PackedTensor(t, 4, indices)
92 | t_e = t_p.expand()
93 | assert not isinstance(t_e, PackedTensor)
94 | assert t_e.shape[0] == 4
95 | assert (t_e[0] == 0.0).all()
96 | assert (t_e[3] == 3.0).all()
97 |
98 | # Empty tensor
99 | t = torch.ones((2, 2))
100 | t_p = PackedTensor(t, 4, [])
101 | t_e = t_p.expand()
102 | assert not isinstance(t_e, PackedTensor)
103 | assert t_e.shape[0] == 4
104 | assert (t_e == 0.0).all()
105 |
106 |
107 | def test_intersection():
108 | """
109 | Test intersection of packed tensors.
110 | """
111 | indices = [0, 2]
112 | u = fill_tensor(torch.ones((2, 2)), indices)
113 | u_p = PackedTensor(u, 4, indices)
114 |
115 | indices = [2, 3]
116 | v = 2.0 * fill_tensor(torch.ones((2, 2)), indices)
117 | v_p = PackedTensor(v, 4, indices)
118 | u_i_p, v_i_p = u_p.intersection(v_p)
119 | assert u_i_p.batch_indices == [2]
120 | assert u_i_p.batch_size == 4
121 |
122 | u_i_e = u_i_p.expand()
123 | assert (u_i_e[2] == 2.0).all()
124 |
125 | v_i_e = v_i_p.expand()
126 | assert (v_i_e[2] == 2 * 2.0).all()
127 |
128 | u = torch.ones((2, 2))
129 | u_p = PackedTensor(u, 4, [0, 1])
130 | v = 2.0 * torch.ones((2, 2))
131 | v_p = PackedTensor(v, 4, [2, 3])
132 | u_i_p, v_i_p = u_p.intersection(v_p)
133 | assert u_i_p is None
134 | assert v_i_p is None
135 |
136 | for i in range(100):
137 | l = make_random_packed_tensor(100, 50)
138 | r = make_random_packed_tensor(100, 50)
139 | indices = sorted(list(set(l.batch_indices) & set(r.batch_indices)))
140 | l, r = l.intersection(r)
141 | for i, index in enumerate(indices):
142 | assert (l.tensor[i] == index).all()
143 | assert (r.tensor[i] == index).all()
144 |
145 |
146 | def test_splitting():
147 | """
148 | Test splitting of packed tensors into parts.
149 | """
150 | indices = [0, 2]
151 | u = fill_tensor(torch.ones((2, 2)), indices)
152 | u_p = PackedTensor(u, 4, indices)
153 |
154 | indices = [2, 3]
155 | v = 2.0 * fill_tensor(torch.ones((2, 2)), indices)
156 | v_p = PackedTensor(v, 4, indices)
157 | u_only, v_only, u_both, v_both = u_p.split_parts(v_p)
158 | assert u_only.batch_indices == [0]
159 | assert v_only.batch_indices == [3]
160 | assert u_both.batch_indices == [2]
161 | assert v_both.batch_indices == [2]
162 |
163 | for i in range(100):
164 | l = make_random_packed_tensor(100, 50)
165 | r = make_random_packed_tensor(100, 50)
166 | indices = sorted(list(set(l.batch_indices) & set(r.batch_indices)))
167 | l_only, r_only, l_both, r_both = l.split_parts(r)
168 |
169 | for l_val, r_val in zip(l_both.batch_indices, r_both.batch_indices):
170 | assert l_val == r_val
171 |
172 | for i, index in enumerate(indices):
173 | assert (l_both.tensor[i] == index).all()
174 | assert (r_both.tensor[i] == index).all()
175 |
176 | indices = sorted(list(set(l.batch_indices) - set(r.batch_indices)))
177 | for i, index in enumerate(indices):
178 | assert l_only.tensor[i] == index
179 |
180 | indices = sorted(list(set(r.batch_indices) - set(l.batch_indices)))
181 | for i, index in enumerate(indices):
182 | assert r_only.tensor[i] == index
183 |
184 | lr = l_only.union(r_only.union(l_both.union(r_both)))
185 | for i, index in enumerate(lr.batch_indices):
186 | assert np.isclose(lr.tensor[i], index)
187 |
188 |
189 | def test_intersection():
190 | """
191 | Test intersection of packed tensors.
192 | """
193 | indices = [0, 2]
194 | u = fill_tensor(torch.ones((2, 2)), indices)
195 | u_p = PackedTensor(u, 4, indices)
196 |
197 | indices = [2, 3]
198 | v = 2.0 * fill_tensor(torch.ones((2, 2)), indices)
199 | v_p = PackedTensor(v, 4, indices)
200 | u_i_p, v_i_p = u_p.intersection(v_p)
201 | assert u_i_p.batch_indices == [2]
202 | assert u_i_p.batch_size == 4
203 |
204 | u_i_e = u_i_p.expand()
205 | assert (u_i_e[2] == 2.0).all()
206 |
207 | v_i_e = v_i_p.expand()
208 | assert (v_i_e[2] == 2 * 2.0).all()
209 |
210 | u = torch.ones((2, 2))
211 | u_p = PackedTensor(u, 4, [0, 1])
212 | v = 2.0 * torch.ones((2, 2))
213 | v_p = PackedTensor(v, 4, [2, 3])
214 | u_i_p, v_i_p = u_p.intersection(v_p)
215 | assert u_i_p is None
216 | assert v_i_p is None
217 |
218 | for i in range(100):
219 | l = make_random_packed_tensor(100, 50)
220 | r = make_random_packed_tensor(100, 50)
221 | indices = sorted(list(set(l.batch_indices) & set(r.batch_indices)))
222 | l, r = l.intersection(r)
223 | for i, index in enumerate(indices):
224 | assert (l.tensor[i] == index).all()
225 | assert (r.tensor[i] == index).all()
226 |
227 | def test_difference():
228 | """
229 | Test difference of packed tensors.
230 | """
231 | u = torch.zeros((2, 2))
232 | u[1] = 2.0
233 | u_p = PackedTensor(u, 4, [1, 3])
234 | v = 1.0 * torch.ones((2, 2))
235 | v_p = PackedTensor(v, 4, [0, 1])
236 |
237 | d_p = u_p.difference(v_p)
238 | assert d_p.batch_indices == [0, 3]
239 | assert d_p.batch_size == 4
240 |
241 | #d_e = d_p.expand()
242 | #assert (d_e[0] == 1.0).all()
243 | #assert (d_e[1] == 0.0).all()
244 | #assert (d_e[2] == 0.0).all()
245 | #assert (d_e[3] == 2.0).all()
246 |
247 | for i in range(10):
248 | l = make_random_packed_tensor(100, 50)
249 | r = make_random_packed_tensor(100, 50)
250 | indices = sorted(list(set(l.batch_indices) ^ set(r.batch_indices)))
251 | d = l.difference(r)
252 | for i, index in enumerate(indices):
253 | assert (d.tensor[i] == index).all()
254 |
255 |
256 | def test_union():
257 | """
258 | Test union of packed tensors.
259 | """
260 | u = torch.ones((2, 2))
261 | u_p = PackedTensor(u, 4, [3])
262 | v = 2.0 * torch.ones((2, 2))
263 | v_p = PackedTensor(v, 4, [1, 2])
264 |
265 | s_p = u_p.union(v_p)
266 | assert s_p.batch_indices == [1, 2, 3]
267 | assert s_p.batch_size == 4
268 |
269 | s_e = s_p.expand()
270 | assert (s_e[0] == 0.0).all()
271 | assert (s_e[1] == 2.0).all()
272 | assert (s_e[2] == 2.0).all()
273 | assert (s_e[3] == 1.0).all()
274 |
275 | s_p = v_p.union(u_p)
276 | assert s_p.batch_indices == [1, 2, 3]
277 | assert s_p.batch_size == 4
278 |
279 | s_e = s_p.expand()
280 | assert (s_e[0] == 0.0).all()
281 | assert (s_e[1] == 2.0).all()
282 | assert (s_e[2] == 2.0).all()
283 | assert (s_e[3] == 1.0).all()
284 |
285 | for i in range(100):
286 | l = make_random_packed_tensor(100, 50)
287 | r = make_random_packed_tensor(100, 50)
288 | indices = sorted(list(set(l.batch_indices) | set(r.batch_indices)))
289 | s = l.union(r)
290 |
291 | for i, index in enumerate(indices):
292 | assert (s.tensor[i] == index).all()
293 |
294 |
295 | def test_apply_batch_norm():
296 | """
297 | Test application of batch norm layer, which requires the dim() member
298 | function.
299 | """
300 | t = make_random_packed_tensor(100, 50, (8, 16, 16))
301 | norm = nn.BatchNorm2d(8)
302 | norm.weight.data.fill_(0)
303 | norm.bias.data.fill_(1.0)
304 | y = norm(t)
305 |
306 | assert (y.tensor == 1.0).all()
307 |
--------------------------------------------------------------------------------
/test/test_qrnn.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests the QRNN implementation for all available backends.
3 | """
4 | from quantnn import (QRNN,
5 | set_default_backend,
6 | get_default_backend)
7 | import numpy as np
8 | import os
9 | import importlib
10 | import pytest
11 | import tempfile
12 |
13 | #
14 | # Import available backends.
15 | #
16 |
17 | backends = []
18 | try:
19 | import quantnn.models.keras
20 |
21 | backends += ["keras"]
22 | except:
23 | pass
24 |
25 | try:
26 | import quantnn.models.pytorch
27 |
28 | backends += ["pytorch"]
29 | except:
30 | pass
31 |
32 |
33 | class TestQrnn:
34 | def setup_method(self):
35 | dir = os.path.dirname(os.path.realpath(__file__))
36 | path = os.path.join(dir, "test_data")
37 | x_train = np.load(os.path.join(path, "x_train.npy"))
38 | x_mean = np.mean(x_train, keepdims=True)
39 | x_sigma = np.std(x_train, keepdims=True)
40 | self.x_train = (x_train - x_mean) / x_sigma
41 | self.y_train = np.load(os.path.join(path, "y_train.npy"))
42 |
43 | @pytest.mark.parametrize("backend", backends)
44 | def test_qrnn(self, backend):
45 | """
46 | Test training of QRNNs using numpy arrays as input.
47 | """
48 | set_default_backend(backend)
49 | qrnn = QRNN(np.linspace(0.05, 0.95, 10),
50 | n_inputs=self.x_train.shape[1])
51 | qrnn.train((self.x_train, self.y_train),
52 | validation_data=(self.x_train, self.y_train),
53 | n_epochs=2)
54 |
55 | qrnn.predict(self.x_train)
56 |
57 | x, qs = qrnn.cdf(self.x_train[:2, :])
58 | assert qs[0] == 0.0
59 | assert qs[-1] == 1.0
60 |
61 | x, y = qrnn.pdf(self.x_train[:2, :])
62 | assert x.shape == y.shape
63 |
64 | mu = qrnn.posterior_mean(self.x_train[:2, :])
65 | assert len(mu.shape) == 1
66 |
67 | r = qrnn.sample_posterior(self.x_train[:4, :], n_samples=2)
68 | assert r.shape == (4, 2)
69 |
70 | r = qrnn.sample_posterior_gaussian_fit(self.x_train[:4, :], n_samples=2)
71 | assert r.shape == (4, 2)
72 |
73 | @pytest.mark.parametrize("backend", backends)
74 | def test_qrnn_dict_iterable(self, backend):
75 | """
76 | Test training with dataset object that yields dicts instead of
77 | tuples.
78 | """
79 | set_default_backend(backend)
80 | backend = get_default_backend()
81 |
82 | class DictWrapper:
83 | def __init__(self, data):
84 | self.data = data
85 |
86 | def __iter__(self):
87 | for x, y in self.data:
88 | yield {"x": x, "y": y}
89 |
90 | def __len__(self):
91 | return len(self.data)
92 |
93 | data = backend.BatchedDataset((self.x_train, self.y_train), 256)
94 | qrnn = QRNN(np.linspace(0.05, 0.95, 10),
95 | n_inputs=self.x_train.shape[1])
96 | qrnn.train(DictWrapper(data), n_epochs=2, keys=("x", "y"))
97 |
98 | @pytest.mark.parametrize("backend", backends)
99 | def test_qrnn_datasets(self, backend):
100 | """
101 | Provide data as dataset object instead of numpy arrays.
102 | """
103 | set_default_backend(backend)
104 | backend = get_default_backend()
105 | data = backend.BatchedDataset((self.x_train, self.y_train), 256)
106 | qrnn = QRNN(np.linspace(0.05, 0.95, 10),
107 | n_inputs=self.x_train.shape[1])
108 | qrnn.train(data, n_epochs=2)
109 |
110 | @pytest.mark.parametrize("backend", backends)
111 | def test_save_qrnn(self, backend):
112 | """
113 | Test saving and loading of QRNNs.
114 | """
115 | set_default_backend(backend)
116 | qrnn = QRNN(np.linspace(0.05, 0.95, 10),
117 | n_inputs=self.x_train.shape[1])
118 | f = tempfile.NamedTemporaryFile()
119 | qrnn.save(f.name)
120 | qrnn_loaded = QRNN.load(f.name)
121 |
122 | x_pred = qrnn.predict(self.x_train)
123 | x_pred_loaded = qrnn.predict(self.x_train)
124 |
125 | if not type(x_pred) == np.ndarray:
126 | x_pred = x_pred.detach()
127 |
128 | assert np.allclose(x_pred, x_pred_loaded)
129 |
130 | @pytest.mark.skipif(not "pytorch" in backends,
131 | reason="No PyTorch backend.")
132 | def test_save_qrnn_pytorch_model(self):
133 | """
134 | Test saving and loading of QRNNs.
135 | """
136 | from torch import nn
137 | quantiles = np.linspace(0.05, 0.95, 10)
138 | model = nn.Sequential(nn.Linear(self.x_train.shape[1], quantiles.size))
139 | qrnn = QRNN(quantiles, model=model)
140 |
141 | # Train the model
142 | data = quantnn.models.pytorch.BatchedDataset((self.x_train, self.y_train), 256)
143 | qrnn.train(data, n_epochs=2)
144 |
145 | # Save the model
146 | f = tempfile.NamedTemporaryFile()
147 | qrnn.save(f.name)
148 | qrnn_loaded = QRNN.load(f.name)
149 |
150 | # Compare predictions from saved and loaded model.
151 | x_pred = qrnn.predict(self.x_train)
152 | x_pred_loaded = qrnn.predict(self.x_train)
153 | if not type(x_pred) == np.ndarray:
154 | x_pred = x_pred.detach()
155 |
156 | assert np.allclose(x_pred, x_pred_loaded)
157 |
--------------------------------------------------------------------------------
/test/test_tensor_backends.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import scipy as sp
4 |
5 | from quantnn.backends import TENSOR_BACKENDS
6 |
7 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS)
8 | def test_conversion(backend):
9 | """
10 | Ensure that conversion back and forth from numpy arrays works.
11 | """
12 |
13 | x = np.random.rand(10, 10)
14 | x_b = backend.from_numpy(x)
15 | x_c = backend.to_numpy(x_b)
16 |
17 | assert np.all(np.isclose(x, x_c))
18 |
19 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS)
20 | def test_sample_uniform(backend):
21 | x = backend.to_numpy(backend.sample_uniform((100, 100)))
22 | assert np.all(np.logical_and(x >= 0.0, x <= 1.0))
23 |
24 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS)
25 | def test_sample_gaussian(backend):
26 | x = backend.to_numpy(backend.sample_gaussian((100, 100)))
27 | assert np.isclose(x.mean(), 0.0, atol=1e-1)
28 | assert np.isclose(x.std(), 1.0, atol=1e-1)
29 |
30 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS)
31 | def test_size(backend):
32 | x = np.arange(10)
33 | x = backend.from_numpy(x)
34 | n = backend.size(x)
35 | assert n == 10
36 |
37 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS)
38 | def test_concatenate(backend):
39 | x = np.arange(10)
40 | x = backend.from_numpy(x)
41 | xs = [backend.expand_dims(x, 0)] * 10
42 | xs = backend.concatenate(xs, 0)
43 | xs = backend.to_numpy(xs)
44 |
45 | for i in range(10):
46 | assert np.all(np.isclose(xs[:, i], i))
47 |
48 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS)
49 | def test_expand_dims(backend):
50 | x = np.arange(10)
51 | x = backend.from_numpy(x)
52 | y = backend.expand_dims(x, 0)
53 |
54 | assert len(y.shape) == len(x.shape) + 1
55 | assert y.shape[0] == 1
56 |
57 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS)
58 | def test_exp(backend):
59 | x = np.arange(10).astype(np.float32)
60 | x = backend.from_numpy(x)
61 | y = backend.exp(x)
62 |
63 | assert np.all(np.isclose(backend.to_numpy(y), np.exp(x)))
64 |
65 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS)
66 | def test_log(backend):
67 | x = np.arange(10).astype(np.float32)
68 | x = backend.from_numpy(x)
69 | y = backend.log(x)
70 |
71 | assert np.all(np.isclose(backend.to_numpy(y), np.log(x)))
72 |
73 |
74 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS)
75 | def test_pad_zeros(backend):
76 | x = np.arange(10)
77 | x = backend.from_numpy(x)
78 | xs = [backend.expand_dims(x, 0)] * 10
79 | xs = backend.concatenate(xs, 0)
80 |
81 | xs = backend.pad_zeros(xs, 2, 0)
82 | xs = backend.pad_zeros(xs, 1, 1)
83 | xs = backend.to_numpy(xs)
84 |
85 |
86 | assert np.all(xs[:2, :] == 0.0)
87 | assert np.all(xs[-2:, :] == 0.0)
88 |
89 | assert np.all(xs[:, :1] == 0.0)
90 | assert np.all(xs[:, -1:] == 0.0)
91 |
92 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS)
93 | def test_pad_zeros_left(backend):
94 | x = np.arange(10)
95 | x = backend.from_numpy(x)
96 | xs = [backend.expand_dims(x, 0)] * 10
97 | xs = backend.concatenate(xs, 0)
98 |
99 | xs = backend.pad_zeros_left(xs, 2, 0)
100 | xs = backend.pad_zeros_left(xs, 1, 1)
101 | xs = backend.to_numpy(xs)
102 |
103 | assert np.all(xs[:2, :] == 0.0)
104 | assert not np.all(xs[-2:, :] == 0.0)
105 |
106 | assert np.all(xs[:, :1] == 0.0)
107 | assert not np.all(xs[:, -1:] == 0.0)
108 |
109 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS)
110 | def test_arange(backend):
111 | x = backend.arange(2, 10, 2)
112 | x = backend.to_numpy(x)
113 |
114 | assert np.all(np.isclose(x, np.arange(2, 10, 2)))
115 |
116 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS)
117 | def test_reshape(backend):
118 | x = backend.arange(2, 10, 2)
119 | x = backend.reshape(x, (-1, 1))
120 | x = backend.to_numpy(x)
121 | assert np.all(np.isclose(x, np.arange(2, 10, 2).reshape(-1, 1)))
122 |
123 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS)
124 | def test_trapz(backend):
125 |
126 | x = backend.arange(2, 10, 2)
127 | y = backend.ones(like=x)
128 |
129 | integral = backend.to_numpy(backend.trapz(y, x, 0))
130 | assert np.all(np.isclose(integral, 6))
131 |
132 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS)
133 | def test_cumsum(backend):
134 |
135 | x = backend.arange(2, 10, 2)
136 | cumsum = backend.to_numpy(backend.cumsum(x, 0))
137 | assert np.all(np.isclose(cumsum, np.cumsum(x)))
138 |
139 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS)
140 | def test_zeros(backend):
141 |
142 | zeros_1 = backend.zeros((2, 2, 2))
143 | zeros_2 = backend.zeros(like=zeros_1)
144 | zeros_1 = backend.to_numpy(zeros_1)
145 | zeros_2 = backend.to_numpy(zeros_2)
146 |
147 | assert np.all(np.isclose(zeros_1, 0.0))
148 | assert np.all(np.isclose(zeros_1,
149 | zeros_2))
150 |
151 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS)
152 | def test_ones(backend):
153 |
154 | ones_1 = backend.ones((2, 2, 2))
155 | ones_2 = backend.ones(like=ones_1)
156 | ones_1 = backend.to_numpy(ones_1)
157 | ones_2 = backend.to_numpy(ones_2)
158 |
159 | assert np.all(np.isclose(ones_1, 1.0))
160 | assert np.all(np.isclose(ones_1,
161 | ones_2))
162 |
163 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS)
164 | def test_softmax(backend):
165 |
166 | x = backend.sample_uniform((10, 10))
167 | x_sf = backend.softmax(x, 1)
168 | x_np = backend.to_numpy(x)
169 | x_sf_np = sp.special.softmax(x_np, 1)
170 |
171 | assert np.all(np.isclose(x_sf, x_sf_np))
172 |
173 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS)
174 | def test_where(backend):
175 | x = backend.ones((10, 10))
176 | y = backend.zeros((10, 10))
177 | z = backend.where(x > 0, y, x)
178 | z_np = backend.to_numpy(z)
179 | assert np.all(np.isclose(z_np, 0.0))
180 |
181 | x = backend.ones((10, 10))
182 | y = backend.zeros((10, 10))
183 | z = backend.where(x > 1, y, x)
184 | z_np = backend.to_numpy(z)
185 | assert np.all(np.isclose(z_np, 1.0))
186 |
--------------------------------------------------------------------------------
/test/test_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import xarray as xr
3 |
4 | from quantnn.utils import (apply,
5 | serialize_dataset,
6 | deserialize_dataset)
7 |
8 | def test_apply():
9 |
10 | f1 = lambda x: 2 * x
11 | f2 = lambda x, y: x + y
12 |
13 | d = {i: i for i in range(5)}
14 |
15 | a = apply(f1, 1)
16 | b = apply(f2, 1, 1)
17 | assert a == a
18 | assert b == b
19 |
20 | d_a = apply(f1, d)
21 | d_b = apply(f2, d, d)
22 | for k in d:
23 | assert k == d_a[k] // 2
24 | assert k == d_b[k] // 2
25 |
26 | def test_serialization():
27 | """
28 | Make sure that serialization of xarray datasets works.
29 | """
30 | dataset_ref = xr.Dataset({"x": (("a", "b"), np.ones((10, 10)))})
31 | b = serialize_dataset(dataset_ref)
32 | dataset = deserialize_dataset(b)
33 |
34 | assert np.all(np.isclose(dataset["x"].data,
35 | dataset_ref["x"].data))
36 |
37 |
38 |
--------------------------------------------------------------------------------