├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yml ├── .travis.yml ├── LICENSE.txt ├── MANIFEST.in ├── README.md ├── contributing.md ├── deepsphere ├── __init__.py ├── data │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ └── dataset.py │ └── transforms │ │ ├── __init__.py │ │ └── transforms.py ├── layers │ ├── __init__.py │ ├── chebyshev.py │ └── samplings │ │ ├── __init__.py │ │ ├── equiangular_pool_unpool.py │ │ ├── healpix_pool_unpool.py │ │ └── icosahedron_pool_unpool.py ├── models │ ├── __init__.py │ └── spherical_unet │ │ ├── __init__.py │ │ ├── decoder.py │ │ ├── encoder.py │ │ ├── unet_model.py │ │ └── utils.py ├── tests │ ├── __init__.py │ └── test_foo.py └── utils │ ├── __init__.py │ ├── initialization.py │ ├── laplacian_funcs.py │ ├── parser.py │ ├── samplings.py │ └── stats_extractor.py ├── docs ├── Makefile └── source │ ├── conf.py │ ├── deepsphere.data.datasets.rst │ ├── deepsphere.data.rst │ ├── deepsphere.data.transforms.rst │ ├── deepsphere.layers.rst │ ├── deepsphere.layers.samplings.rst │ ├── deepsphere.models.rst │ ├── deepsphere.models.spherical_unet.rst │ ├── deepsphere.rst │ ├── deepsphere.tests.rst │ ├── deepsphere.utils.rst │ ├── index.rst │ ├── modules.rst │ ├── scripts.rst │ ├── scripts.temporality.rst │ └── setup.rst ├── images ├── AR_TC_image.png ├── Example_3D_Icosahedronmovie_globe.gif ├── equations │ ├── L_eq.gif │ ├── Lc.gif │ ├── Lc_eq.gif │ ├── T0.gif │ ├── T1.gif │ ├── Tm.gif │ ├── Tm_recursive.gif │ ├── poly_eq.gif │ ├── xhat.gif │ └── y_eq.gif ├── interactiveplot_epoch28.png └── interactiveplot_epoch4.png ├── notebooks ├── demo_visualizations.ipynb └── interactive_visualization.ipynb ├── pyproject.toml ├── requirements-tests.txt ├── requirements.txt ├── scripts ├── __init__.py ├── config.example.yml ├── means.npy ├── run_ar_tc.py ├── stds.npy └── temporality │ ├── __init__.py │ ├── config.yml │ └── run_ar_tc.py ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/python,jupyternotebooks 3 | # Edit at https://www.gitignore.io/?templates=python,jupyternotebooks 4 | 5 | ### JupyterNotebooks ### 6 | # gitignore template for Jupyter Notebooks 7 | # website: http://jupyter.org/ 8 | 9 | .ipynb_checkpoints 10 | */.ipynb_checkpoints/* 11 | 12 | # IPython 13 | profile_default/ 14 | ipython_config.py 15 | 16 | # Remove previous ipynb_checkpoints 17 | # git rm -r .ipynb_checkpoints/ 18 | 19 | ### Python ### 20 | # Byte-compiled / optimized / DLL files 21 | __pycache__/ 22 | *.py[cod] 23 | *$py.class 24 | 25 | # C extensions 26 | *.so 27 | 28 | # Distribution / packaging 29 | .Python 30 | build/ 31 | develop-eggs/ 32 | dist/ 33 | downloads/ 34 | eggs/ 35 | .eggs/ 36 | lib/ 37 | lib64/ 38 | parts/ 39 | sdist/ 40 | var/ 41 | wheels/ 42 | pip-wheel-metadata/ 43 | share/python-wheels/ 44 | *.egg-info/ 45 | .installed.cfg 46 | *.egg 47 | MANIFEST 48 | 49 | # PyInstaller 50 | # Usually these files are written by a python script from a template 51 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 52 | *.manifest 53 | *.spec 54 | 55 | # Installer logs 56 | pip-log.txt 57 | pip-delete-this-directory.txt 58 | 59 | # Unit test / coverage reports 60 | htmlcov/ 61 | .tox/ 62 | .nox/ 63 | .coverage 64 | .coverage.* 65 | .cache 66 | nosetests.xml 67 | coverage.xml 68 | *.cover 69 | .hypothesis/ 70 | .pytest_cache/ 71 | 72 | # Translations 73 | *.mo 74 | *.pot 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | docs/build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # celery beat schedule file 97 | celerybeat-schedule 98 | 99 | # SageMath parsed files 100 | *.sage.py 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyproject 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | # Mr Developer 110 | .mr.developer.cfg 111 | .project 112 | .pydevproject 113 | 114 | # mkdocs documentation 115 | /site 116 | 117 | # mypy 118 | .mypy_cache/ 119 | .dmypy.json 120 | dmypy.json 121 | 122 | # Pyre type checker 123 | .pyre/ 124 | 125 | # End of https://www.gitignore.io/api/python,jupyternotebooks 126 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | # force all unspecified python hooks to run python3 3 | python: python3.7 4 | 5 | repos: 6 | - repo: https://github.com/pre-commit/pre-commit-hooks 7 | rev: v2.0.0 8 | hooks: 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | 12 | - repo: https://github.com/psf/black 13 | rev: 19.3b0 14 | hooks: 15 | - id: black 16 | types: [python] 17 | 18 | - repo: https://github.com/timothycrosley/isort 19 | rev: 4.3.21 20 | hooks: 21 | - id: isort 22 | types: [python] 23 | additional_dependencies: [toml] 24 | 25 | - repo: local 26 | hooks: 27 | - id: unittest 28 | name: unittest 29 | entry: python -m unittest discover 30 | language: system 31 | pass_filenames: false 32 | - id: linting 33 | name: linting 34 | entry: pylint-fail-under --rcfile=setup.cfg --fail_under 10 deepsphere/ 35 | language: system 36 | pass_filenames: false 37 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Build documentation in the docs/ directory with Sphinx 9 | sphinx: 10 | configuration: docs/source/conf.py 11 | 12 | # Build documentation with MkDocs 13 | #mkdocs: 14 | # configuration: mkdocs.yml 15 | 16 | # Optionally build your docs in additional formats such as PDF and ePub 17 | formats: all 18 | 19 | # Optionally set the version of Python and requirements required to build your docs 20 | python: 21 | version: 3.7 22 | install: 23 | - requirements: requirements-tests.txt 24 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: '3.7' 3 | install: 4 | - pip install "git+https://github.com/Droxef/pygsp.git@6b216395beae25bf062d13fbf9abc251eeb5bbff#egg=PyGSP" 5 | - pip install -e .[tests] 6 | script: 7 | - black --check deepsphere/ scripts/ 8 | - python -m unittest discover -v 9 | - isort -rc --check-only deepsphere/ scripts/ 10 | - pylint-fail-under --rcfile=setup.cfg --fail_under 9.5 deepsphere/ scripts/ 11 | deploy: 12 | provider: pypi 13 | user: __token__ 14 | password: 15 | secure: TxN4Hm2Cax0SrRzKeE2gT/Tnwqg8mIgGROkvVuyKiW15mw8YJisNNofYU5sFvUstIT6xuPYusjQ3mRIBtJuRlHgIAn4F59yBaaeGvj6Vivujo+Bj9J7jslHVgYwQ2Tepib/Gx+I6hK+q+eSSBBK07GYQOJot3/BNiVxPh0/adgsXFSkrzjuPdpy6mj3YA7yynKcVlpjSbz8ydGBBtmU0lLAdEUQqbmAET/lVJHg/zTq89iKUOjXCp9o9WH2/uW5XtBxbvz3YGqiG3P9WbeI0PXBQMnEldY6OiINJ0+ZlVQk4CUCn2Yk/mEb9PNU32lh8RrCspZmTngDB72vpqNKv1ek94uksnn/hWPAO7rVm7uDCUOURihxjC21xtu3qg2RXU39vvlv+of7BQeLEOtJu9DcGFXWO1iW0AxkTaly+s/Df9z5DwUA41xFIe7Jvy9GEfA/a86w1ncqMpeHGi5VQa1KbshFJjf7aKytXq83OmS3h1U/ZoPzYvuvsqKIuNq1F2T0zittzZeYkeSV0AiKPBcLdTKUF1aPCUmQXgjqugAmDxVgg475oP11FRAvnZjH9yp05kDBANToPShXV9jJr8uHodzFiXwFBteypEuv2OO0Ut1iT60vkomp2F/RevrQ5kcc5R1XdgrWYnaHF2Cpl20whkC+uB80zgn+WAvoj/N4= 16 | skip_cleanup: true 17 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Laure Vancauwenberghe, Michael Allemann, Yoann Ponti, Basile Chatillon, Lionel Martin, Michaël Defferrard 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements*.txt 2 | recursive-include scripts/ 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepSphere: a graph-based spherical CNN 2 | 3 | [![Documentation Status](https://readthedocs.org/projects/deepsphere/badge/?version=latest)](https://deepsphere.readthedocs.io/en/latest/?badge=latest) 4 | 5 | This is a PyTorch implementation of DeepSphere. 6 | 7 | - [Resources](#resources) 8 | - [Data](#data) 9 | - [Quick Start](#quick-start) 10 | - [Mathematical Background](#mathematical-background) 11 | - [U-Net](#unet) 12 | - [Temporality](#communication) 13 | - [Metric](#metric) 14 | - [Tools](#tools) 15 | - [License & co](#license--co) 16 | 17 | ## Resources 18 | 19 | Code: 20 | * [deepsphere-cosmo-tf1](https://github.com/deepsphere/deepsphere-cosmo-tf1): original repository, implemented in TensorFlow v1. \ 21 | Use to reproduce [`arxiv:1810.12186`][paper_cosmo]. 22 | * [deepsphere-cosmo-tf2](https://github.com/deepsphere/deepsphere-cosmo-tf2): reimplementation in TFv2. \ 23 | Use for new developments in TensorFlow targeting HEALPix, including generative models. 24 | * [deepsphere-tf1](https://github.com/deepsphere/deepsphere-tf1): extended to other samplings and experiments, implemented in TFv1. \ 25 | Use to reproduce [`arxiv:2012.15000`][paper_iclr]. 26 | * [deepsphere-pytorch](https://github.com/deepsphere/deepsphere-pytorch): reimplementation in PyTorch. \ 27 | Use for new developments in PyTorch. 28 | 29 | Papers: 30 | * DeepSphere: Efficient spherical CNN with HEALPix sampling for cosmological applications, 2018.\ 31 | [[paper][paper_cosmo], [blog](https://datascience.ch/deepsphere-a-neural-network-architecture-for-spherical-data), [slides](https://doi.org/10.5281/zenodo.3243380)] 32 | * DeepSphere: towards an equivariant graph-based spherical CNN, 2019.\ 33 | [[paper][paper_rlgm], [poster](https://doi.org/10.5281/zenodo.2839355)] 34 | * DeepSphere: a graph-based spherical CNN, 2020.\ 35 | [[paper][paper_iclr], [slides](https://doi.org/10.5281/zenodo.3777976), [video](https://youtu.be/NC_XLbbCevk)] 36 | 37 | [paper_cosmo]: https://arxiv.org/abs/1810.12186 38 | [paper_rlgm]: https://arxiv.org/abs/1904.05146 39 | [paper_iclr]: https://arxiv.org/abs/2012.15000 40 | 41 | ## Data 42 | 43 | The data used for the experiments contains a [downsampled](http://island.me.berkeley.edu/ugscnn/data/climate_sphere_l5.zip "This link lets you download the downsampled dataset (~30 Gb). This can also be done using the script, described further down.") 44 | snapshot of the [Community Atmospheric Model v5 (CAM5)](https://portal.nersc.gov/project/dasrepo/deepcam/segm_h5_v3_reformat/gb_data_readme "This link takes you to the page where the full dataset can be downloaded and where more information is provided concerning the data.") 45 | simulation. The data is based on the paper [UGSCNN (Jiang et al., 2019)](https://openreview.net/pdf?id=Bkl-43C9FQ "This link takes you to the paper: Spherical CNNs on Unstructured Grids."). The simulation can be thought of as a 16 channel "image", where each channel corresponds to a climate related measurement. 46 | The task is to learn how to infer the correct class for each pixel given the 16 channels. Each pixel is labelled either as background, as being part of a tropical cyclone or as being part of an atmospheric river. 47 | 48 | ![alt text](images/AR_TC_image.png "Background class is visualized in red, tropical cyclones in green and atmospheric rivers in blue.") 49 | 50 | ## Quick Start 51 | 52 | In order to reproduce the results obtained, it is necessary to install the PyGSP branch containing the graph processing for equiangular, icosahedron, and healpix samplings. In future versions, PyGSP will be in the requirements. Subsequently, please refer yourself to the [Pytorch Getting Started information page](https://pytorch.org/get-started/locally/) to run the correct `conda install` command corresponding to your operating system, python version and cuda version. 53 | Once those requirements are met, you can install the `deepsphere` package in your environment. 54 | 55 | Our recommendation for a linux based machine is: 56 | 57 | ``` 58 | conda create --name deepsphere python=3.7 59 | 60 | source activate deepsphere 61 | 62 | pip install git+https://github.com/epfl-lts2/pygsp.git@39a0665f637191152605911cf209fc16a36e5ae9#egg=PyGSP 63 | 64 | conda install pytorch=1.3.1 torchvision=0.4.2 cudatoolkit=10.0 -c pytorch 65 | 66 | pip install git+https://github.com/deepsphere/deepsphere-pytorch 67 | ``` 68 | 69 | The package offers the experiment parameters stored in a [Yaml config file](./scripts/config.example.yml), which can be used by running a [script](./scripts/run_ar_tc_ignite.py) from the command line. 70 | 71 | A special note should be made for the pytorch computation device. If nothing is stipulated in the command line, the device is set to CPU. To set the device to GPU (cuda) one can indicate `—gpu` in the command line, with or without the desired GPU device IDs (e.g. `--gpu 1 2`, if the model is supposed to run on the GPU 1 and 2). 72 | 73 | To visualize any icosahedron or equiangular data the package provides a demonstration [Jupyter notebook](./notebooks/demo_visualizations.ipynb) for data in 2D or 3D. 74 | 75 | Using the predefined parameters you can train and validate the model using the following command: 76 | ``` 77 | python run_ar_tc.py --config-file config.example.yml --gpu 78 | ``` 79 | 80 | If you don't have the data yet, please create the folder `/data/climate/` (or change the file location in the yaml file) and add `download True` to the command. 81 | 82 | ## Mathematical Background 83 | 84 | The Deepsphere package uses the manifold of the sphere to perform the convolutions on the data. Underlying the application of convolutional networks to spherical data through a graph-based discretization lies the field of Graph Signal Processing (GSP). Graph Signal Processing is a field trying to define classical spectral methods on graphs, similarly to the theories existing in the time domain. 85 | 86 | This section attempts to give the key concepts of the sphere manifold in the form of a graph, and how manipulating the data in the eigenvector space allows an optimal convolution operation on the sphere. For an in-depth introduction to the topic, see for example [Graph Signal Processing: Overview, Challenges and Applications (2017)](https://arxiv.org/abs/1712.00468) or [The Emerging Field of Signal Processing on Graphs (2012)](https://arxiv.org/abs/1211.0053). For simpler introductions to the matter, you may refer to [Chapter 1.2 of J. Paratte's PhD Thesis](https://infoscience.epfl.ch/record/231710?ln=en) or [Chapter 2.1 of L. Martin's PhD Thesis](https://infoscience.epfl.ch/record/234372?ln=en). 87 | For an introduction to graph convolutions in the context of neural networks see for example [Convolutional neural networks on graphs with fast localized spectral filtering (2016)](http://papers.nips.cc/paper/6081-convolutional-neural-networks-on-graphs-with-fast-localized-spectral-filtering). 88 | 89 | Following GSP paradigms, the convolution operator defined on graphs can be computed simply with a multiplication in the correct domain, just like classical signal processing. Indeed, in traditional signal processing, filtering (i.e., convolution) can be carried out by a pointwise multiplication as long as the signal is transformed to the Fourier domain. Thus, given a graph signal, we define its graph Fourier transform as the projection of the signal onto the set of eigenvectors of the graph Laplacian: 90 | 91 | ![alt text](images/equations/xhat.gif), 92 | 93 | where *U* and *Λ* are the results of the eigendecomposition of the Laplacian, i.e. ![alt text](images/equations/L_eq.gif) . 94 | 95 | To bring the data to the spectral domain several Laplacians could be used. We decide here that we select the combinatorial Laplacian,![alt text](images/equations/Lc.gif), which is simply defined as: 96 | 97 | ![alt text](images/equations/Lc_eq.gif), 98 | 99 | where *W* is the weighted adjacency matrix of the graph and *D* is the diagonal matrix composed of the degrees, the sum of the weights of all the edges for each node, on its diagonal. 100 | 101 | Filtering, the convolution operator, is defined to this end via a graph filter called *g*, a continuous function directly in the graph Fourier domain, enabling the direct multiplication. Based on the definition of the graph Fourier domain, we can then rewrite the graph filtering equation as a vector-matrix operation in the original domain (the graph domain): 102 | 103 | ![alt text](images/equations/y_eq.gif). 104 | 105 | However, the filtering equation defined above involves the knowledge of the full set of eigenvectors U. Hence it implies the diagonalization of the Laplacian L which is extremely costly for large graphs. To circumvent this problem, one can represent the filter g as a polynomial approximation: the n-degree Chebyshev polynomials. The relation between the graph filter *g(L)*, the graph signal *x*, and the Chebyshev polynomials lies in the approximation: 106 | 107 | ![alt text](images/equations/poly_eq.gif), 108 | 109 | where *c_m* are the coefficients of the approximation and describe entirely the shape of the graph filter *g*. 110 | 111 | Since the Chebyshev polynomials of the first-kind are defined with the recurrence relation, the computation of the approximation is very efficient compared to diagonalization of L since it simply requires the computation of: 112 | 113 | ![alt text](images/equations/Tm_recursive.gif), 114 | 115 | where ![alt text](images/equations/T0.gif) and ![alt text](images/equations/T1.gif). 116 | 117 | Thus, learning the weights of the polynomial approximations makes it possible to learn generic graph filters. The convolution on a spherical graph comes down to backpropagating to tune the weights of the Chebyshev polynomials. 118 | 119 | ## Unet 120 | 121 | The architecture used for the deep learning model is a classic [U-Net](https://arxiv.org/pdf/1505.04597.pdf). 122 | The poolings and unpoolings used correspond to three types of possible spherical samplings: [icosahedron](https://github.com/deepsphere/deepsphere-pytorch/tree/master/deepsphere/layers/samplings/icosahedron_pool_unpool.py), [healpix](https://github.com/deepsphere/deepsphere-pytorch/tree/master/deepsphere/layers/samplings/healpix_pool_unpool.py) and [equiangular](https://github.com/deepsphere/deepsphere-pytorch/tree/master/deepsphere/layers/samplings/equiangular_pool_unpool.py). 123 | 124 | ## Temporality 125 | 126 | Beyond reproducing in pytorch the ARTC experiment, we introduced a new dimension to our package: temporality. We did so following two approaches. First, we combined the U-Net with a recurrent neural network ([LSTM](https://en.wikipedia.org/wiki/Long_short-term_memory)) as presented in [Recurrent Fully Convolutional Network for Video Segmentation](https://arxiv.org/pdf/1606.00487v2.pdf). 127 | Secondly we augmented the feature space of the U-Net, thus taking more than one sample as an input. 128 | 129 | ## Metric 130 | 131 | The metric used to evaluate the performance of the model is the mean of the [average precision](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html) of the classes "Atmospheric River" and "Tropical Cyclone". Only around 2% of the data is labelled as an atmospheric river and 0.1% of the data is labelled as a tropical cyclone. For such unbalanced datasets the average precision is an appropriate metric. 132 | The average precision metric allows to circumvent to some extent the trade-off between precision and recall performance. Average precision computes the average precision value for recall values over the interval 0 to 1. In other words it measures the area under the Precision-Recall Curve in a piecewise constant discretization manner. For the model, using average precision over each class/label type gives a sense of how well the model's detection is performed in the case of an unbalanced dataset. 133 | 134 | ## Tools 135 | 136 | **Ignite.** 137 | Ignite provides a clean training-valdiation-testing loop. Through ignite, engines corresponding to a trainer, validator and tester for the model can be created. Properties of these engines can be set using Handlers. For example, the trainer can have a handler to print certain information during training, and the validator can have a handler to log the metrics or a handler to change the learning rate based on the metrics of the epoch. 138 | 139 | **Tensorboard.** 140 | Tensorboard allows to log metrics, training loss and learning rate rhythms. In the script, one can create a Summary Writer and attach to this object diverse saving options. 141 | 142 | **Visualizations.** 143 | Visualizations are possible in 2D and 3D. The 2D representation is a flattened version of the sphere with a 2D projection of the sampling used (at the moment, this is possible for the icosahedron and equiangular samplings). The 3D gif rendering allows to represent the lables on a turning world sphere. Finally, an interactive plotting notebook is also presented as an inspiration for interactive plots. It allows to plot the metrics at a point in training (for a certain epoch), alongside the predicted labels plotted in 2D. This prediction is opposed to the plot of the ground truths in 2D. 144 | 145 | ## License & co 146 | 147 | The content of this repository is released under the terms of the [MIT license](LICENSE.txt). 148 | 149 | The code, based on the [TensorFlow implementation of DeepSphere](https://github.com/deepsphere/deepsphere-tf1), was mostly developed by [Laure Vancauwenberghe](https://www.linkedin.com/in/laure-vancauwenberghe) and [Michael Allemann](https://www.linkedin.com/in/michael-allemann) while they were interning at [Arcanite Solutions](https://arcanite.ch) under the supervision of Yoann Ponti, Basile Chatillon, Julien Eberle, Lionel Martin, Johan Paratte, [Michaël Defferrard](https://deff.ch). 150 | 151 | Please consider citing our papers if you find this repository useful. 152 | 153 | ``` 154 | @inproceedings{deepsphere_iclr, 155 | title = {{DeepSphere}: a graph-based spherical {CNN}}, 156 | author = {Defferrard, Michaël and Milani, Martino and Gusset, Frédérick and Perraudin, Nathanaël}, 157 | booktitle = {International Conference on Learning Representations (ICLR)}, 158 | year = {2020}, 159 | url = {https://openreview.net/forum?id=B1e3OlStPB}, 160 | } 161 | ``` 162 | 163 | ``` 164 | @inproceedings{deepsphere_rlgm, 165 | title = {{DeepSphere}: towards an equivariant graph-based spherical {CNN}}, 166 | author = {Defferrard, Micha\"el and Perraudin, Nathana\"el and Kacprzak, Tomasz and Sgier, Raphael}, 167 | booktitle = {ICLR Workshop on Representation Learning on Graphs and Manifolds}, 168 | year = {2019}, 169 | archiveprefix = {arXiv}, 170 | eprint = {1904.05146}, 171 | url = {https://arxiv.org/abs/1904.05146}, 172 | } 173 | ``` 174 | -------------------------------------------------------------------------------- /contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | ## Installing libraries 4 | 5 | Several external tools are used in order to properly format the code. Install them with: 6 | 7 | ```bash 8 | pip install -r requirements-tests.txt 9 | ``` 10 | 11 | ## Black 12 | 13 | [Black](https://pypi.org/project/black/) is a code formatter made to be uncompromising. 14 | No time is spent in order to determine how things should be formatted since there is no choice to make. 15 | 16 | To run black on a folder or a file, simply run: 17 | 18 | ```bash 19 | black 20 | ``` 21 | 22 | The configuration of black can be found in the file `pyproject.toml`. 23 | 24 | ## Isort 25 | 26 | [Isort](https://pypi.org/project/isort/) is an external tool that sorts and formats all the imports of a file. 27 | 28 | The command to format a file is: 29 | 30 | ```bash 31 | isort 32 | ``` 33 | 34 | Or to run over a folder: 35 | 36 | ```bash 37 | isort -rc 38 | ``` 39 | 40 | The configuration is also stored in the file `pyproject.toml`. 41 | 42 | ## pylint 43 | 44 | [Pylint](https://www.pylint.org/) is another external tool that can check the coding standard such as line length, common mistakes, errors, etc. 45 | It is used in the CI in combination with [pylint-fail-under](https://pypi.org/project/pylint-fail-under/), a wrapper. 46 | 47 | The command to check the output over a folder or a file of pylint is: 48 | 49 | ```bash 50 | pylint --rcfile=setup.cfg 51 | ``` 52 | 53 | The configuration is stored in the file `setup.cfg`. 54 | 55 | ## Pre Commit hook 56 | 57 | [pre-commit](https://pre-commit.com/) is a tool that will apply some hooks before every commit you do. 58 | This will be used in order to run the different formatting tools. In order to be able to use the pre-commit hook run: 59 | 60 | ```bash 61 | pre-commit install 62 | ``` 63 | 64 | The pre-commit hook does the following: 65 | 1. Adding blank line at the end of file if it's missing 66 | 2. Removing trailing white spaces 67 | 3. Applying Black 68 | 4. Applying Isort 69 | 5. Running the tests 70 | 6. Running Pylint with a threshold score of 10 71 | 72 | If any of these tasks fails, it will abort the commit and in case of tasks that format the code, it will change the files inplace. 73 | The configuration of pre-commit can be found in the file .pre-commit-config.yaml 74 | 75 | ## Documentation 76 | 77 | In order to setup some documentation of the project, [Sphinx](http://www.sphinx-doc.org/en/master/) is used. 78 | It's a library that will generate a nice documentation from the different comments from the code. 79 | 80 | In order for it to work, the docstring should be of the google style format (some examples can be found [here](http://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings) - from the google styleguide and [here](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) - the plugin used by Sphinx in order to parse google docstring) 81 | To generate the documentation, follow the different steps: 82 | 83 | 1. If you added new .py files in the project, go inside the folder `docs/` (Important) and run the command `sphinx-apidoc -f -o source ../` 84 | 2. You can customize the files by editing the .rst files in the folder `docs/source/` 85 | 3. Generate the doc with `make html` while being inside the folder `docs/` 86 | 4. Open the file `docs/bild/html/index.html` in a navigator to see the documentation 87 | -------------------------------------------------------------------------------- /deepsphere/__init__.py: -------------------------------------------------------------------------------- 1 | """DeepSphere Base Documentation doc 2 | """ 3 | 4 | import importlib 5 | import sys 6 | 7 | __version__ = "0.2.1" 8 | 9 | 10 | def import_modules(names, src, dst): 11 | """Import modules in package.""" 12 | for name in names: 13 | module = importlib.import_module("{}.{}".format(src, name)) 14 | setattr(sys.modules[dst], name, module) 15 | 16 | 17 | __all__ = [] 18 | 19 | import_modules(__all__[::-1], "deepsphere", "deepsphere") 20 | -------------------------------------------------------------------------------- /deepsphere/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepsphere/deepsphere-pytorch/43f03f1bef146d256a7e1c1e69df6712d089b9e5/deepsphere/data/__init__.py -------------------------------------------------------------------------------- /deepsphere/data/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepsphere/deepsphere-pytorch/43f03f1bef146d256a7e1c1e69df6712d089b9e5/deepsphere/data/datasets/__init__.py -------------------------------------------------------------------------------- /deepsphere/data/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | """Datasets for reduced atmospheric river and tropical cyclone detection dataset. 2 | """ 3 | 4 | 5 | import itertools 6 | import os 7 | 8 | import numpy as np 9 | from torch.utils.data import Dataset 10 | from torchvision.datasets.utils import download_and_extract_archive 11 | 12 | # pylint: disable=C0330 13 | 14 | 15 | class ARTCDataset(Dataset): 16 | """Dataset for reduced atmospheric river and tropical cyclone dataset. 17 | """ 18 | 19 | resource = "http://island.me.berkeley.edu/ugscnn/data/climate_sphere_l5.zip" 20 | 21 | def __init__(self, path, indices=None, transform_data=None, transform_labels=None, download=False): 22 | """Initialization. 23 | 24 | Args: 25 | path (str): Path to the data or desired place the data will be downloaded to. 26 | indices (list): List of indices representing the subset of the data used for the current dataset. 27 | transform_data (:obj:`transform.Compose`): List of torchvision transforms for the data. 28 | transform_labels (:obj:`transform.Compose`): List of torchvision transforms for the labels. 29 | download (bool): Flag to decide if data should be downloaded or not. 30 | """ 31 | self.path = path 32 | if download: 33 | self.download() 34 | self.files = indices if indices is not None else os.listdir(self.path) 35 | self.transform_data = transform_data 36 | self.transform_labels = transform_labels 37 | 38 | @property 39 | def indices(self): 40 | """Get files. 41 | 42 | Returns: 43 | list: List of strings, which represent the files contained in the dataset. 44 | """ 45 | return self.files 46 | 47 | def __len__(self): 48 | """Get length of dataset. 49 | 50 | Returns: 51 | int: Number of files contained in the dataset. 52 | """ 53 | return len(self.files) 54 | 55 | def __getitem__(self, idx): 56 | """Get an item from the dataset. 57 | 58 | Args: 59 | idx (int): The index of the desired datapoint. 60 | 61 | Returns: 62 | obj, obj: The data and labels corresponding to the desired index. The type depends on the applied transforms. 63 | """ 64 | item = np.load(os.path.join(self.path, self.files[idx])) 65 | data, labels = item["data"], item["labels"] 66 | if self.transform_data: 67 | data = self.transform_data(data) 68 | if self.transform_labels: 69 | labels = self.transform_labels(labels) 70 | return data, labels 71 | 72 | def get_runs(self, runs): 73 | """Get datapoints corresponding to specific runs. 74 | 75 | Args: 76 | runs (list): List of desired runs. 77 | 78 | Returns: 79 | list: List of strings, which represents the files in the dataset, which belong to one of the desired runs. 80 | """ 81 | files = [] 82 | for file in self.files: 83 | for i in runs: 84 | if file.endswith("{}-mesh.npz".format(i)): 85 | files.append(file) 86 | return files 87 | 88 | def download(self): 89 | """Download the dataset if it doesn't already exist. 90 | """ 91 | if not self.check_exists(): 92 | download_and_extract_archive(self.resource, download_root=os.path.split(self.path)[0]) 93 | else: 94 | print("Data already exists") 95 | 96 | def check_exists(self): 97 | """Check if dataset already exists. 98 | """ 99 | return os.path.exists(self.path) 100 | 101 | 102 | class ARTCTemporaldataset(ARTCDataset): 103 | """Dataset for reduced ARTC dataset with temporality functionality. 104 | """ 105 | 106 | def __init__( 107 | self, 108 | path, 109 | sequence_length, 110 | prediction_shift=0, 111 | indices=None, 112 | transform_image=None, 113 | transform_labels=None, 114 | transform_sample=None, 115 | download=False, 116 | ): 117 | """Initialization. Sort by run and sort each run by date and time. The samples at the tendo of each run are invalid and are removed. 118 | The list is then flattened. Self.allowed contains the list of all valid indices. Self.files contains all indices for the construction 119 | of samples. 120 | 121 | Args: 122 | path (str): Path to the data or desired place the data will be downloaded to. 123 | indices (list): List of indices representing the subset of the data used for the current dataset. 124 | transform_data (:obj:`transform.Compose`): List of torchvision transforms for the data. 125 | transform_labels (:obj:`transform.Compose`): List of torchvision transforms for the labels. 126 | download (bool): Flag to decide if data should be downloaded or not. 127 | temporality_length (int): The number of images used per sample. 128 | """ 129 | super().__init__(path, indices, None, None, download) 130 | self.transform_image = transform_image 131 | self.transform_labels = transform_labels 132 | self.transform_sample = transform_sample 133 | self.sequence_length = sequence_length 134 | self.prediction_shift = prediction_shift 135 | sorted_by_run_and_date = [sorted(self.get_runs([i])) for i in [1, 2, 3, 4, 6]] 136 | self.allowed = list(itertools.chain(*[run[: -(self.sequence_length + self.prediction_shift)] for run in sorted_by_run_and_date])) 137 | self.files = list(itertools.chain(*sorted_by_run_and_date)) 138 | 139 | def __len__(self): 140 | """Get length of dataset. 141 | 142 | Returns: 143 | int: Number of files contained in the dataset. 144 | """ 145 | return len(self.allowed) 146 | 147 | def __getitem__(self, idx): 148 | """Get an item from the dataset. 149 | 150 | Args: 151 | idx (int): The index of the desired datapoint. 152 | 153 | Returns: 154 | obj, obj: The data and labels corresponding to the desired index. The type depends on the applied transforms. 155 | """ 156 | sample = [] 157 | idx = self.files.index(self.allowed[idx]) 158 | for i in range(self.sequence_length): 159 | sample.append(np.load(os.path.join(self.path, self.files[idx + i]))) 160 | data = [image["data"] for image in sample] 161 | if self.prediction_shift > 0: 162 | target = np.load(os.path.join(self.path, self.files[idx + i + self.prediction_shift])) 163 | labels = target["labels"] 164 | else: 165 | labels = sample[-1]["labels"] 166 | if self.transform_image: 167 | for i, image in enumerate(data): 168 | data[i] = self.transform_image(image) 169 | if self.transform_labels: 170 | labels = self.transform_labels(labels) 171 | if self.transform_sample: 172 | data = self.transform_sample(data) 173 | return data, labels 174 | -------------------------------------------------------------------------------- /deepsphere/data/transforms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepsphere/deepsphere-pytorch/43f03f1bef146d256a7e1c1e69df6712d089b9e5/deepsphere/data/transforms/__init__.py -------------------------------------------------------------------------------- /deepsphere/data/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | """Transformations for samples of atmospheric rivers and tropical cyclones dataset. 2 | """ 3 | import torch 4 | 5 | 6 | class ToTensor: 7 | """Convert raw data and labels to PyTorch tensor. 8 | """ 9 | 10 | def __call__(self, item): 11 | """Function call operator to change type. 12 | 13 | Args: 14 | item (:obj:`numpy.array`): Numpy array that needs to be transformed. 15 | Returns: 16 | :obj:`torch.Tensor`: Sample of size (vertices, features). 17 | """ 18 | return torch.Tensor(item) 19 | 20 | 21 | class Permute: 22 | """Permute first and second dimension. 23 | """ 24 | 25 | def __call__(self, item): 26 | """Permute first and second dimension. 27 | 28 | Args: 29 | item (:obj:`torch.Tensor`): Torch tensor that needs to be transformed. 30 | 31 | Returns: 32 | :obj:`torch.Tensor`: Permuted input tensor. 33 | """ 34 | return item.permute(1, 0) 35 | 36 | 37 | class Normalize: 38 | """Normalize using mean and std. 39 | """ 40 | 41 | def __init__(self, mean, std): 42 | """Initialization 43 | 44 | Args: 45 | mean (:obj:`numpy.array`): means of each feature 46 | std (:obj:`numpy.array`): standard deviations of each feature 47 | """ 48 | self.mean = torch.from_numpy(mean) 49 | self.std = torch.from_numpy(std) 50 | 51 | def __call__(self, item): 52 | """ 53 | Args: 54 | item (:obj:`torch.Tensor`): Sample of size (vertices, features) to be normalized on its features. 55 | 56 | Returns: 57 | :obj:`torch.Tensor`: Normalized input tensor. 58 | """ 59 | return (item - self.mean) / self.std 60 | 61 | 62 | class Stack: 63 | """Stack images in torch tensor. 64 | """ 65 | 66 | def __init__(self, dimension=0): 67 | """Initialization 68 | 69 | Args: 70 | dimension int: The dimension to be used for stacking. 71 | """ 72 | self.dimension = dimension 73 | 74 | def __call__(self, item): 75 | """Stack images in torch tensor. 76 | 77 | Args: 78 | item (:obj:`torch.Tensor`): Torch tensor that needs to be transformed. 79 | 80 | Returns: 81 | :obj:`torch.Tensor`: Stacked input tensor. 82 | """ 83 | return torch.stack(item, dim=self.dimension) 84 | -------------------------------------------------------------------------------- /deepsphere/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepsphere/deepsphere-pytorch/43f03f1bef146d256a7e1c1e69df6712d089b9e5/deepsphere/layers/__init__.py -------------------------------------------------------------------------------- /deepsphere/layers/chebyshev.py: -------------------------------------------------------------------------------- 1 | """Chebyshev convolution layer. For the moment taking as-is from Michaël Defferrard's implementation. For v0.15 we will rewrite parts of this layer. 2 | """ 3 | # pylint: disable=W0221 4 | 5 | import math 6 | 7 | import torch 8 | from torch import nn 9 | 10 | 11 | def cheb_conv(laplacian, inputs, weight): 12 | """Chebyshev convolution. 13 | 14 | Args: 15 | laplacian (:obj:`torch.sparse.Tensor`): The laplacian corresponding to the current sampling of the sphere. 16 | inputs (:obj:`torch.Tensor`): The current input data being forwarded. 17 | weight (:obj:`torch.Tensor`): The weights of the current layer. 18 | 19 | Returns: 20 | :obj:`torch.Tensor`: Inputs after applying Chebyshev convolution. 21 | """ 22 | B, V, Fin = inputs.shape 23 | K, Fin, Fout = weight.shape 24 | # B = batch size 25 | # V = nb vertices 26 | # Fin = nb input features 27 | # Fout = nb output features 28 | # K = order of Chebyshev polynomials 29 | 30 | # transform to Chebyshev basis 31 | x0 = inputs.permute(1, 2, 0).contiguous() # V x Fin x B 32 | x0 = x0.view([V, Fin * B]) # V x Fin*B 33 | inputs = x0.unsqueeze(0) # 1 x V x Fin*B 34 | 35 | if K > 0: 36 | x1 = torch.sparse.mm(laplacian, x0) # V x Fin*B 37 | inputs = torch.cat((inputs, x1.unsqueeze(0)), 0) # 2 x V x Fin*B 38 | for _ in range(1, K - 1): 39 | x2 = 2 * torch.sparse.mm(laplacian, x1) - x0 40 | inputs = torch.cat((inputs, x2.unsqueeze(0)), 0) # M x Fin*B 41 | x0, x1 = x1, x2 42 | 43 | inputs = inputs.view([K, V, Fin, B]) # K x V x Fin x B 44 | inputs = inputs.permute(3, 1, 2, 0).contiguous() # B x V x Fin x K 45 | inputs = inputs.view([B * V, Fin * K]) # B*V x Fin*K 46 | 47 | # Linearly compose Fin features to get Fout features 48 | weight = weight.view(Fin * K, Fout) 49 | inputs = inputs.matmul(weight) # B*V x Fout 50 | inputs = inputs.view([B, V, Fout]) # B x V x Fout 51 | 52 | return inputs 53 | 54 | 55 | class ChebConv(torch.nn.Module): 56 | """Graph convolutional layer. 57 | """ 58 | 59 | def __init__(self, in_channels, out_channels, kernel_size, bias=True, conv=cheb_conv): 60 | """Initialize the Chebyshev layer. 61 | 62 | Args: 63 | in_channels (int): Number of channels/features in the input graph. 64 | out_channels (int): Number of channels/features in the output graph. 65 | kernel_size (int): Number of trainable parameters per filter, which is also the size of the convolutional kernel. 66 | The order of the Chebyshev polynomials is kernel_size - 1. 67 | bias (bool): Whether to add a bias term. 68 | conv (callable): Function which will perform the actual convolution. 69 | """ 70 | super().__init__() 71 | 72 | self.in_channels = in_channels 73 | self.out_channels = out_channels 74 | self.kernel_size = kernel_size 75 | self._conv = conv 76 | 77 | shape = (kernel_size, in_channels, out_channels) 78 | self.weight = torch.nn.Parameter(torch.Tensor(*shape)) 79 | 80 | if bias: 81 | self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) 82 | else: 83 | self.register_parameter("bias", None) 84 | 85 | self.kaiming_initialization() 86 | 87 | def kaiming_initialization(self): 88 | """Initialize weights and bias. 89 | """ 90 | std = math.sqrt(2 / (self.in_channels * self.kernel_size)) 91 | self.weight.data.normal_(0, std) 92 | if self.bias is not None: 93 | self.bias.data.fill_(0.01) 94 | 95 | def forward(self, laplacian, inputs): 96 | """Forward graph convolution. 97 | 98 | Args: 99 | laplacian (:obj:`torch.sparse.Tensor`): The laplacian corresponding to the current sampling of the sphere. 100 | inputs (:obj:`torch.Tensor`): The current input data being forwarded. 101 | 102 | Returns: 103 | :obj:`torch.Tensor`: The convoluted inputs. 104 | """ 105 | outputs = self._conv(laplacian, inputs, self.weight) 106 | if self.bias is not None: 107 | outputs += self.bias 108 | return outputs 109 | 110 | 111 | class SphericalChebConv(nn.Module): 112 | """Building Block with a Chebyshev Convolution. 113 | """ 114 | 115 | def __init__(self, in_channels, out_channels, lap, kernel_size): 116 | """Initialization. 117 | 118 | Args: 119 | in_channels (int): initial number of channels 120 | out_channels (int): output number of channels 121 | lap (:obj:`torch.sparse.FloatTensor`): laplacian 122 | kernel_size (int): polynomial degree. Defaults to 3. 123 | """ 124 | super().__init__() 125 | self.register_buffer("laplacian", lap) 126 | self.chebconv = ChebConv(in_channels, out_channels, kernel_size) 127 | 128 | def state_dict(self, *args, **kwargs): 129 | """! WARNING ! 130 | 131 | This function overrides the state dict in order to be able to save the model. 132 | This can be removed as soon as saving sparse matrices has been added to Pytorch. 133 | """ 134 | state_dict = super().state_dict(*args, **kwargs) 135 | del_keys = [] 136 | for key in state_dict: 137 | if key.endswith("laplacian"): 138 | del_keys.append(key) 139 | for key in del_keys: 140 | del state_dict[key] 141 | return state_dict 142 | 143 | def forward(self, x): 144 | """Forward pass. 145 | 146 | Args: 147 | x (:obj:`torch.tensor`): input [batch x vertices x channels/features] 148 | 149 | Returns: 150 | :obj:`torch.tensor`: output [batch x vertices x channels/features] 151 | """ 152 | x = self.chebconv(self.laplacian, x) 153 | return x 154 | -------------------------------------------------------------------------------- /deepsphere/layers/samplings/__init__.py: -------------------------------------------------------------------------------- 1 | """DeepSphere Base Documentation doc 2 | """ 3 | -------------------------------------------------------------------------------- /deepsphere/layers/samplings/equiangular_pool_unpool.py: -------------------------------------------------------------------------------- 1 | """ 2 | EquiAngular Sampling's Pooling and Unpooling. 3 | The pooling goes down two bandwidths at a time. 4 | This represents (in the term of classic pooling kernel sizes) a division (pooling) or multiplication (unpooling) of the number of pixels by 4. 5 | The kernel size for all modules is henced fixed. 6 | 7 | Equiangular sampling theory from: 8 | *FFTs for the 2-Sphere:Improvements and Variations* by Healy (doi=10.1.1.51.5335) 9 | 10 | Bandwidth : int or list or tuple. Hence we have a symetric or asymetric sampling. It corresponds to the resolution of the sampling scheme. 11 | :math:`pixels = (2*bw)^{2}` 12 | Allowed number of pixels: 13 | 14 | - (bw=1) 4 pixels, 15 | - (bw=2) 16 pixels, 16 | - (bw=3) 36 pixels, 17 | - (bw=4) 64 pixels, 18 | - (bw=5) 100 pixels. 19 | 20 | If latitude bandwidth is different from longitude bandwidth then we have: 21 | :math:`pixels = ((2*bw_{latitude})**2)*((2*bw_{longitude})**2)` 22 | """ 23 | 24 | # pylint: disable=W0221 25 | 26 | import torch.nn as nn 27 | import torch.nn.functional as F 28 | 29 | from deepsphere.utils.samplings import equiangular_calculator 30 | 31 | 32 | def reformat(x): 33 | """Reformat the input from a 4D tensor to a 3D tensor 34 | 35 | Args: 36 | x (:obj:`torch.tensor`): a 4D tensor 37 | Returns: 38 | :obj:`torch.tensor`: a 3D tensor 39 | """ 40 | x = x.permute(0, 2, 3, 1) 41 | N, D1, D2, Feat = x.size() 42 | x = x.view(N, D1 * D2, Feat) 43 | return x 44 | 45 | 46 | class EquiangularMaxPool(nn.MaxPool1d): 47 | """EquiAngular Maxpooling module using MaxPool 1d from torch 48 | """ 49 | 50 | def __init__(self, ratio, return_indices=False): 51 | """Initialization 52 | 53 | Args: 54 | ratio (float): ratio between latitude and longitude dimensions of the data 55 | """ 56 | self.ratio = ratio 57 | super().__init__(kernel_size=4, return_indices=return_indices) 58 | 59 | def forward(self, x): 60 | """calls Maxpool1d and if desired, keeps indices of the pixels pooled to unpool them 61 | 62 | Args: 63 | input (:obj:`torch.tensor`): batch x pixels x features 64 | 65 | Returns: 66 | tuple(:obj:`torch.tensor`, list(int)): batch x pooled pixels x features and the indices of the pixels pooled 67 | """ 68 | x, _ = equiangular_calculator(x, self.ratio) 69 | x = x.permute(0, 3, 1, 2) 70 | 71 | if self.return_indices: 72 | x, indices = F.max_pool2d(x, self.kernel_size, return_indices=self.return_indices) 73 | else: 74 | x = F.max_pool2d(x, self.kernel_size) 75 | x = reformat(x) 76 | 77 | if self.return_indices: 78 | output = x, indices 79 | else: 80 | output = x 81 | 82 | return output 83 | 84 | 85 | class EquiangularAvgPool(nn.AvgPool1d): 86 | """EquiAngular Average Pooling using Average Pooling 1d from pytorch 87 | """ 88 | 89 | def __init__(self, ratio): 90 | """Initialization 91 | 92 | Args: 93 | ratio (float): ratio between latitude and longitude dimensions of the data 94 | """ 95 | self.ratio = ratio 96 | super().__init__(kernel_size=4) 97 | 98 | def forward(self, x): 99 | """calls Avgpool1d 100 | 101 | Args: 102 | x (:obj:`torch.tensor`): batch x pixels x features 103 | 104 | Returns: 105 | :obj:`torch.tensor` -- batch x pooled pixels x features 106 | """ 107 | x, _ = equiangular_calculator(x, self.ratio) 108 | x = x.permute(0, 3, 1, 2) 109 | x = F.avg_pool2d(x, self.kernel_size) 110 | x = reformat(x) 111 | 112 | return x 113 | 114 | 115 | class EquiangularMaxUnpool(nn.MaxUnpool1d): 116 | """Equiangular Maxunpooling using the MaxUnpool1d of pytorch 117 | """ 118 | 119 | def __init__(self, ratio): 120 | """Initialization 121 | 122 | Args: 123 | ratio (float): ratio between latitude and longitude dimensions of the data 124 | """ 125 | self.ratio = ratio 126 | super().__init__(kernel_size=4) 127 | 128 | def forward(self, x, indices): 129 | """calls MaxUnpool1d using the indices returned previously by EquiAngMaxPool 130 | 131 | Args: 132 | x (:obj:`torch.tensor`): batch x pixels x features 133 | indices (int): indices of pixels equiangular maxpooled previously 134 | 135 | Returns: 136 | :obj:`torch.tensor`: batch x unpooled pixels x features 137 | """ 138 | x, _ = equiangular_calculator(x, self.ratio) 139 | x = x.permute(0, 3, 1, 2) 140 | x = F.max_unpool2d(x, indices, kernel_size=(4, 4)) 141 | x = reformat(x) 142 | return x 143 | 144 | 145 | class EquiangularAvgUnpool(nn.Module): 146 | """EquiAngular Average Unpooling version 1 using the interpolate function when unpooling 147 | """ 148 | 149 | def __init__(self, ratio): 150 | """Initialization 151 | 152 | Args: 153 | ratio (float): ratio between latitude and longitude dimensions of the data 154 | """ 155 | self.ratio = ratio 156 | self.kernel_size = 4 157 | super().__init__() 158 | 159 | def forward(self, x): 160 | """calls pytorch's interpolate function to create the values while unpooling based on the nearby values 161 | Args: 162 | x (:obj:`torch.tensor`): batch x pixels x features 163 | Returns: 164 | :obj:`torch.tensor`: batch x unpooled pixels x features 165 | """ 166 | 167 | x, _ = equiangular_calculator(x, self.ratio) 168 | x = x.permute(0, 3, 1, 2) 169 | x = F.interpolate(x, scale_factor=(self.kernel_size, self.kernel_size), mode="nearest") 170 | x = reformat(x) 171 | return x 172 | 173 | 174 | class Equiangular: 175 | """Equiangular class, which groups together the corresponding pooling and unpooling. 176 | """ 177 | 178 | def __init__(self, ratio=1, mode="average"): 179 | """Initialize equiangular pooling and unpooling objects. 180 | 181 | Args: 182 | ratio (float): ratio between latitude and longitude dimensions of the data 183 | mode (str, optional): specify the mode for pooling/unpooling. 184 | Can be maxpooling or averagepooling. Defaults to 'average'. 185 | """ 186 | if mode == "max": 187 | self.__pooling = EquiangularMaxPool(ratio) 188 | self.__unpooling = EquiangularMaxUnpool(ratio) 189 | else: 190 | self.__pooling = EquiangularAvgPool(ratio) 191 | self.__unpooling = EquiangularAvgUnpool(ratio) 192 | 193 | @property 194 | def pooling(self): 195 | """Getter for the pooling class 196 | """ 197 | return self.__pooling 198 | 199 | @property 200 | def unpooling(self): 201 | """Getter for the unpooling class 202 | """ 203 | return self.__unpooling 204 | -------------------------------------------------------------------------------- /deepsphere/layers/samplings/healpix_pool_unpool.py: -------------------------------------------------------------------------------- 1 | """Healpix Sampling's Pooling and Unpooling 2 | The pooling divides the number of nsides by 2 each time. 3 | This represents (in the term of classic pooling kernel sizes) a division (pooling) or multiplication (unpooling) of the number of pixels by 4. 4 | The kernel size for all modules is hence fixed. 5 | 6 | Sampling theory from: 7 | *HEALPix — a Framework for High Resolution Discretization, and Fast Analysis of Data Distributed on the Sphere* by Gorski (doi: 10.1086/427976) 8 | 9 | Figure 1 for relation number of sides and number of pixels and for unpooling using tile. 10 | The area of the pixels are the same hence latitude and longitude of the resolution are the same. 11 | 12 | The lowest resolution possible with the HEALPix base partitioning of the sphere surface into 12 equal sized pixels 13 | See: https://healpix.jpl.nasa.gov/ 14 | 15 | :math:`N_{pixels} = 12 * N_{sides}^2` 16 | Nsides is the number of divisions from the baseline of 12 equal sized pixels 17 | 18 | """ 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | 22 | # pylint: disable=W0221 23 | 24 | 25 | class HealpixMaxPool(nn.MaxPool1d): 26 | """Healpix Maxpooling module 27 | """ 28 | 29 | def __init__(self, return_indices=False): 30 | """Initialization 31 | """ 32 | super().__init__(kernel_size=4, return_indices=return_indices) 33 | 34 | def forward(self, x): 35 | """Forward call the 1d Maxpooling of pytorch 36 | 37 | Args: 38 | x (:obj:`torch.tensor`):[batch x pixels x features] 39 | 40 | Returns: 41 | tuple((:obj:`torch.tensor`), indices (int)): [batch x pooled pixels x features] and indices of pooled pixels 42 | """ 43 | x = x.permute(0, 2, 1) 44 | if self.return_indices: 45 | x, indices = F.max_pool1d(x, self.kernel_size) 46 | else: 47 | x = F.max_pool1d(x, self.kernel_size) 48 | x = x.permute(0, 2, 1) 49 | 50 | if self.return_indices: 51 | output = x, indices 52 | else: 53 | output = x 54 | return output 55 | 56 | 57 | class HealpixAvgPool(nn.AvgPool1d): 58 | """Healpix Average pooling module 59 | """ 60 | 61 | def __init__(self): 62 | """initialization 63 | """ 64 | super().__init__(kernel_size=4) 65 | 66 | def forward(self, x): 67 | """forward call the 1d Averagepooling of pytorch 68 | 69 | Arguments: 70 | x (:obj:`torch.tensor`): [batch x pixels x features] 71 | 72 | Returns: 73 | [:obj:`torch.tensor`] : [batch x pooled pixels x features] 74 | """ 75 | x = x.permute(0, 2, 1) 76 | x = F.avg_pool1d(x, self.kernel_size) 77 | x = x.permute(0, 2, 1) 78 | return x 79 | 80 | 81 | class HealpixMaxUnpool(nn.MaxUnpool1d): 82 | """Healpix Maxunpooling using the MaxUnpool1d of pytorch 83 | """ 84 | 85 | def __init__(self): 86 | """initialization 87 | """ 88 | super().__init__(kernel_size=4) 89 | 90 | def forward(self, x, indices): 91 | """calls MaxUnpool1d using the indices returned previously by HealpixMaxPool 92 | 93 | Args: 94 | tuple(x (:obj:`torch.tensor`) : [batch x pixels x features] 95 | indices (int)): indices of pixels equiangular maxpooled previously 96 | 97 | Returns: 98 | [:obj:`torch.tensor`] -- [batch x unpooled pixels x features] 99 | """ 100 | x = x.permute(0, 2, 1) 101 | x = F.max_unpool1d(x, indices, self.kernel_size) 102 | x = x.permute(0, 2, 1) 103 | return x 104 | 105 | 106 | class HealpixAvgUnpool(nn.Module): 107 | """Healpix Average Unpooling module 108 | """ 109 | 110 | def __init__(self): 111 | """initialization 112 | """ 113 | self.kernel_size = 4 114 | super().__init__() 115 | 116 | def forward(self, x): 117 | """forward repeats (here more like a numpy tile for the moment) the incoming tensor 118 | 119 | Arguments: 120 | x (:obj:`torch.tensor`): [batch x pixels x features] 121 | 122 | Returns: 123 | [:obj:`torch.tensor`]: [batch x unpooled pixels x features] 124 | """ 125 | x = x.permute(0, 2, 1) 126 | x = F.interpolate(x, scale_factor=self.kernel_size, mode="nearest") 127 | x = x.permute(0, 2, 1) 128 | return x 129 | 130 | 131 | class Healpix: 132 | """Healpix class, which groups together the corresponding pooling and unpooling. 133 | """ 134 | 135 | def __init__(self, mode="average"): 136 | """Initialize healpix pooling and unpooling objects. 137 | 138 | Args: 139 | mode (str, optional): specify the mode for pooling/unpooling. 140 | Can be maxpooling or averagepooling. Defaults to 'average'. 141 | """ 142 | if mode == "max": 143 | self.__pooling = HealpixMaxPool() 144 | self.__unpooling = HealpixMaxUnpool() 145 | else: 146 | self.__pooling = HealpixAvgPool() 147 | self.__unpooling = HealpixAvgUnpool() 148 | 149 | @property 150 | def pooling(self): 151 | """Get pooling 152 | """ 153 | return self.__pooling 154 | 155 | @property 156 | def unpooling(self): 157 | """Get unpooling 158 | """ 159 | return self.__unpooling 160 | -------------------------------------------------------------------------------- /deepsphere/layers/samplings/icosahedron_pool_unpool.py: -------------------------------------------------------------------------------- 1 | """Icosahedron Sampling's Pooling and Unpooling. 2 | Each pooling takes down an order in the icosahedron. 3 | Each unpooling adds the number of pixels corresponding to the next order. 4 | 5 | Icosahedron is a polyhedron with 12 vertices and, 20 faces, where a regular icosahedron is a Platonic solid. 6 | All faces are regular (equilateral) triangles. 7 | This default Icosahedron can be considered at level 0, meaning that no further subdivision has occurred from the platonic solid. 8 | See: https://github.com/maxjiang93/ugscnn/blob/master/meshcnn/mesh.py from Max Jiang 9 | """ 10 | # pylint: disable=W0221 11 | 12 | import math 13 | 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | 17 | 18 | class IcosahedronPool(nn.Module): 19 | """Isocahedron Pooling, consists in keeping only a subset of the original pixels (considering the ordering of an isocahedron sampling method). 20 | """ 21 | 22 | def forward(self, x): 23 | """Forward function calculates the subset of pixels to keep based on input size and the kernel_size. 24 | 25 | Args: 26 | x (:obj:`torch.tensor`) : [batch x pixels x features] 27 | 28 | Returns: 29 | [:obj:`torch.tensor`] : [batch x pixels pooled x features] 30 | """ 31 | M = x.size(1) 32 | order = int(math.log((M - 2) / 10) / math.log(4)) 33 | pool_order = order - 1 34 | subset_pixels_keep = int(10 * math.pow(4, pool_order) + 2) 35 | return x[:, :subset_pixels_keep, :] 36 | 37 | 38 | class IcosahedronUnpool(nn.Module): 39 | """Isocahedron Unpooling, consists in adding 1 values to match the desired un pooling size 40 | """ 41 | 42 | def forward(self, x): 43 | """Forward calculates the subset of pixels that will result from the unpooling kernel_size and then adds 1 valued pixels to match this size 44 | 45 | Args: 46 | x (:obj:`torch.tensor`) : [batch x pixels x features] 47 | 48 | Returns: 49 | [:obj:`torch.tensor`]: [batch x pixels unpooled x features] 50 | """ 51 | M = x.size(1) 52 | order = int(math.log((M - 2) / 10) / math.log(4)) 53 | unpool_order = order + 1 54 | additional_pixels = int((10 * math.pow(4, unpool_order)) + 2) 55 | subset_pixels_add = additional_pixels - M 56 | return F.pad(x, (0, 0, 0, subset_pixels_add, 0, 0), "constant", value=1) 57 | 58 | 59 | class Icosahedron: 60 | """Icosahedron class, which simply groups together the corresponding pooling and unpooling. 61 | """ 62 | 63 | def __init__(self): 64 | """Initialize icosahedron pooling and unpooling objects. 65 | """ 66 | self.__pooling = IcosahedronPool() 67 | self.__unpooling = IcosahedronUnpool() 68 | 69 | @property 70 | def pooling(self): 71 | """Get pooling. 72 | """ 73 | return self.__pooling 74 | 75 | @property 76 | def unpooling(self): 77 | """Get unpooling. 78 | """ 79 | return self.__unpooling 80 | -------------------------------------------------------------------------------- /deepsphere/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepsphere/deepsphere-pytorch/43f03f1bef146d256a7e1c1e69df6712d089b9e5/deepsphere/models/__init__.py -------------------------------------------------------------------------------- /deepsphere/models/spherical_unet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepsphere/deepsphere-pytorch/43f03f1bef146d256a7e1c1e69df6712d089b9e5/deepsphere/models/spherical_unet/__init__.py -------------------------------------------------------------------------------- /deepsphere/models/spherical_unet/decoder.py: -------------------------------------------------------------------------------- 1 | """Decoder for Spherical UNet. 2 | """ 3 | # pylint: disable=W0221 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from deepsphere.layers.chebyshev import SphericalChebConv 9 | from deepsphere.models.spherical_unet.utils import SphericalChebBN, SphericalChebBNPool 10 | 11 | 12 | class SphericalChebBNPoolCheb(nn.Module): 13 | """Building Block calling a SphericalChebBNPool block then a SphericalCheb. 14 | """ 15 | 16 | def __init__(self, in_channels, middle_channels, out_channels, lap, pooling, kernel_size): 17 | """Initialization. 18 | 19 | Args: 20 | in_channels (int): initial number of channels. 21 | middle_channels (int): middle number of channels. 22 | out_channels (int): output number of channels. 23 | lap (:obj:`torch.sparse.FloatTensor`): laplacian. 24 | pooling (:obj:`torch.nn.Module`): pooling/unpooling module. 25 | kernel_size (int, optional): polynomial degree. Defaults to 3. 26 | """ 27 | super().__init__() 28 | self.spherical_cheb_bn_pool = SphericalChebBNPool(in_channels, middle_channels, lap, pooling, kernel_size) 29 | self.spherical_cheb = SphericalChebConv(middle_channels, out_channels, lap, kernel_size) 30 | 31 | def forward(self, x): 32 | """Forward Pass. 33 | 34 | Args: 35 | x (:obj:`torch.Tensor`): input [batch x vertices x channels/features] 36 | 37 | Returns: 38 | :obj:`torch.Tensor`: output [batch x vertices x channels/features] 39 | """ 40 | x = self.spherical_cheb_bn_pool(x) 41 | x = self.spherical_cheb(x) 42 | return x 43 | 44 | 45 | class SphericalChebBNPoolConcat(nn.Module): 46 | """Building Block calling a SphericalChebBNPool Block 47 | then concatenating the output with another tensor 48 | and calling a SphericalChebBN block. 49 | """ 50 | 51 | def __init__(self, in_channels, out_channels, lap, pooling, kernel_size): 52 | """Initialization. 53 | 54 | Args: 55 | in_channels (int): initial number of channels. 56 | out_channels (int): output number of channels. 57 | lap (:obj:`torch.sparse.FloatTensor`): laplacian. 58 | pooling (:obj:`torch.nn.Module`): pooling/unpooling module. 59 | kernel_size (int, optional): polynomial degree. Defaults to 3. 60 | """ 61 | super().__init__() 62 | self.spherical_cheb_bn_pool = SphericalChebBNPool(in_channels, out_channels, lap, pooling, kernel_size) 63 | self.spherical_cheb_bn = SphericalChebBN(in_channels + out_channels, out_channels, lap, kernel_size) 64 | 65 | def forward(self, x, concat_data): 66 | """Forward Pass. 67 | 68 | Args: 69 | x (:obj:`torch.Tensor`): input [batch x vertices x channels/features] 70 | concat_data (:obj:`torch.Tensor`): encoder layer output [batch x vertices x channels/features] 71 | 72 | Returns: 73 | :obj:`torch.Tensor`: output [batch x vertices x channels/features] 74 | """ 75 | x = self.spherical_cheb_bn_pool(x) 76 | # pylint: disable=E1101 77 | x = torch.cat((x, concat_data), dim=2) 78 | # pylint: enable=E1101 79 | x = self.spherical_cheb_bn(x) 80 | return x 81 | 82 | 83 | class Decoder(nn.Module): 84 | """The decoder of the Spherical UNet. 85 | """ 86 | 87 | def __init__(self, unpooling, laps, kernel_size): 88 | """Initialization. 89 | 90 | Args: 91 | unpooling (:obj:`torch.nn.Module`): The unpooling object. 92 | laps (list): List of laplacians. 93 | """ 94 | super().__init__() 95 | self.unpooling = unpooling 96 | self.kernel_size = kernel_size 97 | self.dec_l1 = SphericalChebBNPoolConcat(512, 512, laps[1], self.unpooling, self.kernel_size) 98 | self.dec_l2 = SphericalChebBNPoolConcat(512, 256, laps[2], self.unpooling, self.kernel_size) 99 | self.dec_l3 = SphericalChebBNPoolConcat(256, 128, laps[3], self.unpooling, self.kernel_size) 100 | self.dec_l4 = SphericalChebBNPoolConcat(128, 64, laps[4], self.unpooling, self.kernel_size) 101 | self.dec_l5 = SphericalChebBNPoolCheb(64, 32, 3, laps[5], self.unpooling, self.kernel_size) 102 | # Switch from Logits to Probabilities if evaluating model 103 | self.softmax = nn.Softmax(dim=2) 104 | 105 | def forward(self, x_enc0, x_enc1, x_enc2, x_enc3, x_enc4): 106 | """Forward Pass. 107 | 108 | Args: 109 | x_enc* (:obj:`torch.Tensor`): input tensors. 110 | 111 | Returns: 112 | :obj:`torch.Tensor`: output after forward pass. 113 | """ 114 | x = self.dec_l1(x_enc0, x_enc1) 115 | x = self.dec_l2(x, x_enc2) 116 | x = self.dec_l3(x, x_enc3) 117 | x = self.dec_l4(x, x_enc4) 118 | x = self.dec_l5(x) 119 | if not self.training: 120 | x = self.softmax(x) 121 | return x 122 | -------------------------------------------------------------------------------- /deepsphere/models/spherical_unet/encoder.py: -------------------------------------------------------------------------------- 1 | """Encoder for Spherical UNet. 2 | """ 3 | # pylint: disable=W0221 4 | from torch import nn 5 | 6 | from deepsphere.layers.chebyshev import SphericalChebConv 7 | from deepsphere.models.spherical_unet.utils import SphericalChebBN, SphericalChebBNPool 8 | 9 | 10 | class SphericalChebBN2(nn.Module): 11 | """Building Block made of 2 Building Blocks (convolution, batchnorm, activation). 12 | """ 13 | 14 | def __init__(self, in_channels, middle_channels, out_channels, lap, kernel_size): 15 | """Initialization. 16 | 17 | Args: 18 | in_channels (int): initial number of channels. 19 | middle_channels (int): middle number of channels. 20 | out_channels (int): output number of channels. 21 | lap (:obj:`torch.sparse.FloatTensor`): laplacian. 22 | kernel_size (int, optional): polynomial degree. 23 | """ 24 | 25 | super().__init__() 26 | self.in_channels = in_channels 27 | self.out_channels = out_channels 28 | self.spherical_cheb_bn_1 = SphericalChebBN(in_channels, middle_channels, lap, kernel_size) 29 | self.spherical_cheb_bn_2 = SphericalChebBN(middle_channels, out_channels, lap, kernel_size) 30 | 31 | def forward(self, x): 32 | """Forward Pass. 33 | 34 | Args: 35 | x (:obj:`torch.Tensor`): input [batch x vertices x channels/features] 36 | 37 | Returns: 38 | :obj:`torch.Tensor`: output [batch x vertices x channels/features] 39 | """ 40 | x = self.spherical_cheb_bn_1(x) 41 | x = self.spherical_cheb_bn_2(x) 42 | return x 43 | 44 | 45 | class SphericalChebPool(nn.Module): 46 | """Building Block with a pooling/unpooling and a Chebyshev Convolution. 47 | """ 48 | 49 | def __init__(self, in_channels, out_channels, lap, pooling, kernel_size): 50 | """Initialization. 51 | 52 | Args: 53 | in_channels (int): initial number of channels. 54 | out_channels (int): output number of channels. 55 | lap (:obj:`torch.sparse.FloatTensor`): laplacian. 56 | pooling (:obj:`torch.nn.Module`): pooling/unpooling module. 57 | kernel_size (int, optional): polynomial degree. 58 | """ 59 | super().__init__() 60 | self.pooling = pooling 61 | self.spherical_cheb = SphericalChebConv(in_channels, out_channels, lap, kernel_size) 62 | 63 | def forward(self, x): 64 | """Forward Pass. 65 | 66 | Args: 67 | x (:obj:`torch.Tensor`): input [batch x vertices x channels/features] 68 | 69 | Returns: 70 | :obj:`torch.Tensor`: output [batch x vertices x channels/features] 71 | """ 72 | x = self.pooling(x) 73 | x = self.spherical_cheb(x) 74 | return x 75 | 76 | 77 | class Encoder(nn.Module): 78 | """Encoder for the Spherical UNet. 79 | """ 80 | 81 | def __init__(self, pooling, laps, kernel_size): 82 | """Initialization. 83 | 84 | Args: 85 | pooling (:obj:`torch.nn.Module`): pooling layer. 86 | laps (list): List of laplacians. 87 | kernel_size (int): polynomial degree. 88 | """ 89 | super().__init__() 90 | self.pooling = pooling 91 | self.kernel_size = kernel_size 92 | self.enc_l5 = SphericalChebBN2(16, 32, 64, laps[5], self.kernel_size) 93 | self.enc_l4 = SphericalChebBNPool(64, 128, laps[4], self.pooling, self.kernel_size) 94 | self.enc_l3 = SphericalChebBNPool(128, 256, laps[3], self.pooling, self.kernel_size) 95 | self.enc_l2 = SphericalChebBNPool(256, 512, laps[2], self.pooling, self.kernel_size) 96 | self.enc_l1 = SphericalChebBNPool(512, 512, laps[1], self.pooling, self.kernel_size) 97 | self.enc_l0 = SphericalChebPool(512, 512, laps[0], self.pooling, self.kernel_size) 98 | 99 | def forward(self, x): 100 | """Forward Pass. 101 | 102 | Args: 103 | x (:obj:`torch.Tensor`): input [batch x vertices x channels/features] 104 | 105 | Returns: 106 | x_enc* :obj: `torch.Tensor`: output [batch x vertices x channels/features] 107 | """ 108 | x_enc5 = self.enc_l5(x) 109 | x_enc4 = self.enc_l4(x_enc5) 110 | x_enc3 = self.enc_l3(x_enc4) 111 | x_enc2 = self.enc_l2(x_enc3) 112 | x_enc1 = self.enc_l1(x_enc2) 113 | x_enc0 = self.enc_l0(x_enc1) 114 | 115 | return x_enc0, x_enc1, x_enc2, x_enc3, x_enc4 116 | 117 | 118 | class EncoderTemporalConv(Encoder): 119 | """Encoder for the Spherical UNet temporality with convolution. 120 | """ 121 | 122 | def __init__(self, pooling, laps, sequence_length, kernel_size): 123 | """Initialization. 124 | 125 | Args: 126 | pooling (:obj:`torch.nn.Module`): pooling layer. 127 | laps (list): List of laplacians. 128 | sequence_length (int): The number of images used per sample. 129 | kernel_size (int): Polynomial degree. 130 | """ 131 | super().__init__(pooling, laps, kernel_size) 132 | self.sequence_length = sequence_length 133 | self.enc_l5 = SphericalChebBN2( 134 | self.enc_l5.in_channels * self.sequence_length, 135 | self.enc_l5.in_channels * self.sequence_length, 136 | self.enc_l5.out_channels, 137 | laps[5], 138 | self.kernel_size, 139 | ) 140 | -------------------------------------------------------------------------------- /deepsphere/models/spherical_unet/unet_model.py: -------------------------------------------------------------------------------- 1 | """Spherical Graph Convolutional Neural Network with UNet autoencoder architecture. 2 | """ 3 | 4 | # pylint: disable=W0221 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from deepsphere.layers.samplings.equiangular_pool_unpool import Equiangular 10 | from deepsphere.layers.samplings.healpix_pool_unpool import Healpix 11 | from deepsphere.layers.samplings.icosahedron_pool_unpool import Icosahedron 12 | from deepsphere.models.spherical_unet.decoder import Decoder 13 | from deepsphere.models.spherical_unet.encoder import Encoder, EncoderTemporalConv 14 | from deepsphere.utils.laplacian_funcs import get_equiangular_laplacians, get_healpix_laplacians, get_icosahedron_laplacians 15 | 16 | 17 | class SphericalUNet(nn.Module): 18 | """Spherical GCNN Autoencoder. 19 | """ 20 | 21 | def __init__(self, pooling_class, N, depth, laplacian_type, kernel_size, ratio=1): 22 | """Initialization. 23 | 24 | Args: 25 | pooling_class (obj): One of three classes of pooling methods 26 | N (int): Number of pixels in the input image 27 | depth (int): The depth of the UNet, which is bounded by the N and the type of pooling 28 | kernel_size (int): chebychev polynomial degree 29 | ratio (float): Parameter for equiangular sampling 30 | """ 31 | super().__init__() 32 | self.ratio = ratio 33 | self.kernel_size = kernel_size 34 | if pooling_class == "icosahedron": 35 | self.pooling_class = Icosahedron() 36 | self.laps = get_icosahedron_laplacians(N, depth, laplacian_type) 37 | elif pooling_class == "healpix": 38 | self.pooling_class = Healpix() 39 | self.laps = get_healpix_laplacians(N, depth, laplacian_type) 40 | elif pooling_class == "equiangular": 41 | self.pooling_class = Equiangular() 42 | self.laps = get_equiangular_laplacians(N, depth, self.ratio, laplacian_type) 43 | else: 44 | raise ValueError("Error: sampling method unknown. Please use icosahedron, healpix or equiangular.") 45 | 46 | self.encoder = Encoder(self.pooling_class.pooling, self.laps, self.kernel_size) 47 | self.decoder = Decoder(self.pooling_class.unpooling, self.laps, self.kernel_size) 48 | 49 | def forward(self, x): 50 | """Forward Pass. 51 | 52 | Args: 53 | x (:obj:`torch.Tensor`): input to be forwarded. 54 | 55 | Returns: 56 | :obj:`torch.Tensor`: output 57 | """ 58 | x_encoder = self.encoder(x) 59 | output = self.decoder(*x_encoder) 60 | return output 61 | 62 | 63 | class SphericalUNetTemporalLSTM(SphericalUNet): 64 | """Sphericall GCNN Autoencoder with LSTM. 65 | """ 66 | 67 | def __init__(self, pooling_class, N, depth, laplacian_type, sequence_length, kernel_size, ratio=1): 68 | """Initialization. 69 | 70 | Args: 71 | pooling_class (obj): One of three classes of pooling methods 72 | N (int): Number of pixels in the input image 73 | depth (int): The depth of the UNet, which is bounded by the N and the type of pooling 74 | sequence_length (int): The number of images used per sample 75 | kernel_size (int): chebychev polynomial degree 76 | ratio (float): Parameter for equiangular sampling 77 | """ 78 | super().__init__(pooling_class, N, depth, laplacian_type, kernel_size, ratio) 79 | self.sequence_length = sequence_length 80 | n_pixels = self.laps[0].size(0) 81 | n_features = self.encoder.enc_l0.spherical_cheb.chebconv.in_channels 82 | self.lstm_l0 = nn.LSTM(input_size=n_pixels * n_features, hidden_size=n_pixels * n_features, batch_first=True) 83 | 84 | def forward(self, x): 85 | """Forward Pass. 86 | 87 | Args: 88 | x (:obj:`torch.Tensor`): input to be forwarded. 89 | 90 | Returns: 91 | :obj:`torch.Tensor`: output 92 | """ 93 | device = x.device 94 | encoders_l0 = [] 95 | for idx in range(self.sequence_length): 96 | encoding = self.encoder(x[:, idx, :, :].squeeze(dim=1)) 97 | encoders_l0.append(encoding[0].reshape(encoding[0].size(0), 1, -1)) 98 | 99 | encoders_l0 = torch.cat(encoders_l0, axis=1).to(device) 100 | lstm_output_l0, _ = self.lstm_l0(encoders_l0) 101 | lstm_output_l0 = lstm_output_l0[:, -1, :].reshape(-1, encoding[0].size(1), encoding[0].size(2)) 102 | 103 | output = self.decoder(lstm_output_l0, encoding[1], encoding[2], encoding[3], encoding[4]) 104 | return output 105 | 106 | 107 | class SphericalUNetTemporalConv(SphericalUNet): 108 | """Spherical GCNN Autoencoder with temporality by means of convolution over time. 109 | """ 110 | 111 | def __init__(self, pooling_class, N, depth, laplacian_type, sequence_length, kernel_size, ratio=1): 112 | """Initialization. 113 | 114 | Args: 115 | pooling_class (obj): One of three classes of pooling methods 116 | N (int): Number of pixels in the input image 117 | depth (int): The depth of the UNet, which is bounded by the N and the type of pooling 118 | sequence_length (int): The number of images used per sample 119 | kernel_size (int): chebychev polynomial degree 120 | ratio (float): Parameter for equiangular sampling 121 | """ 122 | super().__init__(pooling_class, N, depth, laplacian_type, kernel_size, ratio) 123 | self.sequence_length = sequence_length 124 | self.encoder = EncoderTemporalConv(self.pooling_class.pooling, self.laps, self.sequence_length, self.kernel_size) 125 | self.decoder = Decoder(self.pooling_class.unpooling, self.laps, self.kernel_size) 126 | 127 | def forward(self, x): 128 | """Forward Pass. 129 | 130 | Args: 131 | x (:obj:`torch.Tensor`): input to be forwarded. 132 | 133 | Returns: 134 | :obj:`torch.Tensor`: output 135 | """ 136 | x_encoder = self.encoder(x) 137 | output = self.decoder(*x_encoder) 138 | return output 139 | -------------------------------------------------------------------------------- /deepsphere/models/spherical_unet/utils.py: -------------------------------------------------------------------------------- 1 | """Layers used in both Encoder and Decoder. 2 | """ 3 | # pylint: disable=W0221 4 | import torch.nn.functional as F 5 | from torch import nn 6 | 7 | from deepsphere.layers.chebyshev import SphericalChebConv 8 | 9 | 10 | class SphericalChebBN(nn.Module): 11 | """Building Block with a Chebyshev Convolution, Batchnormalization, and ReLu activation. 12 | """ 13 | 14 | def __init__(self, in_channels, out_channels, lap, kernel_size): 15 | """Initialization. 16 | 17 | Args: 18 | in_channels (int): initial number of channels. 19 | out_channels (int): output number of channels. 20 | lap (:obj:`torch.sparse.FloatTensor`): laplacian. 21 | kernel_size (int, optional): polynomial degree. Defaults to 3. 22 | """ 23 | super().__init__() 24 | self.spherical_cheb = SphericalChebConv(in_channels, out_channels, lap, kernel_size) 25 | self.batchnorm = nn.BatchNorm1d(out_channels) 26 | 27 | def forward(self, x): 28 | """Forward Pass. 29 | 30 | Args: 31 | x (:obj:`torch.tensor`): input [batch x vertices x channels/features] 32 | 33 | Returns: 34 | :obj:`torch.tensor`: output [batch x vertices x channels/features] 35 | """ 36 | x = self.spherical_cheb(x) 37 | x = self.batchnorm(x.permute(0, 2, 1)) 38 | x = F.relu(x.permute(0, 2, 1)) 39 | return x 40 | 41 | 42 | class SphericalChebBNPool(nn.Module): 43 | """Building Block with a pooling/unpooling, a calling the SphericalChebBN block. 44 | """ 45 | 46 | def __init__(self, in_channels, out_channels, lap, pooling, kernel_size): 47 | """Initialization. 48 | 49 | Args: 50 | in_channels (int): initial number of channels. 51 | out_channels (int): output number of channels. 52 | lap (:obj:`torch.sparse.FloatTensor`): laplacian. 53 | pooling (:obj:`torch.nn.Module`): pooling/unpooling module. 54 | kernel_size (int, optional): polynomial degree. Defaults to 3. 55 | """ 56 | super().__init__() 57 | self.pooling = pooling 58 | self.spherical_cheb_bn = SphericalChebBN(in_channels, out_channels, lap, kernel_size) 59 | 60 | def forward(self, x): 61 | """Forward Pass. 62 | 63 | Args: 64 | x (:obj:`torch.tensor`): input [batch x vertices x channels/features] 65 | 66 | Returns: 67 | :obj:`torch.tensor`: output [batch x vertices x channels/features] 68 | """ 69 | x = self.pooling(x) 70 | x = self.spherical_cheb_bn(x) 71 | return x 72 | -------------------------------------------------------------------------------- /deepsphere/tests/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The :mod:`tests` module contains different directory and files that have the goal to test different parts of the code 3 | 4 | 5 | Class 6 | ----- 7 | 8 | You can see in this module the :class:`TestFoo` that contain the different method: 9 | 10 | .. autosummary:: 11 | 12 | TestFoo.test_foo 13 | 14 | More Doc / Example 15 | ------------------ 16 | 17 | You can add then more doc and even examples 18 | """ 19 | 20 | from .test_foo import TestFoo 21 | -------------------------------------------------------------------------------- /deepsphere/tests/test_foo.py: -------------------------------------------------------------------------------- 1 | """Fake file to test the doc 2 | """ 3 | 4 | import unittest 5 | 6 | 7 | class TestFoo(unittest.TestCase): 8 | """Fake test class in order to setup the tests module 9 | """ 10 | 11 | def test_foo(self): 12 | """Fake test method in order to setup the test module 13 | """ 14 | self.assertTrue(True) 15 | -------------------------------------------------------------------------------- /deepsphere/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepsphere/deepsphere-pytorch/43f03f1bef146d256a7e1c1e69df6712d089b9e5/deepsphere/utils/__init__.py -------------------------------------------------------------------------------- /deepsphere/utils/initialization.py: -------------------------------------------------------------------------------- 1 | """Initializing device 2 | """ 3 | 4 | 5 | import torch 6 | from torch import nn 7 | from torchvision import transforms 8 | 9 | from deepsphere.data.datasets.dataset import ARTCTemporaldataset 10 | from deepsphere.data.transforms.transforms import Stack 11 | from deepsphere.models.spherical_unet.unet_model import SphericalUNetTemporalConv, SphericalUNetTemporalLSTM 12 | 13 | 14 | def init_device(device, unet): 15 | """Initialize device based on cpu/gpu and number of gpu 16 | 17 | Args: 18 | device (str): cpu or gpu 19 | ids (list of int or str): list of gpus that should be used 20 | unet (torch.Module): the model to place on the device(s) 21 | 22 | Raises: 23 | Exception: There is an error in configuring the cpu or gpu 24 | 25 | Returns: 26 | torch.Module, torch.device: the model placed on device, the device 27 | """ 28 | if device is None: 29 | device = torch.device("cpu") 30 | unet = unet.to(device) 31 | elif len(device) == 0: 32 | device = torch.device("cuda") 33 | unet = unet.to(device) 34 | unet = nn.DataParallel(unet) 35 | elif len(device) == 1: 36 | device = torch.device("cuda:{}".format(device[0])) 37 | unet = unet.to(device) 38 | elif len(device) > 1: 39 | ids = device 40 | device = torch.device("cuda:{}".format(ids[0])) 41 | unet = unet.to(device) 42 | unet = nn.DataParallel(unet, device_ids=[int(i) for i in ids]) 43 | else: 44 | raise Exception("Device set up impossible.") 45 | 46 | return unet, device 47 | 48 | 49 | def init_unet_temp(parser): 50 | """Initialize UNet 51 | 52 | Args: 53 | parser (dict): parser arguments 54 | 55 | Returns: 56 | unet: the model 57 | """ 58 | pooling_class = parser.pooling_class 59 | n_pixels = parser.n_pixels 60 | depth = parser.depth 61 | laplacian_type = parser.laplacian_type 62 | sequence_length = parser.sequence_length 63 | kernel_size = parser.kernel_size 64 | if parser.type == "LSTM": 65 | unet = SphericalUNetTemporalLSTM(pooling_class, n_pixels, depth, laplacian_type, sequence_length, kernel_size) 66 | elif parser.type == "conv": 67 | unet = SphericalUNetTemporalConv(pooling_class, n_pixels, depth, laplacian_type, sequence_length, kernel_size) 68 | else: 69 | raise Exception("The first element after --temp must be either 'LSTM' or 'conv' to specify the type.") 70 | return unet 71 | 72 | 73 | def init_dataset_temp(parser, indices, transform_image, transform_labels): 74 | """Initialize the dataset 75 | 76 | Args: 77 | parser (dict): parser arguments 78 | indices (list): The list of indices we want included in the dataset 79 | transform_image (list): The list of torchvision transforms we want to apply to the images 80 | transform_labels (list): The list of torchvision transforms we want to apply to the labels 81 | 82 | Returns: 83 | dataset: the dataset 84 | """ 85 | path_to_data = parser.path_to_data 86 | download = parser.download 87 | if parser.type == "LSTM": 88 | transform_sample = transforms.Compose([Stack()]) 89 | elif parser.type == "conv": 90 | transform_sample = transforms.Compose([transforms.Lambda(lambda item: torch.stack(item, dim=1).reshape(item[0].size(0), -1))]) 91 | else: 92 | raise Exception("Invalid temporality type.") 93 | dataset = ARTCTemporaldataset( 94 | path=path_to_data, 95 | download=download, 96 | sequence_length=parser.sequence_length, 97 | prediction_shift=parser.prediction_shift, 98 | indices=indices, 99 | transform_image=transform_image, 100 | transform_labels=transform_labels, 101 | transform_sample=transform_sample, 102 | ) 103 | return dataset 104 | -------------------------------------------------------------------------------- /deepsphere/utils/laplacian_funcs.py: -------------------------------------------------------------------------------- 1 | """Functions related to getting the laplacian and the right number of pixels after pooling/unpooling. 2 | """ 3 | 4 | import numpy as np 5 | import torch 6 | from pygsp.graphs.nngraphs.spherehealpix import SphereHealpix 7 | from pygsp.graphs.nngraphs.sphereicosahedron import SphereIcosahedron 8 | from pygsp.graphs.sphereequiangular import SphereEquiangular 9 | from scipy import sparse 10 | from scipy.sparse import coo_matrix 11 | 12 | from deepsphere.utils.samplings import ( 13 | equiangular_bandwidth, 14 | equiangular_dimension_unpack, 15 | healpix_resolution_calculator, 16 | icosahedron_nodes_calculator, 17 | icosahedron_order_calculator, 18 | ) 19 | 20 | 21 | def scipy_csr_to_sparse_tensor(csr_mat): 22 | """Convert scipy csr to sparse pytorch tensor. 23 | 24 | Args: 25 | csr_mat (csr_matrix): The sparse scipy matrix. 26 | 27 | Returns: 28 | sparse_tensor :obj:`torch.sparse.FloatTensor`: The sparse torch matrix. 29 | """ 30 | coo = coo_matrix(csr_mat) 31 | values = coo.data 32 | indices = np.vstack((coo.row, coo.col)) 33 | idx = torch.LongTensor(indices) 34 | vals = torch.FloatTensor(values) 35 | shape = coo.shape 36 | sparse_tensor = torch.sparse.FloatTensor(idx, vals, torch.Size(shape)) 37 | sparse_tensor = sparse_tensor.coalesce() 38 | return sparse_tensor 39 | 40 | 41 | def prepare_laplacian(laplacian): 42 | """Prepare a graph Laplacian to be fed to a graph convolutional layer. 43 | """ 44 | 45 | def estimate_lmax(laplacian, tol=5e-3): 46 | """Estimate the largest eigenvalue of an operator. 47 | """ 48 | lmax = sparse.linalg.eigsh(laplacian, k=1, tol=tol, ncv=min(laplacian.shape[0], 10), return_eigenvectors=False) 49 | lmax = lmax[0] 50 | lmax *= 1 + 2 * tol # Be robust to errors. 51 | return lmax 52 | 53 | def scale_operator(L, lmax, scale=1): 54 | """Scale the eigenvalues from [0, lmax] to [-scale, scale]. 55 | """ 56 | I = sparse.identity(L.shape[0], format=L.format, dtype=L.dtype) 57 | L *= 2 * scale / lmax 58 | L -= I 59 | return L 60 | 61 | lmax = estimate_lmax(laplacian) 62 | laplacian = scale_operator(laplacian, lmax) 63 | laplacian = scipy_csr_to_sparse_tensor(laplacian) 64 | return laplacian 65 | 66 | 67 | def get_icosahedron_laplacians(nodes, depth, laplacian_type): 68 | """Get the icosahedron laplacian list for a certain depth. 69 | Args: 70 | nodes (int): initial number of nodes. 71 | depth (int): the depth of the UNet. 72 | laplacian_type ["combinatorial", "normalized"]: the type of the laplacian. 73 | 74 | Returns: 75 | laps (list): increasing list of laplacians. 76 | """ 77 | laps = [] 78 | order = icosahedron_order_calculator(nodes) 79 | for _ in range(depth): 80 | nodes = icosahedron_nodes_calculator(order) 81 | order_initial = icosahedron_order_calculator(nodes) 82 | G = SphereIcosahedron(level=int(order_initial)) 83 | G.compute_laplacian(laplacian_type) 84 | laplacian = prepare_laplacian(G.L) 85 | laps.append(laplacian) 86 | order -= 1 87 | return laps[::-1] 88 | 89 | 90 | def get_healpix_laplacians(nodes, depth, laplacian_type): 91 | """Get the healpix laplacian list for a certain depth. 92 | Args: 93 | nodes (int): initial number of nodes. 94 | depth (int): the depth of the UNet. 95 | laplacian_type ["combinatorial", "normalized"]: the type of the laplacian. 96 | Returns: 97 | laps (list): increasing list of laplacians. 98 | """ 99 | laps = [] 100 | for i in range(depth): 101 | pixel_num = nodes 102 | subdivisions = int(healpix_resolution_calculator(pixel_num)/2**i) 103 | G = SphereHealpix(subdivisions, nest=True, k=20) 104 | G.compute_laplacian(laplacian_type) 105 | laplacian = prepare_laplacian(G.L) 106 | laps.append(laplacian) 107 | return laps[::-1] 108 | 109 | 110 | def get_equiangular_laplacians(nodes, depth, ratio, laplacian_type): 111 | """Get the equiangular laplacian list for a certain depth. 112 | Args: 113 | nodes (int): initial number of nodes. 114 | depth (int): the depth of the UNet. 115 | laplacian_type ["combinatorial", "normalized"]: the type of the laplacian. 116 | 117 | Returns: 118 | laps (list): increasing list of laplacians 119 | """ 120 | laps = [] 121 | pixel_num = nodes 122 | for _ in range(depth): 123 | dim1, dim2 = equiangular_dimension_unpack(pixel_num, ratio) 124 | bw1 = equiangular_bandwidth(dim1) 125 | bw2 = equiangular_bandwidth(dim2) 126 | bw = [bw1, bw2] 127 | G = SphereEquiangular(bandwidth=bw, sampling="SOFT") 128 | G.compute_laplacian(laplacian_type) 129 | laplacian = prepare_laplacian(G.L) 130 | laps.append(laplacian) 131 | return laps[::-1] 132 | -------------------------------------------------------------------------------- /deepsphere/utils/parser.py: -------------------------------------------------------------------------------- 1 | """Command Line Parser realated functions. 2 | One function creates the parser. 3 | Another function allows hybird usage of: 4 | - a yaml file with predefined parameters 5 | and 6 | - user inputted parameters through the command line. 7 | """ 8 | 9 | import argparse 10 | 11 | import yaml 12 | 13 | 14 | def create_parser(): 15 | """Creates a parser with all the variables that can be edited by the user. 16 | 17 | Returns: 18 | parser: a parser for the command line 19 | """ 20 | parser = argparse.ArgumentParser() 21 | 22 | parser.add_argument("--config-file", dest="config_file", type=argparse.FileType(mode="r")) 23 | 24 | parser.add_argument("--pooling_class", default=None, type=str) 25 | parser.add_argument("--n_pixels", default=None, type=int) 26 | parser.add_argument("--depth", default=None, type=int) 27 | parser.add_argument("--laplacian_type", default=None, type=str) 28 | 29 | parser.add_argument("--type", default=None, type=str) 30 | parser.add_argument("--sequence_length", default=None, type=int) 31 | parser.add_argument("--prediction_shift", default=None, type=int) 32 | 33 | parser.add_argument("--partition", default=None, nargs="+") 34 | parser.add_argument("--batch_size", default=None, type=int) 35 | parser.add_argument("--learning_rate", default=None, type=float) 36 | parser.add_argument("--n_epochs", default=None, type=int) 37 | parser.add_argument("--kernel_size", default=None, type=int) 38 | 39 | parser.add_argument("--path_to_data", default=None) 40 | parser.add_argument("--model_save_path", default=None) 41 | parser.add_argument("--tensorboard_path", default=None) 42 | 43 | parser.add_argument("--download", default=None, type=bool) 44 | parser.add_argument("--means_path", default=None) 45 | parser.add_argument("--stds_path", default=None) 46 | parser.add_argument("--seed", default=None, type=int) 47 | 48 | parser.add_argument("--reducelronplateau_mode", default=None) 49 | parser.add_argument("--reducelronplateau_factor", default=None, type=float) 50 | parser.add_argument("--reducelronplateau_patience", default=None, type=int) 51 | 52 | parser.add_argument("--steplr_step_size", default=None, type=int) 53 | parser.add_argument("--steplr_gamma", default=None, type=float) 54 | 55 | parser.add_argument("--warmuplr_warmup_start_value", default=None, type=float) 56 | parser.add_argument("--warmuplr_warmup_end_value", default=None, type=float) 57 | parser.add_argument("--warmuplr_warmup_duration", default=None, type=int) 58 | 59 | parser.add_argument("--earlystopping_patience", default=None, type=int) 60 | 61 | parser.add_argument("--gpu", dest="device", nargs="*") 62 | 63 | return parser 64 | 65 | 66 | def parse_config(parser): 67 | """Takes the yaml file given through the command line 68 | Adds all the yaml file parameters, unless they have already been defined in the command line. 69 | Checks all values have been set else raises a Value error. 70 | Args: 71 | parser (argparse.parser): parser to be updated by the yaml file parameters 72 | Raises: 73 | ValueError: All fields must be set in the yaml config file or in the command line. Raises error if value is None (was not set). 74 | Returns: 75 | dict: parsed args of the parser 76 | """ 77 | args = parser.parse_args() 78 | arg_dict = args.__dict__ 79 | if args.config_file: 80 | data = yaml.load(args.config_file, Loader=yaml.FullLoader) 81 | delattr(args, "config_file") 82 | arg_dict = args.__dict__ 83 | for key, value in data.items(): 84 | # add only those not specified by the user through command line 85 | if isinstance(value, dict): 86 | for tag, element in value.items(): 87 | if arg_dict[tag] is None: 88 | arg_dict[tag] = element 89 | 90 | else: 91 | if arg_dict[key] is None: 92 | arg_dict[key] = value 93 | for key, value in arg_dict.items(): 94 | if key != "device" and key != "type" and key != "sequence_length" and key != "prediction_shift" and arg_dict[key] is None: 95 | raise ValueError("The value of {} is set to None. Please define it in the config yaml file or in the command line.".format(key)) 96 | return args 97 | -------------------------------------------------------------------------------- /deepsphere/utils/samplings.py: -------------------------------------------------------------------------------- 1 | """Different samplings require various calculations. 2 | The calculations present here are for equiangular, healpix, icosahedron samplings. 3 | """ 4 | import math 5 | 6 | 7 | def equiangular_bandwidth(nodes): 8 | """Calculate the equiangular bandwidth based on input nodes 9 | 10 | Args: 11 | nodes (int): the number of nodes should be a power of 4 12 | 13 | Returns: 14 | int: the corresponding bandwidth 15 | """ 16 | bw = math.sqrt(nodes) / 2 17 | return bw 18 | 19 | 20 | def equiangular_dimension_unpack(nodes, ratio): 21 | """Calculate the two underlying dimensions 22 | from the total number of nodes 23 | 24 | Args: 25 | nodes (int): combined dimensions 26 | ratio (float): ratio between the two dimensions 27 | 28 | Returns: 29 | int, int: separated dimensions 30 | """ 31 | dim1 = int((nodes / ratio) ** 0.5) 32 | dim2 = int((nodes * ratio) ** 0.5) 33 | return dim1, dim2 34 | 35 | 36 | def equiangular_calculator(tensor, ratio): 37 | """From a 3D input tensor and a known ratio between the latitude 38 | dimension and longitude dimension of the data, reformat the 3D input 39 | into a 4D output while also obtaining the bandwidth. 40 | 41 | Args: 42 | tensor (:obj:`torch.tensor`): 3D input tensor 43 | ratio (float): the ratio between the latitude and longitude dimension of the data 44 | 45 | Returns: 46 | :obj:`torch.tensor`, int, int: 4D tensor, the bandwidths for lat. and long. 47 | """ 48 | N, M, F = tensor.size() 49 | dim1, dim2 = equiangular_dimension_unpack(M, ratio) 50 | bw_dim1 = equiangular_bandwidth(dim1) 51 | bw_dim2 = equiangular_bandwidth(dim2) 52 | tensor = tensor.view(N, dim1, dim2, F) 53 | return tensor, [bw_dim1, bw_dim2] 54 | 55 | 56 | def healpix_resolution_calculator(nodes): 57 | """Calculate the resolution of a healpix graph 58 | for a given number of nodes. 59 | 60 | Args: 61 | nodes (int): number of nodes in healpix sampling 62 | 63 | Returns: 64 | int: resolution for the matching healpix graph 65 | """ 66 | resolution = int(math.sqrt(nodes / 12)) 67 | return resolution 68 | 69 | 70 | def icosahedron_order_calculator(nodes): 71 | """Calculate the order of a icosahedron graph 72 | for a given number of nodes. 73 | 74 | Args: 75 | nodes (int): number of nodes in icosahedron sampling 76 | 77 | Returns: 78 | int: order for the matching icosahedron graph 79 | """ 80 | order = math.log((nodes - 2) / 10) / math.log(4) 81 | return order 82 | 83 | 84 | def icosahedron_nodes_calculator(order): 85 | """Calculate the number of nodes 86 | corresponding to the order of an icosahedron graph 87 | 88 | Args: 89 | order (int): order of an icosahedron graph 90 | 91 | Returns: 92 | int: number of nodes in icosahedron sampling for that order 93 | """ 94 | nodes = 10 * (4 ** order) + 2 95 | return nodes 96 | -------------------------------------------------------------------------------- /deepsphere/utils/stats_extractor.py: -------------------------------------------------------------------------------- 1 | """Get Means and Standard deviations for all features of a dataset. 2 | """ 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def stats_extractor(dataset): 8 | """Iterates over a dataset object 9 | It is iterated over so as to calculate the mean and standard deviation. 10 | 11 | Args: 12 | dataset (:obj:`torch.utils.data.dataloader`): dataset object to iterate over 13 | 14 | Returns: 15 | :obj:numpy.array, :obj:numpy.array : computed means and standard deviation 16 | """ 17 | 18 | F, V = torch.Tensor(dataset[0][0]).shape 19 | summing = torch.zeros(F) 20 | square_summing = torch.zeros(F) 21 | total = 0 22 | 23 | for item in dataset: 24 | item = torch.Tensor(item[0]) 25 | summing += torch.sum(item, dim=1) 26 | total += V 27 | 28 | means = torch.unsqueeze(summing / total, dim=1) 29 | 30 | for item in dataset: 31 | item = torch.Tensor(item[0]) 32 | square_summing += torch.sum((item - means) ** 2, dim=1) 33 | 34 | stds = np.sqrt(square_summing / (total - 1)) 35 | 36 | return torch.squeeze(means, dim=1).numpy(), stds.numpy() 37 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | 13 | import os 14 | import sys 15 | 16 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 17 | ROOT = os.path.abspath(os.path.join(BASE_DIR, "..", "..")) 18 | 19 | sys.path.insert(0, ROOT) 20 | 21 | 22 | # -- Project information ----------------------------------------------------- 23 | 24 | project = "DeepSphere" 25 | copyright = "2019, Arcanite Solutions" 26 | author = "Arcanite Solutions" 27 | 28 | # read the version of the project 29 | import deepsphere # isort:skip 30 | 31 | release = deepsphere.__version__ 32 | version = release 33 | 34 | 35 | # -- General configuration --------------------------------------------------- 36 | 37 | # Add any Sphinx extension module names here, as strings. They can be 38 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 39 | # ones. 40 | extensions = [ 41 | "sphinx.ext.autodoc", 42 | "sphinx.ext.intersphinx", 43 | "sphinx.ext.ifconfig", 44 | "sphinx.ext.viewcode", 45 | "sphinx.ext.githubpages", 46 | "sphinx.ext.napoleon", 47 | "sphinx.ext.autosummary", 48 | ] 49 | 50 | # Add any paths that contain templates here, relative to this directory. 51 | templates_path = ["_templates"] 52 | 53 | # List of patterns, relative to source directory, that match files and 54 | # directories to ignore when looking for source files. 55 | # This pattern also affects html_static_path and html_extra_path. 56 | exclude_patterns = ["._*"] 57 | 58 | # -- Options for HTML output ------------------------------------------------- 59 | 60 | # The theme to use for HTML and HTML Help pages. See the documentation for 61 | # a list of builtin themes. 62 | # 63 | html_theme = 'sphinx_rtd_theme' # alabaster, haiku, nature, pyramid, agogo, bizstyle, sphinx_rtd_theme 64 | 65 | # Add any paths that contain custom static files (such as style sheets) here, 66 | # relative to this directory. They are copied after the builtin static files, 67 | # so a file named "default.css" will overwrite the builtin "default.css". 68 | html_static_path = ["_static"] 69 | 70 | # Used this information since it's index.rst that contains the root toctree directive 71 | master_doc = "index" 72 | 73 | 74 | # -- Extension configuration ------------------------------------------------- 75 | 76 | # -- Options for intersphinx extension --------------------------------------- 77 | 78 | # Example configuration for intersphinx: refer to the Python standard library. 79 | intersphinx_mapping = { 80 | "python": ("https://docs.python.org/3/", None), 81 | "torch": ("https://pytorch.org/docs/master/", None), 82 | "numpy": ("http://docs.scipy.org/doc/numpy", None), 83 | } 84 | 85 | # -- Options for napoleon extension --------------------------------------- 86 | napoleon_google_docstring = True 87 | napoleon_numpy_docstring = True 88 | 89 | # -- Options for autosummary extension --------------------------------------- 90 | autosummary_generate = True 91 | 92 | # -- Options for autodocs extension --------------------------------------- 93 | 94 | autodoc_mock_imports = ["numpy", "sklearn", "scipy", "tensorboard", "torch", "torchvision", "pyyaml", "jupyter", "pygsp", "ignite"] 95 | # autoclass_content = "both" # include both class docstring and __init__ 96 | autodoc_default_flags = [ 97 | # Make sure that any autodoc declarations show the right members 98 | "members", 99 | "show-inheritance", 100 | ] 101 | -------------------------------------------------------------------------------- /docs/source/deepsphere.data.datasets.rst: -------------------------------------------------------------------------------- 1 | deepsphere.data.datasets package 2 | ================================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | deepsphere.data.datasets.dataset module 8 | --------------------------------------- 9 | 10 | .. automodule:: deepsphere.data.datasets.dataset 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | 16 | Module contents 17 | --------------- 18 | 19 | .. automodule:: deepsphere.data.datasets 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | -------------------------------------------------------------------------------- /docs/source/deepsphere.data.rst: -------------------------------------------------------------------------------- 1 | deepsphere.data package 2 | ======================= 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | deepsphere.data.datasets 10 | deepsphere.data.transforms 11 | 12 | Module contents 13 | --------------- 14 | 15 | .. automodule:: deepsphere.data 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | -------------------------------------------------------------------------------- /docs/source/deepsphere.data.transforms.rst: -------------------------------------------------------------------------------- 1 | deepsphere.data.transforms package 2 | ================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | deepsphere.data.transforms.transforms module 8 | -------------------------------------------- 9 | 10 | .. automodule:: deepsphere.data.transforms.transforms 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | 16 | Module contents 17 | --------------- 18 | 19 | .. automodule:: deepsphere.data.transforms 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | -------------------------------------------------------------------------------- /docs/source/deepsphere.layers.rst: -------------------------------------------------------------------------------- 1 | deepsphere.layers package 2 | ========================= 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | deepsphere.layers.samplings 10 | 11 | Submodules 12 | ---------- 13 | 14 | deepsphere.layers.chebyshev module 15 | ---------------------------------- 16 | 17 | .. automodule:: deepsphere.layers.chebyshev 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: deepsphere.layers 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /docs/source/deepsphere.layers.samplings.rst: -------------------------------------------------------------------------------- 1 | deepsphere.layers.samplings package 2 | =================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | deepsphere.layers.samplings.equiangular\_pool\_unpool module 8 | ------------------------------------------------------------ 9 | 10 | .. automodule:: deepsphere.layers.samplings.equiangular_pool_unpool 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | deepsphere.layers.samplings.healpix\_pool\_unpool module 16 | -------------------------------------------------------- 17 | 18 | .. automodule:: deepsphere.layers.samplings.healpix_pool_unpool 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | deepsphere.layers.samplings.icosahedron\_pool\_unpool module 24 | ------------------------------------------------------------ 25 | 26 | .. automodule:: deepsphere.layers.samplings.icosahedron_pool_unpool 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | 32 | Module contents 33 | --------------- 34 | 35 | .. automodule:: deepsphere.layers.samplings 36 | :members: 37 | :undoc-members: 38 | :show-inheritance: 39 | -------------------------------------------------------------------------------- /docs/source/deepsphere.models.rst: -------------------------------------------------------------------------------- 1 | deepsphere.models package 2 | ========================= 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | deepsphere.models.spherical_unet 10 | 11 | Module contents 12 | --------------- 13 | 14 | .. automodule:: deepsphere.models 15 | :members: 16 | :undoc-members: 17 | :show-inheritance: 18 | -------------------------------------------------------------------------------- /docs/source/deepsphere.models.spherical_unet.rst: -------------------------------------------------------------------------------- 1 | deepsphere.models.spherical\_unet package 2 | ========================================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | deepsphere.models.spherical\_unet.decoder module 8 | ------------------------------------------------ 9 | 10 | .. automodule:: deepsphere.models.spherical_unet.decoder 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | deepsphere.models.spherical\_unet.encoder module 16 | ------------------------------------------------ 17 | 18 | .. automodule:: deepsphere.models.spherical_unet.encoder 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | deepsphere.models.spherical\_unet.unet\_model module 24 | ---------------------------------------------------- 25 | 26 | .. automodule:: deepsphere.models.spherical_unet.unet_model 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | deepsphere.models.spherical\_unet.utils module 32 | ---------------------------------------------- 33 | 34 | .. automodule:: deepsphere.models.spherical_unet.utils 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | 40 | Module contents 41 | --------------- 42 | 43 | .. automodule:: deepsphere.models.spherical_unet 44 | :members: 45 | :undoc-members: 46 | :show-inheritance: 47 | -------------------------------------------------------------------------------- /docs/source/deepsphere.rst: -------------------------------------------------------------------------------- 1 | deepsphere package 2 | ================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | deepsphere.data 10 | deepsphere.layers 11 | deepsphere.models 12 | deepsphere.tests 13 | deepsphere.utils 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: deepsphere 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/source/deepsphere.tests.rst: -------------------------------------------------------------------------------- 1 | deepsphere.tests package 2 | ======================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | deepsphere.tests.test\_foo module 8 | --------------------------------- 9 | 10 | .. automodule:: deepsphere.tests.test_foo 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | 16 | Module contents 17 | --------------- 18 | 19 | .. automodule:: deepsphere.tests 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | -------------------------------------------------------------------------------- /docs/source/deepsphere.utils.rst: -------------------------------------------------------------------------------- 1 | deepsphere.utils package 2 | ======================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | deepsphere.utils.initialization module 8 | -------------------------------------- 9 | 10 | .. automodule:: deepsphere.utils.initialization 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | deepsphere.utils.laplacian\_funcs module 16 | ---------------------------------------- 17 | 18 | .. automodule:: deepsphere.utils.laplacian_funcs 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | deepsphere.utils.parser module 24 | ------------------------------ 25 | 26 | .. automodule:: deepsphere.utils.parser 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | deepsphere.utils.samplings module 32 | --------------------------------- 33 | 34 | .. automodule:: deepsphere.utils.samplings 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | deepsphere.utils.stats\_extractor module 40 | ---------------------------------------- 41 | 42 | .. automodule:: deepsphere.utils.stats_extractor 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | 48 | Module contents 49 | --------------- 50 | 51 | .. automodule:: deepsphere.utils 52 | :members: 53 | :undoc-members: 54 | :show-inheritance: 55 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. DeepSphere documentation master file, created by 2 | sphinx-quickstart on Fri Oct 4 17:28:07 2019. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to DeepSphere's documentation! 7 | ====================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Content: 12 | 13 | 14 | deepsphere 15 | 16 | 17 | Indices and tables 18 | ================== 19 | 20 | * :ref:`genindex` 21 | * :ref:`modindex` 22 | * :ref:`search` 23 | -------------------------------------------------------------------------------- /docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | deepsphere 2 | ========== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | deepsphere 8 | scripts 9 | setup 10 | -------------------------------------------------------------------------------- /docs/source/scripts.rst: -------------------------------------------------------------------------------- 1 | scripts package 2 | =============== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | scripts.temporality 10 | 11 | Submodules 12 | ---------- 13 | 14 | scripts.run\_ar\_tc module 15 | -------------------------- 16 | 17 | .. automodule:: scripts.run_ar_tc 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: scripts 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /docs/source/scripts.temporality.rst: -------------------------------------------------------------------------------- 1 | scripts.temporality package 2 | =========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | scripts.temporality.run\_ar\_tc module 8 | -------------------------------------- 9 | 10 | .. automodule:: scripts.temporality.run_ar_tc 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | 16 | Module contents 17 | --------------- 18 | 19 | .. automodule:: scripts.temporality 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | -------------------------------------------------------------------------------- /docs/source/setup.rst: -------------------------------------------------------------------------------- 1 | setup module 2 | ============ 3 | 4 | .. automodule:: setup 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /images/AR_TC_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepsphere/deepsphere-pytorch/43f03f1bef146d256a7e1c1e69df6712d089b9e5/images/AR_TC_image.png -------------------------------------------------------------------------------- /images/Example_3D_Icosahedronmovie_globe.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepsphere/deepsphere-pytorch/43f03f1bef146d256a7e1c1e69df6712d089b9e5/images/Example_3D_Icosahedronmovie_globe.gif -------------------------------------------------------------------------------- /images/equations/L_eq.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepsphere/deepsphere-pytorch/43f03f1bef146d256a7e1c1e69df6712d089b9e5/images/equations/L_eq.gif -------------------------------------------------------------------------------- /images/equations/Lc.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepsphere/deepsphere-pytorch/43f03f1bef146d256a7e1c1e69df6712d089b9e5/images/equations/Lc.gif -------------------------------------------------------------------------------- /images/equations/Lc_eq.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepsphere/deepsphere-pytorch/43f03f1bef146d256a7e1c1e69df6712d089b9e5/images/equations/Lc_eq.gif -------------------------------------------------------------------------------- /images/equations/T0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepsphere/deepsphere-pytorch/43f03f1bef146d256a7e1c1e69df6712d089b9e5/images/equations/T0.gif -------------------------------------------------------------------------------- /images/equations/T1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepsphere/deepsphere-pytorch/43f03f1bef146d256a7e1c1e69df6712d089b9e5/images/equations/T1.gif -------------------------------------------------------------------------------- /images/equations/Tm.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepsphere/deepsphere-pytorch/43f03f1bef146d256a7e1c1e69df6712d089b9e5/images/equations/Tm.gif -------------------------------------------------------------------------------- /images/equations/Tm_recursive.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepsphere/deepsphere-pytorch/43f03f1bef146d256a7e1c1e69df6712d089b9e5/images/equations/Tm_recursive.gif -------------------------------------------------------------------------------- /images/equations/poly_eq.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepsphere/deepsphere-pytorch/43f03f1bef146d256a7e1c1e69df6712d089b9e5/images/equations/poly_eq.gif -------------------------------------------------------------------------------- /images/equations/xhat.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepsphere/deepsphere-pytorch/43f03f1bef146d256a7e1c1e69df6712d089b9e5/images/equations/xhat.gif -------------------------------------------------------------------------------- /images/equations/y_eq.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepsphere/deepsphere-pytorch/43f03f1bef146d256a7e1c1e69df6712d089b9e5/images/equations/y_eq.gif -------------------------------------------------------------------------------- /images/interactiveplot_epoch28.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepsphere/deepsphere-pytorch/43f03f1bef146d256a7e1c1e69df6712d089b9e5/images/interactiveplot_epoch28.png -------------------------------------------------------------------------------- /images/interactiveplot_epoch4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepsphere/deepsphere-pytorch/43f03f1bef146d256a7e1c1e69df6712d089b9e5/images/interactiveplot_epoch4.png -------------------------------------------------------------------------------- /notebooks/demo_visualizations.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Producing Visuals \n", 8 | "\n", 9 | "\n", 10 | "## Icosahedron\n", 11 | "- Create a graph corresponding the the icosahedron sampling level of the data. (*g = SphereIcosahedron(level=5)*)\n", 12 | "- Generate icosahedron longitude and latitude: *icolong, icolat = np.rad2deg(g.lon), np.rad2deg(g.lat)*\n", 13 | "\n", 14 | "\n", 15 | "## Equiangular\n", 16 | "- Determine the longitude and latitude from the data. Dimension 0 is longitude (height) and dimension 1 is latitude (width).\n", 17 | "- *lon_ = np.arange(x.size(0))/x.size(0)*360*\n", 18 | "- *lat_ = np.arange(x.size(1))/x.size(1)*180-90*\n", 19 | "- *lon, lat = np.meshgrid(lon_, lat_)*\n", 20 | "\n", 21 | "The Following examples were performed for icosahedron data of the Climate data at \"http://island.me.berkeley.edu/ugscnn/data/climate_sphere_l5.zip\"\n" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "# Requirements\n", 29 | "To run the code below please install the following.\n", 30 | "- Cartopy:\n", 31 | "```bash\n", 32 | "conda install -c conda-forge cartopy\n", 33 | "```\n", 34 | "- Imageio\n", 35 | "```bash\n", 36 | "conda install -c conda-forge imageio\n", 37 | "```\n", 38 | "- Matplotlib\n", 39 | "```bash\n", 40 | "conda install -c conda-forge matplotlib\n", 41 | "```" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "import cartopy.crs as ccrs\n", 51 | "import matplotlib.pyplot as plt\n", 52 | "import imageio" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "import numpy as np\n", 62 | "from sklearn.model_selection import train_test_split\n", 63 | "from torchvision import transforms\n", 64 | "from pygsp.graphs.nngraphs.sphereicosahedron import SphereIcosahedron\n", 65 | "from deepsphere.data.datasets.dataset import ARTCDataset\n", 66 | "from deepsphere.data.transforms.transforms import Normalize, Permute, ToTensor\n", 67 | "from deepsphere.utils.stats_extractor import stats_extractor" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "def visualize_2d(x, longitude, latitude, export_path=None):\n", 77 | " \"\"\"Visualize the data on a 2D map\n", 78 | "\n", 79 | " Args:\n", 80 | " x (numpy.array): numpy array with data the size of the longitude and latitude\n", 81 | " longitude (numpy.array): longitude coordinates\n", 82 | " latitude (numpy.array): latitude coordinates\n", 83 | " export_path (string): path and name for saving\n", 84 | " \"\"\"\n", 85 | "\n", 86 | " fig = plt.figure(figsize=(20, 10))\n", 87 | " ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())\n", 88 | "\n", 89 | " ax.set_global()\n", 90 | " ax.coastlines()\n", 91 | "\n", 92 | " plt.scatter(longitude, latitude, s=20, c=x, cmap=plt.get_cmap(\"RdYlBu_r\"), alpha=1)\n", 93 | " if export_path:\n", 94 | " plt.savefig(export_path)\n", 95 | " plt.clf()\n", 96 | " plt.cla()\n", 97 | " plt.close()\n", 98 | "\n", 99 | " else:\n", 100 | " plt.show()" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "def visualize_3d(x, longitude, latitude, export_dir, save_format='png'):\n", 110 | " \"\"\"Visualize the data on a 3D globe\n", 111 | "\n", 112 | " Args:\n", 113 | " x (numpy.array): numpy array with data the size of the longitude and latitude\n", 114 | " longitude (numpy.array): longitude coordinates\n", 115 | " latitude (numpy.array): latitude coordinates\n", 116 | " export_dir (string): path where to save images to make gif and the gif itself\n", 117 | " save_format (string): the format in which to save the images to make the gif. Default=png\n", 118 | " \"\"\"\n", 119 | "\n", 120 | " for i in range(0, 330, 10):\n", 121 | " fig = plt.figure(figsize=(10, 10))\n", 122 | " ax = fig.add_subplot(1, 1, 1, projection=ccrs.Orthographic(i, 0))\n", 123 | " ax.set_global()\n", 124 | " ax.coastlines(linewidth=2)\n", 125 | "\n", 126 | " plt.scatter(longitude, latitude, s=80, c=x, cmap=plt.get_cmap(\"RdYlBu_r\"), alpha=1, transform=ccrs.PlateCarree()) # marker='3',\n", 127 | " idx = str(i)\n", 128 | " try:\n", 129 | " path = export_dir + \"/globes/globe_\" + idx\n", 130 | " plt.savefig(path + \".\" + save_format)\n", 131 | " plt.clf()\n", 132 | " plt.cla()\n", 133 | " plt.close()\n", 134 | " except ValueError:\n", 135 | " raise ValueError('The export directory does not exist or you have not prepared a \"globes\" directory to store the images composing the gif. Please prepare a \"globes\" folder in your directory for the 3D vizualisation.')\n", 136 | "\n", 137 | " images = []\n", 138 | " for frame in range(0, 330, 10):\n", 139 | " images.append(imageio.imread(export_dir + \"/globes/globe_\" + str(frame) + \".\" + save_format))\n", 140 | " imageio.mimsave(export_dir + \"movie_globe.gif\", images, duration=0.75)" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "g = SphereIcosahedron(level=5)\n", 150 | "icolong, icolat = np.rad2deg(g.lon), np.rad2deg(g.lat)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "path_to_data = \"../../../../../data/climate/data_5_all\"\n", 160 | "data = ARTCDataset(path=path_to_data, download=False)" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "visualize_2d(data[10][1].transpose(1,0), icolong, icolat)" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "visualize_3d(data[10][1].transpose(1,0), icolong, icolat, 'Example_3D_Icosahedron')" 179 | ] 180 | } 181 | ], 182 | "metadata": { 183 | "language_info": { 184 | "name": "python", 185 | "pygments_lexer": "ipython3" 186 | } 187 | }, 188 | "nbformat": 4, 189 | "nbformat_minor": 2 190 | } 191 | -------------------------------------------------------------------------------- /notebooks/interactive_visualization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Interactive Plotting for Results\n", 8 | "\n", 9 | "\n", 10 | "During training, data can be saved and visuals (2D and 3D) can be created using the functions presented in the demo visualizations notebook. \n", 11 | "This Jupyter Notebook is an example of code capable of generating an interactive plot using plotly. \n", 12 | "\n", 13 | "#### Training\n", 14 | "The metrics for each epoch would have to be stored during training. For example, a Handler could be added to the validation engine in the script `run_ar_tc.py`, which would call the visualization functions at the end of each validation epoch. \n", 15 | "\n", 16 | "#### Visualizations\n", 17 | "After training, the results can be visualized over time and through interactive plots and images in this code. For each epoch, a trace is created for the tropical cyclones and the atmospheric rivers. This trace will be plotted in the metrics representation per epoch and only shows metrics results up to that point in time. Likewise, the image plotting of the predicted labels vs the truthful labels is only showed for the desired time point (epoch). \n", 18 | "\n", 19 | "#### Running the code\n", 20 | "In this example, the metrics results are saved for each epoch in a numpy file named `visualization_results_epoch_X.npy` where X is the epoch in question. Similarly, the images for each model prediction and the ground truths are called `image_prediction_epoch_X.png` and `image_truth_epoch_X.png`. These would have to be adapted to the name given to your files.\n", 21 | "\n", 22 | "#### The Plot\n", 23 | "A static rendering of what the interactive plot looks like. The slider at the bottom allows you to move in between epochs (i.e. move over time) \n", 24 | "\n", 25 | "![The state of learning of the model after training after 4 epochs](../images/interactiveplot_epoch4.png)\n", 26 | "\n", 27 | "![The state of learning of the model after training after 28 epochs](../images/interactiveplot_epoch28.png)\n" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "import plotly.graph_objects as go\n", 37 | "import numpy as np\n", 38 | "from plotly.subplots import make_subplots\n", 39 | "from PIL import Image\n", 40 | "from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot\n", 41 | "init_notebook_mode(connected=True)\n", 42 | "\n", 43 | "import plotly.io as pio\n", 44 | "pio.renderers.default = 'iframe'" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "tc_values=[]\n", 54 | "ar_values=[]\n", 55 | "for i in range(1,31):\n", 56 | " values = np.load('visualization_results_epoch_{}.npy'.format(i), allow_pickle=True)\n", 57 | " tc_values.append(values[0][1])\n", 58 | " ar_values.append(values[0][2])" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "# Create figure\n", 68 | "fig = make_subplots(rows=5, cols=2, \n", 69 | " shared_yaxes=True,\n", 70 | " subplot_titles=(\"Tropical Cyclones Mean Average Precision\", \n", 71 | " \"Atmospheric Rivers Mean Average Precision\",\n", 72 | " \"Prediction of Extreme Events\",\n", 73 | " \"Ground Truth of Extreme Events\"),\n", 74 | " specs=[[{\"rowspan\": 2}, {\"rowspan\": 2}],\n", 75 | " [None, None],\n", 76 | " [{\"rowspan\": 3},{\"rowspan\": 3}],\n", 77 | " [None, None],\n", 78 | " [None, None]],\n", 79 | " vertical_spacing=0.1)\n", 80 | "\n", 81 | "# Add traces, one for each slider step\n", 82 | "for step in np.arange(0, 30):\n", 83 | " # TROPICAL CYCLONES\n", 84 | " fig.add_trace(\n", 85 | " go.Scatter(\n", 86 | " visible=False,\n", 87 | " line=dict(color=\"#008000\", width=6),\n", 88 | " name=\"TC mAP: \" + str(step),\n", 89 | " x=np.arange(step+1),\n", 90 | " y=tc_values[:step+1]),\n", 91 | " row=1, col=1)\n", 92 | "\n", 93 | " #ATMOSPHERIC RIVERS\n", 94 | " fig.add_trace(\n", 95 | " go.Scatter(\n", 96 | " visible=False,\n", 97 | " line=dict(color=\"#00008b\", width=6),\n", 98 | " name=\"AR mAP: \" + str(step),\n", 99 | " x=np.arange(step+1),\n", 100 | " y=ar_values[:step+1]),\n", 101 | " row=1, col=2)\n", 102 | " \n", 103 | " #2D Image\n", 104 | " fig.add_trace(\n", 105 | " go.Image(\n", 106 | " visible=False,\n", 107 | " z=np.array(Image.open(\"image_prediction_epoch_{}.png\".format(step+1)))),\n", 108 | " row=3, col=1)\n", 109 | " \n", 110 | " #2D Image\n", 111 | " fig.add_trace(\n", 112 | " go.Image(\n", 113 | " visible=False,\n", 114 | " z=np.array(Image.open(\"image_truth_epoch_{}.png\".format(step+1)))),\n", 115 | " row=3, col=2)\n", 116 | "\n", 117 | "\n", 118 | "# Make 10th trace visible\n", 119 | "fig.data[0].visible = True\n", 120 | "fig.data[1].visible = True\n", 121 | "fig.data[2].visible = True\n", 122 | "fig.data[3].visible = True\n", 123 | "\n", 124 | "# Update xaxis properties\n", 125 | "fig.update_xaxes(title_text=\"Epoch\", row=1, col=1)\n", 126 | "fig.update_xaxes(title_text=\"Epoch\", row=1, col=2)\n", 127 | "fig.update_xaxes(showgrid=False, showticklabels=False, row=3, col=1)\n", 128 | "fig.update_xaxes(showgrid=False, showticklabels=False, row=3, col=2)\n", 129 | "\n", 130 | "# Update yaxis properties\n", 131 | "fig.update_yaxes(title_text=\"Mean Average Precision\", row=1, col=1)\n", 132 | "fig.update_yaxes(title_text=\"Mean Average Precision\", row=1, col=2)\n", 133 | "fig.update_yaxes(showgrid=False, showticklabels=False, row=3, col=1)\n", 134 | "fig.update_yaxes(showgrid=False, showticklabels=False, row=3, col=2)\n", 135 | "\n", 136 | "# Create and add slider\n", 137 | "steps = []\n", 138 | "for i in range(0,len(fig.data), 4):\n", 139 | " step = dict(\n", 140 | " #method=\"restyle\",\n", 141 | " args=[\"visible\", [False] * len(fig.data)],\n", 142 | " )\n", 143 | " step[\"args\"][1][i] = True # Toggle i'th trace to \"visible\"\n", 144 | " step[\"args\"][1][i+1] = True\n", 145 | " step[\"args\"][1][i+2] = True\n", 146 | " step[\"args\"][1][i+3] = True\n", 147 | " steps.append(step)\n", 148 | "\n", 149 | "sliders = [dict(\n", 150 | " active=0,\n", 151 | " currentvalue={\"prefix\": \"Epoch: \"},\n", 152 | " pad={\"t\": 50},\n", 153 | " steps=steps\n", 154 | ")]\n", 155 | "\n", 156 | "fig.update_layout(\n", 157 | " height=1250, width=2000,\n", 158 | " title_text=\"Extreme Event Detection Results\",\n", 159 | " sliders=sliders\n", 160 | ")\n", 161 | "\n", 162 | "pio.show(fig)\n", 163 | "\n" 164 | ] 165 | } 166 | ], 167 | "metadata": { 168 | "language_info": { 169 | "name": "python", 170 | "pygments_lexer": "ipython3" 171 | } 172 | }, 173 | "nbformat": 4, 174 | "nbformat_minor": 2 175 | } 176 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 150 3 | target-version = ['py37'] 4 | 5 | [tool.isort] 6 | multi_line_output = 3 7 | include_trailing_comma = true 8 | force_grid_wrap = 0 9 | use_parentheses = true 10 | line_length = 150 11 | known_third_party = ['sklearn', 'ignite', 'pygsp'] 12 | 13 | [build-system] 14 | build-backend = 'setuptools.build_meta' 15 | requires = [ 16 | "setuptools >= 40.0.4", 17 | "setuptools_scm >= 2.0.0, <4", 18 | "wheel >= 0.29.0", 19 | ] 20 | 21 | -------------------------------------------------------------------------------- /requirements-tests.txt: -------------------------------------------------------------------------------- 1 | black==19.3b0 2 | isort[requirements,pyproject]==4.3.21 3 | pre-commit==1.18.3 4 | pylint==2.4.2 5 | pylint-fail-under==0.3.0 6 | Sphinx==2.2.0 7 | sphinx-rtd-theme==0.4.3 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.18.1 2 | scikit-learn==0.21.3 3 | scipy==1.4.1 4 | tensorboard==2.1.0 5 | torch==1.3.1 6 | torchvision==0.4.2 7 | pyyaml==5.2 8 | jupyter==1.0.0 9 | pytorch-ignite==0.2.1 10 | pillow==6.2.2 11 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | """DeepSphere Base Documentation doc 2 | """ 3 | -------------------------------------------------------------------------------- /scripts/config.example.yml: -------------------------------------------------------------------------------- 1 | IMAGE PARAMS: 2 | pooling_class: "icosahedron" 3 | n_pixels: 10242 4 | depth: 6 5 | laplacian_type: "combinatorial" 6 | 7 | MODEL PARAMS: 8 | partition: [0.7,0.2,0.1] 9 | batch_size: 64 10 | learning_rate: 0.001 11 | n_epochs: 30 12 | kernel_size: 3 13 | 14 | SAVING: 15 | path_to_data: "/data/climate/data_5_all" 16 | tensorboard_path: "./" 17 | model_save_path: "./" 18 | 19 | DATALOADERS: 20 | download: False 21 | means_path: means.npy 22 | stds_path: stds.npy 23 | seed: 1 24 | 25 | REDUCEONPLATEAU: 26 | reducelronplateau_mode: "min" 27 | reducelronplateau_factor: 0.05 28 | reducelronplateau_patience: 3 29 | STEP: 30 | steplr_step_size: 30 31 | steplr_gamma: 0.1 32 | 33 | WARMUP: 34 | warmuplr_warmup_start_value: 0.001 35 | warmuplr_warmup_end_value: 0.001 36 | warmuplr_warmup_duration: 1 37 | 38 | EARLY_STOPPING: 39 | earlystopping_patience: 30 40 | -------------------------------------------------------------------------------- /scripts/means.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepsphere/deepsphere-pytorch/43f03f1bef146d256a7e1c1e69df6712d089b9e5/scripts/means.npy -------------------------------------------------------------------------------- /scripts/run_ar_tc.py: -------------------------------------------------------------------------------- 1 | """Example script for running DeepSphere U-Net on reduced AR_TC dataset. 2 | """ 3 | 4 | 5 | import numpy as np 6 | import torch 7 | from ignite.contrib.handlers.param_scheduler import create_lr_scheduler_with_warmup 8 | from ignite.contrib.handlers.tensorboard_logger import GradsHistHandler, OptimizerParamsHandler, OutputHandler, TensorboardLogger, WeightsHistHandler 9 | from ignite.engine import Engine, Events, create_supervised_evaluator 10 | from ignite.handlers import EarlyStopping, TerminateOnNan 11 | from ignite.metrics import EpochMetric 12 | from sklearn.model_selection import train_test_split 13 | from torch import nn, optim 14 | from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR 15 | from torch.utils.data import DataLoader 16 | from torch.utils.tensorboard import SummaryWriter 17 | from torchvision import transforms 18 | 19 | from deepsphere.data.datasets.dataset import ARTCDataset 20 | from deepsphere.data.transforms.transforms import Normalize, Permute, ToTensor 21 | from deepsphere.models.spherical_unet.unet_model import SphericalUNet 22 | from deepsphere.utils.initialization import init_device 23 | from deepsphere.utils.parser import create_parser, parse_config 24 | from deepsphere.utils.stats_extractor import stats_extractor 25 | 26 | 27 | def average_precision_compute_fn(y_pred, y_true): 28 | """Attached function to the custom ignite metric AveragePrecisionMultiLabel 29 | 30 | Args: 31 | y_pred (:obj:`torch.Tensor`): model predictions 32 | y_true (:obj:`torch.Tensor`): ground truths 33 | 34 | Raises: 35 | RuntimeError: Indicates that sklearn should be installed by the user. 36 | 37 | Returns: 38 | :obj:`numpy.array`: average precision vector. 39 | Of the same length as the number of labels present in the data 40 | """ 41 | try: 42 | from sklearn.metrics import average_precision_score 43 | except ImportError: 44 | raise RuntimeError("This metric requires sklearn to be installed.") 45 | 46 | ap = average_precision_score(y_true.numpy(), y_pred.numpy(), None) 47 | return ap 48 | 49 | 50 | # Pylint and Ignite incompatibilities: 51 | # pylint: disable=W0612 52 | # pylint: disable=W0613 53 | 54 | 55 | def validate_output_transform(x, y, y_pred): 56 | """A transform to format the output of the supervised evaluator before calculating the metric 57 | 58 | Args: 59 | x (:obj:`torch.Tensor`): the input to the model 60 | y (:obj:`torch.Tensor`): the output of the model 61 | y_pred (:obj:`torch.Tensor`): the ground truth labels 62 | 63 | Returns: 64 | (:obj:`torch.Tensor`, :obj:`torch.Tensor`): model predictions and ground truths reformatted 65 | """ 66 | output = y_pred 67 | labels = y 68 | B, V, C = output.shape 69 | B_labels, V_labels, C_labels = labels.shape 70 | output = output.view(B * V, C) 71 | labels = labels.view(B_labels * V_labels, C_labels) 72 | return output, labels 73 | 74 | 75 | def add_tensorboard(engine_train, optimizer, model, log_dir): 76 | """Creates an ignite logger object and adds training elements such as weight and gradient histograms 77 | 78 | Args: 79 | engine_train (:obj:`ignite.engine`): the train engine to attach to the logger 80 | optimizer (:obj:`torch.optim`): the model's optimizer 81 | model (:obj:`torch.nn.Module`): the model being trained 82 | log_dir (string): path to where tensorboard data should be saved 83 | """ 84 | # Create a logger 85 | tb_logger = TensorboardLogger(log_dir=log_dir) 86 | 87 | # Attach the logger to the trainer to log training loss at each iteration 88 | tb_logger.attach( 89 | engine_train, log_handler=OutputHandler(tag="training", output_transform=lambda loss: {"loss": loss}), event_name=Events.ITERATION_COMPLETED 90 | ) 91 | 92 | # Attach the logger to the trainer to log optimizer's parameters, e.g. learning rate at each iteration 93 | tb_logger.attach(engine_train, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.EPOCH_COMPLETED) 94 | 95 | # Attach the logger to the trainer to log model's weights as a histogram after each epoch 96 | tb_logger.attach(engine_train, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED) 97 | 98 | # Attach the logger to the trainer to log model's gradients as a histogram after each epoch 99 | tb_logger.attach(engine_train, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED) 100 | 101 | tb_logger.close() 102 | 103 | 104 | def get_dataloaders(parser_args): 105 | """Creates the datasets and the corresponding dataloaders 106 | 107 | Args: 108 | parser_args (dict): parsed arguments 109 | 110 | Returns: 111 | (:obj:`torch.utils.data.dataloader`, :obj:`torch.utils.data.dataloader`): train, validation dataloaders 112 | """ 113 | 114 | path_to_data = parser_args.path_to_data 115 | download = parser_args.download 116 | partition = parser_args.partition 117 | seed = parser_args.seed 118 | means_path = parser_args.means_path 119 | stds_path = parser_args.stds_path 120 | 121 | data = ARTCDataset(path=path_to_data, download=download, indices=None, transform_data=None, transform_labels=None) 122 | 123 | train_indices, temp = train_test_split(data.indices, train_size=partition[0], random_state=seed) 124 | val_indices, _ = train_test_split(temp, test_size=partition[2] / (partition[1] + partition[2]), random_state=seed) 125 | 126 | if (means_path is None) or (stds_path is None): 127 | transform_data_stats = transforms.Compose([ToTensor()]) 128 | train_set_stats = ARTCDataset( 129 | path=path_to_data, download=download, indices=train_indices, transform_data=transform_data_stats, transform_labels=None 130 | ) 131 | means, stds = stats_extractor(train_set_stats) 132 | np.save("./means.npy", means) 133 | np.save("./stds.npy", stds) 134 | else: 135 | try: 136 | means = np.load(means_path) 137 | stds = np.load(stds_path) 138 | except ValueError: 139 | print("No means or stds were provided. Or path names incorrect.") 140 | 141 | transform_data = transforms.Compose([ToTensor(), Permute(), Normalize(mean=means, std=stds)]) 142 | transform_labels = transforms.Compose([ToTensor(), Permute()]) 143 | train_set = ARTCDataset( 144 | path=path_to_data, download=download, indices=train_indices, transform_data=transform_data, transform_labels=transform_labels 145 | ) 146 | validation_set = ARTCDataset( 147 | path=path_to_data, download=download, indices=val_indices, transform_data=transform_data, transform_labels=transform_labels 148 | ) 149 | 150 | dataloader_train = DataLoader(train_set, batch_size=parser_args.batch_size, shuffle=True, num_workers=12) 151 | dataloader_validation = DataLoader(validation_set, batch_size=parser_args.batch_size, shuffle=False, num_workers=12) 152 | return dataloader_train, dataloader_validation 153 | 154 | 155 | def main(parser_args): 156 | """Main function to create trainer engine, add handlers to train and validation engines. 157 | Then runs train engine to perform training and validation. 158 | 159 | Args: 160 | parser_args (dict): parsed arguments 161 | """ 162 | dataloader_train, dataloader_validation = get_dataloaders(parser_args) 163 | criterion = nn.CrossEntropyLoss() 164 | 165 | unet = SphericalUNet(parser_args.pooling_class, parser_args.n_pixels, parser_args.depth, parser_args.laplacian_type, parser_args.kernel_size) 166 | unet, device = init_device(parser_args.device, unet) 167 | lr = parser_args.learning_rate 168 | optimizer = optim.Adam(unet.parameters(), lr=lr) 169 | 170 | def trainer(engine, batch): 171 | """Train Function to define train engine. 172 | Called for every batch of the train engine, for each epoch. 173 | 174 | Args: 175 | engine (ignite.engine): train engine 176 | batch (:obj:`torch.utils.data.dataloader`): batch from train dataloader 177 | 178 | Returns: 179 | :obj:`torch.tensor` : train loss for that batch and epoch 180 | """ 181 | unet.train() 182 | data, labels = batch 183 | labels = labels.to(device) 184 | data = data.to(device) 185 | output = unet(data) 186 | 187 | B, V, C = output.shape 188 | B_labels, V_labels, C_labels = labels.shape 189 | output = output.view(B * V, C) 190 | labels = labels.view(B_labels * V_labels, C_labels).max(1)[1] 191 | 192 | loss = criterion(output, labels) 193 | optimizer.zero_grad() 194 | loss.backward() 195 | optimizer.step() 196 | return loss.item() 197 | 198 | writer = SummaryWriter(parser_args.tensorboard_path) 199 | 200 | engine_train = Engine(trainer) 201 | 202 | engine_validate = create_supervised_evaluator( 203 | model=unet, metrics={"AP": EpochMetric(average_precision_compute_fn)}, device=device, output_transform=validate_output_transform 204 | ) 205 | 206 | engine_train.add_event_handler(Events.EPOCH_STARTED, lambda x: print("Starting Epoch: {}".format(x.state.epoch))) 207 | engine_train.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) 208 | 209 | @engine_train.on(Events.EPOCH_COMPLETED) 210 | def epoch_validation(engine): 211 | """Handler to run the validation engine at the end of the train engine's epoch. 212 | 213 | Args: 214 | engine (ignite.engine): train engine 215 | """ 216 | print("beginning validation epoch") 217 | engine_validate.run(dataloader_validation) 218 | 219 | reduce_lr_plateau = ReduceLROnPlateau( 220 | optimizer, 221 | mode=parser_args.reducelronplateau_mode, 222 | factor=parser_args.reducelronplateau_factor, 223 | patience=parser_args.reducelronplateau_patience, 224 | ) 225 | 226 | @engine_validate.on(Events.EPOCH_COMPLETED) 227 | def update_reduce_on_plateau(engine): 228 | """Handler to reduce the learning rate on plateau at the end of the validation engine's epoch 229 | 230 | Args: 231 | engine (ignite.engine): validation engine 232 | """ 233 | ap = engine.state.metrics["AP"] 234 | mean_average_precision = np.mean(ap[1:]) 235 | reduce_lr_plateau.step(mean_average_precision) 236 | 237 | @engine_validate.on(Events.EPOCH_COMPLETED) 238 | def save_epoch_results(engine): 239 | """Handler to save the metrics at the end of the validation engine's epoch 240 | 241 | Args: 242 | engine (ignite.engine): validation engine 243 | """ 244 | ap = engine.state.metrics["AP"] 245 | mean_average_precision = np.mean(ap[1:]) 246 | print("Average precisions:", ap) 247 | print("mAP:", mean_average_precision) 248 | writer.add_scalars( 249 | "metrics", 250 | {"mean average precision (AR+TC)": mean_average_precision, "AR average precision": ap[2], "TC average precision": ap[1]}, 251 | engine_train.state.epoch, 252 | ) 253 | writer.close() 254 | 255 | step_scheduler = StepLR(optimizer, step_size=parser_args.steplr_step_size, gamma=parser_args.steplr_gamma) 256 | scheduler = create_lr_scheduler_with_warmup( 257 | step_scheduler, 258 | warmup_start_value=parser_args.warmuplr_warmup_start_value, 259 | warmup_end_value=parser_args.warmuplr_warmup_end_value, 260 | warmup_duration=parser_args.warmuplr_warmup_duration, 261 | ) 262 | engine_validate.add_event_handler(Events.EPOCH_COMPLETED, scheduler) 263 | 264 | earlystopper = EarlyStopping( 265 | patience=parser_args.earlystopping_patience, score_function=lambda x: -x.state.metrics["AP"][1], trainer=engine_train 266 | ) 267 | engine_validate.add_event_handler(Events.EPOCH_COMPLETED, earlystopper) 268 | 269 | add_tensorboard(engine_train, optimizer, unet, log_dir=parser_args.tensorboard_path) 270 | 271 | engine_train.run(dataloader_train, max_epochs=parser_args.n_epochs) 272 | 273 | torch.save(unet.state_dict(), parser_args.model_save_path + "unet_state.pt") 274 | 275 | 276 | if __name__ == "__main__": 277 | PARSER_ARGS = parse_config(create_parser()) 278 | main(PARSER_ARGS) 279 | -------------------------------------------------------------------------------- /scripts/stds.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepsphere/deepsphere-pytorch/43f03f1bef146d256a7e1c1e69df6712d089b9e5/scripts/stds.npy -------------------------------------------------------------------------------- /scripts/temporality/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepsphere/deepsphere-pytorch/43f03f1bef146d256a7e1c1e69df6712d089b9e5/scripts/temporality/__init__.py -------------------------------------------------------------------------------- /scripts/temporality/config.yml: -------------------------------------------------------------------------------- 1 | IMAGE PARAMS: 2 | pooling_class: "icosahedron" 3 | n_pixels: 10242 4 | depth: 6 5 | laplacian_type: "combinatorial" 6 | 7 | TEMPORALITY: 8 | type: "LSTM" 9 | sequence_length: 3 10 | prediction_shift: 0 11 | 12 | MODEL PARAMS: 13 | partition: [0.7,0.2,0.1] 14 | batch_size: 54 15 | learning_rate: 0.01 16 | n_epochs: 50 17 | kernel_size: 5 18 | 19 | SAVING: 20 | path_to_data: "/data/climate/data_5_all" 21 | tensorboard_path: "./" 22 | model_save_path: "./" 23 | 24 | DATALOADERS: 25 | download: False 26 | means_path: "../means.npy" 27 | stds_path: "../stds.npy" 28 | seed: 1 29 | 30 | REDUCEONPLATEAU: 31 | reducelronplateau_mode: "min" 32 | reducelronplateau_factor: 0.01 33 | reducelronplateau_patience: 50 34 | 35 | STEP: 36 | steplr_step_size: 10 37 | steplr_gamma: 0.5 38 | 39 | WARMUP: 40 | warmuplr_warmup_start_value: 0.0 41 | warmuplr_warmup_end_value: 0.01 42 | warmuplr_warmup_duration: 10 43 | 44 | EARLY_STOPPING: 45 | earlystopping_patience: 50 46 | -------------------------------------------------------------------------------- /scripts/temporality/run_ar_tc.py: -------------------------------------------------------------------------------- 1 | """Example script for running DeepSphere U-Net on reduced AR_TC dataset. 2 | """ 3 | 4 | 5 | import numpy as np 6 | import torch 7 | from ignite.contrib.handlers.param_scheduler import create_lr_scheduler_with_warmup 8 | from ignite.contrib.handlers.tensorboard_logger import GradsHistHandler, OptimizerParamsHandler, OutputHandler, TensorboardLogger, WeightsHistHandler 9 | from ignite.engine import Engine, Events, create_supervised_evaluator 10 | from ignite.handlers import EarlyStopping, TerminateOnNan 11 | from ignite.metrics import EpochMetric 12 | from sklearn.model_selection import train_test_split 13 | from torch import nn, optim 14 | from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR 15 | from torch.utils.data import DataLoader 16 | from torch.utils.tensorboard import SummaryWriter 17 | from torchvision import transforms 18 | 19 | from deepsphere.data.transforms.transforms import Normalize, Permute, ToTensor 20 | from deepsphere.utils.initialization import init_dataset_temp, init_device, init_unet_temp 21 | from deepsphere.utils.parser import create_parser, parse_config 22 | from deepsphere.utils.stats_extractor import stats_extractor 23 | 24 | 25 | def average_precision_compute_fn(y_pred, y_true): 26 | """Attached function to the custom ignite metric AveragePrecisionMultiLabel 27 | 28 | Args: 29 | y_pred (:obj:`torch.Tensor`): model predictions 30 | y_true (:obj:`torch.Tensor`): ground truths 31 | 32 | Raises: 33 | RuntimeError: Indicates that sklearn should be installed by the user. 34 | 35 | Returns: 36 | :obj:`numpy.array`: average precision vector. 37 | Of the same length as the number of labels present in the data 38 | """ 39 | try: 40 | from sklearn.metrics import average_precision_score 41 | except ImportError: 42 | raise RuntimeError("This metric requires sklearn to be installed.") 43 | 44 | ap = average_precision_score(y_true.numpy(), y_pred.numpy(), None) 45 | return ap 46 | 47 | 48 | # Pylint and Ignite incompatibilities: 49 | # pylint: disable=W0612 50 | # pylint: disable=W0613 51 | 52 | 53 | def validate_output_transform(x, y, y_pred): 54 | """A transform to format the output of the supervised evaluator before calculating the metric 55 | 56 | Args: 57 | x (:obj:`torch.Tensor`): the input to the model 58 | y (:obj:`torch.Tensor`): the output of the model 59 | y_pred (:obj:`torch.Tensor`): the ground truth labels 60 | 61 | Returns: 62 | (:obj:`torch.Tensor`, :obj:`torch.Tensor`): model predictions and ground truths reformatted 63 | """ 64 | output = y_pred 65 | labels = y 66 | B, V, C = output.shape 67 | B_labels, V_labels, C_labels = labels.shape 68 | output = output.view(B * V, C) 69 | labels = labels.view(B_labels * V_labels, C_labels) 70 | return output, labels 71 | 72 | 73 | def add_tensorboard(engine_train, optimizer, model, log_dir): 74 | """Creates an ignite logger object and adds training elements such as weight and gradient histograms 75 | 76 | Args: 77 | engine_train (:obj:`ignite.engine`): the train engine to attach to the logger 78 | optimizer (:obj:`torch.optim`): the model's optimizer 79 | model (:obj:`torch.nn.Module`): the model being trained 80 | log_dir (string): path to where tensorboard data should be saved 81 | """ 82 | # Create a logger 83 | tb_logger = TensorboardLogger(log_dir=log_dir) 84 | 85 | # Attach the logger to the trainer to log training loss at each iteration 86 | tb_logger.attach( 87 | engine_train, log_handler=OutputHandler(tag="training", output_transform=lambda loss: {"loss": loss}), event_name=Events.ITERATION_COMPLETED 88 | ) 89 | 90 | # Attach the logger to the trainer to log optimizer's parameters, e.g. learning rate at each iteration 91 | tb_logger.attach(engine_train, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.EPOCH_COMPLETED) 92 | 93 | # Attach the logger to the trainer to log model's weights as a histogram after each epoch 94 | tb_logger.attach(engine_train, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED) 95 | 96 | # Attach the logger to the trainer to log model's gradients as a histogram after each epoch 97 | tb_logger.attach(engine_train, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED) 98 | 99 | tb_logger.close() 100 | 101 | 102 | def get_dataloaders(parser_args): 103 | """Creates the datasets and the corresponding dataloaders 104 | 105 | Args: 106 | parser_args (dict): parsed arguments 107 | 108 | Returns: 109 | (:obj:`torch.utils.data.dataloader`, :obj:`torch.utils.data.dataloader`): train, validation dataloaders 110 | """ 111 | 112 | partition = parser_args.partition 113 | seed = parser_args.seed 114 | means_path = parser_args.means_path 115 | stds_path = parser_args.stds_path 116 | 117 | data = init_dataset_temp(parser=parser_args, indices=None, transform_image=None, transform_labels=None) 118 | 119 | train_indices, temp = train_test_split(data.indices, train_size=partition[0], random_state=seed) 120 | val_indices, _ = train_test_split(temp, test_size=partition[2] / (partition[1] + partition[2]), random_state=seed) 121 | 122 | if (means_path is None) or (stds_path is None): 123 | transform_image_stats = transforms.Compose([ToTensor()]) 124 | train_set_stats = init_dataset_temp(parser=parser_args, indices=train_indices, transform_image=transform_image_stats, transform_labels=None) 125 | means, stds = stats_extractor(train_set_stats) 126 | np.save("./means.npy", means) 127 | np.save("./stds.npy", stds) 128 | else: 129 | try: 130 | means = np.load(means_path) 131 | stds = np.load(stds_path) 132 | except ValueError: 133 | print("No means or stds were provided. Or path names incorrect.") 134 | 135 | transform_image = transforms.Compose([ToTensor(), Permute(), Normalize(mean=means, std=stds)]) 136 | transform_labels = transforms.Compose([ToTensor(), Permute()]) 137 | train_set = init_dataset_temp(parser=parser_args, indices=train_indices, transform_image=transform_image, transform_labels=transform_labels) 138 | validation_set = init_dataset_temp(parser=parser_args, indices=val_indices, transform_image=transform_image, transform_labels=transform_labels) 139 | 140 | dataloader_train = DataLoader(train_set, batch_size=parser_args.batch_size, shuffle=True, num_workers=12) 141 | dataloader_validation = DataLoader(validation_set, batch_size=parser_args.batch_size, shuffle=False, num_workers=12) 142 | return dataloader_train, dataloader_validation 143 | 144 | 145 | def main(parser_args): 146 | """Main function to create trainer engine, add handlers to train and validation engines. 147 | Then runs train engine to perform training and validation. 148 | 149 | Args: 150 | parser_args (dict): parsed arguments 151 | """ 152 | dataloader_train, dataloader_validation = get_dataloaders(parser_args) 153 | criterion = nn.CrossEntropyLoss() 154 | 155 | unet = init_unet_temp(parser_args) 156 | unet, device = init_device(parser_args.device, unet) 157 | 158 | lr = parser_args.learning_rate 159 | optimizer = optim.Adam(unet.parameters(), lr=lr) 160 | 161 | def trainer(engine, batch): 162 | """Train Function to define train engine. 163 | Called for every batch of the train engine, for each epoch. 164 | 165 | Args: 166 | engine (ignite.engine): train engine 167 | batch (:obj:`torch.utils.data.dataloader`): batch from train dataloader 168 | 169 | Returns: 170 | :obj:`torch.tensor` : train loss for that batch and epoch 171 | """ 172 | unet.train() 173 | data, labels = batch 174 | labels = labels.to(device) 175 | data = data.to(device) 176 | # for sample in data: 177 | # sample = sample.to(device) 178 | output = unet(data) 179 | 180 | B, V, C = output.shape 181 | B_labels, V_labels, C_labels = labels.shape 182 | output = output.view(B * V, C) 183 | labels = labels.view(B_labels * V_labels, C_labels).max(1)[1] 184 | 185 | loss = criterion(output, labels) 186 | optimizer.zero_grad() 187 | loss.backward() 188 | optimizer.step() 189 | return loss.item() 190 | 191 | writer = SummaryWriter(parser_args.tensorboard_path) 192 | 193 | engine_train = Engine(trainer) 194 | 195 | engine_validate = create_supervised_evaluator( 196 | model=unet, metrics={"AP": EpochMetric(average_precision_compute_fn)}, device=device, output_transform=validate_output_transform 197 | ) 198 | 199 | engine_train.add_event_handler(Events.EPOCH_STARTED, lambda x: print("Starting Epoch: {}".format(x.state.epoch))) 200 | engine_train.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) 201 | 202 | @engine_train.on(Events.EPOCH_COMPLETED) 203 | def epoch_validation(engine): 204 | """Handler to run the validation engine at the end of the train engine's epoch. 205 | 206 | Args: 207 | engine (ignite.engine): train engine 208 | """ 209 | print("beginning validation epoch") 210 | engine_validate.run(dataloader_validation) 211 | 212 | reduce_lr_plateau = ReduceLROnPlateau( 213 | optimizer, 214 | mode=parser_args.reducelronplateau_mode, 215 | factor=parser_args.reducelronplateau_factor, 216 | patience=parser_args.reducelronplateau_patience, 217 | ) 218 | 219 | @engine_validate.on(Events.EPOCH_COMPLETED) 220 | def update_reduce_on_plateau(engine): 221 | """Handler to reduce the learning rate on plateau at the end of the validation engine's epoch 222 | 223 | Args: 224 | engine (ignite.engine): validation engine 225 | """ 226 | ap = engine.state.metrics["AP"] 227 | mean_average_precision = np.mean(ap[1:]) 228 | reduce_lr_plateau.step(mean_average_precision) 229 | 230 | @engine_validate.on(Events.EPOCH_COMPLETED) 231 | def save_epoch_results(engine): 232 | """Handler to save the metrics at the end of the validation engine's epoch 233 | 234 | Args: 235 | engine (ignite.engine): validation engine 236 | """ 237 | ap = engine.state.metrics["AP"] 238 | mean_average_precision = np.mean(ap[1:]) 239 | print("Average precisions:", ap) 240 | print("mAP:", mean_average_precision) 241 | writer.add_scalars( 242 | "metrics", 243 | {"mean average precision (AR+TC)": mean_average_precision, "AR average precision": ap[2], "TC average precision": ap[1]}, 244 | engine_train.state.epoch, 245 | ) 246 | writer.close() 247 | 248 | step_scheduler = StepLR(optimizer, step_size=parser_args.steplr_step_size, gamma=parser_args.steplr_gamma) 249 | scheduler = create_lr_scheduler_with_warmup( 250 | step_scheduler, 251 | warmup_start_value=parser_args.warmuplr_warmup_start_value, 252 | warmup_end_value=parser_args.warmuplr_warmup_end_value, 253 | warmup_duration=parser_args.warmuplr_warmup_duration, 254 | ) 255 | engine_validate.add_event_handler(Events.EPOCH_COMPLETED, scheduler) 256 | 257 | earlystopper = EarlyStopping( 258 | patience=parser_args.earlystopping_patience, score_function=lambda x: -x.state.metrics["AP"][1], trainer=engine_train 259 | ) 260 | engine_validate.add_event_handler(Events.EPOCH_COMPLETED, earlystopper) 261 | 262 | add_tensorboard(engine_train, optimizer, unet, log_dir=parser_args.tensorboard_path) 263 | 264 | engine_train.run(dataloader_train, max_epochs=parser_args.n_epochs) 265 | 266 | torch.save(unet.state_dict(), parser_args.model_save_path + "unet_state.pt") 267 | 268 | 269 | if __name__ == "__main__": 270 | # run with (for example): 271 | # python -W ignore run_ar_tc_temporality.py --config-file config.example.yml --path_to_data /data/climate/data_5_all 272 | PARSER_ARGS = parse_config(create_parser()) 273 | main(PARSER_ARGS) 274 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [pylint] 2 | # pylint needs to be run with --rcfile=setup.cfg to detect this file! 3 | # C0411: Wrong import order (we use isort instead) 4 | # R0801: Similar lines in N files 5 | # R0902: Too many instance attributes 6 | # R0903: Too few public methods (min 2) 7 | # R0904: Too many public methods (max 20) 8 | # R0912: Too many branches 9 | # R0913: Too many arguments 10 | # R0914: Too many local variables 11 | # R0915: Too many statements 12 | # R1702: Too many nested blocks 13 | # W0621: Redefining name '$' from outer scope 14 | # W1202: Use % formatting in logging functions and pass the % parameters as argument 15 | disable = C0411, R0801, R0902, R0903, R0904, R0912, R0913, R0914, R0915, R1702, W0621, W1202, W1503 16 | ignored-modules = torch, torchvision, pygsp, ignite 17 | generated-members = torch.*, numpy.*, np.*, ignite.* 18 | max-line-length = 150 19 | argument-rgx = [xyA-Z]|[a-z_][a-z0-9_]{1,30}$ 20 | attr-rgx = [xyA-Z]|[a-z_][a-z0-9_]{1,30}$ 21 | variable-rgx = [xyA-Z]|[a-z_][a-z0-9_]{1,30}$ 22 | extension-pkg-whitelist = cv2 23 | output-format = colorized 24 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import re 3 | from os import path 4 | 5 | from setuptools import find_packages, setup 6 | 7 | this_directory = path.abspath(path.dirname(__file__)) 8 | 9 | with open(path.join(this_directory, "README.md"), "r") as fh: 10 | long_description = fh.read() 11 | 12 | 13 | def read(*parts): 14 | with codecs.open(path.join(this_directory, *parts), "r") as fp: 15 | return fp.read() 16 | 17 | 18 | def find_version(*file_paths): 19 | version_file = read(*file_paths) 20 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) 21 | if version_match: 22 | return version_match.group(1) 23 | raise RuntimeError("Unable to find version string.") 24 | 25 | 26 | def get_requirements(file_name): 27 | """Strip a requirement files of all comments 28 | 29 | Args: 30 | file_name (string): File which contains requirements 31 | 32 | Returns: 33 | list: list of requirements 34 | """ 35 | 36 | with open(path.join(this_directory, "{}.txt".format(file_name)), "r") as file: 37 | reqs = [] 38 | 39 | for req in file.readlines(): 40 | if not req.startswith("#"): 41 | if req.startswith("git+"): 42 | name = req.split("#")[-1].replace("egg=", "").strip() 43 | req.replace("git+", "") 44 | reqs.append(f"{name} @ {req}") 45 | else: 46 | reqs.append(req) 47 | 48 | return reqs 49 | 50 | 51 | INSTALL_REQUIRES = get_requirements("requirements") 52 | TESTS_REQUIRES = get_requirements("requirements-tests") 53 | 54 | EXTRA_REQUIRE = {"tests": TESTS_REQUIRES} 55 | 56 | setup( 57 | name="deepsphere", 58 | version=find_version("deepsphere", "__init__.py"), 59 | description="Deep Sphere package", 60 | long_description=long_description, 61 | long_description_content_type="text/markdown", 62 | author="Arcanite", 63 | author_email="contact@arcanite.ch", 64 | install_requires=INSTALL_REQUIRES, 65 | extras_require=EXTRA_REQUIRE, 66 | packages=find_packages(), 67 | ) 68 | --------------------------------------------------------------------------------