├── .github └── workflows │ ├── install_and_test.yml │ └── release.yml ├── LICENSE ├── MANIFEST.in ├── README.md ├── docs └── source │ ├── _static │ └── logo.svg │ ├── _templates │ └── logo.html │ ├── api │ ├── density.rst │ ├── drnn.rst │ ├── normalizer.rst │ ├── qrnn.rst │ ├── quantiles.rst │ └── quantnn.rst │ ├── api_reference.rst │ ├── cdf_qrnn.png │ ├── conf.py │ ├── drnn.rst │ ├── examples.rst │ ├── index.rst │ ├── notebooks │ ├── cloud_top_pressure_retrieval.ipynb │ ├── rain_rate_retrieval.ipynb │ └── simple_example.ipynb │ ├── pdf_drnn.png │ ├── qrnn.rst │ ├── qrnn.svg │ ├── quantiles.svg │ └── user_guide.rst ├── notebooks ├── cloud_top_pressure_retrieval.ipynb ├── comparison_xgboost.ipynb ├── convolutional_rain_rate_retrieval.ipynb ├── migrating_from_mse_regression.ipynb ├── rain_rate_retrieval.ipynb ├── simple_example.ipynb └── tensor_board_logging.ipynb ├── quantnn.yml ├── quantnn ├── __init__.py ├── a_priori.py ├── backends │ ├── __init__.py │ ├── pytorch.py │ └── tensor.py ├── common.py ├── data.py ├── data │ └── matplotlib_style.rc ├── density.py ├── drnn.py ├── examples │ ├── __init__.py │ ├── gprof_conv.py │ ├── gprof_simple.py │ ├── modis_ctp.py │ └── simple.py ├── files │ ├── __init__.py │ └── sftp.py ├── generic │ └── __init__.py ├── logging │ ├── __init__.py │ └── multiprocessing.py ├── metrics.py ├── models │ ├── __init__.py │ ├── keras │ │ ├── __init__.py │ │ ├── padding.py │ │ ├── unet.py │ │ └── xception.py │ └── pytorch │ │ ├── __init__.py │ │ ├── aggregators.py │ │ ├── base.py │ │ ├── blocks.py │ │ ├── common.py │ │ ├── decoders.py │ │ ├── downsampling.py │ │ ├── encoders.py │ │ ├── factories.py │ │ ├── fully_connected.py │ │ ├── generative.py │ │ ├── lightning.py │ │ ├── logging.py │ │ ├── normalization.py │ │ ├── resnet.py │ │ ├── stages.py │ │ ├── torchvision.py │ │ ├── unet.py │ │ ├── upsampling.py │ │ └── xception.py ├── mrnn.py ├── neural_network_model.py ├── normalizer.py ├── packed_tensor.py ├── plotting.py ├── qrnn.py ├── quantiles.py ├── transformations.py └── utils.py ├── setup.py └── test ├── conftest.py ├── files ├── test_files.py └── test_sftp.py ├── models ├── pytorch │ ├── test_aggregators.py │ ├── test_base.py │ ├── test_blocks.py │ ├── test_decoders.py │ ├── test_downsampling.py │ ├── test_encoders.py │ ├── test_fully_connected.py │ ├── test_normalization.py │ ├── test_torchvision.py │ └── test_upsampling.py ├── test_keras.py └── test_pytorch.py ├── test_data.py ├── test_data ├── x_train.npy └── y_train.npy ├── test_density.py ├── test_drnn.py ├── test_generic.py ├── test_normalizer.py ├── test_packed_tensor.py ├── test_qrnn.py ├── test_quantiles.py ├── test_tensor_backends.py └── test_utils.py /.github/workflows/install_and_test.yml: -------------------------------------------------------------------------------- 1 | name: install_and_test 2 | on: [push] 3 | jobs: 4 | install_and_test_job: 5 | strategy: 6 | matrix: 7 | os: [ubuntu-latest] 8 | python: [3.8, 3.9] 9 | runs-on: ${{ matrix.os }} 10 | steps: 11 | - uses: actions/checkout@v2 12 | - uses: actions/setup-python@v2 13 | with: 14 | python-version: ${{ matrix.python }} 15 | - run: pip install torch torchvision pytest tensorflow 16 | - run: pip install . 17 | - run: pytest test 18 | env: 19 | QUANTNN_SFTP_USER: ${{ secrets.QUANTNN_SFTP_USER }} 20 | QUANTNN_SFTP_PASSWORD: ${{ secrets.QUANTNN_SFTP_PASSWORD }} 21 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: release 2 | on: 3 | push: 4 | tags: 5 | - '*' 6 | jobs: 7 | release_job: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v2 11 | with: 12 | ref: 'main' 13 | - uses: actions/setup-python@v2 14 | with: 15 | python-version: '3.8' 16 | - run: pip install . 17 | - run: pip install wheel twine 18 | - run: python setup.py sdist bdist_wheel 19 | - run: python -m twine upload -u __token__ -p ${{ secrets.TWINE_TOKEN }} dist/* 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2020 Simon Pfreundschuh 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | 9 | End license text. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include quantnn/data/matplotlib_style.rc 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # quantnn 2 | 3 | The ``quantnn`` package provides an implementation of quantile regression neural 4 | networks on top of Keras and Pytorch. 5 | -------------------------------------------------------------------------------- /docs/source/_templates/logo.html: -------------------------------------------------------------------------------- 1 |
6 |
7 |
15 | quantnn logo quantnn
16 |
17 | -------------------------------------------------------------------------------- /docs/source/api/density.rst: -------------------------------------------------------------------------------- 1 | .. automodule:: quantnn.density 2 | :members: 3 | -------------------------------------------------------------------------------- /docs/source/api/drnn.rst: -------------------------------------------------------------------------------- 1 | .. automodule:: quantnn.drnn 2 | :members: 3 | -------------------------------------------------------------------------------- /docs/source/api/normalizer.rst: -------------------------------------------------------------------------------- 1 | .. automodule:: quantnn.normalizer 2 | :members: 3 | -------------------------------------------------------------------------------- /docs/source/api/qrnn.rst: -------------------------------------------------------------------------------- 1 | .. automodule:: quantnn.qrnn 2 | :members: 3 | :inherited-members: 4 | -------------------------------------------------------------------------------- /docs/source/api/quantiles.rst: -------------------------------------------------------------------------------- 1 | .. automodule:: quantnn.quantiles 2 | :members: 3 | -------------------------------------------------------------------------------- /docs/source/api/quantnn.rst: -------------------------------------------------------------------------------- 1 | .. automodule:: quantnn 2 | :members: 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | :caption: Submodules 7 | 8 | backends 9 | qrnn 10 | quantiles 11 | drnn 12 | density 13 | data 14 | normalizer 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /docs/source/api_reference.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ============= 3 | 4 | This part of the documentation contains the source code documentation of 5 | the **quantnn**. It provides detailed information on the usage of specific 6 | components of the package. 7 | 8 | .. toctree:: 9 | :maxdepth: 2 10 | 11 | api/quantnn 12 | api/data 13 | -------------------------------------------------------------------------------- /docs/source/cdf_qrnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simonpf/quantnn/c4d650cf0c6da5b4a704905b6c267d1ca996466f/docs/source/cdf_qrnn.png -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | # -- Project information ----------------------------------------------------- 18 | 19 | project = 'quantnn' 20 | copyright = '2020, Simon Pfreundschuh' 21 | author = 'Simon Pfreundschuh' 22 | 23 | # The full version, including alpha/beta/rc tags 24 | release = '0.0.1' 25 | 26 | 27 | # -- General configuration --------------------------------------------------- 28 | 29 | # Add any Sphinx extension module names here, as strings. They can be 30 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 31 | # ones. 32 | extensions = ["nbsphinx", "sphinx.ext.autodoc", "sphinx.ext.napoleon", 33 | "sphinx.ext.autosummary"] 34 | 35 | # Add any paths that contain templates here, relative to this directory. 36 | templates_path = ['_templates'] 37 | 38 | # List of patterns, relative to source directory, that match files and 39 | # directories to ignore when looking for source files. 40 | # This pattern also affects html_static_path and html_extra_path. 41 | exclude_patterns = [] 42 | 43 | 44 | # -- Options for HTML output ------------------------------------------------- 45 | 46 | # The theme to use for HTML and HTML Help pages. See the documentation for 47 | # a list of builtin themes. 48 | # 49 | html_theme = 'smpl' 50 | html_theme_options = { 51 | "navigation_bar_minimum_height": "15vh", 52 | "navigation_bar_targets": ["index.html", 53 | "user_guide.html", 54 | "examples.html", 55 | "api_reference.html"], 56 | "navigation_bar_names": ["Home", "User guide", "Examples", "API Reference"], 57 | "navigation_bar_element_padding": "40px", 58 | "navigation_bar_background_color": "#333333", 59 | "navigation_bar_element_hover_color": "#ff5050", 60 | "navigation_bar_border_color": "#ff5050", 61 | "navigation_bar_border_style": "solid", 62 | "navigation_bar_border_width": "0px 0px 0px 0px", 63 | 64 | "link_color": "#ff5050", 65 | "link_visited_color": "#ff5050", 66 | "link_hover_color": "#990000", 67 | 68 | "sidebars_right": [], 69 | "sidebars_left":["localtoc.html", "globaltoc.html"], 70 | "globaltoc_maxdepth": 1, 71 | "inline_code_border_radius": "2px", 72 | "sidebar_left_border_color": "#ff5050", 73 | 74 | "highlight_border_color": "#ff5050", 75 | } 76 | 77 | # Add any paths that contain custom static files (such as style sheets) here, 78 | # relative to this directory. They are copied after the builtin static files, 79 | # so a file named "default.css" will overwrite the builtin "default.css". 80 | html_static_path = ['_static'] 81 | -------------------------------------------------------------------------------- /docs/source/drnn.rst: -------------------------------------------------------------------------------- 1 | ========================================== 2 | Density regression neural networks (DRNNs) 3 | ========================================== 4 | 5 | How it works 6 | ------------ 7 | 8 | Defining a model 9 | ---------------- 10 | 11 | Training 12 | -------- 13 | 14 | Evaluation 15 | ---------- 16 | 17 | Handling output 18 | --------------- 19 | 20 | Loading and saving models 21 | ------------------------- 22 | -------------------------------------------------------------------------------- /docs/source/examples.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ======== 3 | 4 | This section contains examples of applications of the **quantnn** package. 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | :caption: Contents: 9 | 10 | notebooks/simple_example 11 | notebooks/cloud_top_pressure_retrieval 12 | notebooks/rain_rate_retrieval 13 | 14 | 15 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. quantnn documentation master file, created by 2 | sphinx-quickstart on Sun Dec 6 10:44:55 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | quantnn 7 | ======= 8 | 9 | **quantnn** is a Python package for solving probabilistic regression problems using 10 | (deterministic) deep neural networks, i.e. to predict a conditional distribution 11 | :math:`P(y|x)` for given input :math:`x`. 12 | 13 | It currently provides implementations of two neural-network based methods to solve 14 | these type of problems: 15 | 16 | 1. **Quantile regression neural networks (QRNNs)**: QRNNs learn to predict the quantiles 17 | of the conditional distribution :math:`P(y|x)`, which can be used to estimate its 18 | cumulative distribution function (CDF). 19 | 20 | .. figure:: cdf_qrnn.png 21 | :width: 800 22 | :alt: A conditional Cumulative distribution function predicted using a QRNN. 23 | 24 | Example of a QRNN applied to predict the quantiles of a function with 25 | heteroscedastic noise. 26 | 27 | 2. **Density regression neural networks (DRNNs)**: DRNNs learn predict a binned 28 | version of the probability density function (PDF) of :math:`P(y|x)`. 29 | 30 | .. figure:: pdf_drnn.png 31 | :width: 800 32 | :alt: A conditional probability density function predicted using a DRNN. 33 | 34 | Example of a DRNN applied to predict the quantiles of a function with 35 | heteroscedastic noise. 36 | 37 | 38 | 39 | Features 40 | -------- 41 | 42 | - A flexible, high-level implementation of QRNN and DRNNs supporting both PyTorch and Keras 43 | as backends. 44 | - Generic functions to manipulate and process QRNN and DRNN predictions such as computing the 45 | posterior mean or classifying inputs. 46 | 47 | Installation 48 | ------------ 49 | 50 | The currently recommended way of installing the **quantnn** package is to checkout the source from 51 | `GitHub `_ and install in editable mode using ``pip``: 52 | 53 | .. code-block:: bash 54 | 55 | pip install -e . 56 | 57 | Content 58 | ------- 59 | 60 | .. toctree:: 61 | :maxdepth: 2 62 | 63 | user_guide 64 | examples 65 | api_reference 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /docs/source/notebooks/cloud_top_pressure_retrieval.ipynb: -------------------------------------------------------------------------------- 1 | ../../../notebooks/cloud_top_pressure_retrieval.ipynb -------------------------------------------------------------------------------- /docs/source/notebooks/rain_rate_retrieval.ipynb: -------------------------------------------------------------------------------- 1 | ../../../notebooks/rain_rate_retrieval.ipynb -------------------------------------------------------------------------------- /docs/source/notebooks/simple_example.ipynb: -------------------------------------------------------------------------------- 1 | ../../../notebooks/simple_example.ipynb -------------------------------------------------------------------------------- /docs/source/pdf_drnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simonpf/quantnn/c4d650cf0c6da5b4a704905b6c267d1ca996466f/docs/source/pdf_drnn.png -------------------------------------------------------------------------------- /docs/source/qrnn.rst: -------------------------------------------------------------------------------- 1 | =========================================== 2 | Quantile regression neural networks (QRNNs) 3 | =========================================== 4 | 5 | Consider first the case of a simple, one-dimensional input vector :math:`x` with 6 | input features :math:`x_1, \ldots, x_n` and a corresponding scalar output value 7 | :math:`y`. If the problem of mapping :math:`x` to :math:`y` does not admit a unique 8 | solution a more suitable approach than predicting a single value for :math:`y` 9 | is to instead predict its conditional distribution :math:`P(y|x)` for given 10 | input :math:`x`. 11 | 12 | QRNNs do this by learning to predict a sequence of quantiles :math:`y_\tau` of 13 | the distribution :math:`P(y | \mathbf{x})`. The quantile :math:`y_\tau` for 14 | :math:`\tau \in [0, 1]` is defined as the value for which :math:`P(y \leq y_\tau 15 | | x) = \tau`. Since the quantiles :math:`y_{\tau_0}, y_{\tau_1}, \ldots` 16 | correspond to the values of the inverse :math:`F^{-1}` of the cumulative 17 | distribution function :math:`F(x) = P(X \leq x | Y = y)`, the QRNN output can be 18 | interpreted as a piece-wise linear approximation of the CDF of :math:`P(y|x)`. 19 | 20 | .. figure:: qrnn.svg 21 | :width: 600 22 | :alt: Illustration of a QRNN 23 | 24 | A QRNN is a neural network which predicts a sequence of quantiles of 25 | the posterior distribution :math:`P(y|x)`. 26 | 27 | How it works 28 | ------------ 29 | 30 | QRNNs make use of quantile regression to learn to predict the quantiles of the 31 | distribution :math:`P(y|x)`. This works because the quantile :math:`y_\tau` for 32 | a given quantile fraction :math:`\tau \in [0, 1]` corresponds to a minimum of the 33 | expected value :math:`\mathbf{E}_y {\mathcal{L}_\tau(y, y_\tau)}` of the quantile 34 | loss function 35 | 36 | .. math:: 37 | 38 | \mathcal{L}_\tau(y, y_\tau) = 39 | \begin{cases} 40 | \tau (y - y_\tau) & \text{if } y > y_\tau \\ 41 | (1 - \tau) (y_\tau - y) & \text{otherwise}. 42 | \end{cases} 43 | 44 | A proof of this can be found on `wikipedia `_. 45 | 46 | Because of this property, training a neural network using the quantile loss 47 | function :math:`\mathcal{L}_\tau` will teach the network to predict the 48 | corresponding quantile :math:`y_\tau`. QRNNs extend this principle to a 49 | sequence of quantiles corresponding to an arbitrary selection of quantile 50 | fractions :math:`\tau_1, \tau_2, \ldots` which are optimized simultaneously. 51 | 52 | Defining a model 53 | ---------------- 54 | 55 | To define a QRNN model you need to specify the quantiles to predict as well as the 56 | architecture of the underlying network to use. The code snippet below shows how 57 | to create a QRNN model to predict the first until 99th percentiles. 58 | 59 | As neural network model the QRNN uses fully-connected neural network with four hidden 60 | layers with 128 neurons each and ReLU activation functions. Since this is taken 61 | from the simple example notebook, the number of input features is set to 1. 62 | 63 | .. code :: 64 | 65 | from quantnn import QRNN 66 | quantiles = np.linspace(0.01, 0.99, 99) 67 | 68 | layers = 4 69 | neurons = 128 70 | activation = "relu" 71 | model = (layers, neurons, activation) 72 | 73 | qrnn = q.QRNN(quantiles, input_dimensions=1, model=model) 74 | 75 | QRNN provides a simplified interface to create QRNNs with simple fully-connected 76 | architectures. The QRNN class, however, doesn't restrict you to use a fully-connected 77 | network but supports any other suitable architecture such as, for example, 78 | fully-convolutional DenseNet-type architecture: 79 | 80 | .. code :: 81 | 82 | model = ... # Define model as Pytorch Module or Keras model object. 83 | qrnn = q.QRNN(quantiles, model) 84 | 85 | Training 86 | -------- 87 | 88 | Evaluation 89 | ---------- 90 | 91 | Saving and loading 92 | ------------------ 93 | 94 | -------------------------------------------------------------------------------- /docs/source/user_guide.rst: -------------------------------------------------------------------------------- 1 | User guide 2 | ========== 3 | 4 | This section describes the basic usage of the **quantnn** package. 5 | 6 | Overview 7 | -------- 8 | 9 | The main functionality of quantnn is implemented by two model classes: 10 | :py:class:`quantnn.qrnn.QRNN` and :py:class:`quantnn.drnn.DRNN`. The 11 | :py:class:`~quantnn.qrnn.QRNN` class provides an implementation of quantile 12 | regression neural networks (QRNNs), whereas the :py:class:`~quantnn.drnn.DRNN` 13 | implements density regression neural networks (DRNNs). 14 | 15 | Basic workflow 16 | -------------- 17 | 18 | The basic usage of both the :py:mod:`~quantnn.QRNN` and :py:mod:`~quantnn.DRNN` 19 | classes is similar and follows the generic machine learning workflow: 20 | 21 | 1. **Defining the model**: A model is defined by instantiating the 22 | corresponding model class. For this, you need to define the architecture of 23 | the underlying neural network and specify which quantiles to predict (QRNN) 24 | or the binning of the PDF (DRNN). 25 | 26 | 2. **Training the model**: The training phase is similar as for any other deep 27 | neural network. The neural network backend (PyTorch or Keras) takes care of 28 | the heavy lifting, you only have to choose the training parameters. 29 | 30 | 3. **Evaluating the model**: Finally, you will of course want to make 31 | predictions with you model. This is done using the model 32 | :py:meth:`~quantnn.QRNN.predict` method, which will produce a tensor of 33 | either quantiles (QRNN) or the binned PDF (DRNN) of the posterior 34 | distribution. To further process these prediction the 35 | :py:mod:`quantnn.quantiles` and :py:mod:`quantnn.density` module provide 36 | function that can be used to derive statistics of the probabilistic 37 | results. 38 | 39 | 4. **Loading and saving the model**: To reuse your train model you can save 40 | and load it using the corresponding class methods of the 41 | :py:class:`~quantnn.qrnn.QRNN` :py:class:`~quantnn.qrnn.DRNN` classes. 42 | 43 | .. note :: 44 | 45 | Care has been taken to design the interfaces of the :py:class:`~quantnn.qrnn.QRNN` 46 | and :py:class:`~quantnn.drnn.DRNN` classes as consistently as possible so that 47 | both classes can be used interchangeably to the largest extent possible. 48 | 49 | Content 50 | ------- 51 | 52 | .. toctree:: 53 | :maxdepth: 2 54 | 55 | qrnn 56 | drnn 57 | 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /quantnn.yml: -------------------------------------------------------------------------------- 1 | name: quantnn 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - python=3.9 7 | - torchvision 8 | - torchaudio 9 | - cudatoolkit=10.2 10 | - pytorch 11 | - tensorflow 12 | - tensorboard 13 | prefix: /home/simonpf/miniconda3/envs/quantnn 14 | -------------------------------------------------------------------------------- /quantnn/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | ======= 3 | quantnn 4 | ======= 5 | 6 | The quantnn package provides functionality for probabilistic modeling and prediction 7 | using deep neural networks. 8 | 9 | The two main features of the quantnn package are implemented by the 10 | :py:class:`~quantnn.qrnn.QRNN` and :py:class:`~quantnn.qrnn.DRNN` classes, which implement 11 | quantile regression neural networks (QRNNs) and density regression neural networks (DRNNs), 12 | respectively. 13 | 14 | The modules :py:mod:`quantnn.quantiles` and :py:mod:`quantnn.density` provide generic 15 | (backend agnostic) functions to manipulate probabilistic predictions. 16 | """ 17 | import logging as _logging 18 | import os 19 | 20 | from rich.logging import RichHandler 21 | from quantnn.neural_network_model import set_default_backend, get_default_backend 22 | from quantnn.qrnn import QRNN 23 | from quantnn.drnn import DRNN 24 | from quantnn.quantiles import ( 25 | cdf, 26 | pdf, 27 | posterior_mean, 28 | probability_less_than, 29 | probability_larger_than, 30 | sample_posterior, 31 | sample_posterior_gaussian, 32 | quantile_loss, 33 | ) 34 | 35 | _LOG_LEVEL = os.environ.get("QUANTNN_LOG_LEVEL", "WARNING").upper() 36 | _logging.basicConfig( 37 | level=_LOG_LEVEL, format="%(message)s", datefmt="[%X]", handlers=[RichHandler()] 38 | ) 39 | -------------------------------------------------------------------------------- /quantnn/a_priori.py: -------------------------------------------------------------------------------- 1 | """ 2 | ================ 3 | quantnn.a_priori 4 | ================ 5 | 6 | Defines classes to represent a priori distributions. 7 | """ 8 | from quantnn.generic import ( 9 | get_array_module, 10 | expand_dims, 11 | as_type, 12 | concatenate, 13 | tensordot, 14 | exp, 15 | ) 16 | 17 | 18 | class LookupTable: 19 | """ 20 | Tensor formulation of a simple piece-wise linear lookup 21 | table. 22 | 23 | The a priori here is described as univariate function represented 24 | by its values at a sequence of nodes and the corresponding probabilities. 25 | """ 26 | 27 | def __init__(self, x, y): 28 | """ 29 | Create a new lookup table a priori instance. 30 | 31 | Args: 32 | x: The x values at which the value of the a priori is known. 33 | y: The corresponding non-normalized density. 34 | """ 35 | self.x = x 36 | self.y = y 37 | 38 | def __call__(self, x, dist_axis=1): 39 | """ 40 | Evaluate the a priori. 41 | 42 | Args: 43 | x: Tensor containing the values at which to evaluate the a priori. 44 | dist_axis: The axis along which the tensor x is sorted. 45 | 46 | Returns; 47 | Tensor with the same size as 'x' containing the values of the a priori 48 | at 'x' obtained by linear interpolation. 49 | """ 50 | if len(x.shape) == 1: 51 | dist_axis = 0 52 | xp = get_array_module(x) 53 | n_dims = len(x.shape) 54 | 55 | n = x.shape[dist_axis] 56 | x_index = [slice(0, None)] * n_dims 57 | x_index[dist_axis] = 0 58 | 59 | selection_l = [slice(0, None)] * n_dims 60 | selection_l[dist_axis] = slice(0, -1) 61 | selection_l = tuple(selection_l) 62 | selection_r = [slice(0, None)] * n_dims 63 | selection_r[dist_axis] = slice(1, None) 64 | selection_r = tuple(selection_r) 65 | 66 | r_shape = [1] * n_dims 67 | r_shape[dist_axis] = -1 68 | r_x = self.x.reshape(r_shape) 69 | r_y = self.y.reshape(r_shape) 70 | 71 | r_x_l = r_x[selection_l] 72 | r_x_r = r_x[selection_r] 73 | r_y_l = r_y[selection_l] 74 | r_y_r = r_y[selection_r] 75 | 76 | rs = [] 77 | 78 | for i in range(0, n): 79 | x_index[dist_axis] = slice(i, i + 1) 80 | index = tuple(x_index) 81 | x_i = x[index] 82 | 83 | mask = as_type(xp, (r_x_l < x_i) * (r_x_r >= x_i), x_i) 84 | r = r_y_l * (r_x_r - x_i) * mask 85 | r += r_y_r * (x_i - r_x_l) * mask 86 | r /= mask * (r_x_r - r_x_l) + (1.0 - mask) 87 | r = expand_dims(xp, r.sum(dist_axis), dist_axis) 88 | rs.append(r) 89 | 90 | r = concatenate(xp, rs, dist_axis) 91 | return r 92 | 93 | 94 | class Gaussian: 95 | def __init__(self, x_a, s, dist_axis=-1): 96 | self.x_a = x_a 97 | self.s = s 98 | self.dist_axis = -1 99 | 100 | def __call__(self, x, dist_axis=1): 101 | xp = get_array_module(x) 102 | n_dims = len(x.shape) 103 | shape = [1] * n_dims 104 | shape[self.dist_axis] = -1 105 | x_a = self.x_a.reshape(shape) 106 | 107 | dx = x - x_a 108 | 109 | sdx = tensordot(xp, dx, self.s, ((self.dist_axis,), (-1,))) 110 | l = -0.5 * (dx * sdx).sum(self.dist_axis) 111 | 112 | return exp(xp, l) 113 | -------------------------------------------------------------------------------- /quantnn/backends/__init__.py: -------------------------------------------------------------------------------- 1 | from quantnn.backends.pytorch import PyTorch 2 | from quantnn.common import UnsupportedTensorType 3 | 4 | _TENSOR_BACKEND_CLASSES = [PyTorch] 5 | 6 | TENSOR_BACKENDS = [b for b in _TENSOR_BACKEND_CLASSES if b.available()] 7 | 8 | 9 | def get_tensor_backend(tensor): 10 | """ 11 | Determine the tensor backend for a given tensor. 12 | 13 | Args: 14 | tensor: A tensor type of any of the supported backends. 15 | 16 | Return: 17 | The backend class which providing the interface to the tensor library 18 | corresponding to ``tensor``. 19 | 20 | Raises: 21 | :py:class:`~quantnn.common.UnsupportedTensorType` when the tensor type is not 22 | supported by quantnn. 23 | """ 24 | for backend in TENSOR_BACKENDS: 25 | if backend.matches_tensor(tensor): 26 | return backend 27 | raise UnsupportedTensorType( 28 | f"The provided tensor of type {type(tensor)} is not supported by quantnn." 29 | ) 30 | -------------------------------------------------------------------------------- /quantnn/backends/pytorch.py: -------------------------------------------------------------------------------- 1 | from quantnn.backends.tensor import TensorBackend 2 | 3 | 4 | class PyTorch(TensorBackend): 5 | """ 6 | TensorBackend implementation using torch tensors. 7 | """ 8 | 9 | @classmethod 10 | def available(cls): 11 | try: 12 | import torch 13 | except ImportError: 14 | return False 15 | return True 16 | 17 | @classmethod 18 | def matches_tensor(cls, tensor): 19 | import torch 20 | 21 | return isinstance(tensor, torch.Tensor) 22 | 23 | @classmethod 24 | def from_numpy(cls, array, like=None): 25 | import torch 26 | 27 | tensor = torch.from_numpy(array) 28 | if like is not None: 29 | tensor = tensor.type(like.dtype).to(like.device) 30 | return tensor 31 | 32 | @classmethod 33 | def to_numpy(cls, array): 34 | import torch 35 | if array.dtype in [torch.bfloat16]: 36 | array = array.float() 37 | return array.cpu().detach().numpy() 38 | 39 | @classmethod 40 | def as_type(cls, tensor, like): 41 | return tensor.type_as(like) 42 | 43 | @classmethod 44 | def sample_uniform(cls, shape=None, like=None): 45 | import torch 46 | 47 | if shape is None and like is None: 48 | raise ValueError( 49 | "'sample_uniform' requires at least one of the arguments " 50 | "'shape' and 'like'. " 51 | ) 52 | dtype = None 53 | device = None 54 | if like is not None: 55 | dtype = like.dtype 56 | device = like.device 57 | if shape is None: 58 | shape = like.shape 59 | 60 | return torch.rand(shape, dtype=dtype, device=device) 61 | 62 | @classmethod 63 | def sample_gaussian(cls, shape=None, like=None): 64 | import torch 65 | 66 | if shape is None and like is None: 67 | raise ValueError( 68 | "'sample_uniform' requires at least one of the arguments " 69 | "'shape' and 'like'. " 70 | ) 71 | dtype = None 72 | device = None 73 | if like is not None: 74 | dtype = like.dtype 75 | device = like.device 76 | if shape is None: 77 | shape = like.shape 78 | 79 | return torch.normal(0, 1, shape, dtype=dtype, device=device) 80 | 81 | @classmethod 82 | def size(cls, tensor): 83 | return tensor.numel() 84 | 85 | @classmethod 86 | def concatenate(cls, tensors, dimension): 87 | import torch 88 | 89 | return torch.cat(tensors, dimension) 90 | 91 | @classmethod 92 | def expand_dims(cls, tensor, dimension_index): 93 | return tensor.unsqueeze(dimension_index) 94 | 95 | @classmethod 96 | def exp(cls, tensor): 97 | return tensor.exp() 98 | 99 | @classmethod 100 | def log(cls, tensor): 101 | return tensor.log() 102 | 103 | @classmethod 104 | def pad_zeros(cls, tensor, n, dimension_index): 105 | import torch 106 | 107 | n_dims = len(tensor.shape) 108 | dimension_index = dimension_index % n_dims 109 | pad = [0] * 2 * n_dims 110 | pad[2 * n_dims - 2 - 2 * dimension_index] = n 111 | pad[2 * n_dims - 1 - 2 * dimension_index] = n 112 | return torch.nn.functional.pad(tensor, pad, "constant", 0.0) 113 | 114 | @classmethod 115 | def pad_zeros_left(cls, tensor, n, dimension_index): 116 | import torch 117 | 118 | n_dims = len(tensor.shape) 119 | dimension_index = dimension_index % n_dims 120 | pad = [0] * 2 * n_dims 121 | pad[2 * n_dims - 2 - 2 * dimension_index] = n 122 | pad[2 * n_dims - 1 - 2 * dimension_index] = 0 123 | return torch.nn.functional.pad(tensor, pad, "constant", 0.0) 124 | 125 | @classmethod 126 | def arange(cls, start, end, step, like=None): 127 | import torch 128 | 129 | device = None 130 | dtype = torch.float32 131 | if like is not None: 132 | dtype = like.dtype 133 | device = like.device 134 | return torch.arange(start, end, step, dtype=dtype, device=device) 135 | 136 | @classmethod 137 | def reshape(cls, tensor, shape): 138 | return tensor.reshape(shape) 139 | 140 | @classmethod 141 | def trapz(cls, y, x, dimension): 142 | import torch 143 | 144 | return torch.trapz(y, x, dim=dimension) 145 | 146 | @classmethod 147 | def cumsum(cls, y, dimension): 148 | import torch 149 | 150 | return torch.cumsum(y, dimension) 151 | 152 | @classmethod 153 | def zeros(cls, shape=None, like=None): 154 | import torch 155 | 156 | if shape is None and like is None: 157 | raise ValueError( 158 | "'zeros' requires at least one of the arguments " "'shape' and 'like'. " 159 | ) 160 | dtype = None 161 | device = None 162 | if like is not None: 163 | if shape is None: 164 | return torch.zeros_like(like) 165 | dtype = like.dtype 166 | device = like.device 167 | 168 | return torch.zeros(shape, device=device, dtype=dtype) 169 | 170 | @classmethod 171 | def ones(cls, shape=None, like=None): 172 | import torch 173 | 174 | if shape is None and like is None: 175 | raise ValueError( 176 | "'ones' requires at least one of the arguments " "'shape' and 'like'. " 177 | ) 178 | dtype = None 179 | device = None 180 | if like is not None: 181 | if shape is None: 182 | return torch.ones_like(like) 183 | dtype = like.dtype 184 | device = like.dtype 185 | 186 | return torch.ones(shape, device=device, dtype=dtype) 187 | 188 | @classmethod 189 | def softmax(cls, x, axis=None): 190 | import torch 191 | 192 | return torch.nn.functional.softmax(x, dim=axis) 193 | 194 | @classmethod 195 | def where(cls, condition, x, y): 196 | import torch 197 | 198 | return torch.where(condition, x, y) 199 | -------------------------------------------------------------------------------- /quantnn/common.py: -------------------------------------------------------------------------------- 1 | """ 2 | quantnn.common 3 | ============== 4 | 5 | Implements common features used by the other submodules of the ``quantnn`` 6 | package. 7 | """ 8 | 9 | 10 | class QuantnnException(Exception): 11 | """ Base exception for exception from the quantnn package.""" 12 | 13 | 14 | class UnknownArrayTypeException(QuantnnException): 15 | """Thrown when a function is called with an unsupported array type.""" 16 | 17 | 18 | class UnsupportedTensorType(QuantnnException): 19 | """ 20 | Thrown when quantnn is asked to handle a tensor type it doesn't support. 21 | """ 22 | 23 | 24 | class UnknownModuleException(QuantnnException): 25 | """ 26 | Thrown when an unsupported backend is passed to a generic array 27 | operation. 28 | """ 29 | 30 | 31 | class UnsupportedBackendException(QuantnnException): 32 | """ 33 | Thrown when quantnn is requested to load a backend that is not supported. 34 | """ 35 | 36 | 37 | class MissingBackendException(QuantnnException): 38 | """ 39 | Thrown when a requested backend could not be imported. 40 | """ 41 | 42 | 43 | class InvalidDimensionException(QuantnnException): 44 | """Thrown when an input array doesn't match expected shape.""" 45 | 46 | 47 | class ModelNotSupported(QuantnnException): 48 | """Thrown when a provided model isn't supported by the chosen backend.""" 49 | 50 | 51 | class MissingAuthenticationInfo(QuantnnException): 52 | """Thrown when required authentication information is not available.""" 53 | 54 | 55 | class DatasetError(QuantnnException): 56 | """ 57 | Thrown when a given dataset object does not provide the expected interface. 58 | """ 59 | 60 | 61 | class InvalidURL(QuantnnException): 62 | """ 63 | Thrown when a provided file URL is invalid. 64 | """ 65 | 66 | 67 | class InputDataError(QuantnnException): 68 | """ 69 | Thrown when the training data does not match the expected format. 70 | """ 71 | 72 | class ModelLoadError(QuantnnException): 73 | """ 74 | Thrown when an error occurs while a model is loaded. 75 | """ 76 | -------------------------------------------------------------------------------- /quantnn/examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simonpf/quantnn/c4d650cf0c6da5b4a704905b6c267d1ca996466f/quantnn/examples/__init__.py -------------------------------------------------------------------------------- /quantnn/examples/gprof_conv.py: -------------------------------------------------------------------------------- 1 | """ 2 | ======================= 3 | quantnn.gprof_conv.py 4 | ======================= 5 | 6 | This module provides download functions and dataset classes for the 7 | convolutional GPROF retrieval example. 8 | """ 9 | from pathlib import Path 10 | from urllib.request import urlretrieve 11 | 12 | 13 | def download_data(destination="data"): 14 | """ 15 | Downloads training and evaluation data for the CTP retrieval. 16 | 17 | Args: 18 | destination: Where to store the downloaded data. 19 | """ 20 | datasets = [ 21 | "gprof_conv.npz", 22 | ] 23 | 24 | Path(destination).mkdir(exist_ok=True) 25 | for file in datasets: 26 | file_path = Path("data") / file 27 | if not file_path.exists(): 28 | print(f"Downloading file {file}.") 29 | url = f"http://spfrnd.de/data/gprof/{file}" 30 | urlretrieve(url, file_path) 31 | -------------------------------------------------------------------------------- /quantnn/examples/modis_ctp.py: -------------------------------------------------------------------------------- 1 | """ 2 | quantnn.examples.ctp 3 | ==================== 4 | 5 | This module implements helper functions for the MODIS cloud-top pressure 6 | example. 7 | """ 8 | from pathlib import Path 9 | from urllib.request import urlretrieve 10 | 11 | _DATA_PATH = "/home/simonpf/src/pansat/notebooks/products/data/" 12 | _MODIS_FILES = [ 13 | _DATA_PATH + "MODIS/MYD021KM.A2016286.1750.061.2018062214718.hdf", 14 | _DATA_PATH + "MODIS/MYD03.A2016286.1750.061.2018062032022.hdf", 15 | ] 16 | 17 | modis_files = _MODIS_FILES 18 | pad_along_orbit = 200 19 | pad_across_orbit = 300 20 | 21 | 22 | def prepare_input_data(modis_files): 23 | """ 24 | Prepares validation data for the MODIS CTP retrieval. 25 | 26 | Args: 27 | modis_file: List of filenames containing the MODIS input data to use 28 | as input. 29 | 30 | Returns: 31 | Dictionary containing the different files required to run the retrieval 32 | on the CALIOP data. 33 | """ 34 | from datetime import datetime 35 | from satpy import Scene 36 | import scipy as sp 37 | from scipy.interpolate import RegularGridInterpolator 38 | from scipy.ndimage import maximum_filter, minimum_filter 39 | from scipy.signal import convolve 40 | import numpy as np 41 | import xarray 42 | from pansat.products.reanalysis.era5 import ERA5Product 43 | from pansat.products.satellite.calipso import clay01km 44 | from pykdtree.kdtree import KDTree 45 | from PIL import Image 46 | 47 | # 48 | # Prepare MODIS data. 49 | # 50 | 51 | scene = Scene(filenames=modis_files, reader="modis_l1b") 52 | scene.load(["true_color", "31", "32", "latitude", "longitude"], resolution=1000) 53 | 54 | scene["true_color_small"] = scene["true_color"][:, ::4, ::4] 55 | scene["bt_11_small"] = scene["31"][::4, ::4] 56 | scene["bt_12_small"] = scene["32"][::4, ::4] 57 | scene["latitude_small"] = scene["latitude"][::4, ::4] 58 | scene["longitude_small"] = scene["longitude"][::4, ::4] 59 | scene.save_dataset("true_color_small", "modis_true_color.png") 60 | image = Image.open("modis_true_color.png") 61 | modis_rgb = np.array(image) 62 | bt_11_rgb = scene["bt_11_small"].compute() 63 | bt_12_rgb = scene["bt_12_small"].compute() 64 | lats_rgb = scene["latitude_small"].compute() 65 | lons_rgb = scene["longitude_small"].compute() 66 | 67 | # MODIS data input features. 68 | 69 | lats_r = scene["latitude"].compute() 70 | lons_r = scene["longitude"].compute() 71 | bt_11 = scene["31"].compute() 72 | bt_12 = scene["32"].compute() 73 | 74 | def mean_filter(img): 75 | k = np.ones((5, 5)) / 25.0 76 | return convolve(img, k, mode="same") 77 | 78 | def std_filter(img): 79 | mu = mean_filter(img ** 2) 80 | mu2 = mean_filter(img) ** 2 81 | return np.sqrt(mu - mu2) 82 | 83 | bt_11_w = maximum_filter(bt_11, [5, 5]) 84 | bt_11_c = minimum_filter(bt_11, [5, 5]) 85 | bt_12_w = maximum_filter(bt_12, [5, 5]) 86 | bt_12_c = minimum_filter(bt_12, [5, 5]) 87 | bt_11_s = std_filter(bt_11) 88 | bt_1112_s = std_filter(bt_11 - bt_12) 89 | 90 | # 91 | # Calipso data 92 | # 93 | 94 | t_0 = datetime(2016, 10, 12, 17, 00) 95 | t_1 = datetime(2016, 10, 12, 17, 50) 96 | calipso_files = clay01km.download(t_0, t_1) 97 | 98 | lat_min = lats_r.data.min() 99 | lat_max = lats_r.data.max() 100 | lon_min = lons_r.data.min() 101 | lon_max = lons_r.data.max() 102 | 103 | dataset = clay01km.open(calipso_files[0]) 104 | lats_c = dataset["latitude"].data 105 | lons_c = dataset["longitude"].data 106 | ctp_c = dataset["layer_top_pressure"] 107 | cth_c = dataset["layer_top_altitude"] 108 | 109 | indices = np.where( 110 | (lats_c > lat_min) 111 | * (lats_c <= lat_max) 112 | * (lons_c > lon_min) 113 | * (lons_c <= lon_max) 114 | ) 115 | points = np.hstack([lats_r.data.reshape(-1, 1), lons_r.data.reshape(-1, 1)]) 116 | kd_tree = KDTree(points) 117 | 118 | points_c = np.hstack([lats_c.reshape(-1, 1), lons_c.reshape(-1, 1)]) 119 | d, indices = kd_tree.query(points_c) 120 | valid = d < 0.01 121 | indices = indices[valid] 122 | lats_c = lats_c[valid] 123 | lons_c = lons_c[valid] 124 | ctp_c = ctp_c[valid] 125 | cth_c = cth_c[valid] 126 | bt_11 = bt_11.data.ravel()[indices] 127 | bt_12 = bt_12.data.ravel()[indices] 128 | bt_11_w = bt_11_w.ravel()[indices] 129 | bt_12_w = bt_12_w.ravel()[indices] 130 | bt_11_c = bt_11_c.ravel()[indices] 131 | bt_12_c = bt_12_c.ravel()[indices] 132 | bt_11_s = bt_11_s.ravel()[indices] 133 | bt_1112_s = bt_1112_s.ravel()[indices] 134 | lats_r = lats_r.data.ravel()[indices] 135 | lons_r = lons_r.data.ravel()[indices] 136 | 137 | # 138 | # ERA 5 data. 139 | # 140 | 141 | t_0 = datetime(2016, 10, 12, 17, 45) 142 | t_1 = datetime(2016, 10, 12, 17, 50) 143 | 144 | surface_variables = ["surface_pressure", "2m_temperature", "tcwv"] 145 | domain = [lat_min - 2, lat_max + 2, lon_min - 2, lon_max + 2] 146 | surface_product = ERA5Product("hourly", "surface", surface_variables, domain) 147 | era_surface_files = surface_product.download(t_0, t_1) 148 | 149 | pressure_variables = ["temperature"] 150 | pressure_product = ERA5Product("hourly", "pressure", pressure_variables, domain) 151 | era_pressure_files = pressure_product.download(t_0, t_1) 152 | 153 | # interpolate pressure data. 154 | 155 | era5_data = xarray.open_dataset(era_pressure_files[0]) 156 | lats_era = era5_data["latitude"][::-1] 157 | lons_era = era5_data["longitude"] 158 | p_era = era5_data["level"] 159 | p_inds = [np.where(p_era == p)[0] for p in [950, 850, 700, 500, 250]] 160 | pressures = [] 161 | for ind in p_inds: 162 | p_interp = RegularGridInterpolator( 163 | [lats_era, lons_era], era5_data["t"].data[0, ind[0], ::-1, :] 164 | ) 165 | pressures.append(p_interp((lats_r, lons_r))) 166 | 167 | era5_data = xarray.open_dataset(era_surface_files[0]) 168 | lats_era = era5_data["latitude"][::-1] 169 | lons_era = era5_data["longitude"] 170 | t_interp = RegularGridInterpolator( 171 | [lats_era, lons_era], era5_data["t2m"].data[0, ::-1, :] 172 | ) 173 | t_surf = t_interp((lats_r, lons_r)) 174 | p_interp = RegularGridInterpolator( 175 | [lats_era, lons_era], era5_data["sp"].data[0, ::-1, :] 176 | ) 177 | p_surf = p_interp((lats_r, lons_r)) 178 | tcwv_interp = RegularGridInterpolator( 179 | [lats_era, lons_era], era5_data["tcwv"].data[0, ::-1, :] 180 | ) 181 | tcwv = tcwv_interp((lats_r, lons_r)) 182 | 183 | # 184 | # Assemble input data 185 | # 186 | 187 | x = np.zeros((lats_r.size, 16)) 188 | x[:, 0] = p_surf 189 | x[:, 1] = t_surf 190 | for i, p in enumerate(pressures): 191 | x[:, 2 + i] = p 192 | x[:, 7] = tcwv 193 | 194 | x[:, 8] = bt_12 195 | x[:, 9] = bt_11 - bt_12 196 | x[:, 10] = bt_11_w - bt_12_w 197 | x[:, 11] = bt_11_c - bt_12_c 198 | x[:, 12] = bt_12_w - bt_12 199 | x[:, 13] = bt_12_c - bt_12 200 | 201 | x[:, 14] = bt_11_s 202 | x[:, 15] = bt_1112_s 203 | 204 | output_data = { 205 | "input_data": x, 206 | "ctp": ctp_c, 207 | "latitude": lats_r, 208 | "longitude": lons_r, 209 | "latitude_rgb": lats_rgb, 210 | "longitude_rgb": lons_rgb, 211 | "modis_rgb": modis_rgb, 212 | "bt_11_rgb": bt_11_rgb, 213 | "bt_12_rgb": bt_12_rgb, 214 | } 215 | return output_data 216 | 217 | 218 | def download_data(destination="data"): 219 | """ 220 | Downloads training and evaluation data for the CTP retrieval. 221 | 222 | Args: 223 | destination: Where to store the downloaded data. 224 | """ 225 | datasets = ["ctp_training_data.npz", "ctp_validation_data.npz"] 226 | 227 | Path(destination).mkdir(exist_ok=True) 228 | 229 | for file in datasets: 230 | file_path = Path("data") / file 231 | if not file_path.exists(): 232 | print(f"Downloading file {file}.") 233 | url = f"http://spfrnd.de/data/ctp/{file}" 234 | urlretrieve(url, file_path) 235 | -------------------------------------------------------------------------------- /quantnn/examples/simple.py: -------------------------------------------------------------------------------- 1 | """ 2 | quantnn.examples.simple 3 | ======================= 4 | 5 | This module provides a simple toy example to illustrate the basic 6 | functionality of quantile regression neural networks. The task is a simple 7 | 1-dimensional regression problem of a signal with heteroscedastic noise: 8 | 9 | .. math:: 10 | 11 | y = \sin(x) + \cdot \cos(x) \cdot \mathcal{N}(0, 1) 12 | 13 | """ 14 | import numpy as np 15 | import scipy as sp 16 | import matplotlib.pyplot as plt 17 | from matplotlib.cm import magma 18 | from matplotlib.colors import Normalize 19 | 20 | 21 | def create_training_data(n=1_000_000): 22 | """ 23 | Create training data by randomly sampling the range :math:`[-\pi, \pi]` 24 | and computing y. 25 | 26 | Args: 27 | n(int): How many sample to compute. 28 | 29 | Return: 30 | Tuple ``(x, y)`` containing the input samples ``x`` given as 2D array 31 | with samples along first and input features along second dimension 32 | and the corresponding :math:`y` values in ``y``. 33 | """ 34 | x = 2.0 * np.pi * np.random.random(size=n) - np.pi 35 | y = np.sin(x) + 1.0 * np.cos(x) * np.random.randn(n) 36 | return x, y 37 | 38 | 39 | def create_validation_data(x): 40 | """ 41 | Creates validation data for the toy example. 42 | 43 | In contrast to the generation of the training data this function allows 44 | specifying the x value of the data which allows plotting the predicted 45 | result over an arbitrary domain. 46 | 47 | Args: 48 | x: Arbitrary array containing the x values for which to compute 49 | corresponding y values. 50 | Return: 51 | Numpy array containing the y values corresponding to the given x 52 | values. 53 | """ 54 | y = np.sin(x) + 1.0 * np.cos(x) * np.random.randn(*x.shape) 55 | return y 56 | 57 | 58 | def plot_histogram(x, y): 59 | """ 60 | Plot 2D histogram of data. 61 | """ 62 | # Calculate histogram 63 | bins_x = np.linspace(-np.pi, np.pi, 201) 64 | bins_y = np.linspace(-4, 4, 201) 65 | x_img, y_img = np.meshgrid(bins_x, bins_y) 66 | img, _, _ = np.histogram2d(x, y, bins=(bins_x, bins_y), density=True) 67 | 68 | # Plot results 69 | f, ax = plt.subplots(1, 1, figsize=(10, 6)) 70 | m = ax.pcolormesh(x_img, y_img, img.T, vmin=0, vmax=0.3, cmap="magma") 71 | x_sin = np.linspace(-np.pi, np.pi, 1001) 72 | y_sin = np.sin(x_sin) 73 | ax.plot(x_sin, y_sin, c="grey", label="$y=\sin(x)$", lw=3) 74 | ax.set_ylim([-2, 2]) 75 | ax.set_xlabel("x") 76 | ax.set_ylabel("y") 77 | plt.colorbar(m, label="Normalized frequency") 78 | plt.legend() 79 | 80 | 81 | def plot_results(x_train, y_train, x_val, y_pred, y_mean, quantiles): 82 | """ 83 | Plots the predicted quantiles against empirical quantiles. 84 | """ 85 | # Calculate histogram and empirical quantiles. 86 | bins_x = np.linspace(-np.pi, np.pi, 201) 87 | bins_y = np.linspace(-4, 4, 201) 88 | x_img, y_img = np.meshgrid(bins_x, bins_y) 89 | img, _, _ = np.histogram2d(x_train, y_train, bins=(bins_x, bins_y), density=True) 90 | norm = np.trapz(img, x=0.5 * (bins_y[1:] + bins_y[:-1]), axis=1) 91 | img_normed = img / norm.reshape(-1, 1) 92 | img_cdf = sp.integrate.cumtrapz( 93 | img_normed, x=0.5 * (bins_y[1:] + bins_y[:-1]), axis=1 94 | ) 95 | 96 | x_centers = 0.5 * (bins_x[1:] + bins_x[:-1]) 97 | y_centers = 0.5 * (bins_y[2:] + bins_y[:-2]) 98 | 99 | norm = Normalize(0, 1) 100 | plt.figure(figsize=(10, 6)) 101 | img = plt.contourf( 102 | x_centers, 103 | y_centers, 104 | img_cdf.T, 105 | levels=quantiles, 106 | norm=norm, 107 | cmap="magma", 108 | ) 109 | for i in range(0, 13, 1): 110 | l_q = plt.plot(x_val, y_pred[:, i], lw=2, ls="--", color="grey") 111 | handles = l_q 112 | handles += plt.plot(x_val, y_mean, c="k", ls="--", lw=2) 113 | labels = ["Predicted quantiles", "Predicted mean"] 114 | plt.legend(handles=handles, labels=labels) 115 | 116 | plt.xlim([-np.pi, np.pi]) 117 | plt.ylim([-3, 3]) 118 | plt.xlabel("x") 119 | plt.ylabel("y") 120 | plt.grid(False) 121 | plt.colorbar(img, label=r"Empirical quantiles") 122 | -------------------------------------------------------------------------------- /quantnn/files/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ============= 3 | quantnn.files 4 | ============= 5 | 6 | The :py:mod:`quantnn.files` module provides an abstraction layer to open files 7 | locally or via SFTP. 8 | 9 | Refer to the documentation of the :py:mod:`quantnn.files.sftp`` module for 10 | information on how to set the username and password for the SFTP connection. 11 | 12 | Example 13 | ------- 14 | 15 | .. code-block:: 16 | 17 | # To load a local file: 18 | with open_file("my_file.txt") as file: 19 | content = file.read() 20 | 21 | # To open a remote file via SFTP: 22 | with open_file("sftp://129.16.35.202/my_file.txt") 23 | 24 | 25 | """ 26 | from contextlib import contextmanager 27 | from pathlib import PurePath, Path 28 | from urllib.parse import urlparse 29 | 30 | from quantnn.files import sftp 31 | from quantnn.common import InvalidURL 32 | 33 | 34 | @contextmanager 35 | def read_file(path, *args, **kwargs): 36 | """ 37 | Generic function to open files. Currently supports opening files 38 | on the local system as well as on a remote machine via SFTP. 39 | """ 40 | if isinstance(path, PurePath): 41 | yield open(path, *args, **kwargs) 42 | return 43 | 44 | url = urlparse(path) 45 | if url.netloc == "": 46 | yield open(path, *args, **kwargs) 47 | return 48 | 49 | if url.scheme == "sftp": 50 | host = url.netloc 51 | if host == "": 52 | raise InvalidURL( 53 | f"No host in SFTP URL." 54 | f"To load a file using SFTP, the URL must be of the form " 55 | f"'sftp:///'." 56 | ) 57 | 58 | with sftp.download_file(host, url.path) as file: 59 | yield open(file, *args, **kwargs) 60 | return 61 | 62 | raise InvalidURL(f"The provided protocol '{url.scheme}' is not supported.") 63 | 64 | 65 | class _DummyCache: 66 | """ 67 | A dummy cache for local files, which are not cached 68 | at all. 69 | 70 | """ 71 | 72 | def __init__(self): 73 | """Create dummy cache.""" 74 | pass 75 | 76 | def download_files(self, host, files, pool): 77 | pass 78 | 79 | def get(self, host, path): 80 | """Get file from cache.""" 81 | return path 82 | 83 | def cleanup(self): 84 | pass 85 | 86 | 87 | class CachedDataFolder: 88 | """ 89 | This class provides an interface to a generic folder containing 90 | dataset files. This folder can be accessed via the local file 91 | system or SFTP. If the folder is located on a remote SFTP server, 92 | the files are cached to avoid having to retransfer the files. 93 | 94 | Attributes: 95 | files: List of available files in the folder. 96 | host: The name of the host where the folder is located or 97 | "" if the folder is local. 98 | cache: Cache object used to cache data accesses. 99 | """ 100 | 101 | def __init__(self, path, pattern="*", n_files=None): 102 | """ 103 | Create a CachedDataFolder. 104 | 105 | Args: 106 | path: Path to the folder to load. 107 | pattern: Glob pattern to select the files. 108 | n_files: If given only the first ``n_files`` matching files will 109 | be loaded. 110 | """ 111 | if isinstance(path, PurePath): 112 | files = path.iterdir() 113 | self.host = "" 114 | self.cache = _DummyCache() 115 | else: 116 | url = urlparse(path) 117 | if url.netloc == "": 118 | files = Path(path).iterdir() 119 | self.host = "" 120 | self.cache = _DummyCache() 121 | else: 122 | if url.scheme == "sftp": 123 | self.host = url.netloc 124 | if self.host == "": 125 | raise InvalidURL( 126 | f"No host in SFTP URL." 127 | f"To load a file using SFTP, the URL must be of the " 128 | f" form 'sftp:///'." 129 | ) 130 | files = sftp.list_files(self.host, url.path) 131 | self.cache = sftp.SFTPCache() 132 | else: 133 | raise InvalidURL( 134 | f"The provided protocol '{url.scheme}' " f" is not supported." 135 | ) 136 | self.files = list(filter(lambda f: f.match(pattern), files)) 137 | if n_files: 138 | self.files = self.files[:n_files] 139 | 140 | def download(self, pool): 141 | """ 142 | This method downloads all files in the folder to populate the 143 | cache. 144 | 145 | Args: 146 | The PoolExecutor to use for the conurrent download. 147 | """ 148 | self.cache.download_files(self.host, self.files, pool) 149 | 150 | def get(self, path): 151 | """ 152 | Retrieve file from folder. 153 | 154 | Args: 155 | path: The path of the file to retrieve. 156 | 157 | Return: 158 | If it is a local file, the filename of the file is returned. 159 | If the file is remote a cached temporary file object with 160 | the data is returned. 161 | """ 162 | return self.cache.get(self.host, path) 163 | 164 | def open(self, path, *args, **kwargs): 165 | """ 166 | Retrieve file from cache and open. 167 | 168 | Args: 169 | path: The path of the file to retrieve. 170 | *args: Passed to open call if file is local. 171 | **kwargs: Passed to open call if file is local. 172 | 173 | """ 174 | file = self.get(path) 175 | if isinstance(file, PurePath): 176 | return open(file, *args, **kwargs) 177 | return file 178 | -------------------------------------------------------------------------------- /quantnn/files/sftp.py: -------------------------------------------------------------------------------- 1 | """ 2 | ================== 3 | quantnn.files.sftp 4 | ================== 5 | 6 | This module provides high-level functions to access file via 7 | SFTP. 8 | """ 9 | from contextlib import contextmanager 10 | from concurrent.futures import Future 11 | from copy import copy 12 | import logging 13 | import os 14 | from pathlib import Path 15 | import tempfile 16 | import warnings 17 | 18 | with warnings.catch_warnings(): 19 | warnings.filterwarnings("ignore") 20 | import paramiko 21 | 22 | from quantnn.common import MissingAuthenticationInfo 23 | 24 | _LOGGER = logging.getLogger("quantnn.files.sftp") 25 | 26 | 27 | def get_login_info(): 28 | """ 29 | Retrieves SFTP login info from the 'QUANTNN_SFTP_USER' AND 30 | 'QUANTNN_SFTP_PASSWORD' environment variables. 31 | 32 | Returns: 33 | 34 | Tuple ``(user_name, password)`` containing the SFTP user name and 35 | password retrieved from the environment variables. 36 | 37 | Raises: 38 | 39 | MissingAuthenticationInfo exception when required information is 40 | not provided as environment variable. 41 | """ 42 | user_name = os.environ.get("QUANTNN_SFTP_USER") 43 | password = os.environ.get("QUANTNN_SFTP_PASSWORD") 44 | if user_name is None or password is None: 45 | raise MissingAuthenticationInfo( 46 | "SFTPStream dataset requires the 'QUANTNN_SFTP_USER' and " 47 | "'QUANTNN_SFTP_PASSWORD' to be set." 48 | ) 49 | return user_name, password 50 | 51 | 52 | @contextmanager 53 | def get_sftp_connection(host): 54 | """ 55 | Contextmanager to open and close an SFTP connection to 56 | a given host. 57 | 58 | Login credentials for the SFTP server are retrieved from the 59 | 'QUANTNN_SFTP_USER' and 'QUANTNN_SFTP_PASSWORD' environment variables. 60 | 61 | Args: 62 | host: IP address of the host. 63 | 64 | Returns: 65 | ``paramiko.SFTP`` object providing access to the open SFTP connection. 66 | """ 67 | user_name, password = get_login_info() 68 | transport = None 69 | sftp = None 70 | try: 71 | transport = paramiko.Transport(host) 72 | transport.connect(username=user_name, password=password) 73 | sftp = paramiko.SFTPClient.from_transport(transport) 74 | yield sftp 75 | finally: 76 | if sftp: 77 | sftp.close() 78 | if transport: 79 | transport.close() 80 | 81 | 82 | def list_files(host, path): 83 | """ 84 | List files in SFTP folder. 85 | 86 | Args: 87 | host: IP address of the host. 88 | path: The path for which to list the files 89 | 90 | 91 | Returns: 92 | List of absolute paths to the files discovered under 93 | the given path. 94 | """ 95 | with get_sftp_connection(host) as sftp: 96 | files = sftp.listdir(path) 97 | return [Path(path) / f for f in files] 98 | 99 | 100 | @contextmanager 101 | def download_file(host, path): 102 | """ 103 | Downloads file from host to a temporary directory and 104 | return the path of this file. 105 | 106 | Args: 107 | host: IP address of the host from which to download the file. 108 | path: Path of the file on the host. 109 | 110 | Return: 111 | pathlib.Path object pointing to the downloaded file. 112 | """ 113 | path = Path(path) 114 | with tempfile.TemporaryDirectory() as directory: 115 | destination = Path(directory) / path.name 116 | with get_sftp_connection(host) as sftp: 117 | _LOGGER.info("Downloading file %s to %s.", path, destination) 118 | sftp.get(str(path), str(destination)) 119 | yield destination 120 | 121 | 122 | def _download_file(host, path): 123 | """ 124 | Wrapper file to concurrently download files via SFTP. 125 | """ 126 | _, file = tempfile.mkstemp() 127 | with get_sftp_connection(host) as sftp: 128 | _LOGGER.info("Downloading file %s to %s.", path, file) 129 | try: 130 | sftp.get(str(path), file) 131 | except Exception: 132 | os.remove(file) 133 | return file 134 | 135 | 136 | class SFTPCache: 137 | """ 138 | Cache for SFTP files. 139 | 140 | Attributes: 141 | files: Dictionary mapping tuples ``(host, path)`` to temporary 142 | file object. 143 | """ 144 | 145 | def __init__(self): 146 | self._owner = True 147 | self.files = {} 148 | 149 | def __del__(self): 150 | """Make sure temporary data is cleaned up.""" 151 | if self._owner: 152 | self._cleanup() 153 | 154 | def _cleanup(self): 155 | """ Clean up temporary files. """ 156 | _LOGGER.info("Cleaning up SFTP cache.") 157 | for file in self.files.values(): 158 | if isinstance(file, Future): 159 | file = file.result() 160 | if isinstance(file, str): 161 | os.remove(file) 162 | else: 163 | os.remove(file.name) 164 | 165 | def download_files(self, host, paths, pool): 166 | """ 167 | Download list of file concurrently. 168 | 169 | Args: 170 | host: The SFTP host from which to download the data. 171 | paths: List of paths to download from the host. 172 | pool: A PoolExecutor to use for parallelizing the download. 173 | """ 174 | tasks = {} 175 | for path in paths: 176 | if (host, path) not in self.files: 177 | task = pool.submit(_download_file, host, path) 178 | tasks[path] = task 179 | for path in paths: 180 | if (host, path) not in self.files: 181 | self.files[(host, path)] = tasks[path].result() 182 | 183 | def get(self, host, path): 184 | """ 185 | Retrieve file from cache. If file is not found in cache it is 186 | retrieved via SFTP and stored in the cache. 187 | 188 | Args: 189 | host: The SFTP host from which to retrieve the file. 190 | path: The path of the file on the host. 191 | 192 | Return: 193 | The temporary file object containing the requested file. 194 | """ 195 | key = (host, path) 196 | if key not in self.files: 197 | _, file = tempfile.mkstemp() 198 | with get_sftp_connection(host) as sftp: 199 | sftp.get(str(path), file) 200 | self.files[key] = file 201 | 202 | value = self.files[key] 203 | if isinstance(value, Future): 204 | return value.result() 205 | return self.files[key] 206 | 207 | def __getstate__(self): 208 | """Set owner attribute to false when object is pickled. """ 209 | dct = copy(self.__dict__) 210 | dct["_owner"] = False 211 | return dct 212 | -------------------------------------------------------------------------------- /quantnn/logging/multiprocessing.py: -------------------------------------------------------------------------------- 1 | """ 2 | =============================== 3 | quantnn.logging.multiprocessing 4 | =============================== 5 | 6 | This module defines utility function to handle logging from sub-processes. 7 | """ 8 | import logging 9 | from logging import handlers 10 | import multiprocessing 11 | import threading 12 | 13 | _LOG_QUEUE = None 14 | 15 | 16 | def get_log_queue(): 17 | """ 18 | Return global logging queue. 19 | """ 20 | global _LOG_QUEUE 21 | if _LOG_QUEUE is None: 22 | _LOG_QUEUE = multiprocessing.Queue() 23 | return _LOG_QUEUE 24 | 25 | 26 | class SubprocessLogging(multiprocessing.Process): 27 | """ 28 | Base class to handle logging from subprocesses. Subprocesses should 29 | inherit from this class in order to have their log messages displayed 30 | cleanly. 31 | """ 32 | def __init__(self): 33 | super().__init__() 34 | self.log_queue = get_log_queue() 35 | 36 | def run(self): 37 | import quantnn.logging 38 | root = logging.getLogger() 39 | root.handlers = [handlers.QueueHandler(self.log_queue)] 40 | 41 | 42 | class LoggingThread(threading.Thread): 43 | """ 44 | Thread to log messages from working processes. 45 | """ 46 | def __init__(self, 47 | log_queue): 48 | """ 49 | Args: 50 | log_queue: The queue to which other processes are logging their 51 | messages. 52 | 53 | """ 54 | super().__init__() 55 | self.log_queue = log_queue 56 | 57 | def run(self): 58 | """ 59 | Listen on queue and print incoming messages. 60 | """ 61 | while True: 62 | record = self.log_queue.get() 63 | if record is None: 64 | break 65 | logger = logging.getLogger(record.name) 66 | logger.handle(record) 67 | 68 | 69 | _USERS = 0 70 | _LOGGING_THREAD = None 71 | 72 | 73 | def start_logging(): 74 | """ 75 | Starts the listener thread that prints messages from the subprocesses. 76 | """ 77 | global _USERS, _LOGGING_THREAD 78 | _USERS += 1 79 | if _LOGGING_THREAD is None: 80 | _LOGGING_THREAD = LoggingThread(get_log_queue()) 81 | _LOGGING_THREAD.start() 82 | 83 | 84 | def stop_logging(): 85 | """ 86 | Signal that no more messages are expected from subprocesses. 87 | """ 88 | global _USERS, _LOGGING_THREAD 89 | _USERS -= 1 90 | if _USERS <= 0: 91 | if _LOGGING_THREAD is not None: 92 | get_log_queue().put(None) 93 | _LOGGING_THREAD.join() 94 | _LOGGING_THREAD = None 95 | -------------------------------------------------------------------------------- /quantnn/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | quantnn.models.pytorch 3 | ====================== 4 | 5 | This moudles provides Pytorch neural network models that can be used a backend 6 | for the :py:class:`quantnn.QRNN` class. 7 | """ 8 | from quantnn.models.pytorch.common import ( 9 | BatchedDataset, 10 | save_model, 11 | load_model, 12 | ) 13 | from torch.nn import CrossEntropyLoss 14 | from quantnn.models.pytorch.fully_connected import FullyConnected 15 | from quantnn.models.pytorch.unet import UNet 16 | -------------------------------------------------------------------------------- /quantnn/models/keras/padding.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import keras 3 | from tensorflow.keras import layers 4 | 5 | 6 | class SymmetricPadding(layers.Layer): 7 | def __init__(self, amount): 8 | super().__init__() 9 | self.paddings = tf.constant( 10 | [[0, 0], [amount, amount], [amount, amount], [0, 0]] 11 | ) 12 | 13 | def call(self, input): 14 | return tf.pad(input, self.paddings, "SYMMETRIC") 15 | -------------------------------------------------------------------------------- /quantnn/models/keras/unet.py: -------------------------------------------------------------------------------- 1 | """ 2 | ========================= 3 | quantnn.models.keras.unet 4 | ========================= 5 | 6 | This module provides an implementation of the UNet [unet]_ 7 | architecture. 8 | 9 | .. [unet] O. Ronneberger, P. Fischer and T. Brox, "U-net: Convolutional networks 10 | for biomedical image segmentation", Proc. Int. Conf. Med. Image Comput. 11 | Comput.-Assist. Intervent. (MICCAI), pp. 234-241, 2015. 12 | """ 13 | import tensorflow as tf 14 | from tensorflow import keras 15 | from tensorflow.keras import layers 16 | from tensorflow.keras import activations 17 | from tensorflow.keras import Input 18 | 19 | from quantnn.models.keras.padding import SymmetricPadding 20 | 21 | 22 | class ConvolutionBlock(layers.Layer): 23 | """ 24 | A convolution block consisting of a pair of 3x3 convolutions followed by 25 | batch normalization and ReLU activations. 26 | """ 27 | 28 | def __init__(self, channels_in, channels_out): 29 | """ 30 | Create new convolution block. 31 | 32 | Args: 33 | channels_in: The number of input channels. 34 | channels_out: The number of output channels. 35 | """ 36 | super().__init__() 37 | input_shape = (None, None, channels_in) 38 | self.block = keras.Sequential() 39 | self.block.add(SymmetricPadding(1)) 40 | self.block.add( 41 | layers.Conv2D(channels_out, 3, padding="valid", input_shape=input_shape) 42 | ) 43 | self.block.add(layers.BatchNormalization()) 44 | self.block.add(layers.ReLU()) 45 | self.block.add(SymmetricPadding(1)) 46 | self.block.add(layers.Conv2D(channels_out, 3, padding="valid")) 47 | self.block.add(layers.BatchNormalization()) 48 | self.block.add(layers.ReLU()) 49 | 50 | def call(self, input): 51 | x = input 52 | return self.block(x) 53 | 54 | 55 | class DownsamplingBlock(keras.Sequential): 56 | """ 57 | A downsampling block consisting of a max pooling layer and a 58 | convolution block. 59 | """ 60 | 61 | def __init__(self, channels_in, channels_out): 62 | """ 63 | Create new convolution block. 64 | 65 | Args: 66 | channels_in: The number of input channels. 67 | channels_out: The number of output channels. 68 | """ 69 | super().__init__() 70 | input_shape = (None, None, channels_in) 71 | self.add(layers.MaxPooling2D(strides=(2, 2))) 72 | self.add(ConvolutionBlock(channels_in, channels_out)) 73 | 74 | 75 | class UpsamplingBlock(layers.Layer): 76 | """ 77 | An upsampling block which which uses bilinear interpolation 78 | to increase the input size. This is followed by a 1x1 convolution to 79 | reduce the number of channels, concatenation of the skip inputs 80 | from the corresponding downsampling layer and a convolution block. 81 | 82 | """ 83 | 84 | def __init__(self, channels_in, channels_out): 85 | """ 86 | Create new convolution block. 87 | 88 | Args: 89 | channels_in: The number of input channels. 90 | channels_out: The number of output channels. 91 | """ 92 | super().__init__() 93 | self.upsample = layers.UpSampling2D(size=(2, 2), interpolation="bilinear") 94 | input_shape = (None, None, channels_in) 95 | self.reduce = layers.Conv2D( 96 | channels_in // 2, 1, padding="same", input_shape=input_shape 97 | ) 98 | self.concat = layers.Concatenate() 99 | self.conv_block = ConvolutionBlock(channels_in, channels_out) 100 | 101 | def call(self, inputs): 102 | x, x_skip = inputs 103 | x_up = self.reduce(self.upsample(x)) 104 | x = self.concat([x_up, x_skip]) 105 | return self.conv_block(x) 106 | 107 | 108 | class UNet(keras.Model): 109 | """ 110 | Keras implementation of the UNet architecture, an input block followed 111 | by 4 encoder blocks and 4 decoder blocks. 112 | 113 | 114 | 115 | 116 | """ 117 | 118 | def __init__(self, n_inputs, n_outputs): 119 | super().__init__() 120 | 121 | self.in_block = keras.Sequential( 122 | [ 123 | SymmetricPadding(1), 124 | layers.Conv2D(64, 3, padding="valid", input_shape=(None, None, 128)), 125 | layers.BatchNormalization(), 126 | layers.ReLU(), 127 | SymmetricPadding(1), 128 | layers.Conv2D(64, 3, padding="valid"), 129 | layers.BatchNormalization(), 130 | layers.ReLU(), 131 | ] 132 | ) 133 | 134 | self.down_block_1 = DownsamplingBlock(64, 128) 135 | self.down_block_2 = DownsamplingBlock(128, 256) 136 | self.down_block_3 = DownsamplingBlock(256, 512) 137 | self.down_block_4 = DownsamplingBlock(512, 1024) 138 | 139 | self.up_block_1 = UpsamplingBlock(1024, 512) 140 | self.up_block_2 = UpsamplingBlock(512, 256) 141 | self.up_block_3 = UpsamplingBlock(256, 128) 142 | self.up_block_4 = UpsamplingBlock(128, 64) 143 | 144 | self.out_block = layers.Conv2D( 145 | n_outputs, 1, padding="same", input_shape=(None, None, 64) 146 | ) 147 | 148 | def call(self, inputs): 149 | 150 | d_64 = self.in_block(inputs) 151 | 152 | d_128 = self.down_block_1(d_64) 153 | d_256 = self.down_block_2(d_128) 154 | d_512 = self.down_block_3(d_256) 155 | d_1024 = self.down_block_4(d_512) 156 | 157 | u_512 = self.up_block_1([d_1024, d_512]) 158 | u_256 = self.up_block_2([u_512, d_256]) 159 | u_128 = self.up_block_3([u_256, d_128]) 160 | u_64 = self.up_block_4([u_128, d_64]) 161 | 162 | return self.out_block(u_64) 163 | -------------------------------------------------------------------------------- /quantnn/models/pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | qrnn.models.pytorch 3 | =================== 4 | 5 | This model provides Pytorch neural network models that can be used a backend 6 | models for the :py:class:`quantnn.QRNN` class. 7 | """ 8 | from quantnn.models.pytorch.common import ( 9 | CrossEntropyLoss, 10 | QuantileLoss, 11 | MSELoss, 12 | BatchedDataset, 13 | save_model, 14 | load_model, 15 | ) 16 | from quantnn.models.pytorch.common import PytorchModel as Model 17 | from quantnn.models.pytorch.fully_connected import FullyConnected 18 | from quantnn.models.pytorch.unet import UNet 19 | -------------------------------------------------------------------------------- /quantnn/models/pytorch/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | quantnn.models.pytorch.base 3 | =========================== 4 | 5 | Helper classes for pytorch models. 6 | """ 7 | 8 | 9 | class ParamCount: 10 | """ 11 | Mixin class for pytorch modules that add a 'n_params' attribute 12 | to the class. 13 | """ 14 | 15 | @property 16 | def n_params(self): 17 | return sum(p.numel() for p in self.parameters() if p.requires_grad) 18 | -------------------------------------------------------------------------------- /quantnn/models/pytorch/downsampling.py: -------------------------------------------------------------------------------- 1 | """ 2 | quantnn.models.pytorch.downsampling 3 | =================================== 4 | 5 | This module provides factory classes for downsampling modules. 6 | """ 7 | from typing import Union, Tuple, Callable 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from quantnn.models.pytorch.normalization import LayerNormFirst 13 | 14 | 15 | class ConvNextDownsamplerFactory: 16 | """ 17 | Downsampler factory consisting of layer normalization followed 18 | by strided convolution. 19 | """ 20 | 21 | def __call__( 22 | self, channels_in: int, channels_out: int, f_dwn: Union[int, Tuple[int, int]] 23 | ): 24 | """ 25 | Args: 26 | channels_in: The number of channels in the input. 27 | channels_out: The number of channels in the output. 28 | f_dwn: The downsampling factor. Can be a tuple if different 29 | downsampling factors should be applied along height and 30 | width of the image. 31 | """ 32 | return nn.Sequential( 33 | LayerNormFirst(channels_in), 34 | nn.Conv2d(channels_in, channels_out, kernel_size=f_dwn, stride=f_dwn), 35 | ) 36 | 37 | 38 | class PatchMergingBlock(nn.Module): 39 | """ 40 | Implements patch merging as employed in the Swin architecture. 41 | """ 42 | 43 | def __init__( 44 | self, 45 | channels_in: int, 46 | channels_out: int, 47 | f_dwn: Union[int, Tuple[int, int]], 48 | norm_factory: Callable[[int], nn.Module] = None, 49 | ): 50 | super().__init__() 51 | if isinstance(f_dwn, tuple): 52 | if f_dwn[0] != f_dwn[1]: 53 | raise ValueError( 54 | "Downsampling by patch merging only supports homogeneous " 55 | "downsampling factors." 56 | ) 57 | f_dwn = f_dwn[0] 58 | self.f_dwn = f_dwn 59 | channels_d = channels_in * f_dwn**2 60 | 61 | if norm_factory is None: 62 | norm_factory = LayerNormFirst 63 | 64 | self.norm = norm_factory(channels_d) 65 | 66 | if channels_d != channels_out: 67 | self.proj = nn.Conv2d(channels_d, channels_out, kernel_size=1) 68 | else: 69 | self.proj = nn.Identity() 70 | 71 | def forward(self, x: torch.Tensor) -> torch.Tensor: 72 | """ 73 | Downsample tensor. 74 | """ 75 | x_d = nn.functional.pixel_unshuffle(x, self.f_dwn) 76 | return self.proj(self.norm(x_d)) 77 | 78 | 79 | class PatchMergingFactory: 80 | """ 81 | A factory class to create patch merging downsamplers as employed 82 | by the swin architecture. 83 | """ 84 | 85 | def __call__( 86 | self, channels_in: int, channels_out: int, f_dwn: Union[int, Tuple[int, int]] 87 | ): 88 | return PatchMergingBlock(channels_in, channels_out, f_dwn) 89 | -------------------------------------------------------------------------------- /quantnn/models/pytorch/factories.py: -------------------------------------------------------------------------------- 1 | """ 2 | quantnn.models.pytorch.factories 3 | ================================ 4 | 5 | Factory objects to build neural networks. 6 | """ 7 | from torch import nn 8 | 9 | class MaxPooling: 10 | """ 11 | Factory for creating max-pooling downsampling layers. 12 | """ 13 | def __init__( 14 | self, 15 | kernel_size=(2, 2), 16 | project_first=True 17 | ): 18 | """ 19 | Args: 20 | project_first: If the channel during the downsampling 21 | are increased and project_first is True, channel 22 | numbers are increased prior to the downsampling 23 | operation. 24 | 25 | """ 26 | self.kernel_size = kernel_size 27 | self.project_first = project_first 28 | 29 | def __call__( 30 | self, 31 | channels_in, 32 | channels_out, 33 | f_down 34 | ): 35 | """ 36 | Args: 37 | channels_in: The number of input channels. 38 | channels_out: The number of input channels. 39 | f_down: The number of output channels. 40 | """ 41 | pool = nn.MaxPool2d( 42 | kernel_size=self.kernel_size, 43 | stride=f_down 44 | ) 45 | 46 | if channels_in == channels_out: 47 | return pool 48 | 49 | project = nn.Conv2d(channels_in, channels_out, kernel_size=1) 50 | 51 | if self.project_first: 52 | return nn.Sequential( 53 | project, 54 | pool 55 | ) 56 | return nn.Sequential( 57 | pool, 58 | project 59 | ) 60 | -------------------------------------------------------------------------------- /quantnn/models/pytorch/generative.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.functional import relu 4 | 5 | from quantnn.models.pytorch.xception import SymmetricPadding 6 | 7 | 8 | class DownBlock(nn.Module): 9 | def __init__(self, channels_in, channels_out): 10 | super().__init__() 11 | self.norm = nn.BatchNorm2d(channels_in) 12 | self.body = nn.Sequential( 13 | nn.ReLU(), 14 | SymmetricPadding(1), 15 | nn.Conv2d(channels_in, channels_in, 3), 16 | nn.BatchNorm2d(channels_in), 17 | SymmetricPadding(1), 18 | nn.Conv2d(channels_in, channels_out, 3), 19 | nn.ReLU(), 20 | SymmetricPadding(1), 21 | nn.MaxPool2d(3, stride=2) 22 | ) 23 | self.skip = nn.Sequential( 24 | SymmetricPadding(1), 25 | nn.Conv2d(channels_in, channels_out, 1), 26 | nn.MaxPool2d(3, stride=2) 27 | ) 28 | 29 | def forward(self, x): 30 | x_n = self.norm(x) 31 | y = self.body(x_n) 32 | return y + self.skip(x_n) 33 | 34 | 35 | class DownBlockSpectral(nn.Module): 36 | def __init__(self, channels_in, channels_out, relu_in=True): 37 | super().__init__() 38 | if relu_in: 39 | self.body = nn.Sequential( 40 | nn.ReLU(), 41 | SymmetricPadding([1, 2, 1, 2]), 42 | nn.utils.spectral_norm(nn.Conv2d(channels_in, channels_in, 4)), 43 | #nn.BatchNorm2d(channels_out), 44 | nn.ReLU(), 45 | SymmetricPadding([1, 2, 1, 2]), 46 | nn.utils.spectral_norm(nn.Conv2d(channels_in, channels_out, 4)), 47 | nn.AvgPool2d(2) 48 | ) 49 | else: 50 | self.body = nn.Sequential( 51 | SymmetricPadding([1, 2, 1, 2]), 52 | nn.utils.spectral_norm(nn.Conv2d(channels_in, channels_in, 4)), 53 | #nn.BatchNorm2d(channels_out), 54 | nn.ReLU(), 55 | SymmetricPadding([1, 2, 1, 2]), 56 | nn.utils.spectral_norm(nn.Conv2d(channels_in, channels_out, 4)), 57 | nn.AvgPool2d(2) 58 | ) 59 | 60 | self.skip = nn.Sequential( 61 | SymmetricPadding([0, 1, 0, 1]), 62 | nn.utils.spectral_norm(nn.Conv2d(channels_in, channels_out, 2)), 63 | nn.AvgPool2d(2) 64 | ) 65 | 66 | def forward(self, x): 67 | y = self.body(x) 68 | return y + self.skip(x) 69 | 70 | 71 | class GeneratorBlock(nn.Module): 72 | def __init__(self, channels_in, channels_out): 73 | super().__init__() 74 | self.body = nn.Sequential( 75 | nn.BatchNorm2d(channels_in), 76 | nn.LeakyReLU(0.1), 77 | SymmetricPadding([1, 2, 1, 2]), 78 | nn.utils.spectral_norm(nn.Conv2d(channels_in, channels_out, 4)), 79 | nn.BatchNorm2d(channels_out), 80 | nn.LeakyReLU(0.1), 81 | SymmetricPadding([1, 2, 1, 2]), 82 | nn.utils.spectral_norm(nn.Conv2d(channels_out, channels_out, 4)), 83 | ) 84 | self.skip = nn.Sequential( 85 | nn.Conv2d(channels_in, channels_out, 1), 86 | ) 87 | 88 | def forward(self, x): 89 | return self.body(x) + self.skip(x) 90 | 91 | class GeneratorBlockUp(nn.Module): 92 | def __init__(self, channels_in, channels_out): 93 | super().__init__() 94 | self.body = nn.Sequential( 95 | nn.BatchNorm2d(channels_in), 96 | nn.LeakyReLU(0.1), 97 | nn.Upsample(scale_factor=2, mode="nearest"), 98 | SymmetricPadding([1, 2, 1, 2]), 99 | nn.utils.spectral_norm(nn.Conv2d(channels_in, channels_out, 4)), 100 | nn.BatchNorm2d(channels_out), 101 | nn.LeakyReLU(0.1), 102 | SymmetricPadding([1, 2, 1, 2]), 103 | nn.utils.spectral_norm(nn.Conv2d(channels_out, channels_out, 4)), 104 | ) 105 | self.skip = nn.Sequential( 106 | nn.Upsample(scale_factor=2, mode="nearest"), 107 | nn.Conv2d(channels_in, channels_out, 1), 108 | ) 109 | 110 | def forward(self, x): 111 | return self.body(x) + self.skip(x) 112 | 113 | 114 | class ConditionalGenerator(nn.Module): 115 | def __init__(self, channels_in, channels, n_lat=8, output_range=1): 116 | super().__init__() 117 | 118 | self.n_lat = n_lat 119 | self.output_range = output_range 120 | 121 | self.conditioner = nn.ModuleList([ 122 | DownBlock(channels_in, channels), 123 | DownBlock(channels, channels), 124 | DownBlock(channels, channels), 125 | DownBlock(channels, channels) 126 | ]) 127 | 128 | self.generator = nn.ModuleList([ 129 | nn.Sequential( 130 | GeneratorBlock(channels + n_lat, channels), 131 | GeneratorBlockUp(channels, channels) 132 | ), 133 | nn.Sequential( 134 | GeneratorBlock(channels, channels), 135 | GeneratorBlockUp(channels, channels) 136 | ), 137 | nn.Sequential( 138 | GeneratorBlock(channels, channels), 139 | GeneratorBlockUp(channels, channels) 140 | ), 141 | nn.Sequential( 142 | GeneratorBlock(channels, channels), 143 | GeneratorBlockUp(channels, channels) 144 | ), 145 | ]) 146 | 147 | self.output = nn.Sequential( 148 | nn.BatchNorm2d(channels), 149 | nn.LeakyReLU(0.1), 150 | nn.utils.spectral_norm(nn.Conv2d(channels, 1, 1)) 151 | ) 152 | 153 | def forward(self, x, z=None): 154 | 155 | y = x 156 | for layer in self.conditioner: 157 | y = layer(y) 158 | 159 | n = x.shape[0] 160 | 161 | if z is None: 162 | z = torch.normal(0, 1, size=(n, self.n_lat, 4, 4)) 163 | else: 164 | z = z.reshape((n, self.n_lat, 4, 4)) 165 | 166 | input = None 167 | for layer in self.generator: 168 | if input is None: 169 | input = layer(torch.cat([y, z], 1)) 170 | else: 171 | input = layer(input) 172 | 173 | return self.output_range * torch.tanh(self.output(input)) 174 | 175 | 176 | class Discriminator(nn.Module): 177 | def __init__(self, channels_in, channels=64): 178 | super().__init__() 179 | 180 | self.body = nn.Sequential( 181 | DownBlockSpectral(channels_in + 1, channels, relu_in=False), 182 | DownBlockSpectral(channels, channels), 183 | DownBlockSpectral(channels, channels), 184 | DownBlockSpectral(channels, channels), 185 | DownBlockSpectral(channels, channels), 186 | nn.ReLU(0.1) 187 | ) 188 | 189 | self.classifier = nn.Sequential( 190 | nn.utils.spectral_norm(nn.Linear(channels, 1)), 191 | ) 192 | 193 | def forward(self, x, y): 194 | if y.ndim < x.ndim: 195 | y = y.unsqueeze(1) 196 | x_in = torch.cat([x, y], 1) 197 | y = self.body(x_in).sum((-2, -1)) 198 | return self.classifier(y) 199 | 200 | 201 | def hinge_loss(d_real, d_fake): 202 | return relu(1.0 - d_real).mean() + relu(1.0 + d_fake).mean() 203 | 204 | def training_step(x, y, g, d, opt_g, opt_d): 205 | y = y.unsqueeze(1) 206 | 207 | batch_size = x.size(0) 208 | z = torch.randn(batch_size, g.n_lat, 4, 4).cuda() 209 | 210 | # Train discriminator 211 | opt_d.zero_grad() 212 | opt_g.zero_grad() 213 | d_real = d(x, y) 214 | d_fake = d(x, g(x, z)) 215 | d_loss = hinge_loss(d_real, d_fake) 216 | d_loss.backward() 217 | opt_d.step() 218 | 219 | # Train generator 220 | opt_d.zero_grad() 221 | opt_g.zero_grad() 222 | z = torch.randn(batch_size, g.n_lat, 4, 4).cuda() 223 | fake = g(x, z) 224 | g_loss = -d(x, fake).mean() 225 | mse_loss = ((fake - y) ** 2).mean() 226 | g_loss = g_loss #+ mse_loss 227 | 228 | g_loss.backward() 229 | opt_g.step() 230 | 231 | return g_loss.item(), d_loss.item(), mse_loss.item() 232 | -------------------------------------------------------------------------------- /quantnn/models/pytorch/logging.py: -------------------------------------------------------------------------------- 1 | """ 2 | ============================== 3 | quantnn.models.pytorch.logging 4 | ============================== 5 | 6 | This module contains training logger that are specific for the 7 | PyTorch backend. 8 | """ 9 | import torch 10 | from torch.utils.tensorboard.writer import SummaryWriter 11 | from torch.utils.tensorboard.summary import hparams 12 | import xarray as xr 13 | 14 | 15 | from quantnn.logging import TrainingLogger 16 | 17 | 18 | class SummaryWriter(SummaryWriter): 19 | """ 20 | Specialization of torch original SummaryWriter that overrides 'add_params' 21 | to avoid creating a new directory to store the hyperparameters. 22 | 23 | Source: https://github.com/pytorch/pytorch/issues/32651 24 | """ 25 | 26 | def add_hparams(self, hparam_dict, metric_dict, epoch): 27 | torch._C._log_api_usage_once("tensorboard.logging.add_hparams") 28 | if type(hparam_dict) is not dict or type(metric_dict) is not dict: 29 | raise TypeError("hparam_dict and metric_dict should be dictionary.") 30 | exp, ssi, sei = hparams(hparam_dict, metric_dict) 31 | 32 | logdir = self._get_file_writer().get_logdir() 33 | 34 | with SummaryWriter(log_dir=logdir) as w_hp: 35 | w_hp.file_writer.add_summary(exp) 36 | w_hp.file_writer.add_summary(ssi) 37 | w_hp.file_writer.add_summary(sei) 38 | for k, v in metric_dict.items(): 39 | w_hp.add_scalar(k, v, epoch) 40 | 41 | 42 | class TensorBoardLogger(TrainingLogger): 43 | """ 44 | Logger that also logs information to tensor board. 45 | """ 46 | 47 | def __init__( 48 | self, n_epochs, log_rate=100, log_directory=None, epoch_begin_callback=None 49 | ): 50 | """ 51 | Create a new logger instance. 52 | 53 | Args: 54 | n_epochs: The number of epochs for which the training will last. 55 | log_rate: The message rate for output to standard out. 56 | log_directory: The directory to use for tensorboard output. 57 | epoch_begin_callback: Callback function the will be called with 58 | arguments ``writer, model``, where ``writer`` is the current 59 | ``torch.utils.tensorboard.writer.SummaryWriter`` object used 60 | used to write output and ``model`` is the model that is being 61 | in its current state. 62 | """ 63 | super().__init__(n_epochs, log_rate) 64 | self.writer = SummaryWriter(log_dir=log_directory) 65 | self.epoch_begin_callback = epoch_begin_callback 66 | self.attributes = None 67 | 68 | def set_attributes(self, attributes): 69 | """ 70 | Stores attributes that describe the training in the logger. 71 | These will be stored in the logger history. 72 | 73 | Args: 74 | Dictionary of attributes to store in the history of the 75 | logger. 76 | """ 77 | super().set_attributes(attributes) 78 | 79 | def epoch_begin(self, model): 80 | """ 81 | Called at the beginning of each epoch. 82 | 83 | Args: 84 | The model that is trained in its current state. 85 | """ 86 | TrainingLogger.epoch_begin(self, model) 87 | if self.epoch_begin_callback: 88 | self.epoch_begin_callback(self.writer, model, self.i_epoch) 89 | 90 | def training_step(self, loss, n_samples, of=None, losses=None): 91 | """ 92 | Log processing of a training batch. This method should be called 93 | after each batch is processed so that the logger can keep track 94 | of training progress. 95 | 96 | Args: 97 | loss: The loss of the current batch. 98 | n_samples: The number of samples in the batch. 99 | of: If available the number of batches in the epoch. 100 | """ 101 | super().training_step(loss, n_samples, of=of, losses=losses) 102 | 103 | def validation_step(self, loss, n_samples, of=None, losses=None): 104 | """ 105 | Log processing of a validation batch. 106 | 107 | Args: 108 | i: The index of the current batch. 109 | loss: The loss of the current batch. 110 | n_samples: The number of samples in the batch. 111 | of: If available the number of batches in the epoch. 112 | """ 113 | super().validation_step(loss, n_samples, of=of, losses=losses) 114 | 115 | def epoch(self, learning_rate=None, metrics=None): 116 | """ 117 | Log processing of epoch. 118 | 119 | Args: 120 | learning_rate: If available the learning rate of the optimizer. 121 | """ 122 | TrainingLogger.epoch(self, learning_rate, metrics=metrics) 123 | self.writer.add_scalar("Learning rate", learning_rate, self.i_epoch) 124 | 125 | for name, v in self.history.variables.items(): 126 | if name == "epochs": 127 | continue 128 | if len(v.dims) == 1: 129 | value = v.data[-1] 130 | self.writer.add_scalar(name, value, self.i_epoch) 131 | 132 | if metrics is not None: 133 | for m in metrics: 134 | if hasattr(m, "get_figures"): 135 | figures = m.get_figures() 136 | if isinstance(figures, dict): 137 | for target in figures.keys(): 138 | f = figures[target] 139 | self.writer.add_figure( 140 | f"{m.name} ({target})", f, self.i_epoch 141 | ) 142 | else: 143 | self.writer.add_figure(f"{m.name}", figures, self.i_epoch) 144 | 145 | def training_end(self): 146 | """ 147 | Called to signal the end of the training to the logger. 148 | """ 149 | if self.attributes is not None: 150 | if self.i_epoch >= self.n_epochs: 151 | metrics = {} 152 | for name, v in self.history.variables.items(): 153 | if name == "epochs": 154 | continue 155 | if len(v.dims) == 1: 156 | metrics[name + "_final"] = v.data[-1] 157 | self.writer.add_hparams(self.attributes, {}, self.i_epoch) 158 | self.writer.flush() 159 | 160 | def __del__(self): 161 | # Extract metric values for hyper parameters. 162 | super().__del__() 163 | -------------------------------------------------------------------------------- /quantnn/models/pytorch/normalization.py: -------------------------------------------------------------------------------- 1 | """ 2 | quantnn.models.pytorch.normalization 3 | ==================================== 4 | 5 | This module implements normalization layers. 6 | """ 7 | import torch 8 | from torch import nn 9 | 10 | 11 | class LayerNormFirst(nn.Module): 12 | """ 13 | Layer norm that normalizes the first dimension 14 | """ 15 | 16 | def __init__( 17 | self, 18 | n_channels, 19 | eps=1e-6 20 | ): 21 | """ 22 | Args: 23 | n_channels: The number of channels in the input. 24 | eps: Epsilon added to variance to avoid numerical issues. 25 | """ 26 | super().__init__() 27 | self.scaling = nn.Parameter(torch.ones(n_channels), requires_grad=True) 28 | self.bias = nn.Parameter(torch.zeros(n_channels), requires_grad=True) 29 | self.eps = eps 30 | 31 | def forward(self, x): 32 | """ 33 | Apply normalization to x. 34 | """ 35 | mu = x.mean(1, keepdim=True) 36 | var = (x - mu).pow(2).mean(1, keepdim=True) 37 | x_n = (x - mu) / torch.sqrt(var + self.eps) 38 | x = self.scaling[..., None, None] * x_n + self.bias[..., None, None] 39 | return x 40 | 41 | 42 | class GRN(nn.Module): 43 | """ 44 | Global Response Normalization (GRN) as proposed in https://openaccess.thecvf.com/content/CVPR2023/html/Woo_ConvNeXt_V2_Co-Designing_and_Scaling_ConvNets_With_Masked_Autoencoders_CVPR_2023_paper.html 45 | """ 46 | def __init__(self, n_channels, eps=1e-6): 47 | """ 48 | n_channels: Number of channels over which to normalize. 49 | eps: Epsilon added to mean to avoid numerical issues. 50 | """ 51 | super().__init__() 52 | self.scaling = nn.Parameter(torch.zeros(1, n_channels, 1, 1), requires_grad=True) 53 | self.bias = nn.Parameter(torch.zeros(1, n_channels, 1, 1), requires_grad=True) 54 | 55 | def forward(self, x): 56 | """ 57 | Apply normalization to x. 58 | """ 59 | x_l2 = torch.norm(x, p=2, dim=(-2, -1), keepdim=True) 60 | rel_imp = x_l2 / (x_l2.mean(dim=1, keepdim=True) + 1e-6) 61 | return self.scaling * (x * rel_imp) + self.bias + x 62 | -------------------------------------------------------------------------------- /quantnn/models/pytorch/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | quantnn.models.pytorch.resnet 3 | ============================= 4 | 5 | This module provides an implementation of a fully-covolutional 6 | [ResNet]_-based decoder-encoder architecture. 7 | 8 | 9 | .. [ResNet] Deep Residual Learning for Image Recognition 10 | """ 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | def _conv2(channels_in, channels_out, kernel_size): 16 | """ 17 | Convolution with reflective padding to keep image size constant. 18 | """ 19 | return nn.Conv2d( 20 | channels_in, 21 | channels_out, 22 | kernel_size=kernel_size, 23 | padding=kernel_size // 2, 24 | padding_mode="reflect", 25 | ) 26 | 27 | 28 | def _conv2_down(channels_in, channels_out, kernel_size): 29 | """ 30 | Convolution combined with downsampling and reflective padding to 31 | decrease input size by a factor of 2. 32 | """ 33 | return nn.Conv2d( 34 | channels_in, 35 | channels_out, 36 | kernel_size=kernel_size, 37 | padding=kernel_size // 2, 38 | stride=2, 39 | padding_mode="reflect", 40 | ) 41 | 42 | 43 | class ResidualBlock(nn.Module): 44 | """ 45 | A residual block consists of either two or three convolution operations 46 | followed by batch norm and relu activation together with an identity 47 | mapping connecting the input and the activation feeding into the 48 | last ReLU layer. 49 | """ 50 | 51 | def __init__(self, channels_in, channels_out, bottleneck=None, downsample=False): 52 | """ 53 | Create new convolution block. 54 | 55 | Args: 56 | channels_in: The number of input channels. 57 | channels_out: The number of output channels. 58 | bottleneck: Whether to apply a bottle neck to reduce the 59 | number of parameters. 60 | downsample: If true the image dimensions are reduced by 61 | a factor of two. 62 | """ 63 | super().__init__() 64 | self.downsample = downsample 65 | if bottleneck is None: 66 | self.block = nn.Sequential( 67 | _conv2(channels_in, channels_out, 3), 68 | nn.BatchNorm2d(channels_out), 69 | nn.ReLU(inplace=True), 70 | ( 71 | _conv2_down(channels_out, channels_out, 3) 72 | if downsample 73 | else _conv2(channels_out, channels_out, 3) 74 | ), 75 | nn.BatchNorm2d(channels_out), 76 | ) 77 | else: 78 | self.block = nn.Sequential( 79 | _conv2(channels_in, bottleneck, 1), 80 | nn.BatchNorm2d(bottleneck), 81 | nn.ReLU(inplace=True), 82 | _conv2(bottleneck, bottleneck, 3), 83 | nn.BatchNorm2d(bottleneck), 84 | nn.ReLU(inplace=True), 85 | ( 86 | _conv2_down(bottleneck, channels_out, 1) 87 | if downsample 88 | else _conv2(bottleneck, channels_out, 1) 89 | ), 90 | nn.BatchNorm2d(channels_out), 91 | ) 92 | self.activation = nn.ReLU(inplace=True) 93 | 94 | self.projection = None 95 | if channels_in != channels_out: 96 | if downsample: 97 | self.projection = _conv2_down(channels_in, channels_out, 1) 98 | else: 99 | self.projection = _conv2(channels_in, channels_out, 1) 100 | 101 | def forward(self, x): 102 | """ 103 | Propagate input through block. 104 | """ 105 | y = self.block(x) 106 | if self.projection: 107 | x = self.projection(x) 108 | return self.activation(y + x) 109 | 110 | 111 | class DownSamplingBlock(nn.Module): 112 | """ 113 | UNet downsampling block consisting of strided convolution followed 114 | by given number of residual blocks. 115 | """ 116 | 117 | def __init__(self, channels_in, channels_out, n_blocks, bottleneck=None): 118 | super().__init__() 119 | modules = [ 120 | ResidualBlock( 121 | channels_in, channels_out, bottleneck=bottleneck, downsample=True 122 | ) 123 | ] * (n_blocks - 1) 124 | modules += [ 125 | ResidualBlock(channels_out, channels_out, bottleneck=bottleneck) 126 | ] * (n_blocks - 1) 127 | self.block = nn.Sequential(*modules) 128 | 129 | def forward(self, x): 130 | """Propagate input through block.""" 131 | return self.block(x) 132 | 133 | 134 | class UpSamplingBlock(nn.Module): 135 | """ 136 | ResNet upsampling block consisting of linear interpolation 137 | followed by given number of residual blocks. 138 | """ 139 | 140 | def __init__( 141 | self, channels_in, channels_skip, channels_out, n_blocks, bottleneck=None 142 | ): 143 | super().__init__() 144 | self.upscaling = nn.Upsample( 145 | scale_factor=2, mode="bilinear", align_corners=True 146 | ) 147 | 148 | modules = [ 149 | ResidualBlock( 150 | channels_in + channels_skip, channels_out, bottleneck=bottleneck 151 | ) 152 | ] 153 | modules += [ 154 | ResidualBlock(channels_out, channels_out, bottleneck=bottleneck) 155 | ] * (n_blocks - 1) 156 | self.block = nn.Sequential(*modules) 157 | 158 | def forward(self, x, x_skip): 159 | """Propagate input through block.""" 160 | x = self.upscaling(x) 161 | x = torch.cat([x, x_skip], dim=1) 162 | return self.block(x) 163 | 164 | 165 | class ResNet(nn.Module): 166 | """ 167 | Decoder-encoder network using residual blocks. 168 | 169 | The ResNet class implements a fully-convolutional decoder-encoder 170 | network for point-to-point regression tasks. 171 | 172 | The network consists of 5 downsampling blocks followed by the 173 | same number of upsampling blocks. The first downsampling block 174 | consists of a 7x7 convolution with stride two followed by batch 175 | norm and ReLU activation. All following block are residual blocks 176 | with bottlenecks. 177 | """ 178 | 179 | def __init__(self, n_inputs, n_outputs, blocks=2): 180 | 181 | super().__init__() 182 | self.n_inputs = n_inputs 183 | self.n_outputs = n_outputs 184 | 185 | self.in_block = nn.Sequential( 186 | _conv2_down(n_inputs, 128, 7), nn.BatchNorm2d(128), nn.ReLU() 187 | ) 188 | 189 | if type(blocks) is int: 190 | blocks = 4 * [blocks] 191 | 192 | self.down_block_1 = DownSamplingBlock(128, 256, blocks[0], bottleneck=256) 193 | self.down_block_2 = DownSamplingBlock(256, 512, blocks[1], bottleneck=256) 194 | self.down_block_3 = DownSamplingBlock(512, 1024, blocks[2], bottleneck=256) 195 | self.down_block_4 = DownSamplingBlock(1024, 2048, blocks[3], bottleneck=512) 196 | 197 | self.up_block_1 = UpSamplingBlock(2048, 1024, 1024, blocks[3], bottleneck=256) 198 | self.up_block_2 = UpSamplingBlock(1024, 512, 512, blocks[2], bottleneck=256) 199 | self.up_block_3 = UpSamplingBlock(512, 256, 256, blocks[1], bottleneck=256) 200 | self.up_block_4 = UpSamplingBlock(256, 128, n_outputs, blocks[0]) 201 | self.up_block_5 = UpSamplingBlock(n_outputs, n_inputs, n_outputs, blocks[0]) 202 | 203 | self.out_block = nn.Sequential( 204 | _conv2(n_outputs, n_outputs, 1), 205 | nn.BatchNorm2d(n_outputs), 206 | nn.ReLU(), 207 | _conv2(n_outputs, n_outputs, 1), 208 | ) 209 | 210 | def forward(self, x): 211 | """Propagate input through resnet.""" 212 | 213 | d_0 = self.in_block(x) 214 | d_1 = self.down_block_1(d_0) 215 | d_2 = self.down_block_2(d_1) 216 | d_3 = self.down_block_3(d_2) 217 | d_4 = self.down_block_4(d_3) 218 | 219 | u_4 = self.up_block_1(d_4, d_3) 220 | u_3 = self.up_block_2(u_4, d_2) 221 | u_2 = self.up_block_3(u_3, d_1) 222 | u_1 = self.up_block_4(u_2, d_0) 223 | u_0 = self.up_block_5(u_1, x) 224 | 225 | return self.out_block(u_0) 226 | -------------------------------------------------------------------------------- /quantnn/models/pytorch/stages.py: -------------------------------------------------------------------------------- 1 | """ 2 | quantnn.models.pytorch.stages 3 | ============================= 4 | 5 | Implements generic stages used in back-bone models. 6 | """ 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from quantnn.models.pytorch import blocks 12 | 13 | 14 | class AggregationTreeNode(nn.Module): 15 | """ 16 | Represents a node in an aggregation tree. 17 | 18 | Aggregation in all right child nodes are combined so that 19 | aggregation is in principle only performed in the nodes at 20 | the lowest level. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | channels_in, 26 | channels_out, 27 | level, 28 | block_factory, 29 | aggregator_factory, 30 | channels_agg=0, 31 | downsample=1, 32 | block_args=None, 33 | block_kwargs=None, 34 | ): 35 | super().__init__() 36 | 37 | if block_args is None: 38 | block_args = [] 39 | if block_kwargs is None: 40 | block_kwargs = {} 41 | 42 | if channels_agg == 0: 43 | channels_agg = 2 * channels_out 44 | 45 | self.aggregator = None 46 | self.level = level 47 | if level <= 0: 48 | self.left = block_factory( 49 | channels_in, 50 | channels_out, 51 | *block_args, 52 | downsample=downsample, 53 | **block_kwargs, 54 | ) 55 | self.right = None 56 | elif level == 1: 57 | self.aggregator = aggregator_factory(channels_agg, channels_out) 58 | self.left = block_factory( 59 | channels_in, 60 | channels_out, 61 | *block_args, 62 | downsample=downsample, 63 | **block_kwargs, 64 | ) 65 | self.right = block_factory( 66 | channels_out, 67 | channels_out, 68 | *block_args, 69 | downsample=1, 70 | **block_kwargs, 71 | ) 72 | else: 73 | self.aggregator = None 74 | self.left = AggregationTreeNode( 75 | channels_in, 76 | channels_out, 77 | level - 1, 78 | block_factory, 79 | aggregator_factory, 80 | downsample=downsample, 81 | block_args=block_args, 82 | block_kwargs=block_kwargs, 83 | ) 84 | self.right = AggregationTreeNode( 85 | channels_out, 86 | channels_out, 87 | level - 1, 88 | block_factory, 89 | aggregator_factory, 90 | channels_agg=channels_agg + channels_out, 91 | downsample=1, 92 | block_args=block_args, 93 | block_kwargs=block_kwargs, 94 | ) 95 | 96 | def forward(self, x, pass_through=None): 97 | """ 98 | Forward input through tree and aggregate results from child nodes. 99 | """ 100 | if self.level <= 0: 101 | return self.left(x) 102 | 103 | if pass_through is None: 104 | pass_through = [] 105 | 106 | y_1 = self.left(x) 107 | if self.aggregator is not None: 108 | y_2 = self.right(y_1) 109 | pass_through = pass_through + [y_1, y_2] 110 | return self.aggregator(torch.cat(pass_through, 1)) 111 | 112 | return self.right(y_1, pass_through + [y_1]) 113 | 114 | 115 | class AggregationTreeRoot(AggregationTreeNode): 116 | """ 117 | Root of an aggregation tree. 118 | """ 119 | 120 | def __init__( 121 | self, 122 | channels_in, 123 | channels_out, 124 | tree_height, 125 | block_factory, 126 | aggregator_factory, 127 | downsample=1, 128 | block_args=None, 129 | block_kwargs=None, 130 | ): 131 | channels_agg = 2 * channels_out + channels_in 132 | super().__init__( 133 | channels_in, 134 | channels_out, 135 | tree_height, 136 | block_factory, 137 | aggregator_factory, 138 | channels_agg=channels_agg, 139 | downsample=1, 140 | ) 141 | 142 | self.downsampler = None 143 | if downsample > 1: 144 | self.downsampler = nn.MaxPool2d(kernel_size=downsample) 145 | 146 | def forward(self, x): 147 | """ 148 | Forward input through tree and aggregate results from child nodes. 149 | """ 150 | if self.level == 0: 151 | return self.left(x) 152 | 153 | if self.downsampler is not None: 154 | x = self.downsampler(x) 155 | 156 | pass_through = [] 157 | 158 | y_1 = self.left(x) 159 | if self.aggregator is not None: 160 | y_2 = self.right(y_1) 161 | pass_through = pass_through + [y_1, y_2, x] 162 | return self.aggregator(torch.cat(pass_through, 1)) 163 | 164 | return self.right(y_1, pass_through + [y_1, x]) 165 | 166 | 167 | class AggregationTreeFactory: 168 | """ 169 | An aggregation tree implementing hierarchical aggregation of blocks in a stage. 170 | """ 171 | 172 | def __init__(self, aggregator_factory=None): 173 | if aggregator_factory is None: 174 | aggregator_factory = blocks.ConvBlockFactory( 175 | norm_factory=nn.BatchNorm2d, activation_factory=nn.ReLU 176 | ) 177 | self.aggregator_factory = aggregator_factory 178 | 179 | def __call__( 180 | self, 181 | channels_in, 182 | channels_out, 183 | n_blocks, 184 | block_factory, 185 | downsample=1, 186 | block_args=None, 187 | block_kwargs=None, 188 | ): 189 | n_levels = np.log2(n_blocks) 190 | return AggregationTreeRoot( 191 | channels_in, 192 | channels_out, 193 | n_levels, 194 | block_factory, 195 | self.aggregator_factory, 196 | downsample=downsample, 197 | block_args=None, 198 | block_kwargs=None, 199 | ) 200 | -------------------------------------------------------------------------------- /quantnn/models/pytorch/unet.py: -------------------------------------------------------------------------------- 1 | """ 2 | quantnn.models.pytorch.unet 3 | =========================== 4 | 5 | This module provides an implementation of the UNet [unet]_ 6 | architecture. 7 | 8 | .. [unet] O. Ronneberger, P. Fischer and T. Brox, "U-net: Convolutional networks 9 | for biomedical image segmentation", Proc. Int. Conf. Med. Image Comput. 10 | Comput.-Assist. Intervent. (MICCAI), pp. 234-241, 2015. 11 | """ 12 | import torch 13 | import torch.nn as nn 14 | 15 | 16 | def _conv2(channels_in, channels_out, kernel_size): 17 | """2D convolution with padding to keep image size constant. """ 18 | return nn.Conv2d( 19 | channels_in, 20 | channels_out, 21 | kernel_size=kernel_size, 22 | padding=kernel_size // 2, 23 | padding_mode="reflect", 24 | ) 25 | 26 | 27 | class ConvolutionBlock(nn.Module): 28 | """ 29 | A convolution block consisting of a pair of 2x2 30 | convolutions followed by a batch normalization layer and 31 | ReLU activaitons. 32 | """ 33 | 34 | def __init__(self, channels_in, channels_out): 35 | """ 36 | Create new convolution block. 37 | 38 | Args: 39 | channels_in: The number of input channels. 40 | channels_out: The number of output channels. 41 | """ 42 | super().__init__() 43 | self.block = nn.Sequential( 44 | _conv2(channels_in, channels_out, 3), 45 | nn.BatchNorm2d(channels_out), 46 | nn.ReLU(inplace=True), 47 | _conv2(channels_out, channels_out, 3), 48 | nn.BatchNorm2d(channels_out), 49 | nn.ReLU(inplace=True), 50 | ) 51 | 52 | def forward(self, x): 53 | """Propagate input through layer.""" 54 | return self.block(x) 55 | 56 | 57 | class DownsamplingBlock(nn.Module): 58 | """ 59 | UNet downsampling block consisting of 2x2 max-pooling followed 60 | by a convolution block. 61 | """ 62 | 63 | def __init__(self, channels_in, channels_out): 64 | super().__init__() 65 | self.block = nn.Sequential( 66 | nn.MaxPool2d(2), ConvolutionBlock(channels_in, channels_out) 67 | ) 68 | 69 | def forward(self, x): 70 | """Propagate input through block.""" 71 | return self.block(x) 72 | 73 | 74 | class UpsamplingBlock(nn.Module): 75 | """ 76 | UNet upsampling block consisting bilinear interpolation followed 77 | by a 1x1 convolution to decrease the channel dimensions and followed 78 | by a UNet convolution block. 79 | """ 80 | 81 | def __init__(self, channels_in, channels_out): 82 | super().__init__() 83 | self.upscaling = nn.Upsample( 84 | scale_factor=2, mode="bilinear", align_corners=True 85 | ) 86 | self.reduce = _conv2(channels_in, channels_in // 2, 3) 87 | self.conv = ConvolutionBlock(channels_in, channels_out) 88 | 89 | def forward(self, x, x_skip): 90 | """Propagate input through block.""" 91 | x = self.reduce(self.upscaling(x)) 92 | x = torch.cat([x, x_skip], dim=1) 93 | return self.conv(x) 94 | 95 | 96 | class UNet(nn.Module): 97 | """ 98 | PyTorch implementation of UNet, consisting of 4 downsampling 99 | blocks followed by 4 upsampling blocks and skip connection between 100 | down- and upsampling blocks of matching output and input size. 101 | 102 | The core of each down and upsampling block consists of two 103 | 2D 3x3 convolution followed by batch norm and ReLU activation 104 | functions. 105 | """ 106 | 107 | def __init__(self, n_inputs, n_outputs): 108 | """ 109 | Args: 110 | n_input: The number of input channels. 111 | n_outputs: The number of output channels. 112 | """ 113 | super().__init__() 114 | self.n_inputs = n_inputs 115 | self.n_outputs = n_outputs 116 | 117 | self.in_block = ConvolutionBlock(n_inputs, 64) 118 | 119 | self.down_block_1 = DownsamplingBlock(64, 128) 120 | self.down_block_2 = DownsamplingBlock(128, 256) 121 | self.down_block_3 = DownsamplingBlock(256, 512) 122 | self.down_block_4 = DownsamplingBlock(512, 1024) 123 | 124 | self.up_block_1 = UpsamplingBlock(1024, 512) 125 | self.up_block_2 = UpsamplingBlock(512, 256) 126 | self.up_block_3 = UpsamplingBlock(256, 128) 127 | self.up_block_4 = UpsamplingBlock(128, n_outputs) 128 | 129 | self.out_block = _conv2(n_outputs, n_outputs, 1) 130 | 131 | def forward(self, x): 132 | """Propagate input through network.""" 133 | 134 | d_64 = self.in_block(x) 135 | d_128 = self.down_block_1(d_64) 136 | d_256 = self.down_block_2(d_128) 137 | d_512 = self.down_block_3(d_256) 138 | d_1024 = self.down_block_4(d_512) 139 | 140 | u_512 = self.up_block_1(d_1024, d_512) 141 | u_256 = self.up_block_2(u_512, d_256) 142 | u_128 = self.up_block_3(u_256, d_128) 143 | u_out = self.up_block_4(u_128, d_64) 144 | 145 | return self.out_block(u_out) 146 | -------------------------------------------------------------------------------- /quantnn/models/pytorch/upsampling.py: -------------------------------------------------------------------------------- 1 | """ 2 | quantnn.models.pytorch.upsampling 3 | ================================= 4 | 5 | Upsampling factories. 6 | """ 7 | from torch import nn 8 | 9 | 10 | class BilinearFactory: 11 | """ 12 | A factory for producing bilinear upsampling layers. 13 | """ 14 | 15 | def __call__(self, channels_in, channels_out, factor): 16 | if channels_in is not None and channels_out is not None: 17 | if channels_in != channels_out: 18 | return nn.Sequential( 19 | nn.Conv2d(channels_in, channels_out, kernel_size=1), 20 | nn.UpsamplingBilinear2d(scale_factor=factor), 21 | ) 22 | return nn.UpsamplingBilinear2d(scale_factor=factor) 23 | 24 | 25 | class UpsampleFactory: 26 | """ 27 | A factory for torch's generic upsampling layer. 28 | """ 29 | 30 | def __init__(self, **kwargs): 31 | self.kwargs = kwargs 32 | 33 | def __call__(self, channels_in, channels_out, factor): 34 | if channels_in is not None and channels_out is not None: 35 | if channels_in != channels_out: 36 | return nn.Sequential( 37 | nn.Conv2d(channels_in, channels_out, kernel_size=1), 38 | nn.Upsample(scale_factor=factor, **self.kwargs), 39 | ) 40 | return nn.Upsample(scale_factor=factor, **self.kwargs) 41 | 42 | 43 | class UpConvolutionFactory: 44 | """ 45 | Factory for up-convolution upsampling as used in the original 46 | U-Net architecture. 47 | """ 48 | 49 | def __init__(self, **kwargs): 50 | self.kwargs = kwargs 51 | 52 | def __call__(self, channels_in, channels_out, factor): 53 | return nn.Sequential( 54 | nn.Upsample(scale_factor=factor, **self.kwargs), 55 | nn.Conv2d(channels_in, channels_out, kernel_size=factor, padding="same"), 56 | ) 57 | -------------------------------------------------------------------------------- /quantnn/models/pytorch/xception.py: -------------------------------------------------------------------------------- 1 | """ 2 | =============================== 3 | quantnn.models.pytorch.xception 4 | =============================== 5 | 6 | PyTorch neural network models based on the Xception architecture. 7 | """ 8 | import torch 9 | from torch import nn 10 | 11 | 12 | class SymmetricPadding(nn.Module): 13 | """ 14 | Network module implementing symmetric padding. 15 | 16 | This is just a wrapper around torch's ``nn.functional.pad`` with mode 17 | set to 'replicate'. 18 | """ 19 | 20 | def __init__(self, amount): 21 | super().__init__() 22 | if isinstance(amount, int): 23 | self.amount = [amount] * 4 24 | else: 25 | self.amount = amount 26 | 27 | def forward(self, x): 28 | return nn.functional.pad(x, self.amount, "replicate") 29 | 30 | 31 | class SeparableConv3x3(nn.Sequential): 32 | """ 33 | Depth-wise separable convolution using with kernel size 3x3. 34 | """ 35 | 36 | def __init__(self, channels_in, channels_out): 37 | super().__init__( 38 | nn.Conv2d( 39 | channels_in, 40 | channels_in, 41 | kernel_size=3, 42 | groups=channels_in, 43 | padding=1, 44 | padding_mode="replicate", 45 | ), 46 | nn.Conv2d(channels_in, channels_out, kernel_size=1), 47 | ) 48 | 49 | 50 | class XceptionBlock(nn.Module): 51 | """ 52 | Xception block consisting of two depth-wise separable convolutions 53 | each folowed by batch-norm and GELU activations. 54 | """ 55 | 56 | def __init__(self, channels_in, channels_out, downsample=False): 57 | """ 58 | Args: 59 | channels_in: The number of incoming channels. 60 | channels_out: The number of outgoing channels. 61 | downsample: Whether or not to insert 3x3 max pooling block 62 | after the first convolution. 63 | """ 64 | super().__init__() 65 | if downsample: 66 | self.block_1 = nn.Sequential( 67 | SeparableConv3x3(channels_in, channels_out), 68 | nn.GroupNorm(1, channels_out), 69 | SymmetricPadding(1), 70 | nn.MaxPool2d(kernel_size=3, stride=2), 71 | nn.GELU(), 72 | ) 73 | else: 74 | self.block_1 = nn.Sequential( 75 | SeparableConv3x3(channels_in, channels_out), 76 | nn.GroupNorm(1, channels_out), 77 | nn.GELU(), 78 | ) 79 | 80 | self.block_2 = nn.Sequential( 81 | SeparableConv3x3(channels_out, channels_out), 82 | nn.GroupNorm(1, channels_out), 83 | nn.GELU(), 84 | ) 85 | 86 | if channels_in != channels_out or downsample: 87 | if downsample: 88 | self.projection = nn.Conv2d(channels_in, channels_out, 1, stride=2) 89 | else: 90 | self.projection = nn.Conv2d(channels_in, channels_out, 1) 91 | else: 92 | self.projection = None 93 | 94 | def forward(self, x): 95 | """ 96 | Propagate input through block. 97 | """ 98 | if self.projection is None: 99 | x_proj = x 100 | else: 101 | x_proj = self.projection(x) 102 | y = self.block_2(self.block_1(x)) 103 | return torch.add(x_proj,y ) 104 | 105 | 106 | class DownsamplingBlock(nn.Sequential): 107 | """ 108 | Xception downsampling block. 109 | """ 110 | 111 | def __init__(self, n_channels, n_blocks): 112 | blocks = [XceptionBlock(n_channels, n_channels, downsample=True)] 113 | for i in range(n_blocks): 114 | blocks.append(XceptionBlock(n_channels, n_channels)) 115 | super().__init__(*blocks) 116 | 117 | 118 | class UpsamplingBlock(nn.Module): 119 | """ 120 | Xception upsampling block. 121 | """ 122 | 123 | def __init__(self, n_channels, skip_connections=True): 124 | """ 125 | Args: 126 | n_channels: The number of incoming and outgoing channels. 127 | """ 128 | super().__init__() 129 | self.upsample = nn.Upsample(mode="bilinear", 130 | scale_factor=2, 131 | align_corners=False) 132 | n_channels_in = n_channels * 2 if skip_connections else n_channels 133 | self.block = nn.Sequential( 134 | SeparableConv3x3(n_channels_in, n_channels), 135 | nn.GroupNorm(1, n_channels), 136 | nn.GELU(), 137 | ) 138 | self.projection = nn.Conv2d(n_channels_in, n_channels, 1,) 139 | 140 | 141 | def forward(self, x, x_skip=None): 142 | """ 143 | Propagate input through block. 144 | """ 145 | x_up = self.upsample(x) 146 | if x_skip is not None: 147 | x_merged = torch.cat([x_up, x_skip], 1) 148 | else: 149 | x_merged = x_up 150 | return torch.add(self.block(x_merged), self.projection(x_merged)) 151 | 152 | 153 | class XceptionFpn(nn.Module): 154 | """ 155 | Feature pyramid network (FPN) with 5 stages based on xception 156 | architecture. 157 | """ 158 | 159 | def __init__(self, n_inputs, n_outputs, n_features=128, blocks=2): 160 | """ 161 | Args: 162 | n_inputs: Number of input channels. 163 | n_outputs: The number of output channels, 164 | n_features: The number of features in the xception blocks. 165 | blocks: The number of blocks per stage 166 | """ 167 | super().__init__() 168 | 169 | if isinstance(blocks, int): 170 | blocks = [blocks] * 5 171 | 172 | self.in_block = nn.Conv2d(n_inputs, n_features, 1) 173 | 174 | self.down_block_2 = DownsamplingBlock(n_features, blocks[0]) 175 | self.down_block_4 = DownsamplingBlock(n_features, blocks[1]) 176 | self.down_block_8 = DownsamplingBlock(n_features, blocks[2]) 177 | self.down_block_16 = DownsamplingBlock(n_features, blocks[3]) 178 | self.down_block_32 = DownsamplingBlock(n_features, blocks[4]) 179 | 180 | self.up_block_16 = UpsamplingBlock(n_features) 181 | self.up_block_8 = UpsamplingBlock(n_features) 182 | self.up_block_4 = UpsamplingBlock(n_features) 183 | self.up_block_2 = UpsamplingBlock(n_features) 184 | self.up_block = UpsamplingBlock(n_features) 185 | 186 | self.head = nn.Sequential( 187 | nn.Conv2d(2 * n_features, n_features, 1), 188 | nn.GroupNorm(1, n_features), 189 | nn.GELU(), 190 | nn.Conv2d(n_features, n_features, 1), 191 | nn.GroupNorm(1, n_features), 192 | nn.GELU(), 193 | nn.Conv2d(n_features, n_outputs, 1), 194 | ) 195 | 196 | def forward(self, x): 197 | """ 198 | Propagate input through block. 199 | """ 200 | x_in = self.in_block(x) 201 | x_2 = self.down_block_2(x_in) 202 | x_4 = self.down_block_4(x_2) 203 | x_8 = self.down_block_8(x_4) 204 | x_16 = self.down_block_16(x_8) 205 | x_32 = self.down_block_32(x_16) 206 | 207 | x_16_u = self.up_block_16(x_32, x_16) 208 | x_8_u = self.up_block_8(x_16_u, x_8) 209 | x_4_u = self.up_block_4(x_8_u, x_4) 210 | x_2_u = self.up_block_2(x_4_u, x_2) 211 | x_u = self.up_block(x_2_u, x_in) 212 | 213 | return self.head(torch.cat([x_in, x_u], 1)) 214 | -------------------------------------------------------------------------------- /quantnn/plotting.py: -------------------------------------------------------------------------------- 1 | """ 2 | ================ 3 | quantnn.plotting 4 | ================ 5 | 6 | The plotting module provides some utility function for plotting QRNN results. 7 | """ 8 | from copy import copy 9 | import pathlib 10 | 11 | from matplotlib import rc 12 | import matplotlib as mpl 13 | import matplotlib.pyplot as plt 14 | from matplotlib.colors import to_rgba 15 | 16 | _STYLE_FILE = pathlib.Path(__file__).parent / "data" / "matplotlib_style.rc" 17 | 18 | 19 | def set_style(latex=False): 20 | """ 21 | Sets matplotlib style to a style file that I find visually more pleasing 22 | then the default settings. 23 | 24 | Args: 25 | latex: Whether or not to use latex to render text. 26 | """ 27 | plt.style.use(str(_STYLE_FILE)) 28 | rc("text", usetex=latex) 29 | 30 | 31 | def plot_confidence_intervals(ax, x, y_pred, quantiles, color="C0"): 32 | """ 33 | Plots symmetric confidence intervals using transparency to display uncertainty. 34 | 35 | This function plots a 1-dimensional sequence of predicts quantiles as confidence 36 | intervals. The intervals are displayed as filled regions with the transparency 37 | set according to the corresponding uncertainty. 38 | 39 | Arguments: 40 | ax: Matplotlib axes instance to use for plotting. 41 | x: The x values corresponding to the prediction in y_pred. 42 | y_pred: 2D array containing the predicted quantiles. 43 | quantiles: The quantiles corresponding to the second axis of y_pred. 44 | color: The color to use for filling. 45 | """ 46 | n = y_pred.shape[1] 47 | if n % 2: 48 | c_0 = y_pred[:, n // 2] 49 | else: 50 | c_0 = 0.5 * (y_pred[:, n // 2] + y_pred[:, n // 2 + 1]) 51 | alpha_0 = 0.9 52 | alpha_min = 0.1 53 | d_alpha = alpha_0 - alpha_min 54 | 55 | c = c_0 56 | for i in range(n // 2 - 1, -1, -1): 57 | q = quantiles[i] 58 | alpha = alpha_0 - 2.0 * d_alpha * (0.5 - q) 59 | color = to_rgba(color, alpha) 60 | ax.fill_between(x, c, y_pred[:, i], edgecolor=None, facecolor=color) 61 | 62 | c = c_0 63 | for i in range(n // 2 + 1, n): 64 | q = quantiles[i] 65 | alpha = alpha_0 - 2.0 * d_alpha * (q - 0.5) 66 | color = to_rgba(color, alpha) 67 | ax.fill_between(x, c, y_pred[:, i], edgecolor=None, facecolor=color) 68 | 69 | 70 | def plot_quantiles(ax, x, y_pred, quantiles, cmap="magma"): 71 | """ 72 | Plots symmetric confidence intervals using transparency to display uncertainty. 73 | 74 | This function plots a 1-dimensional sequence of predicts quantiles as confidence 75 | intervals. The intervals are displayed as filled regions with the transparency 76 | set according to the corresponding uncertainty. 77 | 78 | Arguments: 79 | ax: Matplotlib axes instance to use for plotting. 80 | x: The x values corresponding to the prediction in y_pred. 81 | y_pred: 2D array containing the predicted quantiles. 82 | quantiles: The quantiles corresponding to the second axis of y_pred. 83 | color: The color to use for filling. 84 | """ 85 | cmap = copy(mpl.cm.get_cmap(cmap)) 86 | norm = mpl.colors.BoundaryNorm(quantiles, cmap.N) 87 | mappable = mpl.cm.ScalarMappable(norm=norm, cmap=cmap) 88 | cmap.set_under([0.0] * 4) 89 | cmap.set_over([0.0] * 4) 90 | 91 | n = len(quantiles) 92 | for i in range(n - 1): 93 | q = 0.5 * (quantiles[i] + quantiles[i + 1]) 94 | color = mappable.to_rgba(q) 95 | y_low = y_pred[:, i] 96 | y_high = y_pred[:, i + 1] 97 | ax.fill_between(x, y_low, y_high, edgecolor=None, facecolor=color) 98 | 99 | return mappable 100 | -------------------------------------------------------------------------------- /quantnn/transformations.py: -------------------------------------------------------------------------------- 1 | """ 2 | ======================= 3 | quantnn.transformations 4 | ======================= 5 | 6 | This module defines transformations that can be applied to train 7 | a network in a transformed space but evaluate it in the original 8 | space. 9 | """ 10 | import numpy as np 11 | 12 | from quantnn.backends import get_tensor_backend 13 | 14 | 15 | class Log10: 16 | """ 17 | Transforms values to log space. 18 | """ 19 | 20 | def __init__(self): 21 | self.xp = None 22 | 23 | def __call__(self, x): 24 | """ 25 | Transform tensor. 26 | 27 | Args: 28 | x: Tensor containing the values to transform. 29 | 30 | Return: 31 | Tensor containing the transformed values. 32 | 33 | """ 34 | if self.xp is None: 35 | xp = get_tensor_backend(x) 36 | self.xp = xp 37 | else: 38 | xp = self.xp 39 | return xp.log(x.double()).float() / np.log(10) 40 | 41 | def invert(self, y): 42 | """ 43 | Transform transformed values back to original space. 44 | 45 | Args: 46 | y: Tensor containing the transformed values to transform 47 | back. 48 | 49 | Returns: 50 | Tensor containing the original values. 51 | """ 52 | if self.xp is None: 53 | xp = get_tensor_backend(y) 54 | self.xp = xp 55 | else: 56 | xp = self.xp 57 | return xp.exp(np.log(10) * y.double()).float() 58 | 59 | 60 | class Log: 61 | """ 62 | Transforms values to log space. 63 | """ 64 | 65 | def __init__(self): 66 | self.xp = None 67 | 68 | def __call__(self, x): 69 | """ 70 | Transform tensor. 71 | 72 | Args: 73 | x: Tensor containing the values to transform. 74 | 75 | Return: 76 | Tensor containing the transformed values. 77 | 78 | """ 79 | if self.xp is None: 80 | xp = get_tensor_backend(x) 81 | self.xp = xp 82 | else: 83 | xp = self.xp 84 | return xp.log(x.double()).float() 85 | 86 | def invert(self, y): 87 | """ 88 | Transform transformed values back to original space. 89 | 90 | Args: 91 | y: Tensor containing the transformed values to transform 92 | back. 93 | 94 | Returns: 95 | Tensor containing the original values. 96 | """ 97 | 98 | if self.xp is None: 99 | xp = get_tensor_backend(y) 100 | self.xp = xp 101 | else: 102 | xp = self.xp 103 | return xp.exp(y.double()).float() 104 | 105 | 106 | class Softplus: 107 | """ 108 | Applies softplus transform to values. 109 | """ 110 | 111 | def __init__(self): 112 | self.xp = None 113 | 114 | def __call__(self, x): 115 | """ 116 | Transform tensor. 117 | 118 | Args: 119 | x: Tensor containing the values to transform. 120 | 121 | Return: 122 | Tensor containing the transformed values. 123 | 124 | """ 125 | if self.xp is None: 126 | xp = get_tensor_backend(x) 127 | self.xp = xp 128 | else: 129 | xp = self.xp 130 | 131 | return xp.where(x > 10, x, xp.log(xp.exp(x) - 1.0 + 1e-30)) 132 | 133 | def invert(self, y): 134 | """ 135 | Transform transformed values back to original space. 136 | 137 | Args: 138 | y: Tensor containing the transformed values to transform 139 | back. 140 | 141 | Returns: 142 | Tensor containing the original values. 143 | """ 144 | if self.xp is None: 145 | xp = get_tensor_backend(y) 146 | self.xp = xp 147 | else: 148 | xp = self.xp 149 | return xp.where(y > 10, y, xp.log(xp.exp(y) + 1.0)) 150 | 151 | 152 | class LogLinear: 153 | """ 154 | Composition of natural logarithm transformation and for 155 | x > 1 and identity transformation for x <= 1. 156 | """ 157 | 158 | def __init__(self): 159 | self.xp = None 160 | 161 | def __call__(self, x): 162 | """ 163 | Transform tensor. 164 | 165 | Args: 166 | x: Tensor containing the values to transform. 167 | 168 | Return: 169 | Tensor containing the transformed values. 170 | 171 | """ 172 | if self.xp is None: 173 | xp = get_tensor_backend(x) 174 | self.xp = xp 175 | else: 176 | xp = self.xp 177 | 178 | return xp.where(x > 1, x - 1.0, xp.log(x)) 179 | 180 | def invert(self, y): 181 | """ 182 | Transform transformed values back to original space. 183 | 184 | Args: 185 | y: Tensor containing the transformed values to transform 186 | back. 187 | 188 | Returns: 189 | Tensor containing the original values. 190 | """ 191 | 192 | if self.xp is None: 193 | xp = get_tensor_backend(y) 194 | self.xp = xp 195 | else: 196 | xp = self.xp 197 | return xp.where(y > 0, y + 1.0, xp.exp(y)) 198 | 199 | 200 | class Id: 201 | """ 202 | Identity transform for testing. 203 | """ 204 | 205 | def __init__(self): 206 | self.xp = None 207 | 208 | def __call__(self, x): 209 | """ 210 | Transform tensor. 211 | 212 | Args: 213 | x: Tensor containing the values to transform. 214 | 215 | Return: 216 | Tensor containing the transformed values. 217 | 218 | """ 219 | return x 220 | 221 | def invert(self, y): 222 | """ 223 | Transform transformed values back to original space. 224 | 225 | Args: 226 | y: Tensor containing the transformed values to transform 227 | back. 228 | 229 | Returns: 230 | Tensor containing the original values. 231 | 232 | """ 233 | return y 234 | -------------------------------------------------------------------------------- /quantnn/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | ============= 3 | quantnn.utils 4 | ============= 5 | 6 | This module providers Helper functions that are used in multiple other modules. 7 | """ 8 | import io 9 | from pathlib import Path 10 | from tempfile import NamedTemporaryFile 11 | 12 | import xarray as xr 13 | 14 | 15 | def apply(f, *args): 16 | """ 17 | Applies a function to sequence values or dicts of values. 18 | 19 | Args: 20 | f: The function to apply to ``x`` or all items in ``x``. 21 | *args: Sequence of arguments to be supplied to ``f``. If all arguments 22 | are dicts, the function ``f`` is applied key-wise to all elements 23 | in the dict. Otherwise the function is applied to all provided 24 | argument.s 25 | 26 | Return: 27 | ``{k: f(x_1[k], x_1[k], ...) for k in x}`` or ``f(x)`` depending on 28 | whether ``x_1, ...`` are a dicts or not. 29 | """ 30 | if any(isinstance(x, dict) for x in args): 31 | results = {} 32 | d = [x for x in args if isinstance(x, dict)][0] 33 | for k in d: 34 | args_k = [arg[k] if isinstance(arg, dict) else arg 35 | for arg in args] 36 | results[k] = f(*args_k) 37 | return results 38 | return f(*args) 39 | 40 | 41 | def serialize_dataset(dataset): 42 | """ 43 | Writes xarray dataset to a bytestream. 44 | 45 | Args: 46 | dataset: A xarray dataset to seraialize. 47 | 48 | Returns: 49 | Bytes object containing the dataset as netcdf file. 50 | """ 51 | tmp = NamedTemporaryFile(delete=False) 52 | tmp.close() 53 | filename = tmp.name 54 | try: 55 | dataset.to_netcdf(filename) 56 | with open(filename, "rb") as file: 57 | buffer = file.read() 58 | finally: 59 | Path(filename).unlink() 60 | return buffer 61 | 62 | 63 | def deserialize_dataset(data): 64 | """ 65 | Read xarray dataset from byte stream containing the 66 | dataset in NetCDF format. 67 | 68 | Args: 69 | data: The bytes object containing the binary data of the 70 | NetCDf file. 71 | 72 | Returns: 73 | The deserialized xarray dataset. 74 | """ 75 | tmp = NamedTemporaryFile(delete=False) 76 | tmp.close() 77 | filename = tmp.name 78 | try: 79 | with open(filename, "wb") as file: 80 | buffer = file.write(data) 81 | dataset = xr.load_dataset(filename) 82 | finally: 83 | Path(filename).unlink() 84 | return dataset 85 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="quantnn", 8 | version="0.0.5", 9 | author="Simon Pfreundschuh", 10 | description="Quantile regression neural networks.", 11 | long_description=long_description, 12 | long_description_content_type="text/markdown", 13 | url="https://github.com/simonpf/quantnn", 14 | packages=setuptools.find_packages(), 15 | classifiers=[ 16 | "Programming Language :: Python :: 3", 17 | "Operating System :: OS Independent", 18 | ], 19 | install_requires=[ 20 | "numpy", 21 | "scipy", 22 | "xarray", 23 | "paramiko", 24 | "einops", 25 | "matplotlib", 26 | "rich"], 27 | python_requires=">=3.6", 28 | include_package_data=True, 29 | ) 30 | -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains fixtures that are automatically available in all test files. 3 | """ 4 | import pytest 5 | import numpy 6 | 7 | BACKENDS = [numpy] 8 | 9 | #try: 10 | # import tensorflow as tf 11 | # BACKENDS.append(tf) 12 | #except ModuleNotFoundError: 13 | # pass 14 | 15 | try: 16 | import torch 17 | BACKENDS += [torch] 18 | except ModuleNotFoundError: 19 | pass 20 | 21 | try: 22 | import jax.numpy as jnp 23 | BACKENDS.append(jnp) 24 | except ModuleNotFoundError: 25 | pass 26 | 27 | def pytest_configure(): 28 | pytest.backends = BACKENDS 29 | 30 | -------------------------------------------------------------------------------- /test/files/test_files.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test for the generic function in the :py:mod:`quantnn.files module`. 3 | """ 4 | import os 5 | 6 | import pytest 7 | import numpy as np 8 | from quantnn.files import read_file, CachedDataFolder 9 | from concurrent.futures import ThreadPoolExecutor 10 | 11 | # Currently no SFTP test data available. 12 | HAS_LOGIN_INFO = False 13 | 14 | def test_local_file(tmp_path): 15 | """ 16 | Ensures that opening a local file works. 17 | """ 18 | with open(tmp_path / "test.txt", "w") as file: 19 | file.write("test") 20 | 21 | with read_file(tmp_path / "test.txt") as file: 22 | content = file.read() 23 | 24 | assert content == "test" 25 | 26 | 27 | def test_local_folder(tmp_path): 28 | """ 29 | Ensures that opening a local file works. 30 | """ 31 | with open(tmp_path / "test.txt", "w") as file: 32 | file.write("test") 33 | 34 | data_folder = CachedDataFolder(tmp_path, "*.txt") 35 | 36 | assert len(data_folder.files) == 1 37 | 38 | file = data_folder.open(data_folder.files[0]) 39 | assert file.read() == "test" 40 | 41 | 42 | @pytest.mark.skipif(not HAS_LOGIN_INFO, reason="No SFTP login info.") 43 | def test_remote_file(): 44 | """ 45 | Ensures that opening a local file works. 46 | """ 47 | host = "129.16.35.202" 48 | path = "/mnt/array1/share/MLDatasets/test/data_0.npz" 49 | with read_file("sftp://" + host + path, "rb") as file: 50 | data = np.load(file) 51 | 52 | assert np.all(np.isclose(data["x"], 0.0)) 53 | 54 | 55 | @pytest.mark.skipif(not HAS_LOGIN_INFO, reason="No SFTP login info.") 56 | def test_remote_folder(tmp_path): 57 | """ 58 | Ensures that opening a local file works. 59 | """ 60 | host = "129.16.35.202" 61 | path = "/mnt/array1/share/MLDatasets/test/" 62 | 63 | data_folder = CachedDataFolder("sftp://" + host + path, "*.npz") 64 | 65 | assert len(data_folder.files) == 8 66 | data_folder.files.sort() 67 | 68 | file = data_folder.get(data_folder.files[0]) 69 | data = np.load(file) 70 | assert np.all(np.isclose(data["x"], 0.0)) 71 | 72 | pool = ThreadPoolExecutor(max_workers=4) 73 | data_folder.download(pool) 74 | 75 | file = data_folder.get(data_folder.files[0]) 76 | data = np.load(file) 77 | assert np.all(np.isclose(data["x"], 0.0)) 78 | 79 | 80 | -------------------------------------------------------------------------------- /test/files/test_sftp.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pytest 5 | from quantnn.files import sftp 6 | 7 | # Currently no SFTP test data available. 8 | HAS_LOGIN_INFO = False 9 | 10 | @pytest.mark.skipif(not HAS_LOGIN_INFO, reason="No SFTP login info.") 11 | def test_list_files(): 12 | host = "129.16.35.202" 13 | path = "/mnt/array1/share/MLDatasets/test" 14 | files = sftp.list_files(host, path) 15 | assert len(files) == 8 16 | 17 | @pytest.mark.skipif(not HAS_LOGIN_INFO, reason="No SFTP login info.") 18 | def test_download_file(): 19 | """ 20 | Ensure that downloading of files work and the data is cleaned up after 21 | usage. 22 | """ 23 | host = "129.16.35.202" 24 | path = "/mnt/array1/share/MLDatasets/test/data_0.npz" 25 | tmp_file = None 26 | with sftp.download_file(host, path) as file: 27 | tmp_file = file 28 | data = np.load(file) 29 | assert np.all(np.isclose(data["x"], 0.0)) 30 | 31 | assert not tmp_file.exists() 32 | 33 | @pytest.mark.skipif(not HAS_LOGIN_INFO, reason="No SFTP login info.") 34 | def test_sftp_cache(): 35 | """ 36 | Ensure that downloading of files work and the data is cleaned up after 37 | usage. 38 | """ 39 | host = "129.16.35.202" 40 | path = "/mnt/array1/share/MLDatasets/test/data_0.npz" 41 | 42 | cache = sftp.SFTPCache() 43 | file = cache.get(host, path) 44 | data = np.load(file, allow_pickle=True) 45 | assert np.all(np.isclose(data["x"], 0.0)) 46 | -------------------------------------------------------------------------------- /test/models/pytorch/test_aggregators.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | from quantnn.packed_tensor import PackedTensor 5 | from quantnn.models.pytorch.aggregators import ( 6 | AverageAggregatorFactory, 7 | BlockAggregatorFactory, 8 | SparseAggregator, 9 | SumAggregatorFactory, 10 | LinearAggregatorFactory, 11 | AttentionFusion 12 | ) 13 | from quantnn.models.pytorch.torchvision import ResNetBlockFactory 14 | 15 | 16 | def fill_tensor(t, indices): 17 | """ 18 | Fills tensor with corresponding sample indices. 19 | """ 20 | for i, ind in enumerate(indices): 21 | t[i] = ind 22 | return t 23 | 24 | 25 | def make_random_packed_tensor(batch_size, samples, shape): 26 | """ 27 | Create a sparse tensor representing a training batch with 28 | missing samples. Which samples are missing is randomized. 29 | The elements of the tensor correspond to the sample index. 30 | 31 | Args: 32 | batch_size: The nominal batch size of the training batch. 33 | samples: The number of non-missing samples. 34 | shape: The of each sample in the training batch. 35 | """ 36 | indices = sorted( 37 | np.random.choice(np.arange(batch_size), size=samples, replace=False) 38 | ) 39 | t = np.ones((samples,) + shape, dtype=np.float32) 40 | t = fill_tensor(t, indices) 41 | return PackedTensor(t, batch_size, indices) 42 | 43 | 44 | def test_sparse_aggregator(): 45 | """ 46 | Tests the sparse aggregator by ensuring that sparse inputs are correctly 47 | combined. 48 | """ 49 | aggregator_factory = AverageAggregatorFactory() 50 | aggregator = SparseAggregator((8, 8, 8), 8, aggregator_factory) 51 | 52 | # Ensure full tensor is returned if only one of the provided 53 | # tensors is sparse. 54 | x_1 = torch.ones((100, 8, 32, 32), dtype=torch.float32) 55 | fill_tensor(x_1, range(100)) 56 | x_2 = torch.ones((100, 8, 32, 32), dtype=torch.float32) 57 | fill_tensor(x_2, range(100)) 58 | x_3 = make_random_packed_tensor(100, 50, (8, 32, 32)) 59 | y = aggregator(x_1, x_2, x_3) 60 | assert not isinstance(y, PackedTensor) 61 | assert torch.isclose(y, x_1).all() 62 | 63 | x_2 = torch.ones((100, 8, 32, 32), dtype=torch.float32) 64 | fill_tensor(x_2, range(100)) 65 | x_1 = make_random_packed_tensor(100, 50, (8, 32, 32)) 66 | y = aggregator(x_1, x_2) 67 | assert not isinstance(y, PackedTensor) 68 | assert torch.isclose(y, x_2).all() 69 | 70 | # Make sure merging works with two packed tensors. 71 | x_1 = make_random_packed_tensor(100, 50, (8, 32, 32)) 72 | x_2 = make_random_packed_tensor(100, 50, (8, 32, 32)) 73 | y = aggregator(x_1, x_2) 74 | assert isinstance(y, PackedTensor) 75 | batch_indices_union = sorted(list(set(x_1.batch_indices + x_2.batch_indices))) 76 | assert y.batch_indices == batch_indices_union 77 | for ind, batch_ind in enumerate(y.batch_indices): 78 | assert torch.isclose(y._t[ind], batch_ind * torch.ones(1, 1, 1)).all() 79 | 80 | # Test aggregation with linear layer. 81 | aggregator_factory = LinearAggregatorFactory() 82 | aggregator = SparseAggregator((8, 8), 8, aggregator_factory) 83 | 84 | # Ensure full tensor is returned if only one of the provided 85 | # tensors is sparse. 86 | x_1 = torch.ones((100, 8, 32, 32), dtype=torch.float32) 87 | fill_tensor(x_1, range(100)) 88 | x_2 = make_random_packed_tensor(100, 50, (8, 32, 32)) 89 | y = aggregator(x_1, x_2) 90 | assert not isinstance(y, PackedTensor) 91 | 92 | x_2 = torch.ones((100, 8, 32, 32), dtype=torch.float32) 93 | fill_tensor(x_2, range(100)) 94 | x_1 = make_random_packed_tensor(100, 50, (8, 32, 32)) 95 | y = aggregator(x_1, x_2) 96 | assert not isinstance(y, PackedTensor) 97 | 98 | # Make sure merging works with two packed tensors. 99 | x_1 = make_random_packed_tensor(100, 50, (8, 32, 32)) 100 | x_2 = make_random_packed_tensor(100, 50, (8, 32, 32)) 101 | y = aggregator(x_1, x_2) 102 | assert isinstance(y, PackedTensor) 103 | batch_indices_union = sorted(list(set(x_1.batch_indices + x_2.batch_indices))) 104 | assert y.batch_indices == batch_indices_union 105 | 106 | y = aggregator(x_1, None) 107 | assert isinstance(y, PackedTensor) 108 | 109 | 110 | AGGREGATORS = [ 111 | SumAggregatorFactory(), 112 | AverageAggregatorFactory(), 113 | LinearAggregatorFactory(), 114 | BlockAggregatorFactory(ResNetBlockFactory()), 115 | ] 116 | 117 | 118 | @pytest.mark.parametrize("aggregator", AGGREGATORS) 119 | def test_aggregators(aggregator): 120 | a = torch.ones(1, 10, 16, 16) 121 | b = torch.ones(1, 10, 16, 16) 122 | agg = aggregator((10, 10), 10) 123 | c = agg(a, b) 124 | assert c.shape == (1, 10, 16, 16) 125 | 126 | c = torch.ones(1, 10, 16, 16) 127 | agg = aggregator((10, 10, 10), 10) 128 | c = agg(a, b, c) 129 | assert c.shape == (1, 10, 16, 16) 130 | 131 | 132 | def test_attention_fusion(): 133 | 134 | x = torch.ones(1, 10, 32, 32) 135 | y = torch.ones(1, 20, 32, 32) 136 | z = torch.ones(1, 5, 32, 32) 137 | 138 | mha = AttentionFusion((10, 20, 5), 16, 8) 139 | merged = mha(x, y, z) 140 | 141 | assert merged.shape == (1, 16, 32, 32) 142 | -------------------------------------------------------------------------------- /test/models/pytorch/test_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the quantnn.models.pytorch.base module. 3 | """ 4 | from torch import nn 5 | 6 | from quantnn.models.pytorch.base import ParamCount 7 | 8 | 9 | class ConvModule(nn.Conv2d, ParamCount): 10 | """ 11 | Simple wrapper class that adds ParamCount mixin to 12 | a Conv2d module. 13 | """ 14 | 15 | def __init__( 16 | self, channels_in: int, channels_out: int, kernel_size: int, bias: bool = False 17 | ): 18 | nn.Conv2d.__init__( 19 | self, channels_in, channels_out, kernel_size=kernel_size, bias=bias 20 | ) 21 | 22 | 23 | def test_n_params(): 24 | """ 25 | Ensure that n_params returns the right number of parameters for 26 | a 2D convolution layer. 27 | """ 28 | conv = ConvModule(16, 16, 3, bias=True) 29 | assert conv.n_params == 16 * 16 * 3 * 3 + 16 30 | 31 | conv = ConvModule(16, 16, 3, bias=False) 32 | assert conv.n_params == 16 * 16 * 3 * 3 33 | -------------------------------------------------------------------------------- /test/models/pytorch/test_blocks.py: -------------------------------------------------------------------------------- 1 | from quantnn.models.pytorch.blocks import ( 2 | ConvBlockFactory, 3 | ConvTransposedBlockFactory, 4 | ResNeXtBlockFactory, 5 | ConvNextBlockFactory 6 | ) 7 | 8 | import torch 9 | from torch import nn 10 | 11 | 12 | def test_conv_block_factory(): 13 | """ 14 | Ensures that the basic ConvBlockFactory produces a working 15 | block. 16 | """ 17 | block_factory = ConvBlockFactory(kernel_size=3) 18 | block = block_factory(8, 16) 19 | input = torch.ones(8, 8, 8, 8) 20 | output = block(input) 21 | assert output.shape == (8, 16, 8, 8) 22 | 23 | block_factory = ConvBlockFactory( 24 | kernel_size=3, 25 | norm_factory=nn.BatchNorm2d, 26 | activation_factory=nn.GELU 27 | ) 28 | 29 | block = block_factory(8, 16) 30 | input = torch.ones(8, 8, 8, 8) 31 | output = block(input) 32 | assert output.shape == (8, 16, 8, 8) 33 | 34 | block = block_factory(8, 16, downsample=(1, 2)) 35 | input = torch.ones(8, 8, 8, 8) 36 | output = block(input) 37 | assert output.shape == (8, 16, 8, 4) 38 | 39 | 40 | def test_conv_transposed_block_factory(): 41 | """ 42 | Ensures that the basic ConvBlockFactory produces a working 43 | block. 44 | """ 45 | block_factory = ConvTransposedBlockFactory(kernel_size=3) 46 | block = block_factory(8, 16) 47 | input = torch.ones(8, 8, 8, 8) 48 | output = block(input) 49 | assert output.shape == (8, 16, 8, 8) 50 | 51 | block_factory = ConvTransposedBlockFactory( 52 | kernel_size=3, 53 | norm_factory=nn.BatchNorm2d, 54 | activation_factory=nn.GELU 55 | ) 56 | block = block_factory(8, 16) 57 | input = torch.ones(8, 8, 8, 8) 58 | output = block(input) 59 | assert output.shape == (8, 16, 8, 8) 60 | 61 | block = block_factory(8, 16, downsample=(1, 2)) 62 | input = torch.ones(8, 8, 8, 8) 63 | output = block(input) 64 | assert output.shape == (8, 16, 8, 15) 65 | 66 | def test_resnext_block(): 67 | """ 68 | Ensure that the ResNext factory produces an nn.Module and that 69 | the output has the specified number of channels. 70 | """ 71 | x = torch.ones((1, 1, 8, 8)) 72 | 73 | factory = ResNeXtBlockFactory() 74 | block = factory(1, 64) 75 | y = block(x) 76 | assert y.shape == (1, 64, 8, 8) 77 | 78 | block = factory(1, 64, downsample=2) 79 | y = block(x) 80 | assert y.shape == (1, 64, 4, 4) 81 | 82 | block = factory(8, 64, downsample=(1, 2)) 83 | input = torch.ones(8, 8, 8, 8) 84 | output = block(input) 85 | assert output.shape == (8, 64, 8, 4) 86 | 87 | 88 | def test_convnext_block(): 89 | """ 90 | Ensure that the ConvNext factory produces an nn.Module and that 91 | the output has the specified number of channels. 92 | """ 93 | x = torch.ones((1, 1, 8, 8)) 94 | 95 | # ConvNext V1 96 | factory = ConvNextBlockFactory(version=1) 97 | block = factory(1, 64) 98 | y = block(x) 99 | assert y.shape == (1, 64, 8, 8) 100 | block = factory(1, 64, downsample=2) 101 | y = block(x) 102 | assert y.shape == (1, 64, 4, 4) 103 | 104 | # ConvNext V2 105 | factory = ConvNextBlockFactory(version=2) 106 | block = factory(1, 64) 107 | y = block(x) 108 | assert y.shape == (1, 64, 8, 8) 109 | block = factory(1, 64, downsample=2) 110 | y = block(x) 111 | assert y.shape == (1, 64, 4, 4) 112 | -------------------------------------------------------------------------------- /test/models/pytorch/test_decoders.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for quantnn.models.pytorch.decoders. 3 | """ 4 | import pytest 5 | 6 | import torch 7 | from quantnn.packed_tensor import PackedTensor 8 | from quantnn.models.pytorch.encoders import ( 9 | SpatialEncoder, 10 | MultiInputSpatialEncoder 11 | ) 12 | from quantnn.models.pytorch.decoders import ( 13 | SpatialDecoder, 14 | SparseSpatialDecoder, 15 | DLADecoderStage, 16 | DLADecoder 17 | ) 18 | from quantnn.models.pytorch.torchvision import ResNetBlockFactory 19 | from quantnn.models.pytorch.aggregators import ( 20 | LinearAggregatorFactory, 21 | SparseAggregatorFactory, 22 | BlockAggregatorFactory 23 | ) 24 | from quantnn.models.pytorch.upsampling import BilinearFactory 25 | 26 | 27 | def test_spatial_decoder(): 28 | """ 29 | Test that chaining an encoder and corresponding decoder reproduces 30 | output of the same spatial dimensions as the input. 31 | """ 32 | block_factory = ResNetBlockFactory() 33 | encoder = SpatialEncoder( 34 | channels=1, 35 | stages=[4] * 4, 36 | channel_scaling=2, 37 | max_channels=8, 38 | block_factory=block_factory, 39 | ) 40 | decoder = SpatialDecoder( 41 | channels=1, 42 | stages=[4] * 3, 43 | channel_scaling=2, 44 | max_channels=8, 45 | block_factory=block_factory, 46 | skip_connections=False 47 | ) 48 | # Test forward without skip connections. 49 | x = torch.ones((1, 1, 32, 32)) 50 | y = decoder(encoder(x)) 51 | 52 | # Shape of y should be same as before. 53 | assert y.shape == (1, 1, 32, 32) 54 | 55 | # 56 | # Test asymmetric decoder 57 | # 58 | decoder = SpatialDecoder( 59 | channels=[8, 1, 1], 60 | stages=[4] * 2, 61 | channel_scaling=2, 62 | max_channels=8, 63 | block_factory=block_factory, 64 | skip_connections=encoder.skip_connections, 65 | ) 66 | # Test forward width skips returned. 67 | y = decoder(encoder(x, return_skips=True)) 68 | # Size should be less than input. 69 | assert y.shape == (1, 1, 16, 16) 70 | 71 | encoder = SpatialEncoder( 72 | channels=1, 73 | stages=[4] * 4, 74 | channel_scaling=2, 75 | max_channels=8, 76 | block_factory=block_factory, 77 | downsampling_factors=[2, 2, 4] 78 | ) 79 | decoder = SpatialDecoder( 80 | channels=1, 81 | stages=[4] * 3, 82 | channel_scaling=2, 83 | max_channels=8, 84 | block_factory=block_factory, 85 | skip_connections=False, 86 | upsampling_factors=[4, 2, 2], 87 | ) 88 | # Test forward without skip connections. 89 | x = torch.ones((1, 1, 128, 128)) 90 | y = decoder(encoder(x)) 91 | # Width and height should be reduced by 16. 92 | # Number of channels should be maximum. 93 | assert y.shape == (1, 1, 128, 128) 94 | 95 | # Test decoder with different channel config. 96 | decoder = SpatialDecoder( 97 | channels=[8, 2, 2, 2], 98 | stages=[4] * 3, 99 | channel_scaling=2, 100 | max_channels=8, 101 | block_factory=block_factory, 102 | skip_connections=False, 103 | upsampling_factors=[4, 2, 2], 104 | ) 105 | # Test forward without skip connections. 106 | x = torch.ones((1, 1, 128, 128)) 107 | y = decoder(encoder(x)) 108 | # Width and height should be reduced by 16. 109 | # Number of channels should be maximum. 110 | assert y.shape == (1, 2, 128, 128) 111 | 112 | 113 | def test_encoder_decoder_multi_scale_output(): 114 | """ 115 | Tests the chaining of a multi-input encoder and a decoder with 116 | potentially missing input. 117 | """ 118 | block_factory = ResNetBlockFactory() 119 | aggregator_factory = SparseAggregatorFactory( 120 | LinearAggregatorFactory() 121 | ) 122 | inputs = { 123 | "input_1": (1, 16), 124 | "input_2": (2, 32) 125 | } 126 | encoder = MultiInputSpatialEncoder( 127 | inputs=inputs, 128 | channels=4, 129 | stages=[4] * 4, 130 | block_factory=block_factory, 131 | channel_scaling=2, 132 | max_channels=16, 133 | aggregator_factory=aggregator_factory 134 | ) 135 | 136 | decoder = SparseSpatialDecoder( 137 | channels=4, 138 | stages=[4] * 3, 139 | channel_scaling=2, 140 | max_channels=16, 141 | block_factory=block_factory, 142 | skip_connections=3, 143 | multi_scale_output=16 144 | ) 145 | # Test forward without skip connections. 146 | x = { 147 | "input_1": PackedTensor( 148 | torch.ones((0, 16, 16, 16)), 149 | batch_size=4, 150 | batch_indices=[] 151 | ), 152 | "input_2": PackedTensor( 153 | torch.ones((1, 32, 8, 8)), 154 | batch_size=4, 155 | batch_indices=[0] 156 | ) 157 | } 158 | y = encoder(x, return_skips=True) 159 | y = decoder(y) 160 | assert len(y) == 4 161 | for y_i in y: 162 | assert y_i.shape[1] == 16 163 | assert y[-1].shape[2] == 32 164 | 165 | # 166 | # Ensure using different channels than encoder works. 167 | # 168 | 169 | decoder = SparseSpatialDecoder( 170 | channels=[16, 8, 8, 8], 171 | stages=[4] * 3, 172 | channel_scaling=2, 173 | max_channels=16, 174 | block_factory=block_factory, 175 | skip_connections=encoder.skip_connections, 176 | multi_scale_output=16 177 | ) 178 | x = { 179 | "input_1": PackedTensor( 180 | torch.ones((0, 16, 16, 16)), 181 | batch_size=4, 182 | batch_indices=[] 183 | ), 184 | "input_2": PackedTensor( 185 | torch.ones((1, 32, 8, 8)), 186 | batch_size=4, 187 | batch_indices=[0] 188 | ) 189 | } 190 | y = encoder(x, return_skips=True) 191 | y = decoder(y) 192 | assert y[0].shape[1] == 16 193 | assert y[-1].shape[1] == 16 194 | 195 | # 196 | # Test decoer with more stages than encoder. 197 | # 198 | 199 | decoder = SparseSpatialDecoder( 200 | channels=[16, 8, 8, 8, 2], 201 | stages=[4] * 4, 202 | channel_scaling=2, 203 | max_channels=16, 204 | block_factory=block_factory, 205 | skip_connections=encoder.skip_connections, 206 | base_scale=3 207 | ) 208 | x = { 209 | "input_1": PackedTensor( 210 | torch.ones((0, 16, 16, 16)), 211 | batch_size=4, 212 | batch_indices=[] 213 | ), 214 | "input_2": PackedTensor( 215 | torch.ones((1, 32, 8, 8)), 216 | batch_size=4, 217 | batch_indices=[0] 218 | ) 219 | } 220 | y = encoder(x, return_skips=True) 221 | y = decoder(y) 222 | assert y.shape[1] == 2 223 | assert y.shape[2] == 64 224 | 225 | @pytest.mark.xfail 226 | def test_dla_decoder(): 227 | """ 228 | Test implementation of the DLA decoder stages and full decoder. 229 | """ 230 | x = [ 231 | torch.ones((2, 16, 4, 4)), 232 | torch.ones((2, 8, 16, 16)), 233 | torch.ones((2, 4, 32, 32)), 234 | torch.ones((2, 2, 64, 64)), 235 | ] 236 | 237 | aggregator_factory = BlockAggregatorFactory( 238 | ResNetBlockFactory() 239 | ) 240 | upsampler_factory = BilinearFactory() 241 | 242 | # 243 | # Single stage 244 | # 245 | 246 | decoder = DLADecoderStage( 247 | [16, 8, 4, 2], 248 | [16, 8, 4, 2], 249 | [16, 4, 2, 1], 250 | aggregator_factory, 251 | upsampler_factory 252 | ) 253 | y = decoder(x) 254 | # Output should contain one less tensor than the input. 255 | assert len(y) == len(x) - 1 256 | y[0].shape == (2, 4, 32, 32) 257 | y[1].shape == (2, 8, 16, 16) 258 | y[2].shape == (2, 16, 4, 4) 259 | 260 | # 261 | # Full decoder 262 | # 263 | 264 | decoder = DLADecoder( 265 | [16, 8, 4, 2], 266 | [16, 4, 2, 1], 267 | aggregator_factory, 268 | upsampler_factory 269 | ) 270 | y = decoder(x[::-1]) 271 | # Output should be a single tensor 272 | assert isinstance(y, torch.Tensor) 273 | assert y.shape == (2, 2, 64, 64) 274 | 275 | # 276 | # Test sparse input. 277 | # 278 | aggregator_factory = SparseAggregatorFactory( 279 | BlockAggregatorFactory( 280 | ResNetBlockFactory() 281 | ) 282 | ) 283 | decoder = DLADecoder( 284 | [16, 8, 4, 2], 285 | [16, 4, 2, 1], 286 | aggregator_factory, 287 | upsampler_factory 288 | ) 289 | 290 | x = [ 291 | PackedTensor( 292 | torch.ones((0, 16, 4, 4)), 293 | batch_size=2, 294 | batch_indices=[] 295 | ), 296 | torch.ones((2, 8, 16, 16)), 297 | torch.ones((2, 4, 32, 32)), 298 | PackedTensor( 299 | torch.ones((0, 16, 64, 64)), 300 | batch_size=2, 301 | batch_indices=[] 302 | ) 303 | ] 304 | y = decoder(x[::-1]) 305 | 306 | # Output is, again, a full tensor. 307 | assert isinstance(y, torch.Tensor) 308 | assert y.shape == (2, 2, 64, 64) 309 | 310 | -------------------------------------------------------------------------------- /test/models/pytorch/test_downsampling.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the quantnn.models.pytorch.downsampling module. 3 | """ 4 | import torch 5 | from torch import nn 6 | 7 | from quantnn.models.pytorch.downsampling import ( 8 | ConvNextDownsamplerFactory, 9 | PatchMergingFactory 10 | ) 11 | 12 | 13 | def test_convnext_downsampling(): 14 | """ 15 | Test that ConvNext downsampling works with different downsampling factors 16 | along dimensions. 17 | """ 18 | down = ConvNextDownsamplerFactory()(16, 16, (2, 4)) 19 | x = torch.rand(2, 16, 32, 32) 20 | y = down(x) 21 | 22 | assert y.shape[-2] == 16 23 | assert y.shape[-1] == 8 24 | 25 | 26 | def test_path_merging_downsampling(): 27 | """ 28 | Test PathMerging block for downsampling. 29 | 30 | """ 31 | down = PatchMergingFactory()(16, 16, (2, 2)) 32 | 33 | x = torch.rand(2, 16, 32, 32) 34 | y = down(x) 35 | assert y.shape[-2] == 16 36 | assert y.shape[-1] == 16 37 | -------------------------------------------------------------------------------- /test/models/pytorch/test_fully_connected.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from quantnn.models.pytorch.fully_connected import MLP 4 | 5 | 6 | def test_mlp(): 7 | # 8 | # 2D Input 9 | # 10 | 11 | x = torch.rand(128, 8) 12 | 13 | # No residual connections. 14 | mlp = MLP( 15 | features_in=8, 16 | n_features=128, 17 | features_out=16, 18 | n_layers=4, 19 | ) 20 | y = mlp(x) 21 | assert y.shape == (128, 16) 22 | 23 | # Standard residual connections. 24 | mlp = MLP( 25 | features_in=8, n_features=128, features_out=16, n_layers=4, residuals="simple" 26 | ) 27 | y = mlp(x) 28 | assert y.shape == (128, 16) 29 | 30 | # Hyper residual connections. 31 | mlp = MLP( 32 | features_in=8, n_features=128, features_out=16, n_layers=4, residuals="hyper" 33 | ) 34 | y = mlp(x) 35 | assert y.shape == (128, 16) 36 | 37 | # 38 | # 4D input 39 | # 40 | x = torch.rand(128, 8, 8, 8) 41 | 42 | # No residual connections. 43 | mlp = MLP( 44 | features_in=8, 45 | n_features=128, 46 | features_out=16, 47 | n_layers=4, 48 | ) 49 | y = mlp(x) 50 | assert y.shape == (128, 16, 8, 8) 51 | 52 | # Standard residual connections. 53 | mlp = MLP( 54 | features_in=8, n_features=128, features_out=16, n_layers=4, residuals="simple" 55 | ) 56 | y = mlp(x) 57 | assert y.shape == (128, 16, 8, 8) 58 | 59 | # Hyper residual connections. 60 | mlp = MLP( 61 | features_in=8, n_features=128, features_out=16, n_layers=4, residuals="hyper" 62 | ) 63 | y = mlp(x) 64 | assert y.shape == (128, 16, 8, 8) 65 | 66 | 67 | def test_mlp_output_shape(): 68 | # 69 | # 2D Input 70 | # 71 | 72 | x = torch.rand(128, 8) 73 | 74 | # No residual connections. 75 | mlp = MLP( 76 | features_in=8, n_features=128, features_out=16, n_layers=4, output_shape=(4, 4) 77 | ) 78 | y = mlp(x) 79 | assert y.shape == (128, 4, 4) 80 | 81 | # Standard residual connections. 82 | mlp = MLP( 83 | features_in=8, 84 | n_features=128, 85 | features_out=16, 86 | n_layers=4, 87 | residuals="simple", 88 | output_shape=(4, 4), 89 | ) 90 | y = mlp(x) 91 | assert y.shape == (128, 4, 4) 92 | 93 | # Hyper residual connections. 94 | mlp = MLP( 95 | features_in=8, 96 | n_features=128, 97 | features_out=16, 98 | n_layers=4, 99 | residuals="hyper", 100 | output_shape=(4, 4), 101 | ) 102 | y = mlp(x) 103 | assert y.shape == (128, 4, 4) 104 | 105 | # 106 | # 4D input 107 | # 108 | x = torch.rand(128, 8, 8, 8) 109 | 110 | # No residual connections. 111 | mlp = MLP( 112 | features_in=8, n_features=128, features_out=16, n_layers=4, output_shape=(4, 4) 113 | ) 114 | y = mlp(x) 115 | assert y.shape == (128, 4, 4, 8, 8) 116 | 117 | # Standard residual connections. 118 | mlp = MLP( 119 | features_in=8, 120 | n_features=128, 121 | features_out=16, 122 | n_layers=4, 123 | residuals="simple", 124 | output_shape=(4, 4), 125 | ) 126 | y = mlp(x) 127 | assert y.shape == (128, 4, 4, 8, 8) 128 | 129 | # Hyper residual connections. 130 | mlp = MLP( 131 | features_in=8, 132 | n_features=128, 133 | features_out=16, 134 | n_layers=4, 135 | residuals="hyper", 136 | output_shape=(4, 4), 137 | ) 138 | y = mlp(x) 139 | assert y.shape == (128, 4, 4, 8, 8) 140 | -------------------------------------------------------------------------------- /test/models/pytorch/test_normalization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the quantnn.models.pytorch.normalization module. 3 | """ 4 | import torch 5 | import numpy as np 6 | 7 | from quantnn.models.pytorch.normalization import ( 8 | LayerNormFirst, 9 | GRN 10 | ) 11 | 12 | 13 | def test_layer_norm_first(): 14 | """ 15 | Assert that layer norm with channels along first dimensions works. 16 | """ 17 | norm = LayerNormFirst(16) 18 | x = torch.rand(10, 16, 24, 24) 19 | y = norm(x) 20 | 21 | mu = y.mean(1).detach().numpy() 22 | assert np.all(np.isclose(mu, 0.0, atol=1e-5)) 23 | 24 | 25 | def test_grn(): 26 | """ 27 | Assert that GRN works. 28 | """ 29 | norm = GRN(16) 30 | x = torch.rand(10, 16, 24, 24) 31 | y = norm(x) 32 | x = x.detach().numpy() 33 | y = y.detach().numpy() 34 | assert np.all(np.isclose(x, y)) 35 | -------------------------------------------------------------------------------- /test/models/pytorch/test_torchvision.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the torchvision wrappers defined in 3 | quantnn.models.pytorch.torchvision. 4 | """ 5 | import pytest 6 | import torch 7 | 8 | import quantnn.models.pytorch.torchvision as tv 9 | try: 10 | import quantnn.models.pytorch.torchvision as tv 11 | 12 | HAS_TORCHVISION = True 13 | except ImportError as e: 14 | HAS_TORCHVISION = False 15 | 16 | 17 | @pytest.mark.skipif(not HAS_TORCHVISION, reason="torchvision not available") 18 | def test_resnet_block(): 19 | """ 20 | Ensure that the ResNet factory produces an nn.Module and that 21 | the output has the specified number of channels. 22 | """ 23 | x = torch.ones((1, 1, 8, 8)) 24 | 25 | factory = tv.ResNetBlockFactory() 26 | block = factory(1, 2) 27 | y = block(x) 28 | assert y.shape == (1, 2, 8, 8) 29 | 30 | block = factory(1, 2, downsample=2) 31 | y = block(x) 32 | assert y.shape == (1, 2, 4, 4) 33 | 34 | block = factory(1, 2, downsample=(1, 2)) 35 | y = block(x) 36 | assert y.shape == (1, 2, 8, 4) 37 | 38 | 39 | @pytest.mark.skipif(not HAS_TORCHVISION, reason="torchvision not available") 40 | def test_convnext_block(): 41 | """ 42 | Ensure that the ConvNeXt factory produces an nn.Module and that 43 | the output has the specified number of channels. 44 | """ 45 | x = torch.ones((1, 1, 8, 8)) 46 | 47 | factory = tv.ConvNeXtBlockFactory() 48 | block = factory(1, 2) 49 | y = block(x) 50 | assert y.shape == (1, 2, 8, 8) 51 | 52 | block = factory(1, 2, downsample=2) 53 | y = block(x) 54 | assert y.shape == (1, 2, 4, 4) 55 | 56 | block = factory(1, 2, downsample=(1, 2)) 57 | y = block(x) 58 | assert y.shape == (1, 2, 8, 4) 59 | 60 | 61 | @pytest.mark.skipif(not HAS_TORCHVISION, reason="torchvision not available") 62 | def test_swin_transformer_block(): 63 | """ 64 | Ensure that the ConvNeXt factory produces an nn.Module and that 65 | the output has the specified number of channels. 66 | """ 67 | x = torch.ones((1, 16, 32, 32)) 68 | 69 | factory = tv.SwinBlockFactory() 70 | block = factory(16, 16) 71 | y = block(x) 72 | assert y.shape == (1, 16, 32, 32) 73 | 74 | block = factory(16, 32, downsample=2) 75 | y = block(x) 76 | assert y.shape == (1, 32, 16, 16) 77 | 78 | block = factory(16, 32, downsample=4) 79 | y = block(x) 80 | assert y.shape == (1, 32, 8, 8) 81 | -------------------------------------------------------------------------------- /test/models/pytorch/test_upsampling.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the quantnn.models.pytorch.upsampling module. 3 | """ 4 | import torch 5 | from torch import nn 6 | 7 | from quantnn.models.pytorch.upsampling import ( 8 | BilinearFactory, 9 | UpsampleFactory, 10 | UpConvolutionFactory, 11 | ) 12 | 13 | 14 | def test_bilinear(): 15 | """ 16 | Ensure bilinear upsampler factory supports upsampling with different 17 | factors along different dimensions. 18 | """ 19 | up = BilinearFactory()(channels_in=16, channels_out=8, factor=(2, 4)) 20 | x = torch.rand(2, 16, 32, 32) 21 | y = up(x) 22 | 23 | assert y.shape[1] == 8 24 | assert y.shape[-2] == 64 25 | assert y.shape[-1] == 128 26 | 27 | 28 | def test_upsample(): 29 | """ 30 | Ensure that generic upsampling factory supports upsampling with different 31 | factors along different dimensions. 32 | """ 33 | fac = UpsampleFactory(mode="nearest") 34 | up = fac(channels_in=16, channels_out=8, factor=(2, 4)) 35 | x = torch.rand(2, 16, 32, 32) 36 | y = up(x) 37 | 38 | assert y.shape[1] == 8 39 | assert y.shape[-2] == 64 40 | assert y.shape[-1] == 128 41 | 42 | 43 | def test_upconvolution(): 44 | """ 45 | Ensure that the up-convolution factory supports upsampling with different 46 | factors along different dimensions. 47 | """ 48 | fac = UpConvolutionFactory(mode="nearest") 49 | up = fac(channels_in=16, channels_out=8, factor=(2, 4)) 50 | x = torch.rand(2, 16, 32, 32) 51 | y = up(x) 52 | 53 | assert y.shape[1] == 8 54 | assert y.shape[-2] == 64 55 | assert y.shape[-1] == 128 56 | -------------------------------------------------------------------------------- /test/models/test_keras.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the PyTorch NN backend. 3 | """ 4 | import tensorflow as tf 5 | from tensorflow import keras 6 | 7 | from quantnn.models.keras import QuantileLoss, CrossEntropyLoss 8 | import numpy as np 9 | from quantnn import (QRNN, 10 | set_default_backend, 11 | get_default_backend) 12 | 13 | def test_quantile_loss(): 14 | """ 15 | Ensure that quantile loss corresponds to half of absolute error 16 | loss and that masking works as expected. 17 | """ 18 | loss = QuantileLoss([0.5], mask=-1e3) 19 | 20 | y_pred = np.random.rand(10, 1, 10) 21 | y = np.random.rand(10, 1, 10) 22 | 23 | l = loss(y, y_pred) 24 | 25 | dy = (y_pred - y) 26 | l_ref = 0.5 * np.mean(np.abs(dy)) 27 | 28 | assert np.isclose(l, l_ref) 29 | 30 | y_pred = np.random.rand(20, 1, 10) 31 | y_pred[10:] = -2e3 32 | y = np.random.rand(20, 1, 10) 33 | y[10:] = -2e3 34 | 35 | loss = QuantileLoss([0.5], mask=-1e3) 36 | l = loss(y, y_pred) 37 | l_ref = loss(y[:10], y_pred[:10]) 38 | assert np.isclose(l, l_ref) 39 | 40 | def test_cross_entropy_loss(): 41 | """ 42 | Test masking for cross entropy loss. 43 | 44 | Need to take into account that Keras, by default, expects channels along last axis. 45 | """ 46 | y_pred = np.random.rand(10, 10, 10).astype(np.float32) 47 | y = np.ones((10, 10, 1), dtype=np.float32) 48 | bins = np.linspace(0, 1, 11) 49 | y[:, :, 0] = 0.55 50 | 51 | loss = CrossEntropyLoss(bins, mask=-1.0) 52 | ref = -y_pred[:, :, 5] + np.log(np.exp(y_pred).sum(1)) 53 | assert np.all(np.isclose(loss(y, y_pred), 54 | ref.mean(), 55 | rtol=1e-3)) 56 | 57 | y[5:, :, :] = -1.0 58 | y[:, 5:, :] = -1.0 59 | ref = -y_pred[:5, :5, 5] + np.log(np.exp(y_pred[:5, :5, :]).sum(-1)) 60 | assert np.all(np.isclose(loss(y, y_pred), 61 | ref.mean())) 62 | 63 | loss = CrossEntropyLoss(10, mask=-1.0) 64 | y = np.ones((10, 10, 1), dtype=np.int32) 65 | y[:, :, :] = 5 66 | ref = -y_pred[:, :, 5] + np.log(np.exp(y_pred).sum(1)) 67 | assert np.all(np.isclose(loss(y, y_pred), 68 | ref.mean(), 69 | rtol=1e-3)) 70 | 71 | y[5:, :, :] = -1.0 72 | y[:, 5:, :] = -1.0 73 | 74 | ref = -y_pred[:5, :5, 5] + np.log(np.exp(y_pred[:5, :5, :]).sum(-1)) 75 | assert np.all(np.isclose(loss(y, y_pred), 76 | ref.mean())) 77 | 78 | # Test binary case. 79 | loss = CrossEntropyLoss(2, mask=-1.0) 80 | y = np.ones((10, 10, 1), dtype=np.int32) 81 | y_pred = np.random.rand(10, 10, 1).astype(np.float32) 82 | y[:, :, :] = 1 83 | ref = tf.math.log_sigmoid(y_pred) 84 | assert np.all(np.isclose( 85 | loss(y, y_pred), 86 | -tf.math.reduce_mean(ref), 87 | rtol=1e-3 88 | )) 89 | 90 | y[5:, :, :] = -1.0 91 | y[:, 5:, :] = -1.0 92 | 93 | ref = -tf.math.log_sigmoid(y_pred[:5, :5]) 94 | assert np.all(np.isclose( 95 | loss(y, y_pred), 96 | tf.math.reduce_mean(ref), 97 | rtol=1e-3 98 | )) 99 | 100 | def test_training_with_dataloader(): 101 | """ 102 | Ensure that training with a pytorch dataloader works. 103 | """ 104 | set_default_backend("keras") 105 | x = np.random.rand(1024, 16) 106 | y = np.random.rand(1024) 107 | 108 | batched_data = [ 109 | { 110 | "x": x[i * 128: (i + 1) * 128], 111 | "y": y[i * 128: (i + 1) * 128], 112 | } 113 | for i in range(1024 // 128) 114 | ] 115 | 116 | qrnn = QRNN(np.linspace(0.05, 0.95, 10), n_inputs=x.shape[1]) 117 | qrnn.train(batched_data, n_epochs=1) 118 | 119 | 120 | def test_training_with_dict(): 121 | """ 122 | Ensure that training with batch objects as dicts works. 123 | """ 124 | set_default_backend("keras") 125 | x = np.random.rand(1024, 16) 126 | y = np.random.rand(1024) 127 | 128 | batched_data = [ 129 | { 130 | "x": x[i * 128: (i + 1) * 128], 131 | "y": y[i * 128: (i + 1) * 128], 132 | } 133 | for i in range(1024 // 128) 134 | ] 135 | 136 | qrnn = QRNN(np.linspace(0.05, 0.95, 10), n_inputs=x.shape[1]) 137 | 138 | qrnn.train(batched_data, n_epochs=1) 139 | 140 | 141 | def test_training_with_dict_and_keys(): 142 | """ 143 | Ensure that training with batch objects as dicts and provided keys 144 | argument works. 145 | """ 146 | set_default_backend("keras") 147 | x = np.random.rand(1024, 16) 148 | y = np.random.rand(1024) 149 | 150 | batched_data = [ 151 | { 152 | "x": x[i * 128: (i + 1) * 128], 153 | "x_2": x[i * 128: (i + 1) * 128], 154 | "y": y[i * 128: (i + 1) * 128], 155 | } 156 | for i in range(1024 // 128) 157 | ] 158 | 159 | qrnn = QRNN(np.linspace(0.05, 0.95, 10), n_inputs=x.shape[1]) 160 | qrnn.train(batched_data, n_epochs=1, keys=("x", "y")) 161 | 162 | 163 | def test_training_multiple_outputs(): 164 | """ 165 | Ensure that training with batch objects as dicts and provided keys 166 | argument works. 167 | """ 168 | set_default_backend("keras") 169 | 170 | class MultipleOutputModel(keras.Model): 171 | def __init__(self): 172 | super().__init__() 173 | self.hidden = keras.layers.Dense(128, "relu", input_shape=(16,)) 174 | self.head_1 = keras.layers.Dense(11, None) 175 | self.head_2 = keras.layers.Dense(11, None) 176 | 177 | def call(self, x): 178 | x = self.hidden(x) 179 | y_1 = self.head_1(x) 180 | y_2 = self.head_2(x) 181 | return { 182 | "y_1": y_1, 183 | "y_2": y_2 184 | } 185 | 186 | x = np.random.rand(1024, 16) 187 | y = np.random.rand(1024) 188 | 189 | batched_data = [ 190 | { 191 | "x": x[i * 128: (i + 1) * 128], 192 | "y": { 193 | "y_1": y[i * 128: (i + 1) * 128], 194 | "y_2": y[i * 128: (i + 1) * 128] 195 | } 196 | } 197 | for i in range(1024 // 128) 198 | ] 199 | 200 | model = MultipleOutputModel() 201 | qrnn = QRNN(np.linspace(0.05, 0.95, 11), model=model) 202 | qrnn.train(batched_data, n_epochs=10, keys=("x", "y")) 203 | -------------------------------------------------------------------------------- /test/test_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test for the quantnn.data module. 3 | 4 | """ 5 | import logging 6 | import os 7 | 8 | import numpy as np 9 | import pytest 10 | try: 11 | import torch 12 | TORCH_AVAILABLE = True 13 | except ImportError: 14 | TORCH_AVAILABLE = False 15 | 16 | from quantnn.data import DataFolder, LazyDataFolder 17 | 18 | 19 | # Currently no SFTP test data available. 20 | HAS_LOGIN_INFO = False 21 | 22 | 23 | LOGGER = logging.getLogger(__file__) 24 | 25 | 26 | class Dataset: 27 | """ 28 | A test dataset class to test the streaming of data via SFTP. 29 | """ 30 | def __init__(self, 31 | filename, 32 | batch_size=1): 33 | """ 34 | Create new dataset. 35 | 36 | Args: 37 | filename: Path of the file load the data from. 38 | batch_size: The batch size of the samples to return. 39 | """ 40 | 41 | self.batch_size = batch_size 42 | data = np.load(filename) 43 | self.x = data["x"] 44 | self.y = data["y"].reshape(-1, 1) 45 | LOGGER.info("Loaded data from file %s.", filename) 46 | 47 | def _shuffle(self): 48 | """ 49 | Shuffles the data order keeping x and y samples consistent. 50 | """ 51 | indices = np.random.permutation(self.x.shape[0]) 52 | self.x = self.x[indices] 53 | self.y = self.y[indices] 54 | 55 | def __len__(self): 56 | """ Number of samples in dataset. """ 57 | return self.x.shape[0] // self.batch_size 58 | 59 | def __getitem__(self, index): 60 | """ Return batch from dataset. """ 61 | if index >= len(self): 62 | raise IndexError() 63 | 64 | if index == 0: 65 | self._shuffle() 66 | 67 | start = index * self.batch_size 68 | end = (index + 1) * self.batch_size 69 | return (self.x[start:end], self.y[start:end]) 70 | 71 | 72 | @pytest.mark.xfail() 73 | @pytest.mark.skipif(not HAS_LOGIN_INFO, reason="No SFTP login info.") 74 | def test_sftp_stream(): 75 | """ 76 | Assert that streaming via SFTP yields all data in the given folder 77 | and that kwargs are correctly passed on to dataset class. 78 | """ 79 | host = "129.16.35.202" 80 | path = "/mnt/array1/share/MLDatasets/test/" 81 | stream = DataFolder("sftp://" + host + path, 82 | Dataset, 83 | kwargs={"batch_size": 2}, 84 | aggregate=None, 85 | n_workers=16) 86 | stream_2 = DataFolder("sftp://" + host + path, 87 | Dataset, 88 | kwargs={"batch_size": 2}, 89 | n_workers=16) 90 | 91 | next(iter(stream)) 92 | next(iter(stream_2)) 93 | 94 | x_sum = 0.0 95 | y_sum = 0.0 96 | for x, y in stream: 97 | x_sum += x.sum() 98 | y_sum += y.sum() 99 | assert x.shape[0] == 2 100 | 101 | x_sum = 0.0 102 | y_sum = 0.0 103 | for x, y in stream_2: 104 | x_sum += x.sum() 105 | y_sum += y.sum() 106 | assert x.shape[0] == 2 107 | 108 | x_sum = 0.0 109 | y_sum = 0.0 110 | for x, y in stream: 111 | x_sum += x.sum() 112 | y_sum += y.sum() 113 | assert x.shape[0] == 2 114 | 115 | x_sum = 0.0 116 | y_sum = 0.0 117 | for i, (x, y) in enumerate(stream_2): 118 | x_sum += x.sum() 119 | y_sum += y.sum() 120 | assert x.shape[0] == 2 121 | 122 | assert np.isclose(x_sum, 7 * 8 / 2 * 10 * 10) 123 | assert np.isclose(y_sum, 7 * 8 / 2 * 10) 124 | 125 | 126 | @pytest.mark.skipif(not HAS_LOGIN_INFO, reason="No SFTP login info.") 127 | def test_lazy_datafolder(): 128 | host = "129.16.35.202" 129 | path = "/mnt/array1/share/MLDatasets/test/" 130 | stream = LazyDataFolder("sftp://" + host + path, 131 | Dataset, 132 | kwargs={"batch_size": 2}, 133 | n_workers=2, 134 | batch_queue_size=1) 135 | 136 | x_sum = 0.0 137 | y_sum = 0.0 138 | for x, y in stream: 139 | x_sum += x.sum() 140 | y_sum += y.sum() 141 | assert x.shape[0] == 2 142 | 143 | assert np.isclose(x_sum, 7 * 8 / 2 * 10 * 10) 144 | assert np.isclose(y_sum, 7 * 8 / 2 * 10) 145 | 146 | class TensorDataset: 147 | """ 148 | A test dataset class to test the streaming of data via SFTP. 149 | """ 150 | def __init__(self, 151 | filename, 152 | batch_size=1): 153 | """ 154 | Create new dataset. 155 | 156 | Args: 157 | filename: Path of the file load the data from. 158 | batch_size: The batch size of the samples to return. 159 | """ 160 | 161 | self.batch_size = batch_size 162 | data = np.load(filename) 163 | self.x = data["x"] 164 | self.y = data["y"].reshape(-1, 1) 165 | 166 | def _shuffle(self): 167 | """ 168 | Shuffles the data order keeping x and y samples consistent. 169 | """ 170 | indices = np.random.permutation(self.x.shape[0]) 171 | self.x = self.x[indices] 172 | self.y = self.y[indices] 173 | 174 | def __len__(self): 175 | """ Number of samples in dataset. """ 176 | return self.x.shape[0] // self.batch_size 177 | 178 | def __getitem__(self, index): 179 | """ Return batch from dataset. """ 180 | if index >= len(self): 181 | raise IndexError() 182 | 183 | if index == 0: 184 | self._shuffle() 185 | 186 | start = index * self.batch_size 187 | end = (index + 1) * self.batch_size 188 | return (torch.tensor(self.x[start:end]), 189 | torch.tensor(self.y[start:end])) 190 | 191 | 192 | @pytest.mark.skipif(not HAS_LOGIN_INFO, reason="No SFTP login info.") 193 | @pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not installed.") 194 | def test_aggregation(): 195 | """ 196 | Assert that aggregation of tensor works as expected. 197 | """ 198 | host = "129.16.35.202" 199 | path = "/mnt/array1/share/MLDatasets/test/" 200 | stream = DataFolder("sftp://" + host + path, 201 | TensorDataset, 202 | kwargs={"batch_size": 1}, 203 | aggregate=2, 204 | n_workers=16) 205 | stream_2 = DataFolder("sftp://" + host + path, 206 | TensorDataset, 207 | kwargs={"batch_size": 1}, 208 | aggregate=2, 209 | n_workers=16) 210 | 211 | next(iter(stream)) 212 | next(iter(stream_2)) 213 | 214 | x_sum = 0.0 215 | y_sum = 0.0 216 | for x, y in stream: 217 | x_sum += x.sum() 218 | y_sum += y.sum() 219 | assert x.shape[0] == 2 220 | 221 | x_sum = 0.0 222 | y_sum = 0.0 223 | for x, y in stream_2: 224 | x_sum += x.sum() 225 | y_sum += y.sum() 226 | assert np.all(np.isclose(x[:, 0].detach().numpy().ravel(), 227 | y.detach().numpy().ravel() 228 | )) 229 | assert x.shape[0] == 2 230 | 231 | x_sum = 0.0 232 | y_sum = 0.0 233 | for x, y in stream: 234 | x_sum += x.sum() 235 | y_sum += y.sum() 236 | assert x.shape[0] == 2 237 | 238 | x_sum = 0.0 239 | y_sum = 0.0 240 | 241 | for x, y in stream_2: 242 | x_sum += x.sum() 243 | y_sum += y.sum() 244 | assert x.shape[0] == 2 245 | 246 | assert np.isclose(x_sum, 7 * 8 / 2 * 10 * 10) 247 | assert np.isclose(y_sum, 7 * 8 / 2 * 10) 248 | -------------------------------------------------------------------------------- /test/test_data/x_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simonpf/quantnn/c4d650cf0c6da5b4a704905b6c267d1ca996466f/test/test_data/x_train.npy -------------------------------------------------------------------------------- /test/test_data/y_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simonpf/quantnn/c4d650cf0c6da5b4a704905b6c267d1ca996466f/test/test_data/y_train.npy -------------------------------------------------------------------------------- /test/test_drnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for quantnn.drnn module. 3 | """ 4 | import os 5 | import tempfile 6 | 7 | import pytest 8 | 9 | import numpy as np 10 | from quantnn import drnn, set_default_backend, get_default_backend 11 | from quantnn.drnn import DRNN 12 | 13 | # 14 | # Import available backends. 15 | # 16 | 17 | backends = [] 18 | try: 19 | import quantnn.models.keras 20 | 21 | backends += ["keras"] 22 | except: 23 | pass 24 | 25 | try: 26 | import quantnn.models.pytorch 27 | 28 | backends += ["pytorch"] 29 | except: 30 | pass 31 | 32 | class TestDrnn: 33 | def setup_method(self): 34 | dir = os.path.dirname(os.path.realpath(__file__)) 35 | path = os.path.join(dir, "test_data") 36 | x_train = np.load(os.path.join(path, "x_train.npy")) 37 | x_mean = np.mean(x_train, keepdims=True) 38 | x_sigma = np.std(x_train, keepdims=True) 39 | self.x_train = (x_train - x_mean) / x_sigma 40 | self.bins = np.logspace(0, 3, 21) 41 | y = np.load(os.path.join(path, "y_train.npy")) 42 | self.y_train = y 43 | 44 | def test_to_categorical(self): 45 | """ 46 | Assert that converting a continuous target variable to binned 47 | representation works as expected. 48 | """ 49 | bins = np.linspace(0, 10, 11) 50 | 51 | y = np.arange(12) - 0.5 52 | y_cat = drnn._to_categorical(y, bins) 53 | 54 | assert y_cat[0] == 0 55 | assert np.all(np.isclose(y_cat[1:-1], np.arange(10))) 56 | assert y_cat[-1] == 9 57 | 58 | @pytest.mark.parametrize("backend", backends) 59 | def test_drnn(self, backend): 60 | """ 61 | Test training of DRNN using numpy arrays as input. 62 | """ 63 | set_default_backend(backend) 64 | drnn = DRNN(self.bins, 65 | n_inputs=self.x_train.shape[1]) 66 | drnn.train((self.x_train, self.y_train), 67 | validation_data=(self.x_train, self.y_train), 68 | n_epochs=2) 69 | 70 | drnn.predict(self.x_train) 71 | 72 | mu = drnn.posterior_mean(self.x_train[:2, :]) 73 | assert len(mu.shape) == 1 74 | 75 | r = drnn.sample_posterior(self.x_train[:4, :], n_samples=2) 76 | assert r.shape == (4, 2) 77 | 78 | @pytest.mark.parametrize("backend", backends) 79 | def test_drnn_dict_iterable(self, backend): 80 | """ 81 | Test training with dataset object that yields dicts instead of 82 | tuples. 83 | """ 84 | set_default_backend(backend) 85 | backend = get_default_backend() 86 | 87 | class DictWrapper: 88 | def __init__(self, data): 89 | self.data = data 90 | 91 | def __iter__(self): 92 | for x, y in self.data: 93 | yield {"x": x, "y": y} 94 | 95 | def __len__(self): 96 | return len(self.data) 97 | 98 | data = backend.BatchedDataset((self.x_train, self.y_train), 256) 99 | drnn = DRNN(self.bins, 100 | n_inputs=self.x_train.shape[1]) 101 | drnn.train(DictWrapper(data), n_epochs=2, keys=("x", "y")) 102 | 103 | @pytest.mark.parametrize("backend", backends) 104 | def test_drnn_datasets(self, backend): 105 | """ 106 | Provide data as dataset object instead of numpy arrays. 107 | """ 108 | set_default_backend(backend) 109 | backend = get_default_backend() 110 | data = backend.BatchedDataset((self.x_train, self.y_train), 256) 111 | drnn = DRNN(self.bins, 112 | n_inputs=self.x_train.shape[1]) 113 | drnn.train(data, n_epochs=2) 114 | 115 | @pytest.mark.parametrize("backend", backends) 116 | def test_save_drnn(self, backend): 117 | """ 118 | Test saving and loading of DRNNs. 119 | """ 120 | set_default_backend(backend) 121 | drnn = DRNN(self.bins, 122 | n_inputs=self.x_train.shape[1]) 123 | f = tempfile.NamedTemporaryFile() 124 | drnn.save(f.name) 125 | drnn_loaded = DRNN.load(f.name) 126 | 127 | x_pred = drnn.predict(self.x_train) 128 | x_pred_loaded = drnn.predict(self.x_train) 129 | 130 | if not type(x_pred) == np.ndarray: 131 | x_pred = x_pred.detach() 132 | 133 | assert np.allclose(x_pred, x_pred_loaded) 134 | -------------------------------------------------------------------------------- /test/test_generic.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for generic array manipulation functions. 3 | """ 4 | import numpy as np 5 | import pytest 6 | from quantnn.generic import (get_array_module, to_array, sample_uniform, 7 | sample_gaussian, numel, concatenate, expand_dims, 8 | pad_zeros, pad_zeros_left, as_type, arange, 9 | reshape, trapz, cumsum, cumtrapz, ones, zeros, 10 | softmax, exp, tensordot, argmax, take_along_axis, 11 | digitize, scatter_add) 12 | 13 | 14 | @pytest.mark.parametrize("backend", pytest.backends) 15 | def test_get_array_module(backend): 16 | """ 17 | Ensures that get_array_module returns right array object 18 | when given an array created using the arange method of the 19 | corresponding module object. 20 | """ 21 | x = backend.ones(10) 22 | module = get_array_module(x) 23 | assert module == backend 24 | 25 | 26 | @pytest.mark.parametrize("backend", pytest.backends) 27 | def test_to_array(backend): 28 | """ 29 | Converts numpy array to array of given backend and ensures 30 | that corresponding module object matches the backend. 31 | """ 32 | x = np.arange(10) 33 | array = to_array(backend, x) 34 | assert get_array_module(array) == backend 35 | 36 | 37 | @pytest.mark.parametrize("backend", pytest.backends) 38 | def test_sample_uniform(backend): 39 | """ 40 | Ensures that array of random samples has array type 41 | corresponding to the right backend module. 42 | """ 43 | samples = sample_uniform(backend, (10, )) 44 | assert get_array_module(samples) == backend 45 | 46 | 47 | @pytest.mark.parametrize("backend", pytest.backends) 48 | def test_sample_gaussian(backend): 49 | """ 50 | Ensures that array of random samples has array type 51 | corresponding to the right backend module. 52 | """ 53 | samples = sample_gaussian(backend, (10, )) 54 | assert get_array_module(samples) == backend 55 | 56 | 57 | @pytest.mark.parametrize("backend", pytest.backends) 58 | def test_numel(backend): 59 | """ 60 | Ensures that the numel function returns the right number of elements. 61 | """ 62 | array = backend.ones(10) 63 | assert numel(array) == 10 64 | 65 | 66 | @pytest.mark.parametrize("backend", pytest.backends) 67 | def test_concatenate(backend): 68 | """ 69 | Ensures that concatenation of array yields tensor with the expected size. 70 | """ 71 | array_1 = backend.ones((10, 1)) 72 | array_2 = backend.ones((10, 2)) 73 | result = concatenate(backend, [array_1, array_2], 1) 74 | assert numel(result) == 30 75 | 76 | 77 | @pytest.mark.parametrize("backend", pytest.backends) 78 | def test_expand_dims(backend): 79 | """ 80 | Ensures that expansion of dims yields expected shape. 81 | """ 82 | array = backend.ones((10,)) 83 | result = expand_dims(backend, array, 1) 84 | assert len(result.shape) == 2 85 | assert result.shape[1] == 1 86 | 87 | 88 | @pytest.mark.parametrize("backend", pytest.backends) 89 | def test_pad_zeros(backend): 90 | """ 91 | Ensures that zero padding pads zeros. 92 | """ 93 | array = backend.ones((10, 10)) 94 | result = pad_zeros(backend, array, 2, 1) 95 | result = pad_zeros(backend, result, 1, 0) 96 | assert result.shape[0] == 12 97 | assert result.shape[1] == 14 98 | assert result[0, 1] == 0.0 99 | assert result[-1, -2] == 0.0 100 | 101 | 102 | @pytest.mark.parametrize("backend", pytest.backends) 103 | def test_pad_zeros_left(backend): 104 | """ 105 | Ensures that zero padding pads zeros only on left side. 106 | """ 107 | array = backend.ones((10, 10)) 108 | result = pad_zeros_left(backend, array, 2, 1) 109 | result = pad_zeros_left(backend, result, 1, 0) 110 | assert result.shape[0] == 11 111 | assert result.shape[1] == 12 112 | assert result[0, 1] == 0.0 113 | assert result[-1, -2] != 0.0 114 | 115 | 116 | @pytest.mark.parametrize("backend", pytest.backends) 117 | def test_as_type(backend): 118 | """ 119 | Ensures that conversion of types works. 120 | """ 121 | array = backend.ones((10, 10)) 122 | mask = array > 0.0 123 | result = as_type(backend, mask, array) 124 | assert array.dtype == result.dtype 125 | 126 | 127 | @pytest.mark.parametrize("backend", pytest.backends) 128 | def test_arange(backend): 129 | """ 130 | Ensures that generation of ranges works as expected. 131 | """ 132 | array = arange(backend, 0, 10.1, 1) 133 | assert array[0] == 0.0 134 | assert array[-1] == 10.0 135 | 136 | 137 | @pytest.mark.parametrize("backend", pytest.backends) 138 | def test_reshape(backend): 139 | array = arange(backend, 0, 10.1, 1) 140 | result = reshape(backend, array, (1, 11, 1)) 141 | assert result.shape[0] == 1 142 | assert result.shape[1] == 11 143 | assert result.shape[2] == 1 144 | 145 | 146 | @pytest.mark.parametrize("backend", pytest.backends) 147 | def test_trapz(backend): 148 | array = arange(backend, 0, 10.1, 1) 149 | result = trapz(backend, array, array, 0) 150 | assert result == 50 151 | 152 | 153 | @pytest.mark.parametrize("backend", pytest.backends) 154 | def test_cumsum(backend): 155 | array = reshape(backend, arange(backend, 0, 10.1, 1), (11, 1)) 156 | result = cumsum(backend, array, 0) 157 | assert result[-1, 0] == 55 158 | result = cumsum(backend, array, 1) 159 | assert result[-1, 0] == 10 160 | 161 | 162 | @pytest.mark.parametrize("backend", pytest.backends) 163 | def test_cumtrapz(backend): 164 | y = reshape(backend, arange(backend, 0, 10.1, 1), (11, 1)) 165 | x = arange(backend, 0, 10.1, 1) 166 | 167 | result = cumtrapz(backend, y, x, 0) 168 | assert result[0, 0] == 0.0 169 | assert result[-1, 0] == 50.0 170 | 171 | result = cumtrapz(backend, y, 2.0 * x, 0) 172 | assert result[0, 0] == 0.0 173 | assert result[-1, 0] == 100.0 174 | 175 | 176 | @pytest.mark.parametrize("backend", pytest.backends) 177 | def test_zeros(backend): 178 | x = ones(backend, (1, 1)) 179 | assert x[0, 0] == 1.0 180 | 181 | 182 | @pytest.mark.parametrize("backend", pytest.backends) 183 | def test_zeros(backend): 184 | x = zeros(backend, (1, 1)) 185 | assert x[0, 0] == 0.0 186 | 187 | 188 | @pytest.mark.parametrize("backend", pytest.backends) 189 | def test_softmax(backend): 190 | array = arange(backend, 0, 10.1, 1) 191 | y = softmax(backend, array) 192 | 193 | 194 | @pytest.mark.parametrize("backend", pytest.backends) 195 | def test_tensordot(backend): 196 | x = arange(backend, 0, 10.1, 1) 197 | y = ones(backend, 11) 198 | z = tensordot(backend, x, y, ((0,), (0,))) 199 | assert np.isclose(z, 55) 200 | 201 | 202 | @pytest.mark.parametrize("backend", pytest.backends) 203 | def test_argmax(backend): 204 | x = arange(backend, 0, 10.1, 1).reshape(-1, 1) 205 | i = argmax(backend, x) 206 | assert i == 10 207 | 208 | i = argmax(backend, x, 1) 209 | assert all(i == 0) 210 | 211 | 212 | @pytest.mark.parametrize("backend", pytest.backends) 213 | def test_take_along_axis(backend): 214 | x = arange(backend, 0, 10.1, 1).reshape(-1, 1) 215 | indices = to_array(backend, [[0], [1], [2], [3]]) 216 | i = take_along_axis(backend, x, indices, 0) 217 | 218 | assert i[0] == 0 219 | assert i[1] == 1 220 | assert i[2] == 2 221 | assert i[3] == 3 222 | 223 | 224 | @pytest.mark.parametrize("backend", pytest.backends) 225 | def test_digitize(backend): 226 | x = arange(backend, 0, 10.0, 1).reshape(-1, 1) + 0.5 227 | bins = arange(backend, 0, 10.1, 1) 228 | 229 | inds = digitize(backend, x, bins) 230 | assert inds[0] == 1 231 | assert inds[-1] == 10 232 | 233 | 234 | @pytest.mark.parametrize("backend", pytest.backends) 235 | def test_scatter_add(backend): 236 | x = zeros(backend, (3, 3)) 237 | y = ones(backend, (2, 3)) 238 | indices = to_array(backend, [0, 2]) 239 | 240 | z = scatter_add(backend, x, indices, y, 0) 241 | assert np.isclose(z[0, 0], 1.0) 242 | assert np.isclose(z[1, 0], 0.0) 243 | assert np.isclose(z[2, 0], 1.0) 244 | 245 | x = zeros(backend, (3, 3)) 246 | y = ones(backend, (3, 2)) 247 | z = scatter_add(backend, x, indices, y, 1) 248 | assert np.isclose(z[0, 0], 1.0) 249 | assert np.isclose(z[0, 1], 0.0) 250 | assert np.isclose(z[0, 2], 1.0) 251 | -------------------------------------------------------------------------------- /test/test_normalizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from quantnn.normalizer import Normalizer, MinMaxNormalizer 3 | 4 | def test_normalizer_2d(): 5 | """ 6 | Checks that all feature indices that are not excluded have zero 7 | mean and unit std. dev. 8 | """ 9 | x = np.random.normal(size=(100000, 10)) + np.arange(10).reshape(1, -1) 10 | normalizer = Normalizer(x, 11 | exclude_indices=range(1, 10, 2)) 12 | 13 | x_normed = normalizer(x) 14 | 15 | # Included indices should have zero mean and std. dev. 1.0. 16 | assert np.all(np.isclose(x_normed[:, ::2].mean(axis=0), 17 | 0.0, 18 | atol=1e-1)) 19 | assert np.all(np.isclose(x_normed[:, ::2].std(axis=0), 20 | 1.0, 21 | 1e-1)) 22 | 23 | # Excluded indices 24 | assert np.all(np.isclose(x_normed[:, 1::2].mean(axis=0), 25 | np.arange(10)[1::2].reshape(1, -1), 26 | 1e-2)) 27 | assert np.all(np.isclose(x_normed[:, 1::2].std(axis=0), 28 | 1.0, 29 | 1e-2)) 30 | 31 | # Channels without variation should be set to -1.0 32 | x = np.zeros((100, 10)) 33 | normalizer = Normalizer(x) 34 | x_normed = normalizer(x) 35 | assert np.all(np.isclose(x_normed, -1.0)) 36 | 37 | def test_min_max_normalizer_2d(): 38 | """ 39 | Checks that all feature indices that are not excluded have zero 40 | mean and unit std. dev. 41 | """ 42 | x = np.random.normal(size=(100000, 11)) + np.arange(11).reshape(1, -1) 43 | normalizer = MinMaxNormalizer(x, exclude_indices=range(1, 10, 2)) 44 | x[:, 10] = np.nan 45 | 46 | x_normed = normalizer(x) 47 | 48 | # Included indices should have minimum value -0.9 and 49 | # maximum value 1.0. 50 | assert np.all(np.isclose(x_normed[:, :10:2].min(axis=0), 51 | -1.0)) 52 | assert np.all(np.isclose(x_normed[:, :10:2].max(axis=0), 53 | 1.0)) 54 | # nan values should be set to -1.0. 55 | assert np.all(np.isclose(x_normed[:, -1], -1.5)) 56 | 57 | # Channels without variation should be set to -1.0 58 | x = np.zeros((100, 10)) 59 | normalizer = MinMaxNormalizer(x) 60 | x_normed = normalizer(x) 61 | assert np.all(np.isclose(x_normed, -1.0)) 62 | 63 | def test_invert(): 64 | """ 65 | Ensure that the inverse function of the Normalizer works as expected. 66 | """ 67 | x = np.random.normal(size=(100000, 10)) + np.arange(10).reshape(1, -1) 68 | normalizer = Normalizer(x, exclude_indices=[0, 1, 2]) 69 | 70 | x_normed = normalizer(x) 71 | x = normalizer.invert(x_normed) 72 | 73 | assert np.all(np.isclose(np.mean(x, axis=0), 74 | np.arange(10, dtype=np.float32), 75 | atol=1e-2)) 76 | 77 | def test_save_and_load(tmp_path): 78 | """ 79 | Ensure that saved and loaded normalizer yields same results as original. 80 | """ 81 | x = np.random.normal(size=(100000, 10)) + np.arange(10).reshape(1, -1) 82 | normalizer = Normalizer(x, 83 | exclude_indices=range(1, 10, 2)) 84 | normalizer.save(tmp_path / "normalizer.pckl") 85 | loaded = Normalizer.load(tmp_path / "normalizer.pckl") 86 | 87 | x_normed = normalizer(x) 88 | x_normed_loaded = loaded(x) 89 | 90 | assert np.all(np.isclose(x_normed, 91 | x_normed_loaded)) 92 | 93 | 94 | def test_load_sftp(tmp_path): 95 | """ 96 | Ensure that saved and loaded normalizer yields same results as original. 97 | """ 98 | x = np.random.normal(size=(100000, 10)) + np.arange(10).reshape(1, -1) 99 | normalizer = Normalizer(x, 100 | exclude_indices=range(1, 10, 2)) 101 | normalizer.save(tmp_path / "normalizer.pckl") 102 | loaded = Normalizer.load(tmp_path / "normalizer.pckl") 103 | 104 | x_normed = normalizer(x) 105 | x_normed_loaded = loaded(x) 106 | 107 | assert np.all(np.isclose(x_normed, 108 | x_normed_loaded, 109 | rtol=1e-3)) 110 | -------------------------------------------------------------------------------- /test/test_packed_tensor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from quantnn.packed_tensor import PackedTensor 5 | 6 | def fill_tensor(t, indices): 7 | """ 8 | Fills tensor with corresponding sample indices. 9 | """ 10 | for i, ind in enumerate(indices): 11 | t[i] = ind 12 | return t 13 | 14 | 15 | def make_random_packed_tensor(batch_size, samples, shape=(1,)): 16 | """ 17 | Create a sparse tensor representing a training batch with 18 | missing samples. Which samples are missing is randomized. 19 | The elements of the tensor correspond to the sample index. 20 | 21 | Args: 22 | batch_size: The nominal batch size of the training batch. 23 | samples: The number of non-missing samples. 24 | shape: The of each sample in the training batch. 25 | """ 26 | indices = sorted( 27 | np.random.choice(np.arange(batch_size), size=samples, replace=False) 28 | ) 29 | t = np.ones((samples,) + shape, dtype=np.float32) 30 | t = fill_tensor(t, indices) 31 | return PackedTensor(t, batch_size, indices) 32 | 33 | 34 | def test_attributes(): 35 | """ 36 | Test attributes of packed tensor. 37 | """ 38 | t = torch.ones((2, 2)) 39 | t_p = PackedTensor(t, 4, [0, 1]) 40 | 41 | assert t_p.batch_size == 4 42 | assert t_p.batch_indices == [0, 1] 43 | assert t_p.shape == (2, 2) 44 | 45 | 46 | def test_stack(): 47 | """ 48 | Test stacking of list of tensor into batch. 49 | """ 50 | t = torch.ones((2, 2)) 51 | u = 2.0 * torch.ones((2, 2)) 52 | 53 | tensors = [None, t, u, None] 54 | b = PackedTensor.stack(tensors) 55 | 56 | assert b.batch_indices == [1, 2] 57 | assert b.batch_size == 4 58 | 59 | b_e = b.expand() 60 | assert (b_e[0] == 0.0).all() 61 | assert (b_e[1] == 1.0).all() 62 | assert (b_e[2] == 2.0).all() 63 | assert (b_e[0] == 0.0).all() 64 | 65 | 66 | def test_set_get_item(): 67 | """ 68 | Test setting of items. 69 | """ 70 | t = torch.ones((2, 2)) 71 | t_p = PackedTensor(t, 4, [2, 3]) 72 | 73 | # Test setting channel. 74 | t_p[:, 0] = 4.0 75 | t_e = t_p.expand() 76 | assert t_e[0, 0] == 0.0 77 | assert t_e[2, 0] == 4.0 78 | assert t_e[2, 1] == 1.0 79 | 80 | # Test getting data from tensor. 81 | assert (t_p[:, 0].tensor == 4.0).all() 82 | assert (t_p[..., 0].tensor == 4.0).all() 83 | 84 | 85 | def test_expand(): 86 | """ 87 | Test expansion of packed tensor. 88 | """ 89 | indices = [0, 3] 90 | t = fill_tensor(torch.ones((2, 2)), indices) 91 | t_p = PackedTensor(t, 4, indices) 92 | t_e = t_p.expand() 93 | assert not isinstance(t_e, PackedTensor) 94 | assert t_e.shape[0] == 4 95 | assert (t_e[0] == 0.0).all() 96 | assert (t_e[3] == 3.0).all() 97 | 98 | # Empty tensor 99 | t = torch.ones((2, 2)) 100 | t_p = PackedTensor(t, 4, []) 101 | t_e = t_p.expand() 102 | assert not isinstance(t_e, PackedTensor) 103 | assert t_e.shape[0] == 4 104 | assert (t_e == 0.0).all() 105 | 106 | 107 | def test_intersection(): 108 | """ 109 | Test intersection of packed tensors. 110 | """ 111 | indices = [0, 2] 112 | u = fill_tensor(torch.ones((2, 2)), indices) 113 | u_p = PackedTensor(u, 4, indices) 114 | 115 | indices = [2, 3] 116 | v = 2.0 * fill_tensor(torch.ones((2, 2)), indices) 117 | v_p = PackedTensor(v, 4, indices) 118 | u_i_p, v_i_p = u_p.intersection(v_p) 119 | assert u_i_p.batch_indices == [2] 120 | assert u_i_p.batch_size == 4 121 | 122 | u_i_e = u_i_p.expand() 123 | assert (u_i_e[2] == 2.0).all() 124 | 125 | v_i_e = v_i_p.expand() 126 | assert (v_i_e[2] == 2 * 2.0).all() 127 | 128 | u = torch.ones((2, 2)) 129 | u_p = PackedTensor(u, 4, [0, 1]) 130 | v = 2.0 * torch.ones((2, 2)) 131 | v_p = PackedTensor(v, 4, [2, 3]) 132 | u_i_p, v_i_p = u_p.intersection(v_p) 133 | assert u_i_p is None 134 | assert v_i_p is None 135 | 136 | for i in range(100): 137 | l = make_random_packed_tensor(100, 50) 138 | r = make_random_packed_tensor(100, 50) 139 | indices = sorted(list(set(l.batch_indices) & set(r.batch_indices))) 140 | l, r = l.intersection(r) 141 | for i, index in enumerate(indices): 142 | assert (l.tensor[i] == index).all() 143 | assert (r.tensor[i] == index).all() 144 | 145 | 146 | def test_splitting(): 147 | """ 148 | Test splitting of packed tensors into parts. 149 | """ 150 | indices = [0, 2] 151 | u = fill_tensor(torch.ones((2, 2)), indices) 152 | u_p = PackedTensor(u, 4, indices) 153 | 154 | indices = [2, 3] 155 | v = 2.0 * fill_tensor(torch.ones((2, 2)), indices) 156 | v_p = PackedTensor(v, 4, indices) 157 | u_only, v_only, u_both, v_both = u_p.split_parts(v_p) 158 | assert u_only.batch_indices == [0] 159 | assert v_only.batch_indices == [3] 160 | assert u_both.batch_indices == [2] 161 | assert v_both.batch_indices == [2] 162 | 163 | for i in range(100): 164 | l = make_random_packed_tensor(100, 50) 165 | r = make_random_packed_tensor(100, 50) 166 | indices = sorted(list(set(l.batch_indices) & set(r.batch_indices))) 167 | l_only, r_only, l_both, r_both = l.split_parts(r) 168 | 169 | for l_val, r_val in zip(l_both.batch_indices, r_both.batch_indices): 170 | assert l_val == r_val 171 | 172 | for i, index in enumerate(indices): 173 | assert (l_both.tensor[i] == index).all() 174 | assert (r_both.tensor[i] == index).all() 175 | 176 | indices = sorted(list(set(l.batch_indices) - set(r.batch_indices))) 177 | for i, index in enumerate(indices): 178 | assert l_only.tensor[i] == index 179 | 180 | indices = sorted(list(set(r.batch_indices) - set(l.batch_indices))) 181 | for i, index in enumerate(indices): 182 | assert r_only.tensor[i] == index 183 | 184 | lr = l_only.union(r_only.union(l_both.union(r_both))) 185 | for i, index in enumerate(lr.batch_indices): 186 | assert np.isclose(lr.tensor[i], index) 187 | 188 | 189 | def test_intersection(): 190 | """ 191 | Test intersection of packed tensors. 192 | """ 193 | indices = [0, 2] 194 | u = fill_tensor(torch.ones((2, 2)), indices) 195 | u_p = PackedTensor(u, 4, indices) 196 | 197 | indices = [2, 3] 198 | v = 2.0 * fill_tensor(torch.ones((2, 2)), indices) 199 | v_p = PackedTensor(v, 4, indices) 200 | u_i_p, v_i_p = u_p.intersection(v_p) 201 | assert u_i_p.batch_indices == [2] 202 | assert u_i_p.batch_size == 4 203 | 204 | u_i_e = u_i_p.expand() 205 | assert (u_i_e[2] == 2.0).all() 206 | 207 | v_i_e = v_i_p.expand() 208 | assert (v_i_e[2] == 2 * 2.0).all() 209 | 210 | u = torch.ones((2, 2)) 211 | u_p = PackedTensor(u, 4, [0, 1]) 212 | v = 2.0 * torch.ones((2, 2)) 213 | v_p = PackedTensor(v, 4, [2, 3]) 214 | u_i_p, v_i_p = u_p.intersection(v_p) 215 | assert u_i_p is None 216 | assert v_i_p is None 217 | 218 | for i in range(100): 219 | l = make_random_packed_tensor(100, 50) 220 | r = make_random_packed_tensor(100, 50) 221 | indices = sorted(list(set(l.batch_indices) & set(r.batch_indices))) 222 | l, r = l.intersection(r) 223 | for i, index in enumerate(indices): 224 | assert (l.tensor[i] == index).all() 225 | assert (r.tensor[i] == index).all() 226 | 227 | def test_difference(): 228 | """ 229 | Test difference of packed tensors. 230 | """ 231 | u = torch.zeros((2, 2)) 232 | u[1] = 2.0 233 | u_p = PackedTensor(u, 4, [1, 3]) 234 | v = 1.0 * torch.ones((2, 2)) 235 | v_p = PackedTensor(v, 4, [0, 1]) 236 | 237 | d_p = u_p.difference(v_p) 238 | assert d_p.batch_indices == [0, 3] 239 | assert d_p.batch_size == 4 240 | 241 | #d_e = d_p.expand() 242 | #assert (d_e[0] == 1.0).all() 243 | #assert (d_e[1] == 0.0).all() 244 | #assert (d_e[2] == 0.0).all() 245 | #assert (d_e[3] == 2.0).all() 246 | 247 | for i in range(10): 248 | l = make_random_packed_tensor(100, 50) 249 | r = make_random_packed_tensor(100, 50) 250 | indices = sorted(list(set(l.batch_indices) ^ set(r.batch_indices))) 251 | d = l.difference(r) 252 | for i, index in enumerate(indices): 253 | assert (d.tensor[i] == index).all() 254 | 255 | 256 | def test_union(): 257 | """ 258 | Test union of packed tensors. 259 | """ 260 | u = torch.ones((2, 2)) 261 | u_p = PackedTensor(u, 4, [3]) 262 | v = 2.0 * torch.ones((2, 2)) 263 | v_p = PackedTensor(v, 4, [1, 2]) 264 | 265 | s_p = u_p.union(v_p) 266 | assert s_p.batch_indices == [1, 2, 3] 267 | assert s_p.batch_size == 4 268 | 269 | s_e = s_p.expand() 270 | assert (s_e[0] == 0.0).all() 271 | assert (s_e[1] == 2.0).all() 272 | assert (s_e[2] == 2.0).all() 273 | assert (s_e[3] == 1.0).all() 274 | 275 | s_p = v_p.union(u_p) 276 | assert s_p.batch_indices == [1, 2, 3] 277 | assert s_p.batch_size == 4 278 | 279 | s_e = s_p.expand() 280 | assert (s_e[0] == 0.0).all() 281 | assert (s_e[1] == 2.0).all() 282 | assert (s_e[2] == 2.0).all() 283 | assert (s_e[3] == 1.0).all() 284 | 285 | for i in range(100): 286 | l = make_random_packed_tensor(100, 50) 287 | r = make_random_packed_tensor(100, 50) 288 | indices = sorted(list(set(l.batch_indices) | set(r.batch_indices))) 289 | s = l.union(r) 290 | 291 | for i, index in enumerate(indices): 292 | assert (s.tensor[i] == index).all() 293 | 294 | 295 | def test_apply_batch_norm(): 296 | """ 297 | Test application of batch norm layer, which requires the dim() member 298 | function. 299 | """ 300 | t = make_random_packed_tensor(100, 50, (8, 16, 16)) 301 | norm = nn.BatchNorm2d(8) 302 | norm.weight.data.fill_(0) 303 | norm.bias.data.fill_(1.0) 304 | y = norm(t) 305 | 306 | assert (y.tensor == 1.0).all() 307 | -------------------------------------------------------------------------------- /test/test_qrnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests the QRNN implementation for all available backends. 3 | """ 4 | from quantnn import (QRNN, 5 | set_default_backend, 6 | get_default_backend) 7 | import numpy as np 8 | import os 9 | import importlib 10 | import pytest 11 | import tempfile 12 | 13 | # 14 | # Import available backends. 15 | # 16 | 17 | backends = [] 18 | try: 19 | import quantnn.models.keras 20 | 21 | backends += ["keras"] 22 | except: 23 | pass 24 | 25 | try: 26 | import quantnn.models.pytorch 27 | 28 | backends += ["pytorch"] 29 | except: 30 | pass 31 | 32 | 33 | class TestQrnn: 34 | def setup_method(self): 35 | dir = os.path.dirname(os.path.realpath(__file__)) 36 | path = os.path.join(dir, "test_data") 37 | x_train = np.load(os.path.join(path, "x_train.npy")) 38 | x_mean = np.mean(x_train, keepdims=True) 39 | x_sigma = np.std(x_train, keepdims=True) 40 | self.x_train = (x_train - x_mean) / x_sigma 41 | self.y_train = np.load(os.path.join(path, "y_train.npy")) 42 | 43 | @pytest.mark.parametrize("backend", backends) 44 | def test_qrnn(self, backend): 45 | """ 46 | Test training of QRNNs using numpy arrays as input. 47 | """ 48 | set_default_backend(backend) 49 | qrnn = QRNN(np.linspace(0.05, 0.95, 10), 50 | n_inputs=self.x_train.shape[1]) 51 | qrnn.train((self.x_train, self.y_train), 52 | validation_data=(self.x_train, self.y_train), 53 | n_epochs=2) 54 | 55 | qrnn.predict(self.x_train) 56 | 57 | x, qs = qrnn.cdf(self.x_train[:2, :]) 58 | assert qs[0] == 0.0 59 | assert qs[-1] == 1.0 60 | 61 | x, y = qrnn.pdf(self.x_train[:2, :]) 62 | assert x.shape == y.shape 63 | 64 | mu = qrnn.posterior_mean(self.x_train[:2, :]) 65 | assert len(mu.shape) == 1 66 | 67 | r = qrnn.sample_posterior(self.x_train[:4, :], n_samples=2) 68 | assert r.shape == (4, 2) 69 | 70 | r = qrnn.sample_posterior_gaussian_fit(self.x_train[:4, :], n_samples=2) 71 | assert r.shape == (4, 2) 72 | 73 | @pytest.mark.parametrize("backend", backends) 74 | def test_qrnn_dict_iterable(self, backend): 75 | """ 76 | Test training with dataset object that yields dicts instead of 77 | tuples. 78 | """ 79 | set_default_backend(backend) 80 | backend = get_default_backend() 81 | 82 | class DictWrapper: 83 | def __init__(self, data): 84 | self.data = data 85 | 86 | def __iter__(self): 87 | for x, y in self.data: 88 | yield {"x": x, "y": y} 89 | 90 | def __len__(self): 91 | return len(self.data) 92 | 93 | data = backend.BatchedDataset((self.x_train, self.y_train), 256) 94 | qrnn = QRNN(np.linspace(0.05, 0.95, 10), 95 | n_inputs=self.x_train.shape[1]) 96 | qrnn.train(DictWrapper(data), n_epochs=2, keys=("x", "y")) 97 | 98 | @pytest.mark.parametrize("backend", backends) 99 | def test_qrnn_datasets(self, backend): 100 | """ 101 | Provide data as dataset object instead of numpy arrays. 102 | """ 103 | set_default_backend(backend) 104 | backend = get_default_backend() 105 | data = backend.BatchedDataset((self.x_train, self.y_train), 256) 106 | qrnn = QRNN(np.linspace(0.05, 0.95, 10), 107 | n_inputs=self.x_train.shape[1]) 108 | qrnn.train(data, n_epochs=2) 109 | 110 | @pytest.mark.parametrize("backend", backends) 111 | def test_save_qrnn(self, backend): 112 | """ 113 | Test saving and loading of QRNNs. 114 | """ 115 | set_default_backend(backend) 116 | qrnn = QRNN(np.linspace(0.05, 0.95, 10), 117 | n_inputs=self.x_train.shape[1]) 118 | f = tempfile.NamedTemporaryFile() 119 | qrnn.save(f.name) 120 | qrnn_loaded = QRNN.load(f.name) 121 | 122 | x_pred = qrnn.predict(self.x_train) 123 | x_pred_loaded = qrnn.predict(self.x_train) 124 | 125 | if not type(x_pred) == np.ndarray: 126 | x_pred = x_pred.detach() 127 | 128 | assert np.allclose(x_pred, x_pred_loaded) 129 | 130 | @pytest.mark.skipif(not "pytorch" in backends, 131 | reason="No PyTorch backend.") 132 | def test_save_qrnn_pytorch_model(self): 133 | """ 134 | Test saving and loading of QRNNs. 135 | """ 136 | from torch import nn 137 | quantiles = np.linspace(0.05, 0.95, 10) 138 | model = nn.Sequential(nn.Linear(self.x_train.shape[1], quantiles.size)) 139 | qrnn = QRNN(quantiles, model=model) 140 | 141 | # Train the model 142 | data = quantnn.models.pytorch.BatchedDataset((self.x_train, self.y_train), 256) 143 | qrnn.train(data, n_epochs=2) 144 | 145 | # Save the model 146 | f = tempfile.NamedTemporaryFile() 147 | qrnn.save(f.name) 148 | qrnn_loaded = QRNN.load(f.name) 149 | 150 | # Compare predictions from saved and loaded model. 151 | x_pred = qrnn.predict(self.x_train) 152 | x_pred_loaded = qrnn.predict(self.x_train) 153 | if not type(x_pred) == np.ndarray: 154 | x_pred = x_pred.detach() 155 | 156 | assert np.allclose(x_pred, x_pred_loaded) 157 | -------------------------------------------------------------------------------- /test/test_tensor_backends.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import scipy as sp 4 | 5 | from quantnn.backends import TENSOR_BACKENDS 6 | 7 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS) 8 | def test_conversion(backend): 9 | """ 10 | Ensure that conversion back and forth from numpy arrays works. 11 | """ 12 | 13 | x = np.random.rand(10, 10) 14 | x_b = backend.from_numpy(x) 15 | x_c = backend.to_numpy(x_b) 16 | 17 | assert np.all(np.isclose(x, x_c)) 18 | 19 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS) 20 | def test_sample_uniform(backend): 21 | x = backend.to_numpy(backend.sample_uniform((100, 100))) 22 | assert np.all(np.logical_and(x >= 0.0, x <= 1.0)) 23 | 24 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS) 25 | def test_sample_gaussian(backend): 26 | x = backend.to_numpy(backend.sample_gaussian((100, 100))) 27 | assert np.isclose(x.mean(), 0.0, atol=1e-1) 28 | assert np.isclose(x.std(), 1.0, atol=1e-1) 29 | 30 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS) 31 | def test_size(backend): 32 | x = np.arange(10) 33 | x = backend.from_numpy(x) 34 | n = backend.size(x) 35 | assert n == 10 36 | 37 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS) 38 | def test_concatenate(backend): 39 | x = np.arange(10) 40 | x = backend.from_numpy(x) 41 | xs = [backend.expand_dims(x, 0)] * 10 42 | xs = backend.concatenate(xs, 0) 43 | xs = backend.to_numpy(xs) 44 | 45 | for i in range(10): 46 | assert np.all(np.isclose(xs[:, i], i)) 47 | 48 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS) 49 | def test_expand_dims(backend): 50 | x = np.arange(10) 51 | x = backend.from_numpy(x) 52 | y = backend.expand_dims(x, 0) 53 | 54 | assert len(y.shape) == len(x.shape) + 1 55 | assert y.shape[0] == 1 56 | 57 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS) 58 | def test_exp(backend): 59 | x = np.arange(10).astype(np.float32) 60 | x = backend.from_numpy(x) 61 | y = backend.exp(x) 62 | 63 | assert np.all(np.isclose(backend.to_numpy(y), np.exp(x))) 64 | 65 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS) 66 | def test_log(backend): 67 | x = np.arange(10).astype(np.float32) 68 | x = backend.from_numpy(x) 69 | y = backend.log(x) 70 | 71 | assert np.all(np.isclose(backend.to_numpy(y), np.log(x))) 72 | 73 | 74 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS) 75 | def test_pad_zeros(backend): 76 | x = np.arange(10) 77 | x = backend.from_numpy(x) 78 | xs = [backend.expand_dims(x, 0)] * 10 79 | xs = backend.concatenate(xs, 0) 80 | 81 | xs = backend.pad_zeros(xs, 2, 0) 82 | xs = backend.pad_zeros(xs, 1, 1) 83 | xs = backend.to_numpy(xs) 84 | 85 | 86 | assert np.all(xs[:2, :] == 0.0) 87 | assert np.all(xs[-2:, :] == 0.0) 88 | 89 | assert np.all(xs[:, :1] == 0.0) 90 | assert np.all(xs[:, -1:] == 0.0) 91 | 92 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS) 93 | def test_pad_zeros_left(backend): 94 | x = np.arange(10) 95 | x = backend.from_numpy(x) 96 | xs = [backend.expand_dims(x, 0)] * 10 97 | xs = backend.concatenate(xs, 0) 98 | 99 | xs = backend.pad_zeros_left(xs, 2, 0) 100 | xs = backend.pad_zeros_left(xs, 1, 1) 101 | xs = backend.to_numpy(xs) 102 | 103 | assert np.all(xs[:2, :] == 0.0) 104 | assert not np.all(xs[-2:, :] == 0.0) 105 | 106 | assert np.all(xs[:, :1] == 0.0) 107 | assert not np.all(xs[:, -1:] == 0.0) 108 | 109 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS) 110 | def test_arange(backend): 111 | x = backend.arange(2, 10, 2) 112 | x = backend.to_numpy(x) 113 | 114 | assert np.all(np.isclose(x, np.arange(2, 10, 2))) 115 | 116 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS) 117 | def test_reshape(backend): 118 | x = backend.arange(2, 10, 2) 119 | x = backend.reshape(x, (-1, 1)) 120 | x = backend.to_numpy(x) 121 | assert np.all(np.isclose(x, np.arange(2, 10, 2).reshape(-1, 1))) 122 | 123 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS) 124 | def test_trapz(backend): 125 | 126 | x = backend.arange(2, 10, 2) 127 | y = backend.ones(like=x) 128 | 129 | integral = backend.to_numpy(backend.trapz(y, x, 0)) 130 | assert np.all(np.isclose(integral, 6)) 131 | 132 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS) 133 | def test_cumsum(backend): 134 | 135 | x = backend.arange(2, 10, 2) 136 | cumsum = backend.to_numpy(backend.cumsum(x, 0)) 137 | assert np.all(np.isclose(cumsum, np.cumsum(x))) 138 | 139 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS) 140 | def test_zeros(backend): 141 | 142 | zeros_1 = backend.zeros((2, 2, 2)) 143 | zeros_2 = backend.zeros(like=zeros_1) 144 | zeros_1 = backend.to_numpy(zeros_1) 145 | zeros_2 = backend.to_numpy(zeros_2) 146 | 147 | assert np.all(np.isclose(zeros_1, 0.0)) 148 | assert np.all(np.isclose(zeros_1, 149 | zeros_2)) 150 | 151 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS) 152 | def test_ones(backend): 153 | 154 | ones_1 = backend.ones((2, 2, 2)) 155 | ones_2 = backend.ones(like=ones_1) 156 | ones_1 = backend.to_numpy(ones_1) 157 | ones_2 = backend.to_numpy(ones_2) 158 | 159 | assert np.all(np.isclose(ones_1, 1.0)) 160 | assert np.all(np.isclose(ones_1, 161 | ones_2)) 162 | 163 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS) 164 | def test_softmax(backend): 165 | 166 | x = backend.sample_uniform((10, 10)) 167 | x_sf = backend.softmax(x, 1) 168 | x_np = backend.to_numpy(x) 169 | x_sf_np = sp.special.softmax(x_np, 1) 170 | 171 | assert np.all(np.isclose(x_sf, x_sf_np)) 172 | 173 | @pytest.mark.parametrize("backend", TENSOR_BACKENDS) 174 | def test_where(backend): 175 | x = backend.ones((10, 10)) 176 | y = backend.zeros((10, 10)) 177 | z = backend.where(x > 0, y, x) 178 | z_np = backend.to_numpy(z) 179 | assert np.all(np.isclose(z_np, 0.0)) 180 | 181 | x = backend.ones((10, 10)) 182 | y = backend.zeros((10, 10)) 183 | z = backend.where(x > 1, y, x) 184 | z_np = backend.to_numpy(z) 185 | assert np.all(np.isclose(z_np, 1.0)) 186 | -------------------------------------------------------------------------------- /test/test_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import xarray as xr 3 | 4 | from quantnn.utils import (apply, 5 | serialize_dataset, 6 | deserialize_dataset) 7 | 8 | def test_apply(): 9 | 10 | f1 = lambda x: 2 * x 11 | f2 = lambda x, y: x + y 12 | 13 | d = {i: i for i in range(5)} 14 | 15 | a = apply(f1, 1) 16 | b = apply(f2, 1, 1) 17 | assert a == a 18 | assert b == b 19 | 20 | d_a = apply(f1, d) 21 | d_b = apply(f2, d, d) 22 | for k in d: 23 | assert k == d_a[k] // 2 24 | assert k == d_b[k] // 2 25 | 26 | def test_serialization(): 27 | """ 28 | Make sure that serialization of xarray datasets works. 29 | """ 30 | dataset_ref = xr.Dataset({"x": (("a", "b"), np.ones((10, 10)))}) 31 | b = serialize_dataset(dataset_ref) 32 | dataset = deserialize_dataset(b) 33 | 34 | assert np.all(np.isclose(dataset["x"].data, 35 | dataset_ref["x"].data)) 36 | 37 | 38 | --------------------------------------------------------------------------------