├── .gitignore ├── LICENSE ├── README.md ├── data ├── gt_bach.wav └── gt_counting.wav ├── dataio.py ├── diff_operators.py ├── environment.yml ├── experiment_scripts ├── __init__.py ├── fwi │ └── data_cylinder_5.mat ├── test_audio.py ├── test_conv_neural_process.py ├── test_neural_process.py ├── test_sdf.py ├── train_audio.py ├── train_helmholtz.py ├── train_img.py ├── train_img_inpainting.py ├── train_img_neural_process.py ├── train_inverse_helmholtz.py ├── train_poisson_grad_img.py ├── train_poisson_gradcomp_img.py ├── train_poisson_lapl_img.py ├── train_sdf.py ├── train_video.py └── train_wave_equation.py ├── explore_siren.ipynb ├── loss_functions.py ├── make_figures.py ├── meta_modules.py ├── modules.py ├── sdf_meshing.py ├── torchmeta ├── __init__.py ├── datasets │ ├── __init__.py │ ├── assets │ │ ├── cifar100 │ │ │ ├── cifar-fs │ │ │ │ ├── test.json │ │ │ │ ├── train.json │ │ │ │ └── val.json │ │ │ └── fc100 │ │ │ │ ├── test.json │ │ │ │ ├── train.json │ │ │ │ └── val.json │ │ ├── cub │ │ │ ├── test.json │ │ │ ├── train.json │ │ │ └── val.json │ │ ├── doublemnist │ │ │ ├── test.json │ │ │ ├── train.json │ │ │ └── val.json │ │ ├── omniglot │ │ │ ├── test.json │ │ │ ├── train.json │ │ │ └── val.json │ │ ├── tcga │ │ │ ├── cancers.json │ │ │ ├── task_variables.json │ │ │ ├── test.json │ │ │ ├── train.json │ │ │ └── val.json │ │ └── triplemnist │ │ │ ├── test.json │ │ │ ├── train.json │ │ │ └── val.json │ ├── cifar100 │ │ ├── __init__.py │ │ ├── base.py │ │ ├── cifar_fs.py │ │ └── fc100.py │ ├── cub.py │ ├── doublemnist.py │ ├── helpers.py │ ├── miniimagenet.py │ ├── omniglot.py │ ├── tcga.py │ ├── tieredimagenet.py │ ├── triplemnist.py │ └── utils.py ├── modules │ ├── __init__.py │ ├── batchnorm.py │ ├── container.py │ ├── conv.py │ ├── linear.py │ ├── module.py │ ├── normalization.py │ └── utils.py ├── tests │ ├── __init__.py │ ├── test_dataloaders.py │ ├── test_prototype.py │ ├── test_splitters.py │ └── test_toy.py ├── toy │ ├── __init__.py │ ├── harmonic.py │ ├── helpers.py │ ├── sinusoid.py │ └── sinusoid_line.py ├── transforms │ ├── __init__.py │ ├── augmentations.py │ ├── categorical.py │ ├── splitters.py │ ├── target_transforms.py │ └── utils.py ├── utils │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── dataloader.py │ │ ├── dataset.py │ │ ├── sampler.py │ │ └── task.py │ ├── metrics.py │ └── prototype.py └── version.py ├── training.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Vincent Sitzmann 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Implicit Neural Representations with Periodic Activation Functions 2 | ### [Project Page](https://vsitzmann.github.io/siren) | [Paper](https://arxiv.org/abs/2006.09661) | [Data](https://drive.google.com/drive/folders/1_iq__37-hw7FJOEUK1tX7mdp8SKB368K?usp=sharing) 3 | [![Explore Siren in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vsitzmann/siren/blob/master/explore_siren.ipynb)
4 | 5 | [Vincent Sitzmann](https://vsitzmann.github.io/)\*, 6 | [Julien N. P. Martel](http://www.jmartel.net)\*, 7 | [Alexander W. Bergman](http://alexanderbergman7.github.io), 8 | [David B. Lindell](http://www.davidlindell.com/), 9 | [Gordon Wetzstein](https://stanford.edu/~gordonwz/)
10 | Stanford University, \*denotes equal contribution 11 | 12 | This is the official implementation of the paper "Implicit Neural Representations with Periodic Activation Functions". 13 | 14 | [![siren_video](https://img.youtube.com/vi/Q2fLWGBeaiI/0.jpg)](https://www.youtube.com/watch?v=Q2fLWGBeaiI) 15 | 16 | 17 | ## Google Colab 18 | If you want to experiment with Siren, we have written a [Colab](https://colab.research.google.com/github/vsitzmann/siren/blob/master/explore_siren.ipynb). 19 | It's quite comprehensive and comes with a no-frills, drop-in implementation of SIREN. It doesn't require 20 | installing anything, and goes through the following experiments / SIREN properties: 21 | * Fitting an image 22 | * Fitting an audio signal 23 | * Solving Poisson's equation 24 | * Initialization scheme & distribution of activations 25 | * Distribution of activations is shift-invariant 26 | * Periodicity & behavior outside of the training range. 27 | 28 | ## Tensorflow Playground 29 | You can also play arond with a tiny SIREN interactively, directly in the browser, via the Tensorflow Playground [here](https://dcato98.github.io/playground/#activation=sine). Thanks to [David Cato](https://github.com/dcato98) for implementing this! 30 | 31 | ## Get started 32 | If you want to reproduce all the results (including the baselines) shown in the paper, the videos, point clouds, and 33 | audio files can be found [here](https://drive.google.com/drive/folders/1_iq__37-hw7FJOEUK1tX7mdp8SKB368K?usp=sharing). 34 | 35 | You can then set up a conda environment with all dependencies like so: 36 | ``` 37 | conda env create -f environment.yml 38 | conda activate siren 39 | ``` 40 | 41 | ## High-Level structure 42 | The code is organized as follows: 43 | * dataio.py loads training and testing data. 44 | * training.py contains a generic training routine. 45 | * modules.py contains layers and full neural network modules. 46 | * meta_modules.py contains hypernetwork code. 47 | * utils.py contains utility functions, most promintently related to the writing of Tensorboard summaries. 48 | * diff_operators.py contains implementations of differential operators. 49 | * loss_functions.py contains loss functions for the different experiments. 50 | * make_figures.py contains helper functions to create the convergence videos shown in the video. 51 | * ./experiment_scripts/ contains scripts to reproduce experiments in the paper. 52 | 53 | ## Reproducing experiments 54 | The directory `experiment_scripts` contains one script per experiment in the paper. 55 | 56 | To monitor progress, the training code writes tensorboard summaries into a "summaries"" subdirectory in the logging_root. 57 | 58 | ### Image experiments 59 | The image experiment can be reproduced with 60 | ``` 61 | python experiment_scripts/train_img.py --model_type=sine 62 | ``` 63 | The figures in the paper were made by extracting images from the tensorboard summaries. Example code how to do this can 64 | be found in the make_figures.py script. 65 | 66 | ### Audio experiments 67 | This github repository comes with both the "counting" and "bach" audio clips under ./data. 68 | 69 | They can be trained with 70 | ``` 71 | python experiment_scipts/train_audio.py --model_type=sine --wav_path= 72 | ``` 73 | 74 | ### Video experiments 75 | The "bikes" video sequence comes with scikit-video and need not be downloaded. The cat video can be downloaded with the 76 | link above. 77 | 78 | To fit a model to a video, run 79 | ``` 80 | python experiment_scipts/train_video.py --model_type=sine --experiment_name bikes_video 81 | ``` 82 | 83 | ### Poisson experiments 84 | For the poisson experiments, there are three separate scripts: One for reconstructing an image from its gradients 85 | (train_poisson_grad_img.py), from its laplacian (train_poisson_lapl_image.py), and to combine two images 86 | (train_poisson_gradcomp_img.py). 87 | 88 | Some of the experiments were run using the BSD500 datast, which you can download [here](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/). 89 | 90 | ### SDF Experiments 91 | To fit a Signed Distance Function (SDF) with SIREN, you first need a pointcloud in .xyz format that includes surface normals. 92 | If you only have a mesh / ply file, this can be accomplished with the open-source tool Meshlab. 93 | 94 | To reproduce our results, we provide both models of the Thai Statue from the 3D Stanford model repository and the living room used in our paper 95 | for download here. 96 | 97 | To start training a SIREN, run: 98 | ``` 99 | python experiments_scripts/train_single_sdf.py --model_type=sine --point_cloud_path= --batch_size=250000 --experiment_name=experiment_1 100 | ``` 101 | This will regularly save checkpoints in the directory specified by the rootpath in the script, in a subdirectory "experiment_1". 102 | The batch_size is typically adjusted to fit in the entire memory of your GPU. 103 | Our experiments show that with a 256, 3 hidden layer SIREN one can set the batch size between 230-250'000 for a NVidia GPU with 12GB memory. 104 | 105 | To inspect a SDF fitted to a 3D point cloud, we now need to create a mesh from the zero-level set of the SDF. 106 | This is performed with another script that uses a marching cubes algorithm (adapted from the DeepSDF github repo) 107 | and creates the mesh saved in a .ply file format. It can be called with: 108 | ``` 109 | python experiments_scripts/test_single_sdf.py --checkpoint_path= --experiment_name=experiment_1_rec 110 | ``` 111 | This will save the .ply file as "reconstruction.ply" in "experiment_1_rec" (be patient, the marching cube meshing step takes some time ;) ) 112 | In the event the machine you use for the reconstruction does not have enough RAM, running test_sdf script will likely freeze. If this is the case, 113 | please use the option --resolution=512 in the command line above (set to 1600 by default) that will reconstruct the mesh at a lower spatial resolution. 114 | 115 | The .ply file can be visualized using a software such as [Meshlab](https://www.meshlab.net/#download) (a cross-platform visualizer and editor for 3D models). 116 | 117 | ### Helmholtz and wave equation experiments 118 | The helmholtz and wave equation experiments can be reproduced with the train_wave_equation.py and train_helmholtz.py scripts. 119 | 120 | ## Torchmeta 121 | We're using the excellent [torchmeta](https://github.com/tristandeleu/pytorch-meta) to implement hypernetworks. We 122 | realized that there is a technical report, which we forgot to cite - it'll make it into the camera-ready version! 123 | 124 | ## Citation 125 | If you find our work useful in your research, please cite: 126 | ``` 127 | @inproceedings{sitzmann2019siren, 128 | author = {Sitzmann, Vincent 129 | and Martel, Julien N.P. 130 | and Bergman, Alexander W. 131 | and Lindell, David B. 132 | and Wetzstein, Gordon}, 133 | title = {Implicit Neural Representations 134 | with Periodic Activation Functions}, 135 | booktitle = {arXiv}, 136 | year={2020} 137 | } 138 | ``` 139 | 140 | ## Contact 141 | If you have any questions, please feel free to email the authors. 142 | -------------------------------------------------------------------------------- /data/gt_bach.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsitzmann/siren/4df34baee3f0f9c8f351630992c1fe1f69114b5f/data/gt_bach.wav -------------------------------------------------------------------------------- /data/gt_counting.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsitzmann/siren/4df34baee3f0f9c8f351630992c1fe1f69114b5f/data/gt_counting.wav -------------------------------------------------------------------------------- /diff_operators.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import grad 3 | 4 | 5 | def hessian(y, x): 6 | ''' hessian of y wrt x 7 | y: shape (meta_batch_size, num_observations, channels) 8 | x: shape (meta_batch_size, num_observations, 2) 9 | ''' 10 | meta_batch_size, num_observations = y.shape[:2] 11 | grad_y = torch.ones_like(y[..., 0]).to(y.device) 12 | h = torch.zeros(meta_batch_size, num_observations, y.shape[-1], x.shape[-1], x.shape[-1]).to(y.device) 13 | for i in range(y.shape[-1]): 14 | # calculate dydx over batches for each feature value of y 15 | dydx = grad(y[..., i], x, grad_y, create_graph=True)[0] 16 | 17 | # calculate hessian on y for each x value 18 | for j in range(x.shape[-1]): 19 | h[..., i, j, :] = grad(dydx[..., j], x, grad_y, create_graph=True)[0][..., :] 20 | 21 | status = 0 22 | if torch.any(torch.isnan(h)): 23 | status = -1 24 | return h, status 25 | 26 | 27 | def laplace(y, x): 28 | grad = gradient(y, x) 29 | return divergence(grad, x) 30 | 31 | 32 | def divergence(y, x): 33 | div = 0. 34 | for i in range(y.shape[-1]): 35 | div += grad(y[..., i], x, torch.ones_like(y[..., i]), create_graph=True)[0][..., i:i+1] 36 | return div 37 | 38 | 39 | def gradient(y, x, grad_outputs=None): 40 | if grad_outputs is None: 41 | grad_outputs = torch.ones_like(y) 42 | grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0] 43 | return grad 44 | 45 | 46 | def jacobian(y, x): 47 | ''' jacobian of y wrt x ''' 48 | meta_batch_size, num_observations = y.shape[:2] 49 | jac = torch.zeros(meta_batch_size, num_observations, y.shape[-1], x.shape[-1]).to(y.device) # (meta_batch_size*num_points, 2, 2) 50 | for i in range(y.shape[-1]): 51 | # calculate dydx over batches for each feature value of y 52 | y_flat = y[...,i].view(-1, 1) 53 | jac[:, :, i, :] = grad(y_flat, x, torch.ones_like(y_flat), create_graph=True)[0] 54 | 55 | status = 0 56 | if torch.any(torch.isnan(jac)): 57 | status = -1 58 | 59 | return jac, status 60 | 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: siren 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - absl-py=0.9.0=py36_0 9 | - attrs=19.3.0=py_0 10 | - backcall=0.1.0=py36_0 11 | - blas=1.0=mkl 12 | - bleach=3.1.4=py_0 13 | - blinker=1.4=py36_0 14 | - bzip2=1.0.8=h516909a_2 15 | - c-ares=1.15.0=h7b6447c_1001 16 | - ca-certificates=2020.6.20=hecda079_0 17 | - cachetools=3.1.1=py_0 18 | - cairo=1.14.12=h8948797_3 19 | - certifi=2020.6.20=py36h9f0ad1d_0 20 | - cffi=1.14.0=py36he30daa8_1 21 | - chardet=3.0.4=py36_1003 22 | - click=7.1.2=py_0 23 | - cloudpickle=1.4.1=py_0 24 | - configargparse=1.1=py_0 25 | - cryptography=2.9.2=py36h1ba5d50_0 26 | - cudatoolkit=10.1.243=h6bb024c_0 27 | - cycler=0.10.0=py36_0 28 | - cytoolz=0.10.1=py36h7b6447c_0 29 | - dask-core=2.17.2=py_0 30 | - dbus=1.13.14=hb2f20db_0 31 | - decorator=4.4.2=py_0 32 | - defusedxml=0.6.0=py_0 33 | - entrypoints=0.3=py36_0 34 | - expat=2.2.6=he6710b0_0 35 | - ffmpeg=4.0.2=ha0c5888_2 36 | - fftw=3.3.8=nompi_h7f3a6c3_1110 37 | - fontconfig=2.13.0=h9420a91_0 38 | - freetype=2.9.1=h8a8886c_1 39 | - gettext=0.19.8.1=h5e8e0c9_1 40 | - ghostscript=9.22=hf484d3e_1001 41 | - giflib=5.1.9=h516909a_0 42 | - glib=2.63.1=h3eb4bd4_1 43 | - gmp=6.1.2=h6c8ec71_1 44 | - gnutls=3.5.19=h2a4e5f8_1 45 | - google-auth=1.14.1=py_0 46 | - google-auth-oauthlib=0.4.1=py_2 47 | - graphite2=1.3.13=he1b5a44_1001 48 | - graphviz=2.38.0=hcf1ce16_1009 49 | - grpcio=1.27.2=py36hf8bcb03_0 50 | - gst-plugins-base=1.14.0=hbbd80ab_1 51 | - gstreamer=1.14.0=hb31296c_0 52 | - h5py=2.8.0=py36h989c5e5_3 53 | - harfbuzz=1.8.8=hffaf4a1_0 54 | - hdf5=1.10.2=hba1933b_1 55 | - icu=58.2=he6710b0_3 56 | - idna=2.9=py_1 57 | - imageio=2.8.0=py_0 58 | - imagemagick=7.0.8_11=pl526hc610aec_0 59 | - importlib-metadata=1.6.0=py36_0 60 | - importlib_metadata=1.6.0=0 61 | - intel-openmp=2020.1=217 62 | - ipykernel=5.1.4=py36h39e3cac_0 63 | - ipython=7.13.0=py36h5ca1d4c_0 64 | - ipython_genutils=0.2.0=py36_0 65 | - ipywidgets=7.5.1=py_0 66 | - jbig=2.1=h516909a_2002 67 | - jedi=0.17.0=py36_0 68 | - jinja2=2.11.2=py_0 69 | - joblib=0.15.1=py_0 70 | - jpeg=9d=h516909a_0 71 | - jsonschema=3.2.0=py36_0 72 | - jupyter=1.0.0=py36_7 73 | - jupyter_client=6.1.3=py_0 74 | - jupyter_console=6.1.0=py_0 75 | - jupyter_core=4.6.3=py36_0 76 | - kiwisolver=1.2.0=py36hfd86e86_0 77 | - ld_impl_linux-64=2.33.1=h53a641e_7 78 | - libedit=3.1.20181209=hc058e9b_0 79 | - libffi=3.3=he6710b0_1 80 | - libgcc-ng=9.1.0=hdf63c60_0 81 | - libgfortran-ng=7.3.0=hdf63c60_0 82 | - libiconv=1.15=h516909a_1006 83 | - libpng=1.6.37=hbc83047_0 84 | - libprotobuf=3.11.4=hd408876_0 85 | - libsodium=1.0.16=h1bed415_0 86 | - libstdcxx-ng=9.1.0=hdf63c60_0 87 | - libtiff=4.1.0=h2733197_1 88 | - libtool=2.4.6=h14c3975_1002 89 | - libuuid=1.0.3=h1bed415_2 90 | - libwebp=0.5.2=7 91 | - libxcb=1.13=h1bed415_1 92 | - libxml2=2.9.9=hea5a465_1 93 | - lz4-c=1.9.2=he6710b0_0 94 | - markdown=3.1.1=py36_0 95 | - markupsafe=1.1.1=py36h7b6447c_0 96 | - mistune=0.8.4=py36h7b6447c_0 97 | - mkl=2020.1=217 98 | - mkl-service=2.3.0=py36he904b0f_0 99 | - mkl_fft=1.0.15=py36ha843d7b_0 100 | - mkl_random=1.1.1=py36h0573a6f_0 101 | - nbconvert=5.6.1=py36_0 102 | - nbformat=5.0.6=py_0 103 | - ncurses=6.2=he6710b0_1 104 | - nettle=3.3=0 105 | - networkx=2.4=py_0 106 | - ninja=1.9.0=py36hfd86e86_0 107 | - notebook=6.0.3=py36_0 108 | - numpy=1.18.1=py36h4f9e942_0 109 | - numpy-base=1.18.1=py36hde5b4d6_1 110 | - oauthlib=3.1.0=py_0 111 | - olefile=0.46=py36_0 112 | - openh264=1.8.0=hdbcaa40_1000 113 | - openjpeg=2.3.1=h981e76c_3 114 | - openssl=1.1.1g=h516909a_0 115 | - pandoc=2.2.3.2=0 116 | - pandocfilters=1.4.2=py36_1 117 | - pango=1.40.14=he752989_2 118 | - parso=0.7.0=py_0 119 | - pcre=8.43=he6710b0_0 120 | - perl=5.26.2=h516909a_1006 121 | - pexpect=4.8.0=py36_0 122 | - pickleshare=0.7.5=py36_0 123 | - pillow=6.2.1=py36h34e0f95_0 124 | - pip=20.0.2=py36_3 125 | - pixman=0.38.0=h516909a_1003 126 | - pkg-config=0.29.2=h516909a_1006 127 | - prometheus_client=0.7.1=py_0 128 | - prompt-toolkit=3.0.5=py_0 129 | - prompt_toolkit=3.0.5=0 130 | - protobuf=3.11.4=py36he6710b0_0 131 | - ptyprocess=0.6.0=py36_0 132 | - pyasn1=0.4.8=py_0 133 | - pyasn1-modules=0.2.7=py_0 134 | - pycparser=2.20=py_0 135 | - pygments=2.6.1=py_0 136 | - pyjwt=1.7.1=py36_0 137 | - pyopenssl=19.1.0=py36_0 138 | - pyparsing=2.4.7=py_0 139 | - pyqt=5.9.2=py36h05f1152_2 140 | - pyrsistent=0.16.0=py36h7b6447c_0 141 | - pysocks=1.7.1=py36_0 142 | - python=3.6.10=h7579374_2 143 | - python-dateutil=2.8.1=py_0 144 | - python_abi=3.6=1_cp36m 145 | - pytorch=1.5.0=py3.6_cuda10.1.243_cudnn7.6.3_0 146 | - pywavelets=1.1.1=py36h7b6447c_0 147 | - pyyaml=5.3.1=py36h7b6447c_0 148 | - pyzmq=18.1.1=py36he6710b0_0 149 | - qt=5.9.7=h5867ecd_1 150 | - qtconsole=4.7.4=py_0 151 | - qtpy=1.9.0=py_0 152 | - readline=8.0=h7b6447c_0 153 | - requests=2.23.0=py36_0 154 | - requests-oauthlib=1.3.0=py_0 155 | - rsa=4.0=py_0 156 | - scikit-image=0.16.2=py36h0573a6f_0 157 | - scikit-learn=0.22.1=py36hd81dba3_0 158 | - scikit-video=1.1.11=pyh24bf2e0_0 159 | - scipy=1.4.1=py36h0b6359f_0 160 | - send2trash=1.5.0=py36_0 161 | - setuptools=46.4.0=py36_0 162 | - sip=4.19.8=py36hf484d3e_0 163 | - six=1.14.0=py36_0 164 | - sqlite=3.31.1=h62c20be_1 165 | - tensorboard-plugin-wit=1.6.0=py_0 166 | - terminado=0.8.3=py36_0 167 | - testpath=0.4.4=py_0 168 | - tk=8.6.8=hbc83047_0 169 | - toolz=0.10.0=py_0 170 | - torchvision=0.6.0=py36_cu101 171 | - tornado=6.0.4=py36h7b6447c_1 172 | - tqdm=4.46.0=py_0 173 | - traitlets=4.3.3=py36_0 174 | - urllib3=1.25.8=py36_0 175 | - wcwidth=0.1.9=py_0 176 | - webencodings=0.5.1=py36_1 177 | - werkzeug=1.0.1=py_0 178 | - wheel=0.34.2=py36_0 179 | - widgetsnbextension=3.5.1=py36_0 180 | - x264=1!152.20180717=h14c3975_1001 181 | - xorg-kbproto=1.0.7=h14c3975_1002 182 | - xorg-libice=1.0.10=h516909a_0 183 | - xorg-libsm=1.2.2=h470a237_5 184 | - xorg-libx11=1.6.9=h516909a_0 185 | - xorg-libxext=1.3.4=h516909a_0 186 | - xorg-libxpm=3.5.13=h516909a_0 187 | - xorg-libxrender=0.9.10=h516909a_1002 188 | - xorg-libxt=1.1.5=h516909a_1003 189 | - xorg-renderproto=0.11.1=h14c3975_1002 190 | - xorg-xextproto=7.3.0=h14c3975_1002 191 | - xorg-xproto=7.0.31=h14c3975_1007 192 | - xz=5.2.5=h7b6447c_0 193 | - yaml=0.1.7=had09818_2 194 | - zeromq=4.3.1=he6710b0_3 195 | - zipp=3.1.0=py_0 196 | - zlib=1.2.11=h7b6447c_3 197 | - zstd=1.4.4=h0b5b093_3 198 | - pip: 199 | - astor==0.8.1 200 | - cmapy==0.6.6 201 | - gast==0.2.2 202 | - google-pasta==0.2.0 203 | - imageio-ffmpeg==0.4.2 204 | - imbalanced-learn==0.6.2 205 | - keras-applications==1.0.8 206 | - keras-preprocessing==1.1.2 207 | - matplotlib==2.2.5 208 | - moviepy==1.0.3 209 | - opencv-python==3.4.9.33 210 | - opt-einsum==3.2.1 211 | - proglog==0.1.9 212 | - pytz==2020.1 213 | - tensorboard==1.15.0 214 | - tensorflow==1.15.0 215 | - tensorflow-estimator==1.15.1 216 | - termcolor==1.1.0 217 | - wrapt==1.12.1 218 | - youtube-dl==2020.6.6 219 | 220 | -------------------------------------------------------------------------------- /experiment_scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsitzmann/siren/4df34baee3f0f9c8f351630992c1fe1f69114b5f/experiment_scripts/__init__.py -------------------------------------------------------------------------------- /experiment_scripts/fwi/data_cylinder_5.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsitzmann/siren/4df34baee3f0f9c8f351630992c1fe1f69114b5f/experiment_scripts/fwi/data_cylinder_5.mat -------------------------------------------------------------------------------- /experiment_scripts/test_audio.py: -------------------------------------------------------------------------------- 1 | # Enable import from parent package 2 | import sys 3 | import os 4 | sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) ) 5 | 6 | import dataio, utils, modules 7 | 8 | from torch.utils.data import DataLoader 9 | import configargparse 10 | import torch 11 | import scipy.io.wavfile as wavfile 12 | 13 | p = configargparse.ArgumentParser() 14 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.') 15 | 16 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging') 17 | p.add_argument('--experiment_name', type=str, default='audio', 18 | help='Name of subdirectory in logging_root where wav file will be saved.') 19 | p.add_argument('--gt_wav_path', type=str, default='../data/gt_bach.wav', help='ground truth wav path') 20 | 21 | p.add_argument('--model_type', type=str, default='sine', 22 | help='Options currently are "sine" (all sine activations), "relu" (all relu activations,' 23 | '"nerf" (relu activations and positional encoding as in NeRF), "rbf" (input rbf layer, rest relu),' 24 | 'and in the future: "mixed" (first layer sine, other layers tanh)') 25 | p.add_argument('--checkpoint_path', required=True, help='Checkpoint to trained model.') 26 | 27 | opt = p.parse_args() 28 | 29 | audio_dataset = dataio.AudioFile(filename=opt.gt_wav_path) 30 | coord_dataset = dataio.ImplicitAudioWrapper(audio_dataset) 31 | 32 | dataloader = DataLoader(coord_dataset, shuffle=True, batch_size=1, pin_memory=True, num_workers=0) 33 | 34 | # Define the model and load in checkpoint path 35 | if opt.model_type == 'sine' or opt.model_type == 'relu' or opt.model_type == 'tanh': 36 | model = modules.SingleBVPNet(type=opt.model_type, mode='mlp', in_features=1) 37 | elif opt.model_type == 'rbf' or opt.model_type == 'nerf': 38 | model = modules.SingleBVPNet(type='relu', mode=opt.model_type, fn_samples=len(audio_dataset.data), in_features=1) 39 | else: 40 | raise NotImplementedError 41 | model.load_state_dict(torch.load(opt.checkpoint_path)) 42 | model.cuda() 43 | 44 | root_path = os.path.join(opt.logging_root, opt.experiment_name) 45 | utils.cond_mkdir(root_path) 46 | 47 | # Get ground truth and input data 48 | model_input, gt = next(iter(dataloader)) 49 | model_input = {key: value.cuda() for key, value in model_input.items()} 50 | gt = {key: value.cuda() for key, value in gt.items()} 51 | 52 | # Evaluate the trained model 53 | with torch.no_grad(): 54 | model_output = model(model_input) 55 | 56 | waveform = torch.squeeze(model_output['model_out']).detach().cpu().numpy() 57 | rate = torch.squeeze(gt['rate']).detach().cpu().numpy() 58 | wavfile.write(os.path.join(opt.logging_root, opt.experiment_name, 'pred_waveform.wav'), rate, waveform) -------------------------------------------------------------------------------- /experiment_scripts/test_conv_neural_process.py: -------------------------------------------------------------------------------- 1 | # Enable import from parent package 2 | import sys 3 | import os 4 | sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) ) 5 | 6 | import dataio, meta_modules, utils, training, loss_functions 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | 10 | import torch 11 | from torch.utils.data import DataLoader 12 | import configargparse 13 | 14 | import imageio 15 | from functools import partial 16 | import random 17 | from tqdm.autonotebook import tqdm 18 | import time 19 | import utils 20 | from torch.utils.tensorboard import SummaryWriter 21 | 22 | p = configargparse.ArgumentParser() 23 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.') 24 | 25 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging') 26 | p.add_argument('--experiment_name', type=str, required=True, 27 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.') 28 | 29 | # General training options 30 | p.add_argument('--checkpoint_path', default=None, type=str, required=True, 31 | help='path to directory where checkpoints & tensorboard events will be saved.') 32 | p.add_argument('--dataset', type=str, default='celeba_32x32', 33 | help='Time interval in seconds until tensorboard summary is saved.') 34 | p.add_argument('--model_type', type=str, default='sine', 35 | help='Nonlinearity in the neural implicit representation') 36 | p.add_argument('--test_sparsity', type=float, default=200, 37 | help='Amount of subsampled pixels input into the set encoder') 38 | p.add_argument('--partial_conv', action='store_true', default=False, help='Use a partial convolution encoder') 39 | opt = p.parse_args() 40 | 41 | if opt.experiment_name is None: 42 | opt.experiment_name = opt.checkpoint_path.split('/')[-3] + '_TEST' 43 | else: 44 | opt.experiment_name = opt.checkpoint_path.split('/')[-3] + '_' + opt.experiment_name 45 | 46 | assert opt.dataset == 'celeba_32x32' 47 | img_dataset_test = dataio.CelebA(split='test', downsampled=True) 48 | coord_dataset_test = dataio.Implicit2DWrapper(img_dataset_test, sidelength=(32, 32)) 49 | generalization_dataset_test = dataio.ImageGeneralizationWrapper(coord_dataset_test, test_sparsity=200, 50 | generalization_mode='conv_cnp_test') 51 | image_resolution = (32, 32) 52 | 53 | img_dataset_train = dataio.CelebA(split='train', downsampled=True) 54 | coord_dataset_train = dataio.Implicit2DWrapper(img_dataset_train, sidelength=(32, 32)) 55 | generalization_dataset_train = dataio.ImageGeneralizationWrapper(coord_dataset_train, test_sparsity=200, 56 | generalization_mode='conv_cnp_test') 57 | 58 | # Define the model. 59 | model = meta_modules.ConvolutionalNeuralProcessImplicit2DHypernet(in_features=img_dataset_test.img_channels, 60 | out_features=img_dataset_test.img_channels, 61 | image_resolution=image_resolution, 62 | partial_conv=opt.partial_conv) 63 | model.cuda() 64 | model.eval() 65 | 66 | root_path = os.path.join(opt.logging_root, opt.experiment_name) 67 | utils.cond_mkdir(root_path) 68 | 69 | 70 | # Load checkpoint 71 | model.load_state_dict(torch.load(opt.checkpoint_path)) 72 | 73 | # First experiment: Upsample training image 74 | model_input = {'coords':dataio.get_mgrid(image_resolution)[None,:].cuda(), 75 | 'img_sparse':generalization_dataset_train[0][0]['img_sparse'].unsqueeze(0).cuda()} 76 | model_output = model(model_input) 77 | 78 | out_img = dataio.lin2img(model_output['model_out'], image_resolution).squeeze().permute(1,2,0).detach().cpu().numpy() 79 | out_img += 1 80 | out_img /= 2. 81 | out_img = np.clip(out_img, 0., 1.) 82 | 83 | imageio.imwrite(os.path.join(root_path, 'upsampled_train.png'), out_img) 84 | 85 | # Second experiment: sample larger range 86 | model_input = {'coords':dataio.get_mgrid(image_resolution)[None,:].cuda()*5, 87 | 'img_sparse':generalization_dataset_train[0][0]['img_sparse'].unsqueeze(0).cuda()} 88 | model_output = model(model_input) 89 | 90 | out_img = dataio.lin2img(model_output['model_out'], image_resolution).squeeze().permute(1,2,0).detach().cpu().numpy() 91 | out_img += 1 92 | out_img /= 2. 93 | out_img = np.clip(out_img, 0., 1.) 94 | 95 | imageio.imwrite(os.path.join(root_path, 'outside_range.png'), out_img) 96 | 97 | # Third experiment: interpolate between latent codes 98 | idx1, idx2 = 57, 181 99 | model_input_1 = {'coords': dataio.get_mgrid(image_resolution)[None, :].cuda(), 100 | 'img_sparse': generalization_dataset_train[idx1][0]['img_sparse'].unsqueeze(0).cuda()} 101 | model_input_2 = {'coords': dataio.get_mgrid(image_resolution)[None, :].cuda(), 102 | 'img_sparse': generalization_dataset_train[idx2][0]['img_sparse'].unsqueeze(0).cuda()} 103 | 104 | embedding_1 = model.get_hypo_net_weights(model_input_1)[1] 105 | embedding_2 = model.get_hypo_net_weights(model_input_2)[1] 106 | for i in np.linspace(0,1,8): 107 | embedding = i*embedding_1 + (1.-i)*embedding_2 108 | model_input = {'coords': dataio.get_mgrid(image_resolution)[None, :].cuda(), 'embedding': embedding} 109 | model_output = model(model_input) 110 | 111 | out_img = dataio.lin2img(model_output['model_out'], image_resolution).squeeze().permute(1,2,0).detach().cpu().numpy() 112 | out_img += 1 113 | out_img /= 2. 114 | out_img = np.clip(out_img, 0., 1.) 115 | 116 | if i == 0.: 117 | out_img_cat = out_img 118 | else: 119 | out_img_cat = np.concatenate((out_img_cat, out_img), axis=1) 120 | 121 | imageio.imwrite(os.path.join(root_path, 'interpolated_image.png'), out_img_cat) 122 | 123 | # Fourth experiment: Fit test images 124 | def to_uint8(img): 125 | img = img * 255 126 | img = img.astype(np.uint8) 127 | return img 128 | 129 | def getTestMSE(dataloader, subdir): 130 | MSEs = [] 131 | total_steps = 0 132 | utils.cond_mkdir(os.path.join(root_path, subdir)) 133 | utils.cond_mkdir(os.path.join(root_path, 'ground_truth')) 134 | 135 | with tqdm(total=len(dataloader)) as pbar: 136 | for step, (model_input, gt) in enumerate(dataloader): 137 | model_input['idx'] = torch.Tensor([model_input['idx']]).long() 138 | model_input = {key: value.cuda() for key, value in model_input.items()} 139 | gt = {key: value.cuda() for key, value in gt.items()} 140 | 141 | with torch.no_grad(): 142 | model_output = model(model_input) 143 | 144 | out_img = dataio.lin2img(model_output['model_out'], image_resolution).squeeze().permute(1, 2, 0).detach().cpu().numpy() 145 | out_img += 1 146 | out_img /= 2. 147 | out_img = np.clip(out_img, 0., 1.) 148 | gt_img = dataio.lin2img(gt['img'], image_resolution).squeeze().permute(1, 2, 0).detach().cpu().numpy() 149 | gt_img += 1 150 | gt_img /= 2. 151 | gt_img = np.clip(gt_img, 0., 1.) 152 | 153 | sparse_img = model_input['img_sparse'].squeeze().detach().cpu().permute(1,2,0).numpy() 154 | mask = np.sum((sparse_img == 0), axis=2) == 3 155 | sparse_img += 1 156 | sparse_img /= 2. 157 | sparse_img = np.clip(sparse_img, 0., 1.) 158 | sparse_img[mask, ...] = 1. 159 | 160 | imageio.imwrite(os.path.join(root_path, subdir, str(total_steps)+'_sparse.png'), to_uint8(sparse_img)) 161 | imageio.imwrite(os.path.join(root_path, subdir, str(total_steps)+'.png'), to_uint8(out_img)) 162 | imageio.imwrite(os.path.join(root_path, 'ground_truth', str(total_steps)+'.png'), to_uint8(gt_img)) 163 | 164 | MSE = np.mean((out_img - gt_img) ** 2) 165 | MSEs.append(MSE) 166 | 167 | pbar.update(1) 168 | total_steps += 1 169 | 170 | return MSEs 171 | 172 | sparsities = [10, 100, 1000, 'full', 'half'] 173 | for sparsity in sparsities: 174 | generalization_dataset_test.update_test_sparsity(sparsity) 175 | dataloader = DataLoader(generalization_dataset_test, shuffle=False, batch_size=1, pin_memory=True, num_workers=0) 176 | MSE = getTestMSE(dataloader, 'test_'+str(sparsity)+'_pixels') 177 | np.save(os.path.join(root_path, 'MSE_'+str(sparsity)+'_context.npy'), MSE) 178 | print(np.mean(MSE)) 179 | -------------------------------------------------------------------------------- /experiment_scripts/test_neural_process.py: -------------------------------------------------------------------------------- 1 | # Enable import from parent package 2 | import os 3 | import sys 4 | 5 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 6 | 7 | import dataio, meta_modules 8 | import numpy as np 9 | 10 | import torch 11 | from torch.utils.data import DataLoader 12 | import configargparse 13 | 14 | import imageio 15 | from tqdm.autonotebook import tqdm 16 | import utils 17 | import skimage 18 | 19 | p = configargparse.ArgumentParser() 20 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.') 21 | 22 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging') 23 | p.add_argument('--experiment_name', type=str, required=True, 24 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.') 25 | 26 | # General training options 27 | p.add_argument('--checkpoint_path', default=None, type=str, required=True, 28 | help='path to directory where checkpoints & tensorboard events will be saved.') 29 | p.add_argument('--dataset', type=str, default='celeba_32x32', 30 | help='Time interval in seconds until tensorboard summary is saved.') 31 | p.add_argument('--model_type', type=str, default='sine', 32 | help='Nonlinearity in the neural implicit representation') 33 | p.add_argument('--test_sparsity', type=float, default=200, 34 | help='Amount of subsampled pixels input into the set encoder') 35 | opt = p.parse_args() 36 | 37 | if opt.experiment_name is None: 38 | opt.experiment_name = opt.checkpoint_path.split('/')[-3] + '_TEST' 39 | else: 40 | opt.experiment_name = opt.checkpoint_path.split('/')[-3] + '_' + opt.experiment_name 41 | 42 | assert opt.dataset == 'celeba_32x32' 43 | img_dataset_test = dataio.CelebA(split='test', downsampled=True) 44 | coord_dataset_test = dataio.Implicit2DWrapper(img_dataset_test, sidelength=(32, 32)) 45 | generalization_dataset_test = dataio.ImageGeneralizationWrapper(coord_dataset_test, test_sparsity=10, 46 | generalization_mode='cnp_test') 47 | image_resolution = (32, 32) 48 | 49 | img_dataset_train = dataio.CelebA(split='train', downsampled=True) 50 | coord_dataset_train = dataio.Implicit2DWrapper(img_dataset_train, sidelength=(32, 32)) 51 | generalization_dataset_train = dataio.ImageGeneralizationWrapper(coord_dataset_train, test_sparsity=10, 52 | generalization_mode='cnp_test') 53 | 54 | # Define the model. 55 | model = meta_modules.NeuralProcessImplicit2DHypernet(in_features=img_dataset_test.img_channels + 2, 56 | out_features=img_dataset_test.img_channels, 57 | image_resolution=image_resolution) 58 | model.cuda() 59 | model.eval() 60 | 61 | root_path = os.path.join(opt.logging_root, opt.experiment_name) 62 | utils.cond_mkdir(root_path) 63 | 64 | # Load checkpoint 65 | model.load_state_dict(torch.load(opt.checkpoint_path)) 66 | 67 | # First experiment: Upsample training image 68 | model_input = {'coords': dataio.get_mgrid(image_resolution)[None, :].cuda(), 69 | 'img_sub': generalization_dataset_train[0][0]['img_sub'].unsqueeze(0).cuda(), 70 | 'coords_sub': generalization_dataset_train[0][0]['coords_sub'].unsqueeze(0).cuda()} 71 | model_output = model(model_input) 72 | 73 | out_img = dataio.lin2img(model_output['model_out'], image_resolution).squeeze().permute(1, 2, 0).detach().cpu().numpy() 74 | out_img += 1 75 | out_img /= 2. 76 | out_img = np.clip(out_img, 0., 1.) 77 | 78 | imageio.imwrite(os.path.join(root_path, 'upsampled_train.png'), out_img) 79 | 80 | # Second experiment: sample larger range 81 | model_input = {'coords': dataio.get_mgrid(image_resolution)[None, :].cuda() * 5, 82 | 'img_sub': generalization_dataset_train[0][0]['img_sub'].unsqueeze(0).cuda(), 83 | 'coords_sub': generalization_dataset_train[0][0]['coords_sub'].unsqueeze(0).cuda()} 84 | model_output = model(model_input) 85 | 86 | out_img = dataio.lin2img(model_output['model_out'], image_resolution).squeeze().permute(1, 2, 0).detach().cpu().numpy() 87 | out_img += 1 88 | out_img /= 2. 89 | out_img = np.clip(out_img, 0., 1.) 90 | 91 | imageio.imwrite(os.path.join(root_path, 'outside_range.png'), out_img) 92 | 93 | # Third experiment: interpolate between latent codes 94 | idx1, idx2 = 57, 181 95 | model_input_1 = {'coords': dataio.get_mgrid(image_resolution)[None, :].cuda(), 96 | 'img_sub': generalization_dataset_train[idx1][0]['img_sub'].unsqueeze(0).cuda(), 97 | 'coords_sub': generalization_dataset_train[idx1][0]['coords_sub'].unsqueeze(0).cuda()} 98 | model_input_2 = {'coords': dataio.get_mgrid(image_resolution)[None, :].cuda(), 99 | 'img_sub': generalization_dataset_train[idx2][0]['img_sub'].unsqueeze(0).cuda(), 100 | 'coords_sub': generalization_dataset_train[idx2][0]['coords_sub'].unsqueeze(0).cuda()} 101 | 102 | embedding_1 = model.get_hypo_net_weights(model_input_1)[1] 103 | embedding_2 = model.get_hypo_net_weights(model_input_2)[1] 104 | for i in np.linspace(0, 1, 8): 105 | embedding = i * embedding_1 + (1. - i) * embedding_2 106 | model_input = {'coords': dataio.get_mgrid(image_resolution)[None, :].cuda(), 'embedding': embedding} 107 | model_output = model(model_input) 108 | 109 | out_img = dataio.lin2img(model_output['model_out'], image_resolution).squeeze().permute(1, 2, 110 | 0).detach().cpu().numpy() 111 | out_img += 1 112 | out_img /= 2. 113 | out_img = np.clip(out_img, 0., 1.) 114 | 115 | if i == 0.: 116 | out_img_cat = out_img 117 | else: 118 | out_img_cat = np.concatenate((out_img_cat, out_img), axis=1) 119 | 120 | imageio.imwrite(os.path.join(root_path, 'interpolated_image.png'), out_img_cat) 121 | 122 | 123 | # Fourth experiment: Fit test images 124 | def to_uint8(img): 125 | img = img * 255 126 | img = img.astype(np.uint8) 127 | return img 128 | 129 | 130 | def getTestMSE(dataloader, subdir): 131 | MSEs = [] 132 | PSNRs = [] 133 | total_steps = 0 134 | utils.cond_mkdir(os.path.join(root_path, subdir)) 135 | utils.cond_mkdir(os.path.join(root_path, 'ground_truth')) 136 | 137 | with tqdm(total=len(dataloader)) as pbar: 138 | for step, (model_input, gt) in enumerate(dataloader): 139 | model_input['idx'] = torch.Tensor([model_input['idx']]).long() 140 | model_input = {key: value.cuda() for key, value in model_input.items()} 141 | gt = {key: value.cuda() for key, value in gt.items()} 142 | 143 | with torch.no_grad(): 144 | model_output = model(model_input) 145 | 146 | out_img = dataio.lin2img(model_output['model_out'], image_resolution).squeeze().permute(1, 2, 147 | 0).detach().cpu().numpy() 148 | out_img += 1 149 | out_img /= 2. 150 | out_img = np.clip(out_img, 0., 1.) 151 | gt_img = dataio.lin2img(gt['img'], image_resolution).squeeze().permute(1, 2, 0).detach().cpu().numpy() 152 | gt_img += 1 153 | gt_img /= 2. 154 | gt_img = np.clip(gt_img, 0., 1.) 155 | 156 | sparse_img = np.ones((image_resolution[0], image_resolution[1], 3)) 157 | coords_sub = model_input['coords_sub'].squeeze().detach().cpu().numpy() 158 | rgb_sub = model_input['img_sub'].squeeze().detach().cpu().numpy() 159 | for index in range(0, coords_sub.shape[0]): 160 | r = int(round((coords_sub[index][0] + 1) / 2 * 31)) 161 | c = int(round((coords_sub[index][1] + 1) / 2 * 31)) 162 | sparse_img[r, c, :] = np.clip((rgb_sub[index, :] + 1) / 2, 0., 1.) 163 | 164 | imageio.imwrite(os.path.join(root_path, subdir, str(total_steps) + '_sparse.png'), to_uint8(sparse_img)) 165 | imageio.imwrite(os.path.join(root_path, subdir, str(total_steps) + '.png'), to_uint8(out_img)) 166 | imageio.imwrite(os.path.join(root_path, 'ground_truth', str(total_steps) + '.png'), to_uint8(gt_img)) 167 | 168 | MSE = np.mean((out_img - gt_img) ** 2) 169 | MSEs.append(MSE) 170 | 171 | PSNR = skimage.measure.compare_psnr(out_img, gt_img, data_range=1) 172 | PSNRs.append(PSNR) 173 | 174 | pbar.update(1) 175 | total_steps += 1 176 | 177 | return MSEs, PSNRs 178 | 179 | 180 | sparsities = [10, 100, 1000, 'full', 'half'] 181 | for sparsity in sparsities: 182 | generalization_dataset_test.update_test_sparsity(sparsity) 183 | dataloader = DataLoader(generalization_dataset_test, shuffle=False, batch_size=1, pin_memory=True, num_workers=0) 184 | MSE, PSNR = getTestMSE(dataloader, 'test_' + str(sparsity) + '_pixels') 185 | np.save(os.path.join(root_path, 'MSE_' + str(sparsity) + '_context.npy'), MSE) 186 | np.save(os.path.join(root_path, 'PSNR_' + str(sparsity) + '_context.npy'), PSNR) 187 | print(np.mean(MSE)) 188 | -------------------------------------------------------------------------------- /experiment_scripts/test_sdf.py: -------------------------------------------------------------------------------- 1 | '''Test script for experiments in paper Sec. 4.2, Supplement Sec. 3, reconstruction from laplacian. 2 | ''' 3 | 4 | # Enable import from parent package 5 | import os 6 | import sys 7 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 8 | 9 | import torch 10 | import modules, utils 11 | import sdf_meshing 12 | import configargparse 13 | 14 | p = configargparse.ArgumentParser() 15 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.') 16 | 17 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging') 18 | p.add_argument('--experiment_name', type=str, required=True, 19 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.') 20 | 21 | # General training options 22 | p.add_argument('--batch_size', type=int, default=16384) 23 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.') 24 | 25 | p.add_argument('--model_type', type=str, default='sine', 26 | help='Options are "sine" (all sine activations) and "mixed" (first layer sine, other layers tanh)') 27 | p.add_argument('--mode', type=str, default='mlp', 28 | help='Options are "mlp" or "nerf"') 29 | p.add_argument('--resolution', type=int, default=1600) 30 | 31 | opt = p.parse_args() 32 | 33 | 34 | class SDFDecoder(torch.nn.Module): 35 | def __init__(self): 36 | super().__init__() 37 | # Define the model. 38 | if opt.mode == 'mlp': 39 | self.model = modules.SingleBVPNet(type=opt.model_type, final_layer_factor=1, in_features=3) 40 | elif opt.mode == 'nerf': 41 | self.model = modules.SingleBVPNet(type='relu', mode='nerf', final_layer_factor=1, in_features=3) 42 | self.model.load_state_dict(torch.load(opt.checkpoint_path)) 43 | self.model.cuda() 44 | 45 | def forward(self, coords): 46 | model_in = {'coords': coords} 47 | return self.model(model_in)['model_out'] 48 | 49 | 50 | sdf_decoder = SDFDecoder() 51 | 52 | root_path = os.path.join(opt.logging_root, opt.experiment_name) 53 | utils.cond_mkdir(root_path) 54 | 55 | sdf_meshing.create_mesh(sdf_decoder, os.path.join(root_path, 'test'), N=opt.resolution) 56 | -------------------------------------------------------------------------------- /experiment_scripts/train_audio.py: -------------------------------------------------------------------------------- 1 | # Enable import from parent package 2 | import sys 3 | import os 4 | sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) ) 5 | 6 | import dataio, meta_modules, utils, training, loss_functions, modules 7 | 8 | from torch.utils.data import DataLoader 9 | import configargparse 10 | from functools import partial 11 | 12 | p = configargparse.ArgumentParser() 13 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.') 14 | 15 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging') 16 | p.add_argument('--experiment_name', type=str, default='audio', 17 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.') 18 | 19 | p.add_argument('--wav_path', type=str, default='../data/gt_bach.wav', help='root for logging') 20 | 21 | # General training options 22 | p.add_argument('--batch_size', type=int, default=1) 23 | p.add_argument('--lr', type=float, default=1e-4, help='learning rate. default=5e-5') 24 | p.add_argument('--num_epochs', type=int, default=9001, 25 | help='Number of epochs to train for.') 26 | 27 | p.add_argument('--epochs_til_ckpt', type=int, default=1000, 28 | help='Time interval in seconds until checkpoint is saved.') 29 | p.add_argument('--steps_til_summary', type=int, default=1000, 30 | help='Time interval in seconds until tensorboard summary is saved.') 31 | 32 | p.add_argument('--model_type', type=str, default='sine', 33 | help='Options currently are "sine" (all sine activations), "relu" (all relu activations,' 34 | '"nerf" (relu activations and positional encoding as in NeRF), "rbf" (input rbf layer, rest relu),' 35 | 'and in the future: "mixed" (first layer sine, other layers tanh)') 36 | 37 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.') 38 | opt = p.parse_args() 39 | 40 | audio_dataset = dataio.AudioFile(filename=opt.wav_path) 41 | coord_dataset = dataio.ImplicitAudioWrapper(audio_dataset) 42 | 43 | dataloader = DataLoader(coord_dataset, shuffle=True, batch_size=opt.batch_size, pin_memory=True, num_workers=0) 44 | 45 | # Define the model. 46 | if opt.model_type == 'sine' or opt.model_type == 'relu' or opt.model_type == 'tanh': 47 | model = modules.SingleBVPNet(type=opt.model_type, mode='mlp', in_features=1) 48 | elif opt.model_type == 'rbf' or opt.model_type == 'nerf': 49 | model = modules.SingleBVPNet(type='relu', mode=opt.model_type, fn_samples=len(audio_dataset.data), in_features=1) 50 | else: 51 | raise NotImplementedError 52 | model.cuda() 53 | 54 | root_path = os.path.join(opt.logging_root, opt.experiment_name) 55 | utils.cond_mkdir(root_path) 56 | 57 | # Define the loss 58 | loss_fn = loss_functions.function_mse 59 | summary_fn = partial(utils.write_audio_summary, root_path) 60 | 61 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr, 62 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt, 63 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn) -------------------------------------------------------------------------------- /experiment_scripts/train_helmholtz.py: -------------------------------------------------------------------------------- 1 | # Enable import from parent package 2 | import sys 3 | import os 4 | sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) ) 5 | 6 | import dataio, meta_modules, utils, training, loss_functions, modules 7 | 8 | from torch.utils.data import DataLoader 9 | import torch 10 | import configargparse 11 | import numpy as np 12 | 13 | p = configargparse.ArgumentParser() 14 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.') 15 | 16 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging') 17 | p.add_argument('--experiment_name', type=str, required=True, 18 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.') 19 | 20 | # General training options 21 | p.add_argument('--batch_size', type=int, default=32) 22 | p.add_argument('--lr', type=float, default=2e-5, help='learning rate. default=2e-5') 23 | p.add_argument('--num_epochs', type=int, default=50000, 24 | help='Number of epochs to train for.') 25 | 26 | p.add_argument('--epochs_til_ckpt', type=int, default=1000, 27 | help='Time interval in seconds until checkpoint is saved.') 28 | p.add_argument('--steps_til_summary', type=int, default=100, 29 | help='Time interval in seconds until tensorboard summary is saved.') 30 | p.add_argument('--model', type=str, default='sine', required=False, choices=['sine', 'tanh', 'sigmoid', 'relu'], 31 | help='Type of model to evaluate, default is sine.') 32 | p.add_argument('--mode', type=str, default='mlp', required=False, choices=['mlp', 'rbf', 'pinn'], 33 | help='Whether to use uniform velocity parameter') 34 | p.add_argument('--velocity', type=str, default='uniform', required=False, choices=['uniform', 'square', 'circle'], 35 | help='Whether to use uniform velocity parameter') 36 | p.add_argument('--clip_grad', default=0.0, type=float, help='Clip gradient.') 37 | p.add_argument('--use_lbfgs', default=False, type=bool, help='use L-BFGS.') 38 | 39 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.') 40 | opt = p.parse_args() 41 | 42 | # if we have a velocity perturbation, offset the source 43 | if opt.velocity!='uniform': 44 | source_coords = [-0.35, 0.] 45 | else: 46 | source_coords = [0., 0.] 47 | 48 | dataset = dataio.SingleHelmholtzSource(sidelength=230, velocity=opt.velocity, source_coords=source_coords) 49 | 50 | dataloader = DataLoader(dataset, shuffle=True, batch_size=opt.batch_size, pin_memory=True, num_workers=0) 51 | 52 | # Define the model. 53 | if opt.mode == 'pinn': 54 | model = modules.PINNet(out_features=2, type='tanh', mode=opt.mode) 55 | opt.use_lbfgs = True 56 | else: 57 | model = modules.SingleBVPNet(out_features=2, type=opt.model, mode=opt.mode, final_layer_factor=1.) 58 | 59 | model.cuda() 60 | 61 | # Define the loss 62 | loss_fn = loss_functions.helmholtz_pml 63 | summary_fn = utils.write_helmholtz_summary 64 | 65 | root_path = os.path.join(opt.logging_root, opt.experiment_name) 66 | 67 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr, 68 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt, 69 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn, clip_grad=opt.clip_grad, 70 | use_lbfgs=opt.use_lbfgs) 71 | -------------------------------------------------------------------------------- /experiment_scripts/train_img.py: -------------------------------------------------------------------------------- 1 | # Enable import from parent package 2 | import sys 3 | import os 4 | sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) ) 5 | 6 | import dataio, meta_modules, utils, training, loss_functions, modules 7 | 8 | from torch.utils.data import DataLoader 9 | import configargparse 10 | from functools import partial 11 | 12 | p = configargparse.ArgumentParser() 13 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.') 14 | 15 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging') 16 | p.add_argument('--experiment_name', type=str, required=True, 17 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.') 18 | 19 | # General training options 20 | p.add_argument('--batch_size', type=int, default=1) 21 | p.add_argument('--lr', type=float, default=1e-4, help='learning rate. default=1e-4') 22 | p.add_argument('--num_epochs', type=int, default=10000, 23 | help='Number of epochs to train for.') 24 | 25 | p.add_argument('--epochs_til_ckpt', type=int, default=25, 26 | help='Time interval in seconds until checkpoint is saved.') 27 | p.add_argument('--steps_til_summary', type=int, default=1000, 28 | help='Time interval in seconds until tensorboard summary is saved.') 29 | 30 | p.add_argument('--model_type', type=str, default='sine', 31 | help='Options currently are "sine" (all sine activations), "relu" (all relu activations,' 32 | '"nerf" (relu activations and positional encoding as in NeRF), "rbf" (input rbf layer, rest relu),' 33 | 'and in the future: "mixed" (first layer sine, other layers tanh)') 34 | 35 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.') 36 | opt = p.parse_args() 37 | 38 | img_dataset = dataio.Camera() 39 | coord_dataset = dataio.Implicit2DWrapper(img_dataset, sidelength=512, compute_diff='all') 40 | image_resolution = (512, 512) 41 | 42 | dataloader = DataLoader(coord_dataset, shuffle=True, batch_size=opt.batch_size, pin_memory=True, num_workers=0) 43 | 44 | # Define the model. 45 | if opt.model_type == 'sine' or opt.model_type == 'relu' or opt.model_type == 'tanh' or opt.model_type == 'selu' or opt.model_type == 'elu'\ 46 | or opt.model_type == 'softplus': 47 | model = modules.SingleBVPNet(type=opt.model_type, mode='mlp', sidelength=image_resolution) 48 | elif opt.model_type == 'rbf' or opt.model_type == 'nerf': 49 | model = modules.SingleBVPNet(type='relu', mode=opt.model_type, sidelength=image_resolution) 50 | else: 51 | raise NotImplementedError 52 | model.cuda() 53 | 54 | root_path = os.path.join(opt.logging_root, opt.experiment_name) 55 | 56 | # Define the loss 57 | loss_fn = partial(loss_functions.image_mse, None) 58 | summary_fn = partial(utils.write_image_summary, image_resolution) 59 | 60 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr, 61 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt, 62 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn) 63 | -------------------------------------------------------------------------------- /experiment_scripts/train_img_inpainting.py: -------------------------------------------------------------------------------- 1 | # Enable import from parent package 2 | import sys 3 | import os 4 | sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) ) 5 | 6 | import dataio, meta_modules, utils, training, loss_functions, modules 7 | 8 | from torch.utils.data import DataLoader 9 | import configargparse 10 | from functools import partial 11 | import torch 12 | from PIL import Image 13 | from torchvision.transforms import ToTensor 14 | import numpy as np 15 | 16 | p = configargparse.ArgumentParser() 17 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.') 18 | 19 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging') 20 | p.add_argument('--experiment_name', type=str, required=True, 21 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.') 22 | 23 | # General training options 24 | p.add_argument('--batch_size', type=int, default=1) 25 | p.add_argument('--lr', type=float, default=1e-4, help='learning rate. default=1e-4') 26 | p.add_argument('--num_epochs', type=int, default=10000, 27 | help='Number of epochs to train for.') 28 | p.add_argument('--k1', type=float, default=1, help='weight on prior') 29 | p.add_argument('--sparsity', type=float, default=0.1, help='percentage of pixels filled') 30 | p.add_argument('--prior', type=str, default=None, help='prior') 31 | p.add_argument('--downsample', action='store_true', default=False, help='use image downsampling kernel') 32 | 33 | p.add_argument('--epochs_til_ckpt', type=int, default=25, 34 | help='Time interval in seconds until checkpoint is saved.') 35 | p.add_argument('--steps_til_summary', type=int, default=1000, 36 | help='Time interval in seconds until tensorboard summary is saved.') 37 | 38 | p.add_argument('--dataset', type=str, default='camera', 39 | help='Time interval in seconds until tensorboard summary is saved.') 40 | p.add_argument('--model_type', type=str, default='sine', 41 | help='Options currently are "sine" (all sine activations), "relu" (all relu activations,' 42 | '"nerf" (relu activations and positional encoding as in NeRF), "rbf" (input rbf layer, rest relu),' 43 | 'and in the future: "mixed" (first layer sine, other layers tanh)') 44 | 45 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.') 46 | 47 | p.add_argument('--mask_path', type=str, default=None, help='Path to mask image') 48 | p.add_argument('--custom_image', type=str, default=None, help='Path to single training image') 49 | opt = p.parse_args() 50 | 51 | 52 | if opt.dataset == 'camera': 53 | img_dataset = dataio.Camera() 54 | coord_dataset = dataio.Implicit2DWrapper(img_dataset, sidelength=512, compute_diff='all') 55 | image_resolution = (512, 512) 56 | if opt.dataset == 'camera_downsampled': 57 | img_dataset = dataio.Camera(downsample_factor=2) 58 | coord_dataset = dataio.Implicit2DWrapper(img_dataset, sidelength=256, compute_diff='all') 59 | image_resolution = (256, 256) 60 | if opt.dataset == 'custom': 61 | img_dataset = dataio.ImageFile(opt.custom_image) 62 | coord_dataset = dataio.Implicit2DWrapper(img_dataset, sidelength=(img_dataset[0].size[1], img_dataset[0].size[0]), 63 | compute_diff='all') 64 | image_resolution = (img_dataset[0].size[1], img_dataset[0].size[0]) 65 | 66 | dataloader = DataLoader(coord_dataset, shuffle=True, batch_size=opt.batch_size, pin_memory=True, num_workers=0) 67 | 68 | # Define the model. 69 | if opt.model_type == 'sine' or opt.model_type == 'relu' or opt.model_type == 'tanh': 70 | model = modules.SingleBVPNet(type=opt.model_type, mode='mlp', out_features=img_dataset.img_channels, sidelength=image_resolution, 71 | downsample=opt.downsample) 72 | elif opt.model_type == 'rbf' or opt.model_type == 'nerf': 73 | model = modules.SingleBVPNet(type='relu', mode=opt.model_type, out_features=img_dataset.img_channels, sidelength=image_resolution, 74 | downsample=opt.downsample) 75 | else: 76 | raise NotImplementedError 77 | model.cuda() 78 | 79 | root_path = os.path.join(opt.logging_root, opt.experiment_name) 80 | 81 | if opt.mask_path: 82 | mask = Image.open(opt.mask_path) 83 | mask = ToTensor()(mask) 84 | mask = mask.float().cuda() 85 | percentage = torch.sum(mask).cpu().numpy() / np.prod(mask.shape) 86 | print("mask sparsity %f" % (percentage)) 87 | else: 88 | mask = torch.rand(image_resolution) < opt.sparsity 89 | mask = mask.float().cuda() 90 | 91 | # Define the loss 92 | if opt.prior is None: 93 | loss_fn = partial(loss_functions.image_mse, mask.view(-1,1)) 94 | elif opt.prior == 'TV': 95 | loss_fn = partial(loss_functions.image_mse_TV_prior, mask.view(-1,1), opt.k1, model) 96 | elif opt.prior == 'FH': 97 | loss_fn = partial(loss_functions.image_mse_FH_prior, mask.view(-1,1), opt.k1, model) 98 | summary_fn = partial(utils.write_image_summary, image_resolution) 99 | 100 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr, 101 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt, 102 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn) 103 | -------------------------------------------------------------------------------- /experiment_scripts/train_img_neural_process.py: -------------------------------------------------------------------------------- 1 | # Enable import from parent package 2 | import sys 3 | import os 4 | sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) ) 5 | 6 | import dataio, meta_modules, utils, training, loss_functions 7 | 8 | import torch 9 | from torch.utils.data import DataLoader 10 | import configargparse 11 | from functools import partial 12 | 13 | p = configargparse.ArgumentParser() 14 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.') 15 | 16 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging') 17 | p.add_argument('--experiment_name', type=str, required=True, 18 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.') 19 | 20 | # General training options 21 | p.add_argument('--batch_size', type=int, default=100) 22 | p.add_argument('--lr', type=float, default=5e-5, help='learning rate. default=5e-5') 23 | p.add_argument('--num_epochs', type=int, default=401, 24 | help='Number of epochs to train for.') 25 | p.add_argument('--kl_weight', type=float, default=1e-1, 26 | help='Weight for l2 loss term on code vectors z (lambda_latent in paper).') 27 | p.add_argument('--fw_weight', type=float, default=1e2, 28 | help='Weight for the l2 loss term on the weights of the sine network') 29 | p.add_argument('--train_sparsity_range', type=int, nargs='+', default=[10, 200], 30 | help='Two integers: lowest number of sparse pixels sampled followed by highest number of sparse' 31 | 'pixels sampled when training the conditional neural process') 32 | 33 | p.add_argument('--epochs_til_ckpt', type=int, default=10, 34 | help='Time interval in seconds until checkpoint is saved.') 35 | p.add_argument('--steps_til_summary', type=int, default=1000, 36 | help='Time interval in seconds until tensorboard summary is saved.') 37 | 38 | p.add_argument('--dataset', type=str, default='celeba_32x32', 39 | help='Time interval in seconds until tensorboard summary is saved.') 40 | p.add_argument('--model_type', type=str, default='sine', 41 | help='Nonlinearity for the hypo-network module') 42 | 43 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.') 44 | 45 | p.add_argument('--conv_encoder', action='store_true', default=False, help='Use convolutional encoder process') 46 | opt = p.parse_args() 47 | 48 | 49 | assert opt.dataset == 'celeba_32x32' 50 | if opt.conv_encoder: gmode = 'conv_cnp' 51 | else: gmode = 'cnp' 52 | 53 | img_dataset = dataio.CelebA(split='train', downsampled=True) 54 | coord_dataset = dataio.Implicit2DWrapper(img_dataset, sidelength=(32, 32)) 55 | generalization_dataset = dataio.ImageGeneralizationWrapper(coord_dataset, 56 | train_sparsity_range=opt.train_sparsity_range, 57 | generalization_mode=gmode) 58 | image_resolution = (32, 32) 59 | 60 | dataloader = DataLoader(generalization_dataset, shuffle=True, batch_size=opt.batch_size, pin_memory=True, num_workers=0) 61 | 62 | if opt.conv_encoder: 63 | model = meta_modules.ConvolutionalNeuralProcessImplicit2DHypernet(in_features=img_dataset.img_channels, 64 | out_features=img_dataset.img_channels, 65 | image_resolution=image_resolution) 66 | else: 67 | model = meta_modules.NeuralProcessImplicit2DHypernet(in_features=img_dataset.img_channels + 2, 68 | out_features=img_dataset.img_channels, 69 | image_resolution=image_resolution) 70 | model.cuda() 71 | 72 | # Define the loss 73 | loss_fn = partial(loss_functions.image_hypernetwork_loss, None, opt.kl_weight, opt.fw_weight) 74 | summary_fn = partial(utils.write_image_summary_small, image_resolution, None) 75 | 76 | root_path = os.path.join(opt.logging_root, opt.experiment_name) 77 | 78 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr, 79 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt, 80 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn, clip_grad=True) 81 | -------------------------------------------------------------------------------- /experiment_scripts/train_inverse_helmholtz.py: -------------------------------------------------------------------------------- 1 | '''Reproduces Paper Sec. 4.3, Supplement Sec. 5, reconstruction from gradient. 2 | ''' 3 | 4 | # Enable import from parent package 5 | import sys 6 | import os 7 | sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) ) 8 | 9 | import dataio, utils, training, loss_functions, modules 10 | from torch.utils.data import DataLoader 11 | import torch 12 | import configargparse 13 | from scipy.io import loadmat 14 | 15 | p = configargparse.ArgumentParser() 16 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.') 17 | 18 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging') 19 | p.add_argument('--experiment_name', type=str, required=True, 20 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.') 21 | 22 | # General training options 23 | p.add_argument('--batch_size', type=int, default=32) 24 | p.add_argument('--lr', type=float, default=2e-5, help='learning rate. default=2e-5') 25 | p.add_argument('--num_epochs', type=int, default=80000, 26 | help='Number of epochs to train for.') 27 | 28 | p.add_argument('--epochs_til_ckpt', type=int, default=1000, 29 | help='Time interval in seconds until checkpoint is saved.') 30 | p.add_argument('--steps_til_summary', type=int, default=100, 31 | help='Time interval in seconds until tensorboard summary is saved.') 32 | p.add_argument('--model', type=str, default='sine', required=False, choices=['sine', 'tanh', 'sigmoid'], 33 | help='Type of model to evaluate, default is sine.') 34 | p.add_argument('--data', type=str, default='./fwi/data_cylinder_5.mat', required=False, 35 | help='Data file with the source/rec coordinates and data.') 36 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.') 37 | p.add_argument('--clip_grad', default=0.0, type=float, help='Clip gradient.') 38 | p.add_argument('--pretrain', default=False, action='store_true', help='don''t solve for velocity.') 39 | p.add_argument('--load_model', type=str, default=None, required=False, 40 | help='Load pretrained model from checkpoint.') 41 | 42 | opt = p.parse_args() 43 | 44 | # we need to load source and receiver data generated by the principled solver for FWI 45 | data = loadmat(opt.data) 46 | source_coords = data['source'] 47 | rec_coords = data['receivers'] 48 | rec_val = data['rec_val'] 49 | 50 | dataset = dataio.InverseHelmholtz(source_coords, rec_coords, rec_val, sidelength=115, pretrain=opt.pretrain) 51 | dataloader = DataLoader(dataset, shuffle=True, batch_size=opt.batch_size, pin_memory=True, num_workers=0) 52 | 53 | # Define the model. 54 | N_src = source_coords.shape[0] 55 | model = modules.SingleBVPNet(in_features=2, out_features=2 * N_src + 1, type=opt.model, final_layer_factor=1.) 56 | 57 | if opt.load_model is not None: 58 | model.load_state_dict(torch.load(opt.load_model)) 59 | 60 | model.cuda() 61 | 62 | # Define the loss 63 | loss_fn = loss_functions.helmholtz_pml 64 | summary_fn = utils.write_helmholtz_summary 65 | 66 | root_path = os.path.join(opt.logging_root, opt.experiment_name) 67 | 68 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr, 69 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt, 70 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn, clip_grad=opt.clip_grad) 71 | -------------------------------------------------------------------------------- /experiment_scripts/train_poisson_grad_img.py: -------------------------------------------------------------------------------- 1 | '''Reproduces Paper Sec. 4.1, Supplement Sec. 3, reconstruction from gradient. 2 | ''' 3 | 4 | # Enable import from parent package 5 | import sys 6 | import os 7 | sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) ) 8 | 9 | import dataio, meta_modules, utils, training, loss_functions, modules 10 | 11 | from torch.utils.data import DataLoader 12 | import configargparse 13 | 14 | p = configargparse.ArgumentParser() 15 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.') 16 | 17 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging') 18 | p.add_argument('--experiment_name', type=str, required=True, 19 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.') 20 | 21 | # General training options 22 | p.add_argument('--batch_size', type=int, default=16384) 23 | p.add_argument('--lr', type=float, default=1e-4, help='learning rate. default=5e-5') 24 | p.add_argument('--num_epochs', type=int, default=10000, 25 | help='Number of epochs to train for.') 26 | 27 | p.add_argument('--epochs_til_ckpt', type=int, default=25, 28 | help='Time interval in seconds until checkpoint is saved.') 29 | p.add_argument('--steps_til_summary', type=int, default=100, 30 | help='Time interval in seconds until tensorboard summary is saved.') 31 | 32 | p.add_argument('--dataset', type=str, choices=['camera','bsd500'], default='camera', 33 | help='Dataset: choices=[camera,bsd500].') 34 | p.add_argument('--model_type', type=str, default='sine', 35 | help='Options are "sine" (all sine activations) and "mixed" (first layer sine, other layers tanh)') 36 | 37 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.') 38 | opt = p.parse_args() 39 | 40 | if opt.dataset == 'camera': 41 | img_dataset = dataio.Camera() 42 | coord_dataset = dataio.Implicit2DWrapper(img_dataset, sidelength=256, compute_diff='gradients') 43 | elif opt.dataset == 'bsd500': 44 | # you can select the image your like in idx to sample 45 | img_dataset = dataio.BSD500ImageDataset(in_folder='../data/BSD500/train', 46 | idx_to_sample=[19]) 47 | coord_dataset = dataio.Implicit2DWrapper(img_dataset, sidelength=256, compute_diff='gradients') 48 | 49 | dataloader = DataLoader(coord_dataset, shuffle=True, batch_size=opt.batch_size, pin_memory=True, num_workers=0) 50 | 51 | # Define the model. 52 | if opt.model_type == 'sine' or opt.model_type == 'relu' or opt.model_type == 'tanh' or opt.model_type == 'softplus': 53 | model = modules.SingleBVPNet(type=opt.model_type, mode='mlp', sidelength=(256, 256)) 54 | elif opt.model_type == 'rbf' or opt.model_type == 'nerf': 55 | model = modules.SingleBVPNet(type='relu', mode=opt.model_type, sidelength=(256, 256)) 56 | else: 57 | raise NotImplementedError 58 | model.cuda() 59 | 60 | # Define the loss & summary functions 61 | loss_fn = loss_functions.gradients_mse 62 | summary_fn = utils.write_gradients_summary 63 | 64 | root_path = os.path.join(opt.logging_root, opt.experiment_name) 65 | 66 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr, 67 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt, 68 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn, double_precision=False) 69 | -------------------------------------------------------------------------------- /experiment_scripts/train_poisson_gradcomp_img.py: -------------------------------------------------------------------------------- 1 | '''Reproduces Paper Sec. 4.1, Supplement Sec. 3, poisson image editing. 2 | ''' 3 | 4 | # Enable import from parent package 5 | import sys 6 | import os 7 | 8 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 9 | 10 | import dataio, meta_modules, utils, training, loss_functions, modules 11 | 12 | from torch.utils.data import DataLoader 13 | import configargparse 14 | 15 | p = configargparse.ArgumentParser() 16 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.') 17 | 18 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging') 19 | p.add_argument('--experiment_name', type=str, required=True, 20 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.') 21 | 22 | # General training options 23 | p.add_argument('--batch_size', type=int, default=16384) 24 | p.add_argument('--lr', type=float, default=1e-4, help='learning rate. default=5e-5') 25 | p.add_argument('--num_epochs', type=int, default=10000, 26 | help='Number of epochs to train for.') 27 | 28 | p.add_argument('--epochs_til_ckpt', type=int, default=25, 29 | help='Time interval in seconds until checkpoint is saved.') 30 | p.add_argument('--steps_til_summary', type=int, default=100, 31 | help='Time interval in seconds until tensorboard summary is saved.') 32 | 33 | p.add_argument('--model_type', type=str, default='sine', 34 | help='Options are "sine" (all sine activations) and "mixed" (first layer sine, other layers tanh)') 35 | 36 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.') 37 | opt = p.parse_args() 38 | 39 | # Dataset 40 | img_filepath1 = '/media/data4e/jnmartel/neural_bvps/gizeh.jpg' 41 | img_filepath2 = '/media/data4e/jnmartel/neural_bvps/bear.jpg' 42 | is_color = False 43 | coord_dataset = dataio.CompositeGradients(img_filepath1, img_filepath2, is_color=is_color, 44 | sidelength=512) 45 | dataloader = DataLoader(coord_dataset, shuffle=True, batch_size=opt.batch_size, pin_memory=True, num_workers=0) 46 | 47 | # Define the model. 48 | if not is_color: 49 | model = modules.SingleBVPNet(type=opt.model_type, final_layer_factor=1) 50 | loss_fn = loss_functions.gradients_mse 51 | else: 52 | model = modules.SingleBVPNet(out_features=3, type=opt.model_type, final_layer_factor=1) 53 | loss_fn = loss_functions.gradients_color_mse 54 | model.cuda() 55 | 56 | summary_fn = utils.write_gradcomp_summary 57 | 58 | root_path = os.path.join(opt.logging_root, opt.experiment_name) 59 | 60 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr, 61 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt, 62 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn, double_precision=False) 63 | -------------------------------------------------------------------------------- /experiment_scripts/train_poisson_lapl_img.py: -------------------------------------------------------------------------------- 1 | '''Reproduces Paper Sec. 4.1, Supplement Sec. 3, reconstruction from laplacian. 2 | ''' 3 | 4 | # Enable import from parent package 5 | import sys 6 | import os 7 | sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) ) 8 | 9 | import dataio, meta_modules, utils, training, loss_functions, modules 10 | 11 | from torch.utils.data import DataLoader 12 | import configargparse 13 | 14 | p = configargparse.ArgumentParser() 15 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.') 16 | 17 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging') 18 | p.add_argument('--experiment_name', type=str, required=True, 19 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.') 20 | 21 | # General training options 22 | p.add_argument('--batch_size', type=int, default=16384) 23 | p.add_argument('--lr', type=float, default=1e-4, help='learning rate. default=5e-5') 24 | p.add_argument('--num_epochs', type=int, default=10000, 25 | help='Number of epochs to train for.') 26 | 27 | p.add_argument('--epochs_til_ckpt', type=int, default=25, 28 | help='Time interval in seconds until checkpoint is saved.') 29 | p.add_argument('--steps_til_summary', type=int, default=100, 30 | help='Time interval in seconds until tensorboard summary is saved.') 31 | 32 | p.add_argument('--dataset', type=str, choices=['camera','bsd500'], default='camera', 33 | help='Dataset: choices=[camera,bsd500].') 34 | p.add_argument('--model_type', type=str, default='sine', 35 | help='Options are "sine" (all sine activations) and "mixed" (first layer sine, other layers tanh)') 36 | 37 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.') 38 | opt = p.parse_args() 39 | 40 | 41 | if opt.dataset == 'camera': 42 | img_dataset = dataio.Camera() 43 | coord_dataset = dataio.Implicit2DWrapper(img_dataset, sidelength=256, compute_diff='laplacian') 44 | elif opt.dataset == 'bsd500': 45 | # you can select the image your like in idx to sample 46 | img_dataset = dataio.BSD500ImageDataset(in_folder='/media/data3/awb/BSD500/train', 47 | idx_to_sample=[19]) 48 | coord_dataset = dataio.Implicit2DWrapper(img_dataset, sidelength=256, compute_diff='laplacian') 49 | 50 | dataloader = DataLoader(coord_dataset, shuffle=True, batch_size=opt.batch_size, pin_memory=True, num_workers=0) 51 | 52 | # Define the model 53 | if opt.model_type == 'sine' or opt.model_type == 'relu' or opt.model_type == 'tanh': 54 | model = modules.SingleBVPNet(type=opt.model_type, mode='mlp', sidelength=(256, 256)) 55 | elif opt.model_type == 'rbf' or opt.model_type == 'nerf': 56 | model = modules.SingleBVPNet(type='relu', mode=opt.model_type, sidelength=(256, 256)) 57 | else: 58 | raise NotImplementedError 59 | model.cuda() 60 | 61 | # Define the loss 62 | loss_fn = loss_functions.laplace_mse 63 | summary_fn = utils.write_laplace_summary 64 | 65 | root_path = os.path.join(opt.logging_root, opt.experiment_name) 66 | 67 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr, 68 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt, 69 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn, double_precision=False) 70 | -------------------------------------------------------------------------------- /experiment_scripts/train_sdf.py: -------------------------------------------------------------------------------- 1 | '''Reproduces Sec. 4.2 in main paper and Sec. 4 in Supplement. 2 | ''' 3 | 4 | # Enable import from parent package 5 | import sys 6 | import os 7 | sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) ) 8 | 9 | import dataio, meta_modules, utils, training, loss_functions, modules 10 | 11 | from torch.utils.data import DataLoader 12 | import configargparse 13 | 14 | p = configargparse.ArgumentParser() 15 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.') 16 | 17 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging') 18 | p.add_argument('--experiment_name', type=str, required=True, 19 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.') 20 | 21 | # General training options 22 | p.add_argument('--batch_size', type=int, default=1400) 23 | p.add_argument('--lr', type=float, default=1e-4, help='learning rate. default=5e-5') 24 | p.add_argument('--num_epochs', type=int, default=10000, 25 | help='Number of epochs to train for.') 26 | 27 | p.add_argument('--epochs_til_ckpt', type=int, default=1, 28 | help='Time interval in seconds until checkpoint is saved.') 29 | p.add_argument('--steps_til_summary', type=int, default=100, 30 | help='Time interval in seconds until tensorboard summary is saved.') 31 | 32 | p.add_argument('--model_type', type=str, default='sine', 33 | help='Options are "sine" (all sine activations) and "mixed" (first layer sine, other layers tanh)') 34 | p.add_argument('--point_cloud_path', type=str, default='/home/sitzmann/data/point_cloud.xyz', 35 | help='Options are "sine" (all sine activations) and "mixed" (first layer sine, other layers tanh)') 36 | 37 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.') 38 | opt = p.parse_args() 39 | 40 | 41 | sdf_dataset = dataio.PointCloud(opt.point_cloud_path, on_surface_points=opt.batch_size) 42 | dataloader = DataLoader(sdf_dataset, shuffle=True, batch_size=1, pin_memory=True, num_workers=0) 43 | 44 | # Define the model. 45 | if opt.model_type == 'nerf': 46 | model = modules.SingleBVPNet(type='relu', mode='nerf', in_features=3) 47 | else: 48 | model = modules.SingleBVPNet(type=opt.model_type, in_features=3) 49 | model.cuda() 50 | 51 | # Define the loss 52 | loss_fn = loss_functions.sdf 53 | summary_fn = utils.write_sdf_summary 54 | 55 | root_path = os.path.join(opt.logging_root, opt.experiment_name) 56 | 57 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr, 58 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt, 59 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn, double_precision=False, 60 | clip_grad=True) 61 | -------------------------------------------------------------------------------- /experiment_scripts/train_video.py: -------------------------------------------------------------------------------- 1 | '''Reproduces Supplement Sec. 7''' 2 | 3 | # Enable import from parent package 4 | import sys 5 | import os 6 | sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) ) 7 | 8 | import dataio, meta_modules, utils, training, loss_functions, modules 9 | 10 | from torch.utils.data import DataLoader 11 | import configargparse 12 | from functools import partial 13 | import skvideo.datasets 14 | 15 | p = configargparse.ArgumentParser() 16 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.') 17 | 18 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging') 19 | p.add_argument('--experiment_name', type=str, required=True, 20 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.') 21 | 22 | # General training options 23 | p.add_argument('--batch_size', type=int, default=1) 24 | p.add_argument('--lr', type=float, default=1e-4, help='learning rate. default=1e-4') 25 | p.add_argument('--num_epochs', type=int, default=100000, 26 | help='Number of epochs to train for.') 27 | 28 | p.add_argument('--epochs_til_ckpt', type=int, default=1000, 29 | help='Time interval in seconds until checkpoint is saved.') 30 | p.add_argument('--steps_til_summary', type=int, default=100, 31 | help='Time interval in seconds until tensorboard summary is saved.') 32 | p.add_argument('--dataset', type=str, default='bikes', 33 | help='Video dataset; one of (cat, bikes)', choices=['cat', 'bikes']) 34 | p.add_argument('--model_type', type=str, default='sine', 35 | help='Options currently are "sine" (all sine activations), "relu" (all relu activations,' 36 | '"nerf" (relu activations and positional encoding as in NeRF), "rbf" (input rbf layer, rest relu)') 37 | p.add_argument('--sample_frac', type=float, default=38e-4, 38 | help='What fraction of video pixels to sample in each batch (default is all)') 39 | 40 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.') 41 | opt = p.parse_args() 42 | 43 | if opt.dataset == 'cat': 44 | video_path = './data/video_512.npy' 45 | elif opt.dataset == 'bikes': 46 | video_path = skvideo.datasets.bikes() 47 | 48 | vid_dataset = dataio.Video(video_path) 49 | coord_dataset = dataio.Implicit3DWrapper(vid_dataset, sidelength=vid_dataset.shape, sample_fraction=opt.sample_frac) 50 | dataloader = DataLoader(coord_dataset, shuffle=True, batch_size=opt.batch_size, pin_memory=True, num_workers=0) 51 | 52 | # Define the model. 53 | if opt.model_type == 'sine' or opt.model_type == 'relu' or opt.model_type == 'tanh': 54 | model = modules.SingleBVPNet(type=opt.model_type, in_features=3, out_features=vid_dataset.channels, 55 | mode='mlp', hidden_features=1024, num_hidden_layers=3) 56 | elif opt.model_type == 'rbf' or opt.model_type == 'nerf': 57 | model = modules.SingleBVPNet(type='relu', in_features=3, out_features=vid_dataset.channels, mode=opt.model_type) 58 | else: 59 | raise NotImplementedError 60 | model.cuda() 61 | 62 | root_path = os.path.join(opt.logging_root, opt.experiment_name) 63 | 64 | # Define the loss 65 | loss_fn = partial(loss_functions.image_mse, None) 66 | summary_fn = partial(utils.write_video_summary, vid_dataset) 67 | 68 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr, 69 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt, 70 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn) 71 | -------------------------------------------------------------------------------- /experiment_scripts/train_wave_equation.py: -------------------------------------------------------------------------------- 1 | '''Reproduces Paper Sec. 4.3 and Supplement Sec. 5''' 2 | 3 | # Enable import from parent package 4 | import sys 5 | import os 6 | sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) ) 7 | 8 | import dataio, meta_modules, utils, training, loss_functions, modules 9 | 10 | from torch.utils.data import DataLoader 11 | import configargparse 12 | 13 | p = configargparse.ArgumentParser() 14 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.') 15 | 16 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging') 17 | p.add_argument('--experiment_name', type=str, required=True, 18 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.') 19 | 20 | # General training options 21 | p.add_argument('--batch_size', type=int, default=32) 22 | p.add_argument('--lr', type=float, default=2e-5, help='learning rate. default=2e-5') 23 | p.add_argument('--num_epochs', type=int, default=100000, 24 | help='Number of epochs to train for.') 25 | 26 | p.add_argument('--epochs_til_ckpt', type=int, default=1000, 27 | help='Time interval in seconds until checkpoint is saved.') 28 | p.add_argument('--steps_til_summary', type=int, default=100, 29 | help='Time interval in seconds until tensorboard summary is saved.') 30 | p.add_argument('--model', type=str, default='sine', required=False, choices=['sine', 'tanh', 'sigmoid', 'relu'], 31 | help='Type of model to evaluate, default is sine.') 32 | p.add_argument('--mode', type=str, default='mlp', required=False, choices=['mlp', 'rbf', 'pinn'], 33 | help='Whether to use uniform velocity parameter') 34 | p.add_argument('--velocity', type=str, default='uniform', required=False, choices=['uniform', 'square', 'circle'], 35 | help='Whether to use uniform velocity parameter') 36 | p.add_argument('--pretrain', action='store_true', default=False, required=False, help='Pretrain dirichlet and neumann conditions') 37 | p.add_argument('--clip_grad', default=0.0, type=float, help='Clip gradient.') 38 | p.add_argument('--use_lbfgs', default=False, type=bool, help='use L-BFGS.') 39 | 40 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.') 41 | opt = p.parse_args() 42 | 43 | # if we have a velocity perturbation, offset the source 44 | source_coords = [0., 0., 0.] 45 | 46 | dataset = dataio.WaveSource(sidelength=340, velocity=opt.velocity, 47 | source_coords=source_coords, pretrain=opt.pretrain) 48 | 49 | dataloader = DataLoader(dataset, shuffle=True, batch_size=opt.batch_size, pin_memory=True, num_workers=0) 50 | 51 | model = modules.SingleBVPNet(in_features=3, out_features=1, type=opt.model, mode=opt.mode, 52 | final_layer_factor=1., hidden_features=512, num_hidden_layers=3) 53 | model.cuda() 54 | 55 | # Define the loss 56 | loss_fn = loss_functions.wave_pml 57 | summary_fn = utils.write_wave_summary 58 | 59 | root_path = os.path.join(opt.logging_root, opt.experiment_name) 60 | 61 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr, 62 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt, 63 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn, clip_grad=opt.clip_grad, 64 | use_lbfgs=opt.use_lbfgs) 65 | -------------------------------------------------------------------------------- /meta_modules.py: -------------------------------------------------------------------------------- 1 | '''Modules for hypernetwork experiments, Paper Sec. 4.4 2 | ''' 3 | 4 | import torch 5 | from torch import nn 6 | from collections import OrderedDict 7 | import modules 8 | 9 | 10 | class HyperNetwork(nn.Module): 11 | def __init__(self, hyper_in_features, hyper_hidden_layers, hyper_hidden_features, hypo_module): 12 | ''' 13 | 14 | Args: 15 | hyper_in_features: In features of hypernetwork 16 | hyper_hidden_layers: Number of hidden layers in hypernetwork 17 | hyper_hidden_features: Number of hidden units in hypernetwork 18 | hypo_module: MetaModule. The module whose parameters are predicted. 19 | ''' 20 | super().__init__() 21 | 22 | hypo_parameters = hypo_module.meta_named_parameters() 23 | 24 | self.names = [] 25 | self.nets = nn.ModuleList() 26 | self.param_shapes = [] 27 | for name, param in hypo_parameters: 28 | self.names.append(name) 29 | self.param_shapes.append(param.size()) 30 | 31 | hn = modules.FCBlock(in_features=hyper_in_features, out_features=int(torch.prod(torch.tensor(param.size()))), 32 | num_hidden_layers=hyper_hidden_layers, hidden_features=hyper_hidden_features, 33 | outermost_linear=True, nonlinearity='relu') 34 | self.nets.append(hn) 35 | 36 | if 'weight' in name: 37 | self.nets[-1].net[-1].apply(lambda m: hyper_weight_init(m, param.size()[-1])) 38 | elif 'bias' in name: 39 | self.nets[-1].net[-1].apply(lambda m: hyper_bias_init(m)) 40 | 41 | def forward(self, z): 42 | ''' 43 | Args: 44 | z: Embedding. Input to hypernetwork. Could be output of "Autodecoder" (see above) 45 | 46 | Returns: 47 | params: OrderedDict. Can be directly passed as the "params" parameter of a MetaModule. 48 | ''' 49 | params = OrderedDict() 50 | for name, net, param_shape in zip(self.names, self.nets, self.param_shapes): 51 | batch_param_shape = (-1,) + param_shape 52 | params[name] = net(z).reshape(batch_param_shape) 53 | return params 54 | 55 | 56 | class NeuralProcessImplicit2DHypernet(nn.Module): 57 | '''A canonical 2D representation hypernetwork mapping 2D coords to out_features.''' 58 | def __init__(self, in_features, out_features, image_resolution=None, encoder_nl='sine'): 59 | super().__init__() 60 | 61 | latent_dim = 256 62 | self.hypo_net = modules.SingleBVPNet(out_features=out_features, type='sine', sidelength=image_resolution, 63 | in_features=2) 64 | self.hyper_net = HyperNetwork(hyper_in_features=latent_dim, hyper_hidden_layers=1, hyper_hidden_features=256, 65 | hypo_module=self.hypo_net) 66 | self.set_encoder = modules.SetEncoder(in_features=in_features, out_features=latent_dim, num_hidden_layers=2, 67 | hidden_features=latent_dim, nonlinearity=encoder_nl) 68 | print(self) 69 | 70 | def freeze_hypernet(self): 71 | for param in self.hyper_net.parameters(): 72 | param.requires_grad = False 73 | 74 | def get_hypo_net_weights(self, model_input): 75 | pixels, coords = model_input['img_sub'], model_input['coords_sub'] 76 | ctxt_mask = model_input.get('ctxt_mask', None) 77 | embedding = self.set_encoder(coords, pixels, ctxt_mask=ctxt_mask) 78 | hypo_params = self.hyper_net(embedding) 79 | return hypo_params, embedding 80 | 81 | def forward(self, model_input): 82 | if model_input.get('embedding', None) is None: 83 | pixels, coords = model_input['img_sub'], model_input['coords_sub'] 84 | ctxt_mask = model_input.get('ctxt_mask', None) 85 | embedding = self.set_encoder(coords, pixels, ctxt_mask=ctxt_mask) 86 | else: 87 | embedding = model_input['embedding'] 88 | hypo_params = self.hyper_net(embedding) 89 | 90 | model_output = self.hypo_net(model_input, params=hypo_params) 91 | return {'model_in':model_output['model_in'], 'model_out':model_output['model_out'], 'latent_vec':embedding, 92 | 'hypo_params':hypo_params} 93 | 94 | 95 | class ConvolutionalNeuralProcessImplicit2DHypernet(nn.Module): 96 | def __init__(self, in_features, out_features, image_resolution=None, partial_conv=False): 97 | super().__init__() 98 | latent_dim = 256 99 | 100 | if partial_conv: 101 | self.encoder = modules.PartialConvImgEncoder(channel=in_features, image_resolution=image_resolution) 102 | else: 103 | self.encoder = modules.ConvImgEncoder(channel=in_features, image_resolution=image_resolution) 104 | self.hypo_net = modules.SingleBVPNet(out_features=out_features, type='sine', sidelength=image_resolution, 105 | in_features=2) 106 | self.hyper_net = HyperNetwork(hyper_in_features=latent_dim, hyper_hidden_layers=1, hyper_hidden_features=256, 107 | hypo_module=self.hypo_net) 108 | print(self) 109 | 110 | def forward(self, model_input): 111 | if model_input.get('embedding', None) is None: 112 | embedding = self.encoder(model_input['img_sparse']) 113 | else: 114 | embedding = model_input['embedding'] 115 | hypo_params = self.hyper_net(embedding) 116 | 117 | model_output = self.hypo_net(model_input, params=hypo_params) 118 | 119 | return {'model_in': model_output['model_in'], 'model_out': model_output['model_out'], 'latent_vec': embedding, 120 | 'hypo_params': hypo_params} 121 | 122 | def get_hypo_net_weights(self, model_input): 123 | embedding = self.encoder(model_input['img_sparse']) 124 | hypo_params = self.hyper_net(embedding) 125 | return hypo_params, embedding 126 | 127 | def freeze_hypernet(self): 128 | for param in self.hyper_net.parameters(): 129 | param.requires_grad = False 130 | for param in self.encoder.parameters(): 131 | param.requires_grad = False 132 | 133 | 134 | ############################ 135 | # Initialization schemes 136 | def hyper_weight_init(m, in_features_main_net): 137 | if hasattr(m, 'weight'): 138 | nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in') 139 | m.weight.data = m.weight.data / 1.e2 140 | 141 | if hasattr(m, 'bias'): 142 | with torch.no_grad(): 143 | m.bias.uniform_(-1/in_features_main_net, 1/in_features_main_net) 144 | 145 | 146 | def hyper_bias_init(m): 147 | if hasattr(m, 'weight'): 148 | nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in') 149 | m.weight.data = m.weight.data / 1.e2 150 | 151 | if hasattr(m, 'bias'): 152 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight) 153 | with torch.no_grad(): 154 | m.bias.uniform_(-1/fan_in, 1/fan_in) 155 | -------------------------------------------------------------------------------- /sdf_meshing.py: -------------------------------------------------------------------------------- 1 | '''From the DeepSDF repository https://github.com/facebookresearch/DeepSDF 2 | ''' 3 | #!/usr/bin/env python3 4 | 5 | import logging 6 | import numpy as np 7 | import plyfile 8 | import skimage.measure 9 | import time 10 | import torch 11 | 12 | 13 | def create_mesh( 14 | decoder, filename, N=256, max_batch=64 ** 3, offset=None, scale=None 15 | ): 16 | start = time.time() 17 | ply_filename = filename 18 | 19 | decoder.eval() 20 | 21 | # NOTE: the voxel_origin is actually the (bottom, left, down) corner, not the middle 22 | voxel_origin = [-1, -1, -1] 23 | voxel_size = 2.0 / (N - 1) 24 | 25 | overall_index = torch.arange(0, N ** 3, 1, out=torch.LongTensor()) 26 | samples = torch.zeros(N ** 3, 4) 27 | 28 | # transform first 3 columns 29 | # to be the x, y, z index 30 | samples[:, 2] = overall_index % N 31 | samples[:, 1] = (overall_index.long() / N) % N 32 | samples[:, 0] = ((overall_index.long() / N) / N) % N 33 | 34 | # transform first 3 columns 35 | # to be the x, y, z coordinate 36 | samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2] 37 | samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1] 38 | samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0] 39 | 40 | num_samples = N ** 3 41 | 42 | samples.requires_grad = False 43 | 44 | head = 0 45 | 46 | while head < num_samples: 47 | print(head) 48 | sample_subset = samples[head : min(head + max_batch, num_samples), 0:3].cuda() 49 | 50 | samples[head : min(head + max_batch, num_samples), 3] = ( 51 | decoder(sample_subset) 52 | .squeeze()#.squeeze(1) 53 | .detach() 54 | .cpu() 55 | ) 56 | head += max_batch 57 | 58 | sdf_values = samples[:, 3] 59 | sdf_values = sdf_values.reshape(N, N, N) 60 | 61 | end = time.time() 62 | print("sampling takes: %f" % (end - start)) 63 | 64 | convert_sdf_samples_to_ply( 65 | sdf_values.data.cpu(), 66 | voxel_origin, 67 | voxel_size, 68 | ply_filename + ".ply", 69 | offset, 70 | scale, 71 | ) 72 | 73 | 74 | def convert_sdf_samples_to_ply( 75 | pytorch_3d_sdf_tensor, 76 | voxel_grid_origin, 77 | voxel_size, 78 | ply_filename_out, 79 | offset=None, 80 | scale=None, 81 | ): 82 | """ 83 | Convert sdf samples to .ply 84 | 85 | :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n) 86 | :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid 87 | :voxel_size: float, the size of the voxels 88 | :ply_filename_out: string, path of the filename to save to 89 | 90 | This function adapted from: https://github.com/RobotLocomotion/spartan 91 | """ 92 | 93 | start_time = time.time() 94 | 95 | numpy_3d_sdf_tensor = pytorch_3d_sdf_tensor.numpy() 96 | 97 | verts, faces, normals, values = np.zeros((0, 3)), np.zeros((0, 3)), np.zeros((0, 3)), np.zeros(0) 98 | try: 99 | verts, faces, normals, values = skimage.measure.marching_cubes_lewiner( 100 | numpy_3d_sdf_tensor, level=0.0, spacing=[voxel_size] * 3 101 | ) 102 | except: 103 | pass 104 | 105 | # transform from voxel coordinates to camera coordinates 106 | # note x and y are flipped in the output of marching_cubes 107 | mesh_points = np.zeros_like(verts) 108 | mesh_points[:, 0] = voxel_grid_origin[0] + verts[:, 0] 109 | mesh_points[:, 1] = voxel_grid_origin[1] + verts[:, 1] 110 | mesh_points[:, 2] = voxel_grid_origin[2] + verts[:, 2] 111 | 112 | # apply additional offset and scale 113 | if scale is not None: 114 | mesh_points = mesh_points / scale 115 | if offset is not None: 116 | mesh_points = mesh_points - offset 117 | 118 | # try writing to the ply file 119 | 120 | num_verts = verts.shape[0] 121 | num_faces = faces.shape[0] 122 | 123 | verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")]) 124 | 125 | for i in range(0, num_verts): 126 | verts_tuple[i] = tuple(mesh_points[i, :]) 127 | 128 | faces_building = [] 129 | for i in range(0, num_faces): 130 | faces_building.append(((faces[i, :].tolist(),))) 131 | faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))]) 132 | 133 | el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex") 134 | el_faces = plyfile.PlyElement.describe(faces_tuple, "face") 135 | 136 | ply_data = plyfile.PlyData([el_verts, el_faces]) 137 | logging.debug("saving mesh to %s" % (ply_filename_out)) 138 | ply_data.write(ply_filename_out) 139 | 140 | logging.debug( 141 | "converting to ply format and writing to file took {} s".format( 142 | time.time() - start_time 143 | ) 144 | ) 145 | -------------------------------------------------------------------------------- /torchmeta/__init__.py: -------------------------------------------------------------------------------- 1 | from torchmeta import datasets 2 | from torchmeta import modules 3 | from torchmeta import toy 4 | from torchmeta import transforms 5 | from torchmeta import utils 6 | 7 | from torchmeta.version import VERSION as __version__ 8 | -------------------------------------------------------------------------------- /torchmeta/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from torchmeta.datasets.triplemnist import TripleMNIST 2 | from torchmeta.datasets.doublemnist import DoubleMNIST 3 | from torchmeta.datasets.cub import CUB 4 | from torchmeta.datasets.cifar100 import CIFARFS, FC100 5 | from torchmeta.datasets.miniimagenet import MiniImagenet 6 | from torchmeta.datasets.omniglot import Omniglot 7 | from torchmeta.datasets.tieredimagenet import TieredImagenet 8 | from torchmeta.datasets.tcga import TCGA 9 | 10 | from torchmeta.datasets import helpers 11 | 12 | __all__ = [ 13 | 'TCGA', 14 | 'Omniglot', 15 | 'MiniImagenet', 16 | 'TieredImagenet', 17 | 'CIFARFS', 18 | 'FC100', 19 | 'CUB', 20 | 'DoubleMNIST', 21 | 'TripleMNIST', 22 | 'helpers' 23 | ] 24 | -------------------------------------------------------------------------------- /torchmeta/datasets/assets/cifar100/cifar-fs/test.json: -------------------------------------------------------------------------------- 1 | [["aquatic_mammals","whale"],["flowers","poppy"],["flowers","rose"],["fruit_and_vegetables","sweet_pepper"],["household_electrical_devices","telephone"],["household_furniture","bed"],["household_furniture","table"],["household_furniture","wardrobe"],["large_carnivores","leopard"],["large_natural_outdoor_scenes","plain"],["large_omnivores_and_herbivores","chimpanzee"],["medium_mammals","fox"],["non-insect_invertebrates","snail"],["non-insect_invertebrates","worm"],["people","baby"],["people","man"],["people","woman"],["vehicles_1","bicycle"],["vehicles_1","pickup_truck"],["vehicles_2","rocket"]] -------------------------------------------------------------------------------- /torchmeta/datasets/assets/cifar100/cifar-fs/train.json: -------------------------------------------------------------------------------- 1 | [["aquatic_mammals","dolphin"],["aquatic_mammals","seal"],["fish","aquarium_fish"],["fish","ray"],["fish","trout"],["flowers","orchid"],["flowers","sunflower"],["flowers","tulip"],["food_containers","bottle"],["food_containers","bowl"],["food_containers","can"],["food_containers","cup"],["food_containers","plate"],["fruit_and_vegetables","apple"],["fruit_and_vegetables","mushroom"],["fruit_and_vegetables","orange"],["fruit_and_vegetables","pear"],["household_electrical_devices","clock"],["household_electrical_devices","keyboard"],["household_furniture","chair"],["household_furniture","couch"],["insects","bee"],["insects","caterpillar"],["insects","cockroach"],["large_carnivores","bear"],["large_carnivores","lion"],["large_carnivores","tiger"],["large_carnivores","wolf"],["large_man-made_outdoor_things","bridge"],["large_man-made_outdoor_things","castle"],["large_man-made_outdoor_things","house"],["large_man-made_outdoor_things","road"],["large_man-made_outdoor_things","skyscraper"],["large_natural_outdoor_scenes","cloud"],["large_natural_outdoor_scenes","forest"],["large_natural_outdoor_scenes","mountain"],["large_omnivores_and_herbivores","elephant"],["large_omnivores_and_herbivores","kangaroo"],["medium_mammals","porcupine"],["medium_mammals","possum"],["medium_mammals","raccoon"],["medium_mammals","skunk"],["non-insect_invertebrates","lobster"],["non-insect_invertebrates","spider"],["people","boy"],["people","girl"],["reptiles","dinosaur"],["reptiles","lizard"],["reptiles","snake"],["reptiles","turtle"],["small_mammals","hamster"],["small_mammals","mouse"],["small_mammals","rabbit"],["small_mammals","shrew"],["small_mammals","squirrel"],["trees","oak_tree"],["trees","palm_tree"],["trees","pine_tree"],["trees","willow_tree"],["vehicles_1","bus"],["vehicles_1","train"],["vehicles_2","lawn_mower"],["vehicles_2","streetcar"],["vehicles_2","tank"]] -------------------------------------------------------------------------------- /torchmeta/datasets/assets/cifar100/cifar-fs/val.json: -------------------------------------------------------------------------------- 1 | [["aquatic_mammals","beaver"],["aquatic_mammals","otter"],["fish","flatfish"],["fish","shark"],["household_electrical_devices","lamp"],["household_electrical_devices","television"],["insects","beetle"],["insects","butterfly"],["large_natural_outdoor_scenes","sea"],["large_omnivores_and_herbivores","camel"],["large_omnivores_and_herbivores","cattle"],["non-insect_invertebrates","crab"],["reptiles","crocodile"],["trees","maple_tree"],["vehicles_1","motorcycle"],["vehicles_2","tractor"]] -------------------------------------------------------------------------------- /torchmeta/datasets/assets/cifar100/fc100/test.json: -------------------------------------------------------------------------------- 1 | ["aquatic_mammals","insects","medium_mammals","people"] -------------------------------------------------------------------------------- /torchmeta/datasets/assets/cifar100/fc100/train.json: -------------------------------------------------------------------------------- 1 | ["fish","flowers","food_containers","fruit_and_vegetables","household_electrical_devices","household_furniture","large_man-made_outdoor_things","large_natural_outdoor_scenes","reptiles","trees","vehicles_1","vehicles_2"] -------------------------------------------------------------------------------- /torchmeta/datasets/assets/cifar100/fc100/val.json: -------------------------------------------------------------------------------- 1 | ["large_carnivores","large_omnivores_and_herbivores","non-insect_invertebrates","small_mammals"] -------------------------------------------------------------------------------- /torchmeta/datasets/assets/cub/test.json: -------------------------------------------------------------------------------- 1 | ["004.Groove_billed_Ani","008.Rhinoceros_Auklet","012.Yellow_headed_Blackbird","016.Painted_Bunting","020.Yellow_breasted_Chat","024.Red_faced_Cormorant","028.Brown_Creeper","032.Mangrove_Cuckoo","036.Northern_Flicker","040.Olive_sided_Flycatcher","044.Frigatebird","048.European_Goldfinch","052.Pied_billed_Grebe","056.Pine_Grosbeak","060.Glaucous_winged_Gull","064.Ring_billed_Gull","068.Ruby_throated_Hummingbird","072.Pomarine_Jaeger","076.Dark_eyed_Junco","080.Green_Kingfisher","084.Red_legged_Kittiwake","088.Western_Meadowlark","092.Nighthawk","096.Hooded_Oriole","100.Brown_Pelican","104.American_Pipit","108.White_necked_Raven","112.Great_Grey_Shrike","116.Chipping_Sparrow","120.Fox_Sparrow","124.Le_Conte_Sparrow","128.Seaside_Sparrow","132.White_crowned_Sparrow","136.Barn_Swallow","140.Summer_Tanager","144.Common_Tern","148.Green_tailed_Towhee","152.Blue_headed_Vireo","156.White_eyed_Vireo","160.Black_throated_Blue_Warbler","164.Cerulean_Warbler","168.Kentucky_Warbler","172.Nashville_Warbler","176.Prairie_Warbler","180.Wilson_Warbler","184.Louisiana_Waterthrush","188.Pileated_Woodpecker","192.Downy_Woodpecker","196.House_Wren","200.Common_Yellowthroat"] -------------------------------------------------------------------------------- /torchmeta/datasets/assets/cub/train.json: -------------------------------------------------------------------------------- 1 | ["001.Black_footed_Albatross","003.Sooty_Albatross","005.Crested_Auklet","007.Parakeet_Auklet","009.Brewer_Blackbird","011.Rusty_Blackbird","013.Bobolink","015.Lazuli_Bunting","017.Cardinal","019.Gray_Catbird","021.Eastern_Towhee","023.Brandt_Cormorant","025.Pelagic_Cormorant","027.Shiny_Cowbird","029.American_Crow","031.Black_billed_Cuckoo","033.Yellow_billed_Cuckoo","035.Purple_Finch","037.Acadian_Flycatcher","039.Least_Flycatcher","041.Scissor_tailed_Flycatcher","043.Yellow_bellied_Flycatcher","045.Northern_Fulmar","047.American_Goldfinch","049.Boat_tailed_Grackle","051.Horned_Grebe","053.Western_Grebe","055.Evening_Grosbeak","057.Rose_breasted_Grosbeak","059.California_Gull","061.Heermann_Gull","063.Ivory_Gull","065.Slaty_backed_Gull","067.Anna_Hummingbird","069.Rufous_Hummingbird","071.Long_tailed_Jaeger","073.Blue_Jay","075.Green_Jay","077.Tropical_Kingbird","079.Belted_Kingfisher","081.Pied_Kingfisher","083.White_breasted_Kingfisher","085.Horned_Lark","087.Mallard","089.Hooded_Merganser","091.Mockingbird","093.Clark_Nutcracker","095.Baltimore_Oriole","097.Orchard_Oriole","099.Ovenbird","101.White_Pelican","103.Sayornis","105.Whip_poor_Will","107.Common_Raven","109.American_Redstart","111.Loggerhead_Shrike","113.Baird_Sparrow","115.Brewer_Sparrow","117.Clay_colored_Sparrow","119.Field_Sparrow","121.Grasshopper_Sparrow","123.Henslow_Sparrow","125.Lincoln_Sparrow","127.Savannah_Sparrow","129.Song_Sparrow","131.Vesper_Sparrow","133.White_throated_Sparrow","135.Bank_Swallow","137.Cliff_Swallow","139.Scarlet_Tanager","141.Artic_Tern","143.Caspian_Tern","145.Elegant_Tern","147.Least_Tern","149.Brown_Thrasher","151.Black_capped_Vireo","153.Philadelphia_Vireo","155.Warbling_Vireo","157.Yellow_throated_Vireo","159.Black_and_white_Warbler","161.Blue_winged_Warbler","163.Cape_May_Warbler","165.Chestnut_sided_Warbler","167.Hooded_Warbler","169.Magnolia_Warbler","171.Myrtle_Warbler","173.Orange_crowned_Warbler","175.Pine_Warbler","177.Prothonotary_Warbler","179.Tennessee_Warbler","181.Worm_eating_Warbler","183.Northern_Waterthrush","185.Bohemian_Waxwing","187.American_Three_toed_Woodpecker","189.Red_bellied_Woodpecker","191.Red_headed_Woodpecker","193.Bewick_Wren","195.Carolina_Wren","197.Marsh_Wren","199.Winter_Wren"] -------------------------------------------------------------------------------- /torchmeta/datasets/assets/cub/val.json: -------------------------------------------------------------------------------- 1 | ["002.Laysan_Albatross","006.Least_Auklet","010.Red_winged_Blackbird","014.Indigo_Bunting","018.Spotted_Catbird","022.Chuck_will_Widow","026.Bronzed_Cowbird","030.Fish_Crow","034.Gray_crowned_Rosy_Finch","038.Great_Crested_Flycatcher","042.Vermilion_Flycatcher","046.Gadwall","050.Eared_Grebe","054.Blue_Grosbeak","058.Pigeon_Guillemot","062.Herring_Gull","066.Western_Gull","070.Green_Violetear","074.Florida_Jay","078.Gray_Kingbird","082.Ringed_Kingfisher","086.Pacific_Loon","090.Red_breasted_Merganser","094.White_breasted_Nuthatch","098.Scott_Oriole","102.Western_Wood_Pewee","106.Horned_Puffin","110.Geococcyx","114.Black_throated_Sparrow","118.House_Sparrow","122.Harris_Sparrow","126.Nelson_Sharp_tailed_Sparrow","130.Tree_Sparrow","134.Cape_Glossy_Starling","138.Tree_Swallow","142.Black_Tern","146.Forsters_Tern","150.Sage_Thrasher","154.Red_eyed_Vireo","158.Bay_breasted_Warbler","162.Canada_Warbler","166.Golden_winged_Warbler","170.Mourning_Warbler","174.Palm_Warbler","178.Swainson_Warbler","182.Yellow_Warbler","186.Cedar_Waxwing","190.Red_cockaded_Woodpecker","194.Cactus_Wren","198.Rock_Wren"] -------------------------------------------------------------------------------- /torchmeta/datasets/assets/doublemnist/test.json: -------------------------------------------------------------------------------- 1 | ["02", "17", "25", "32", "36", "46", "47", "49", "55", "57", "66", "67", "68", "73", "78", "80", "83", "86", "92", "96"] -------------------------------------------------------------------------------- /torchmeta/datasets/assets/doublemnist/train.json: -------------------------------------------------------------------------------- 1 | ["00", "01", "04", "05", "06", "08", "09", "11", "12", "13", "14", "15", "16", "18", "19", "20", "21", "23", "24", "26", "28", "29", "30", "31", "33", "35", "37", "38", "41", "42", "43", "44", "45", "50", "51", "53", "54", "56", "59", "60", "62", "63", "65", "69", "70", "72", "74", "75", "76", "77", "79", "81", "82", "84", "85", "87", "88", "89", "90", "91", "94", "95", "97", "98"] -------------------------------------------------------------------------------- /torchmeta/datasets/assets/doublemnist/val.json: -------------------------------------------------------------------------------- 1 | ["03", "07", "10", "22", "27", "34", "39", "40", "48", "52", "58", "61", "64", "71", "93", "99"] -------------------------------------------------------------------------------- /torchmeta/datasets/assets/omniglot/test.json: -------------------------------------------------------------------------------- 1 | {"background":{},"evaluation":{"Gurmukhi":["character42","character43","character44","character45"],"Kannada":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25","character26","character27","character28","character29","character30","character31","character32","character33","character34","character35","character36","character37","character38","character39","character40","character41"],"Keble":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25","character26"],"Malayalam":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25","character26","character27","character28","character29","character30","character31","character32","character33","character34","character35","character36","character37","character38","character39","character40","character41","character42","character43","character44","character45","character46","character47"],"Manipuri":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25","character26","character27","character28","character29","character30","character31","character32","character33","character34","character35","character36","character37","character38","character39","character40"],"Mongolian":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25","character26","character27","character28","character29","character30"],"Old_Church_Slavonic_(Cyrillic)":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25","character26","character27","character28","character29","character30","character31","character32","character33","character34","character35","character36","character37","character38","character39","character40","character41","character42","character43","character44","character45"],"Oriya":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25","character26","character27","character28","character29","character30","character31","character32","character33","character34","character35","character36","character37","character38","character39","character40","character41","character42","character43","character44","character45","character46"],"Sylheti":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25","character26","character27","character28"],"Syriac_(Serto)":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23"],"Tengwar":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25"],"Tibetan":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25","character26","character27","character28","character29","character30","character31","character32","character33","character34","character35","character36","character37","character38","character39","character40","character41","character42"],"ULOG":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25","character26"]}} -------------------------------------------------------------------------------- /torchmeta/datasets/assets/omniglot/val.json: -------------------------------------------------------------------------------- 1 | {"background":{"Armenian":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25","character26","character27","character28","character29","character30","character31","character32","character33","character34","character35","character36","character37","character38","character39","character40","character41"],"Bengali":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25","character26","character27","character28","character29","character30","character31","character32","character33","character34","character35","character36","character37","character38","character39","character40","character41","character42","character43","character44","character45","character46"],"Early_Aramaic":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22"],"Hebrew":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22"],"Mkhedruli_(Georgian)":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25","character26","character27","character28","character29","character30","character31","character32","character33","character34","character35","character36","character37","character38","character39","character40","character41"]},"evaluation":{}} -------------------------------------------------------------------------------- /torchmeta/datasets/assets/tcga/cancers.json: -------------------------------------------------------------------------------- 1 | ["ACC","BLCA","BRCA","CESC","CHOL","COAD","DLBC","ESCA","FPPP","GBM","HNSC","KICH","KIRP","LAML","LGG","LIHC","LUAD","LUSC","MESO","OV","PAAD","PCPG","PRAD","READ","SARC","SKCM","STAD","TGCT","THCA","THYM","UCEC","UCS","UVM"] -------------------------------------------------------------------------------- /torchmeta/datasets/assets/tcga/task_variables.json: -------------------------------------------------------------------------------- 1 | ["gender","tumor_tissue_site","icd_10","icd_o_3_site","_EVENT","histological_type","_RFS","anatomic_neoplasm_subdivision","oct_embedded","_PANCAN_CNA_PANCAN_K8","_PANCAN_DNAMethyl_PANCAN","_PANCAN_Cluster_Cluster_PANCAN","_PANCAN_miRNA_PANCAN","_PANCAN_mutation_PANCAN","_PANCAN_UNC_RNAseq_PANCAN_K16","_PANCAN_RPPA_PANCAN_K8","height","clinical_stage","menopause_status","clinical_M","year_of_tobacco_smoking_onset","clinical_T","lymphatic_invasion","venous_invasion","synchronous_colon_cancer_present","additional_treatment_completion_success_outcome","breast_carcinoma_estrogen_receptor_status","breast_carcinoma_progesterone_receptor_status","clinical_N","white_cell_count_result","alcohol_history_documented","PAM50Call_RNAseq","mental_status_changes","lymphovascular_invasion_present","family_history_of_cancer","_PANCAN_DNAMethyl_BRCA","Node_nature2012","Tumor_nature2012","Metastasis_nature2012","ER_Status_nature2012","PR_Status_nature2012","HER2_Final_Status_nature2012","asthma_history","family_history_of_primary_brain_tumor","_PANCAN_mirna_BRCA","colon_polyps_present","animal_insect_allergy_history","Expression_Subtype","KRAS","Pathology_Updated","gleason_score","extrathyroid_carcinoma_present_extension_status","GeneExp_Subtype","AWG_cancer_type_Oct62011","hypermutation","hypertension","pregnancies","diabetes","biochemical_recurrence","_PANCAN_mirna_OV","melanoma_origin_skin_anatomic_site","barretts_esophagus","diagnosis_subtype","_PANCAN_mirna_KIRC","birth_control_pill_history_usage_category","albumin_result_specified_value","amount_of_alcohol_consumption_per_day","albumin_result_lower_limit","albumin_result_upper_limit","family_history_of_stomach_cancer","_PANCAN_DNAMethyl_GBM","_PANCAN_DNAMethyl_UCEC","_PANCAN_mirna_UCEC","melanoma_clark_level_value","_PANCAN_mirna_HNSC","fibrosis_ishak_score","adjacent_hepatic_tissue_inflammation_extent_type","_PANCAN_DNAMethyl_HNSC","antireflux_treatment","ldh1_mutation_found","_PANCAN_DNAMethyl_COADREAD","_PANCAN_mirna_LUAD","leukemia_french_american_british_morphology_code","acute_myeloid_leukemia_calgb_cytogenetics_risk_category","_PANCAN_DNAMethyl_LAML","_PANCAN_mirna_COAD","_PANCAN_mirna_LAML","metastatic_diagnosis","_PANCAN_DNAMethyl_LUAD","_PANCAN_DNAMethyl_LUSC","_PANCAN_mirna_LUSC","cancer_diagnosis_cancer_type_icd9_text_name","_PANCAN_DNAMethyl_BLCA","_PANCAN_mirna_BLCA","leiomyosarcoma_histologic_subtype","family_history_other_cancer","necrosis","weiss_venous_invasion","weiss_score","_PANCAN_mirna_READ","atypical_mitotic_figures","relative_cancer_type","metastatic_breast_carcinoma_estrogen_receptor_status","metastatic_breast_carcinoma_progesterone_receptor_status","well_differentiated_liposarcoma_primary_dx","food_allergy_types","metastatic_breast_carcinom_lb_prc_hr2_n_mmnhstchmstry_rcptr_stts"] -------------------------------------------------------------------------------- /torchmeta/datasets/assets/triplemnist/test.json: -------------------------------------------------------------------------------- 1 | ["002", "003", "008", "014", "016", "017", "039", "046", "047", "051", "056", "058", "060", "065", "067", "068", "073", "076", "077", "086", "087", "088", "096", "098", "099", "106", "109", "111", "113", "123", "129", "135", "139", "140", "141", "146", "154", "158", "176", "180", "186", "193", "198", "206", "208", "213", "214", "222", "224", "244", "253", "255", "257", "259", "262", "265", "271", "290", "296", "301", "304", "305", "315", "319", "321", "322", "323", "339", "340", "342", "354", "357", "358", "359", "360", "364", "365", "371", "377", "380", "382", "385", "390", "394", "401", "409", "411", "412", "418", "419", "420", "424", "428", "433", "434", "438", "440", "451", "460", "465", "467", "471", "472", "473", "480", "481", "484", "486", "489", "490", "493", "496", "497", "504", "510", "513", "515", "530", "537", "539", "544", "552", "555", "559", "569", "575", "576", "582", "588", "595", "596", "608", "615", "628", "630", "631", "635", "638", "645", "647", "661", "665", "680", "681", "695", "717", "720", "737", "742", "745", "755", "759", "762", "763", "765", "766", "768", "780", "781", "782", "790", "806", "812", "815", "817", "819", "823", "824", "826", "836", "837", "842", "843", "846", "860", "862", "865", "866", "867", "873", "878", "884", "885", "895", "899", "907", "919", "929", "942", "945", "951", "954", "956", "957", "964", "976", "988", "990", "996", "998"] -------------------------------------------------------------------------------- /torchmeta/datasets/assets/triplemnist/train.json: -------------------------------------------------------------------------------- 1 | ["000", "004", "005", "006", "007", "009", "011", "013", "015", "018", "020", "021", "022", "023", "024", "026", "027", "028", "030", "031", "032", "033", "034", "035", "036", "037", "038", "040", "041", "042", "043", "045", "048", "049", "050", "052", "054", "055", "057", "059", "061", "062", "063", "064", "066", "070", "071", "072", "074", "075", "078", "079", "080", "082", "084", "085", "089", "090", "091", "093", "094", "095", "097", "100", "101", "102", "103", "104", "105", "107", "110", "114", "115", "116", "117", "118", "119", "120", "121", "122", "124", "125", "127", "128", "131", "132", "134", "138", "142", "145", "147", "148", "150", "151", "153", "155", "156", "157", "159", "160", "161", "162", "163", "164", "165", "166", "167", "169", "170", "171", "172", "174", "175", "177", "178", "179", "181", "182", "183", "184", "185", "188", "189", "190", "191", "192", "195", "196", "199", "200", "201", "202", "203", "204", "205", "209", "210", "211", "212", "216", "217", "219", "221", "223", "226", "227", "228", "229", "230", "231", "232", "233", "234", "235", "236", "237", "239", "240", "241", "242", "243", "245", "246", "248", "249", "250", "252", "254", "256", "258", "260", "261", "263", "264", "266", "267", "268", "269", "270", "272", "273", "274", "275", "276", "277", "278", "279", "280", "282", "283", "284", "285", "287", "288", "289", "291", "292", "294", "295", "297", "298", "299", "300", "302", "303", "306", "307", "308", "309", "310", "311", "313", "314", "316", "317", "318", "320", "326", "327", "328", "329", "331", "332", "333", "335", "336", "338", "343", "344", "345", "346", "347", "348", "349", "351", "352", "353", "356", "362", "363", "366", "367", "368", "369", "370", "372", "373", "374", "375", "376", "378", "379", "381", "384", "386", "387", "388", "389", "391", "392", "395", "396", "397", "398", "399", "402", "403", "404", "406", "408", "414", "415", "416", "417", "421", "425", "426", "427", "429", "430", "431", "435", "437", "443", "444", "445", "446", "447", "448", "449", "450", "452", "453", "455", "456", "457", "458", "461", "462", "463", "468", "469", "470", "474", "476", "478", "479", "482", "483", "485", "487", "491", "492", "498", "499", "500", "505", "508", "509", "511", "512", "514", "517", "518", "520", "521", "522", "523", "524", "525", "526", "527", "528", "529", "531", "532", "533", "534", "536", "538", "541", "542", "543", "545", "546", "547", "548", "549", "551", "553", "554", "556", "558", "560", "561", "563", "564", "565", "566", "567", "568", "571", "577", "578", "580", "581", "584", "585", "586", "587", "590", "591", "594", "597", "598", "599", "600", "603", "605", "606", "607", "610", "613", "614", "616", "617", "618", "619", "620", "621", "623", "624", "625", "626", "632", "633", "637", "640", "641", "642", "643", "644", "646", "648", "650", "652", "653", "654", "655", "656", "657", "659", "660", "662", "663", "664", "666", "667", "668", "669", "670", "671", "672", "674", "675", "676", "677", "678", "682", "683", "684", "685", "687", "688", "689", "690", "691", "692", "693", "694", "696", "698", "699", "700", "701", "702", "704", "705", "706", "708", "709", "710", "711", "712", "713", "714", "716", "719", "721", "722", "723", "724", "725", "727", "728", "729", "730", "732", "733", "734", "735", "736", "738", "739", "743", "744", "747", "748", "749", "751", "752", "753", "754", "756", "757", "758", "760", "761", "764", "767", "769", "770", "771", "773", "774", "776", "777", "778", "783", "784", "785", "786", "787", "788", "791", "792", "793", "795", "797", "798", "799", "800", "801", "802", "803", "804", "805", "807", "808", "809", "810", "811", "813", "814", "818", "820", "821", "822", "825", "827", "828", "830", "831", "832", "833", "834", "835", "838", "839", "840", "844", "845", "848", "849", "850", "853", "854", "856", "857", "858", "859", "861", "863", "864", "868", "869", "870", "871", "872", "874", "875", "876", "877", "881", "882", "883", "887", "891", "892", "893", "894", "896", "897", "898", "901", "902", "903", "905", "908", "909", "913", "914", "916", "917", "918", "920", "921", "922", "923", "924", "925", "927", "928", "930", "931", "932", "933", "934", "935", "936", "937", "938", "939", "941", "943", "946", "947", "948", "949", "950", "952", "953", "958", "959", "960", "965", "966", "967", "968", "969", "970", "971", "972", "973", "974", "977", "978", "979", "982", "983", "984", "985", "986", "987", "989", "991", "993", "994", "995", "997"] -------------------------------------------------------------------------------- /torchmeta/datasets/assets/triplemnist/val.json: -------------------------------------------------------------------------------- 1 | ["001", "010", "012", "019", "025", "029", "044", "053", "069", "081", "083", "092", "108", "112", "126", "130", "133", "136", "137", "143", "144", "149", "152", "168", "173", "187", "194", "197", "207", "215", "218", "220", "225", "238", "247", "251", "281", "286", "293", "312", "324", "325", "330", "334", "337", "341", "350", "355", "361", "383", "393", "400", "405", "407", "410", "413", "422", "423", "432", "436", "439", "441", "442", "454", "459", "464", "466", "475", "477", "488", "494", "495", "501", "502", "503", "506", "507", "516", "519", "535", "540", "550", "557", "562", "570", "572", "573", "574", "579", "583", "589", "592", "593", "601", "602", "604", "609", "611", "612", "622", "627", "629", "634", "636", "639", "649", "651", "658", "673", "679", "686", "697", "703", "707", "715", "718", "726", "731", "740", "741", "746", "750", "772", "775", "779", "789", "794", "796", "816", "829", "841", "847", "851", "852", "855", "879", "880", "886", "888", "889", "890", "900", "904", "906", "910", "911", "912", "915", "926", "940", "944", "955", "961", "962", "963", "975", "980", "981", "992", "999"] -------------------------------------------------------------------------------- /torchmeta/datasets/cifar100/__init__.py: -------------------------------------------------------------------------------- 1 | from torchmeta.datasets.cifar100.cifar_fs import CIFARFS 2 | from torchmeta.datasets.cifar100.fc100 import FC100 3 | 4 | __all__ = ['CIFARFS', 'FC100'] 5 | -------------------------------------------------------------------------------- /torchmeta/datasets/cifar100/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import json 4 | import h5py 5 | from PIL import Image 6 | 7 | from torchvision.datasets.utils import check_integrity, download_url 8 | from torchmeta.utils.data import Dataset, ClassDataset 9 | 10 | 11 | class CIFAR100ClassDataset(ClassDataset): 12 | folder = 'cifar100' 13 | subfolder = None 14 | download_url = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz' 15 | gz_folder = 'cifar-100-python' 16 | gz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 17 | files_md5 = { 18 | 'train': '16019d7e3df5f24257cddd939b257f8d', 19 | 'test': 'f0ef6b0ae62326f3e7ffdfab6717acfc', 20 | 'meta': '7973b15100ade9c7d40fb424638fde48' 21 | } 22 | 23 | filename = 'data.hdf5' 24 | filename_labels = '{0}_labels.json' 25 | filename_fine_names = 'fine_names.json' 26 | 27 | def __init__(self, root, meta_train=False, meta_val=False, meta_test=False, 28 | meta_split=None, transform=None, class_augmentations=None, 29 | download=False): 30 | super(CIFAR100ClassDataset, self).__init__(meta_train=meta_train, 31 | meta_val=meta_val, meta_test=meta_test, meta_split=meta_split, 32 | class_augmentations=class_augmentations) 33 | 34 | if self.subfolder is None: 35 | raise ValueError() 36 | 37 | self.root = os.path.join(os.path.expanduser(root), self.folder) 38 | self.transform = transform 39 | 40 | self.split_filename_labels = os.path.join(self.root, self.subfolder, 41 | self.filename_labels.format(self.meta_split)) 42 | self._data = None 43 | self._labels = None 44 | 45 | if download: 46 | self.download() 47 | 48 | if not self._check_integrity(): 49 | raise RuntimeError('CIFAR100 integrity check failed') 50 | self._num_classes = len(self.labels) 51 | 52 | def __getitem__(self, index): 53 | coarse_label_name, fine_label_name = self.labels[index % self.num_classes] 54 | data = self.data['{0}/{1}'.format(coarse_label_name, fine_label_name)] 55 | transform = self.get_transform(index, self.transform) 56 | target_transform = self.get_target_transform(index) 57 | 58 | return CIFAR100Dataset(index, data, coarse_label_name, fine_label_name, 59 | transform=transform, target_transform=target_transform) 60 | 61 | @property 62 | def num_classes(self): 63 | return self._num_classes 64 | 65 | @property 66 | def data(self): 67 | if self._data is None: 68 | self._data = h5py.File(os.path.join(self.root, self.filename), 'r') 69 | return self._data 70 | 71 | @property 72 | def labels(self): 73 | if self._labels is None: 74 | with open(self.split_filename_labels, 'r') as f: 75 | self._labels = json.load(f) 76 | return self._labels 77 | 78 | def _check_integrity(self): 79 | return (self._check_integrity_data() 80 | and os.path.isfile(self.split_filename_labels) 81 | and os.path.isfile(os.path.join(self.root, self.filename_fine_names))) 82 | 83 | def _check_integrity_data(self): 84 | return os.path.isfile(os.path.join(self.root, self.filename)) 85 | 86 | def close(self): 87 | if self._data is not None: 88 | self._data.close() 89 | self._data = None 90 | 91 | def download(self): 92 | import tarfile 93 | import pickle 94 | import shutil 95 | 96 | if self._check_integrity_data(): 97 | return 98 | 99 | gz_filename = '{0}.tar.gz'.format(self.gz_folder) 100 | download_url(self.download_url, self.root, filename=gz_filename, 101 | md5=self.gz_md5) 102 | with tarfile.open(os.path.join(self.root, gz_filename), 'r:gz') as tar: 103 | tar.extractall(path=self.root) 104 | 105 | train_filename = os.path.join(self.root, self.gz_folder, 'train') 106 | check_integrity(train_filename, self.files_md5['train']) 107 | with open(train_filename, 'rb') as f: 108 | data = pickle.load(f, encoding='bytes') 109 | images = data[b'data'] 110 | fine_labels = data[b'fine_labels'] 111 | coarse_labels = data[b'coarse_labels'] 112 | 113 | test_filename = os.path.join(self.root, self.gz_folder, 'test') 114 | check_integrity(test_filename, self.files_md5['test']) 115 | with open(test_filename, 'rb') as f: 116 | data = pickle.load(f, encoding='bytes') 117 | images = np.concatenate((images, data[b'data']), axis=0) 118 | fine_labels = np.concatenate((fine_labels, data[b'fine_labels']), axis=0) 119 | coarse_labels = np.concatenate((coarse_labels, data[b'coarse_labels']), axis=0) 120 | 121 | images = images.reshape((-1, 3, 32, 32)) 122 | images = images.transpose((0, 2, 3, 1)) 123 | 124 | meta_filename = os.path.join(self.root, self.gz_folder, 'meta') 125 | check_integrity(meta_filename, self.files_md5['meta']) 126 | with open(meta_filename, 'rb') as f: 127 | data = pickle.load(f, encoding='latin1') 128 | fine_label_names = data['fine_label_names'] 129 | coarse_label_names = data['coarse_label_names'] 130 | 131 | filename = os.path.join(self.root, self.filename) 132 | fine_names = dict() 133 | with h5py.File(filename, 'w') as f: 134 | for i, coarse_name in enumerate(coarse_label_names): 135 | group = f.create_group(coarse_name) 136 | fine_indices = np.unique(fine_labels[coarse_labels == i]) 137 | for j in fine_indices: 138 | dataset = group.create_dataset(fine_label_names[j], 139 | data=images[fine_labels == j]) 140 | fine_names[coarse_name] = [fine_label_names[j] for j in fine_indices] 141 | 142 | filename_fine_names = os.path.join(self.root, self.filename_fine_names) 143 | with open(filename_fine_names, 'w') as f: 144 | json.dump(fine_names, f) 145 | 146 | gz_folder = os.path.join(self.root, self.gz_folder) 147 | if os.path.isdir(gz_folder): 148 | shutil.rmtree(gz_folder) 149 | if os.path.isfile('{0}.tar.gz'.format(gz_folder)): 150 | os.remove('{0}.tar.gz'.format(gz_folder)) 151 | 152 | 153 | class CIFAR100Dataset(Dataset): 154 | def __init__(self, index, data, coarse_label_name, fine_label_name, 155 | transform=None, target_transform=None): 156 | super(CIFAR100Dataset, self).__init__(index, transform=transform, 157 | target_transform=target_transform) 158 | self.data = data 159 | self.coarse_label_name = coarse_label_name 160 | self.fine_label_name = fine_label_name 161 | 162 | def __len__(self): 163 | return self.data.shape[0] 164 | 165 | def __getitem__(self, index): 166 | image = Image.fromarray(self.data[index]) 167 | target = (self.coarse_label_name, self.fine_label_name) 168 | 169 | if self.transform is not None: 170 | image = self.transform(image) 171 | 172 | if self.target_transform is not None: 173 | target = self.target_transform(target) 174 | 175 | return (image, target) 176 | -------------------------------------------------------------------------------- /torchmeta/datasets/cifar100/cifar_fs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from torchmeta.datasets.cifar100.base import CIFAR100ClassDataset 5 | from torchmeta.datasets.utils import get_asset 6 | from torchmeta.utils.data import ClassDataset, CombinationMetaDataset 7 | 8 | 9 | class CIFARFS(CombinationMetaDataset): 10 | """ 11 | The CIFAR-FS dataset, introduced in [1]. This dataset contains 12 | images of 100 different classes from the CIFAR100 dataset [2]. 13 | 14 | Parameters 15 | ---------- 16 | root : string 17 | Root directory where the dataset folder `cifar100` exists. 18 | 19 | num_classes_per_task : int 20 | Number of classes per tasks. This corresponds to `N` in `N-way` 21 | classification. 22 | 23 | meta_train : bool (default: `False`) 24 | Use the meta-train split of the dataset. If set to `True`, then the 25 | arguments `meta_val` and `meta_test` must be set to `False`. Exactly one 26 | of these three arguments must be set to `True`. 27 | 28 | meta_val : bool (default: `False`) 29 | Use the meta-validation split of the dataset. If set to `True`, then the 30 | arguments `meta_train` and `meta_test` must be set to `False`. Exactly one 31 | of these three arguments must be set to `True`. 32 | 33 | meta_test : bool (default: `False`) 34 | Use the meta-test split of the dataset. If set to `True`, then the 35 | arguments `meta_train` and `meta_val` must be set to `False`. Exactly one 36 | of these three arguments must be set to `True`. 37 | 38 | meta_split : string in {'train', 'val', 'test'}, optional 39 | Name of the split to use. This overrides the arguments `meta_train`, 40 | `meta_val` and `meta_test` if all three are set to `False`. 41 | 42 | transform : callable, optional 43 | A function/transform that takes a `PIL` image, and returns a transformed 44 | version. See also `torchvision.transforms`. 45 | 46 | target_transform : callable, optional 47 | A function/transform that takes a target, and returns a transformed 48 | version. See also `torchvision.transforms`. 49 | 50 | dataset_transform : callable, optional 51 | A function/transform that takes a dataset (ie. a task), and returns a 52 | transformed version of it. E.g. `transforms.ClassSplitter()`. 53 | 54 | class_augmentations : list of callable, optional 55 | A list of functions that augment the dataset with new classes. These classes 56 | are transformations of existing classes. E.g. `transforms.HorizontalFlip()`. 57 | 58 | download : bool (default: `False`) 59 | If `True`, downloads the pickle files and processes the dataset in the root 60 | directory (under the `cifar100` folder). If the dataset is already 61 | available, this does not download/process the dataset again. 62 | 63 | Notes 64 | ----- 65 | The meta train/validation/test splits are over 64/16/20 classes from the 66 | CIFAR100 dataset. 67 | 68 | References 69 | ---------- 70 | .. [1] Bertinetto L., Henriques J. F., Torr P. H.S., Vedaldi A. (2019). 71 | Meta-learning with differentiable closed-form solvers. In International 72 | Conference on Learning Representations (https://arxiv.org/abs/1805.08136) 73 | 74 | .. [2] Krizhevsky A. (2009). Learning Multiple Layers of Features from Tiny 75 | Images. (https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf) 76 | """ 77 | def __init__(self, root, num_classes_per_task=None, meta_train=False, 78 | meta_val=False, meta_test=False, meta_split=None, 79 | transform=None, target_transform=None, dataset_transform=None, 80 | class_augmentations=None, download=False): 81 | dataset = CIFARFSClassDataset(root, meta_train=meta_train, 82 | meta_val=meta_val, meta_test=meta_test, meta_split=meta_split, 83 | transform=transform, class_augmentations=class_augmentations, 84 | download=download) 85 | super(CIFARFS, self).__init__(dataset, num_classes_per_task, 86 | target_transform=target_transform, dataset_transform=dataset_transform) 87 | 88 | 89 | class CIFARFSClassDataset(CIFAR100ClassDataset): 90 | subfolder = 'cifar-fs' 91 | 92 | def __init__(self, root, meta_train=False, meta_val=False, meta_test=False, 93 | meta_split=None, transform=None, class_augmentations=None, 94 | download=False): 95 | super(CIFARFSClassDataset, self).__init__(root, meta_train=meta_train, 96 | meta_val=meta_val, meta_test=meta_test, meta_split=meta_split, 97 | transform=transform, class_augmentations=class_augmentations, 98 | download=download) 99 | 100 | def download(self): 101 | if self._check_integrity(): 102 | return 103 | super(CIFARFSClassDataset, self).download() 104 | 105 | subfolder = os.path.join(self.root, self.subfolder) 106 | if not os.path.exists(subfolder): 107 | os.makedirs(subfolder) 108 | 109 | for split in ['train', 'val', 'test']: 110 | split_filename_labels = os.path.join(subfolder, 111 | self.filename_labels.format(split)) 112 | if os.path.isfile(split_filename_labels): 113 | continue 114 | 115 | data = get_asset(self.folder, self.subfolder, 116 | '{0}.json'.format(split), dtype='json') 117 | with open(split_filename_labels, 'w') as f: 118 | json.dump(data, f) 119 | -------------------------------------------------------------------------------- /torchmeta/datasets/cifar100/fc100.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from torchmeta.datasets.cifar100.base import CIFAR100ClassDataset 5 | from torchmeta.datasets.utils import get_asset 6 | from torchmeta.utils.data import ClassDataset, CombinationMetaDataset 7 | 8 | 9 | class FC100(CombinationMetaDataset): 10 | """ 11 | The Fewshot-CIFAR100 dataset, introduced in [1]. This dataset contains 12 | images of 100 different classes from the CIFAR100 dataset [2]. 13 | 14 | Parameters 15 | ---------- 16 | root : string 17 | Root directory where the dataset folder `cifar100` exists. 18 | 19 | num_classes_per_task : int 20 | Number of classes per tasks. This corresponds to `N` in `N-way` 21 | classification. 22 | 23 | meta_train : bool (default: `False`) 24 | Use the meta-train split of the dataset. If set to `True`, then the 25 | arguments `meta_val` and `meta_test` must be set to `False`. Exactly one 26 | of these three arguments must be set to `True`. 27 | 28 | meta_val : bool (default: `False`) 29 | Use the meta-validation split of the dataset. If set to `True`, then the 30 | arguments `meta_train` and `meta_test` must be set to `False`. Exactly one 31 | of these three arguments must be set to `True`. 32 | 33 | meta_test : bool (default: `False`) 34 | Use the meta-test split of the dataset. If set to `True`, then the 35 | arguments `meta_train` and `meta_val` must be set to `False`. Exactly one 36 | of these three arguments must be set to `True`. 37 | 38 | meta_split : string in {'train', 'val', 'test'}, optional 39 | Name of the split to use. This overrides the arguments `meta_train`, 40 | `meta_val` and `meta_test` if all three are set to `False`. 41 | 42 | transform : callable, optional 43 | A function/transform that takes a `PIL` image, and returns a transformed 44 | version. See also `torchvision.transforms`. 45 | 46 | target_transform : callable, optional 47 | A function/transform that takes a target, and returns a transformed 48 | version. See also `torchvision.transforms`. 49 | 50 | dataset_transform : callable, optional 51 | A function/transform that takes a dataset (ie. a task), and returns a 52 | transformed version of it. E.g. `transforms.ClassSplitter()`. 53 | 54 | class_augmentations : list of callable, optional 55 | A list of functions that augment the dataset with new classes. These classes 56 | are transformations of existing classes. E.g. `transforms.HorizontalFlip()`. 57 | 58 | download : bool (default: `False`) 59 | If `True`, downloads the pickle files and processes the dataset in the root 60 | directory (under the `cifar100` folder). If the dataset is already 61 | available, this does not download/process the dataset again. 62 | 63 | Notes 64 | ----- 65 | The meta train/validation/test splits are over 12/4/4 superclasses from the 66 | CIFAR100 dataset. The meta train/validation/test splits contain 60/20/20 67 | classes. 68 | 69 | References 70 | ---------- 71 | .. [1] Oreshkin B. N., Rodriguez P., Lacoste A. (2018). TADAM: Task dependent 72 | adaptive metric for improved few-shot learning. In Advances in Neural 73 | Information Processing Systems (https://arxiv.org/abs/1805.10123) 74 | 75 | .. [2] Krizhevsky A. (2009). Learning Multiple Layers of Features from Tiny 76 | Images. (https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf) 77 | """ 78 | def __init__(self, root, num_classes_per_task=None, meta_train=False, 79 | meta_val=False, meta_test=False, meta_split=None, 80 | transform=None, target_transform=None, dataset_transform=None, 81 | class_augmentations=None, download=False): 82 | dataset = FC100ClassDataset(root, meta_train=meta_train, 83 | meta_val=meta_val, meta_test=meta_test, meta_split=meta_split, 84 | transform=transform, class_augmentations=class_augmentations, 85 | download=download) 86 | super(FC100, self).__init__(dataset, num_classes_per_task, 87 | target_transform=target_transform, dataset_transform=dataset_transform) 88 | 89 | 90 | class FC100ClassDataset(CIFAR100ClassDataset): 91 | subfolder = 'fc100' 92 | 93 | def __init__(self, root, meta_train=False, meta_val=False, meta_test=False, 94 | meta_split=None, transform=None, class_augmentations=None, 95 | download=False): 96 | super(FC100ClassDataset, self).__init__(root, meta_train=meta_train, 97 | meta_val=meta_val, meta_test=meta_test, meta_split=meta_split, 98 | transform=transform, class_augmentations=class_augmentations, 99 | download=download) 100 | 101 | def download(self): 102 | if self._check_integrity(): 103 | return 104 | super(FC100ClassDataset, self).download() 105 | 106 | subfolder = os.path.join(self.root, self.subfolder) 107 | if not os.path.exists(subfolder): 108 | os.makedirs(subfolder) 109 | 110 | filename_fine_names = os.path.join(self.root, self.filename_fine_names) 111 | with open(filename_fine_names, 'r') as f: 112 | fine_names = json.load(f) 113 | 114 | for split in ['train', 'val', 'test']: 115 | split_filename_labels = os.path.join(subfolder, 116 | self.filename_labels.format(split)) 117 | if os.path.isfile(split_filename_labels): 118 | continue 119 | 120 | data = get_asset(self.folder, self.subfolder, 121 | '{0}.json'.format(split), dtype='json') 122 | with open(split_filename_labels, 'w') as f: 123 | labels = [[coarse_name, fine_name] for coarse_name in data 124 | for fine_name in fine_names[coarse_name]] 125 | json.dump(labels, f) 126 | -------------------------------------------------------------------------------- /torchmeta/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | def get_asset_path(*args): 5 | basedir = os.path.dirname(__file__) 6 | return os.path.join(basedir, 'assets', *args) 7 | 8 | 9 | def get_asset(*args, dtype=None): 10 | filename = get_asset_path(*args) 11 | if not os.path.isfile(filename): 12 | raise IOError('{} not found'.format(filename)) 13 | 14 | if dtype is None: 15 | _, dtype = os.path.splitext(filename) 16 | dtype = dtype[1:] 17 | 18 | if dtype == 'json': 19 | with open(filename, 'r') as f: 20 | data = json.load(f) 21 | else: 22 | raise NotImplementedError() 23 | return data 24 | -------------------------------------------------------------------------------- /torchmeta/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from torchmeta.modules.batchnorm import MetaBatchNorm1d, MetaBatchNorm2d, MetaBatchNorm3d 2 | from torchmeta.modules.container import MetaSequential 3 | from torchmeta.modules.conv import MetaConv1d, MetaConv2d, MetaConv3d 4 | from torchmeta.modules.linear import MetaLinear, MetaBilinear 5 | from torchmeta.modules.module import MetaModule 6 | from torchmeta.modules.normalization import MetaLayerNorm 7 | 8 | __all__ = [ 9 | 'MetaBatchNorm1d', 'MetaBatchNorm2d', 'MetaBatchNorm3d', 10 | 'MetaSequential', 11 | 'MetaConv1d', 'MetaConv2d', 'MetaConv3d', 12 | 'MetaLinear', 'MetaBilinear', 13 | 'MetaModule', 14 | 'MetaLayerNorm' 15 | ] -------------------------------------------------------------------------------- /torchmeta/modules/batchnorm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from collections import OrderedDict 5 | from torch.nn.modules.batchnorm import _BatchNorm 6 | from torchmeta.modules.module import MetaModule 7 | 8 | class _MetaBatchNorm(_BatchNorm, MetaModule): 9 | def forward(self, input, params=None): 10 | self._check_input_dim(input) 11 | if params is None: 12 | params = OrderedDict(self.named_parameters()) 13 | 14 | # exponential_average_factor is self.momentum set to 15 | # (when it is available) only so that if gets updated 16 | # in ONNX graph when this node is exported to ONNX. 17 | if self.momentum is None: 18 | exponential_average_factor = 0.0 19 | else: 20 | exponential_average_factor = self.momentum 21 | 22 | if self.training and self.track_running_stats: 23 | if self.num_batches_tracked is not None: 24 | self.num_batches_tracked += 1 25 | if self.momentum is None: # use cumulative moving average 26 | exponential_average_factor = 1.0 / float(self.num_batches_tracked) 27 | else: # use exponential moving average 28 | exponential_average_factor = self.momentum 29 | 30 | weight = params.get('weight', None) 31 | bias = params.get('bias', None) 32 | 33 | return F.batch_norm( 34 | input, self.running_mean, self.running_var, weight, bias, 35 | self.training or not self.track_running_stats, 36 | exponential_average_factor, self.eps) 37 | 38 | class MetaBatchNorm1d(_MetaBatchNorm): 39 | __doc__ = nn.BatchNorm1d.__doc__ 40 | 41 | def _check_input_dim(self, input): 42 | if input.dim() != 2 and input.dim() != 3: 43 | raise ValueError('expected 2D or 3D input (got {}D input)' 44 | .format(input.dim())) 45 | 46 | class MetaBatchNorm2d(_MetaBatchNorm): 47 | __doc__ = nn.BatchNorm2d.__doc__ 48 | 49 | def _check_input_dim(self, input): 50 | if input.dim() != 4: 51 | raise ValueError('expected 4D input (got {}D input)' 52 | .format(input.dim())) 53 | 54 | class MetaBatchNorm3d(_MetaBatchNorm): 55 | __doc__ = nn.BatchNorm3d.__doc__ 56 | 57 | def _check_input_dim(self, input): 58 | if input.dim() != 5: 59 | raise ValueError('expected 5D input (got {}D input)' 60 | .format(input.dim())) 61 | -------------------------------------------------------------------------------- /torchmeta/modules/container.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from torchmeta.modules.module import MetaModule 4 | from torchmeta.modules.utils import get_subdict 5 | 6 | class MetaSequential(nn.Sequential, MetaModule): 7 | __doc__ = nn.Sequential.__doc__ 8 | 9 | def forward(self, input, params=None): 10 | for name, module in self._modules.items(): 11 | if isinstance(module, MetaModule): 12 | input = module(input, params=get_subdict(params, name)) 13 | elif isinstance(module, nn.Module): 14 | input = module(input) 15 | else: 16 | raise TypeError('The module must be either a torch module ' 17 | '(inheriting from `nn.Module`), or a `MetaModule`. ' 18 | 'Got type: `{0}`'.format(type(module))) 19 | return input 20 | -------------------------------------------------------------------------------- /torchmeta/modules/conv.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from collections import OrderedDict 5 | from torch.nn.modules.utils import _single, _pair, _triple 6 | from torchmeta.modules.module import MetaModule 7 | 8 | class MetaConv1d(nn.Conv1d, MetaModule): 9 | __doc__ = nn.Conv1d.__doc__ 10 | 11 | def forward(self, input, params=None): 12 | if params is None: 13 | params = OrderedDict(self.named_parameters()) 14 | bias = params.get('bias', None) 15 | 16 | if self.padding_mode == 'circular': 17 | expanded_padding = ((self.padding[0] + 1) // 2, self.padding[0] // 2) 18 | return F.conv1d(F.pad(input, expanded_padding, mode='circular'), 19 | params['weight'], bias, self.stride, 20 | _single(0), self.dilation, self.groups) 21 | 22 | return F.conv1d(input, params['weight'], bias, self.stride, 23 | self.padding, self.dilation, self.groups) 24 | 25 | class MetaConv2d(nn.Conv2d, MetaModule): 26 | __doc__ = nn.Conv2d.__doc__ 27 | 28 | def forward(self, input, params=None): 29 | if params is None: 30 | params = OrderedDict(self.named_parameters()) 31 | bias = params.get('bias', None) 32 | 33 | if self.padding_mode == 'circular': 34 | expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2, 35 | (self.padding[0] + 1) // 2, self.padding[0] // 2) 36 | return F.conv2d(F.pad(input, expanded_padding, mode='circular'), 37 | params['weight'], bias, self.stride, 38 | _pair(0), self.dilation, self.groups) 39 | 40 | return F.conv2d(input, params['weight'], bias, self.stride, 41 | self.padding, self.dilation, self.groups) 42 | 43 | class MetaConv3d(nn.Conv3d, MetaModule): 44 | __doc__ = nn.Conv3d.__doc__ 45 | 46 | def forward(self, input, params=None): 47 | if params is None: 48 | params = OrderedDict(self.named_parameters()) 49 | bias = params.get('bias', None) 50 | 51 | if self.padding_mode == 'circular': 52 | expanded_padding = ((self.padding[2] + 1) // 2, self.padding[2] // 2, 53 | (self.padding[1] + 1) // 2, self.padding[1] // 2, 54 | (self.padding[0] + 1) // 2, self.padding[0] // 2) 55 | return F.conv3d(F.pad(input, expanded_padding, mode='circular'), 56 | params['weight'], bias, self.stride, 57 | _triple(0), self.dilation, self.groups) 58 | 59 | return F.conv3d(input, params['weight'], bias, self.stride, 60 | self.padding, self.dilation, self.groups) 61 | -------------------------------------------------------------------------------- /torchmeta/modules/linear.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from collections import OrderedDict 5 | from torchmeta.modules.module import MetaModule 6 | 7 | class MetaLinear(nn.Linear, MetaModule): 8 | __doc__ = nn.Linear.__doc__ 9 | 10 | def forward(self, input, params=None): 11 | if params is None: 12 | params = OrderedDict(self.named_parameters()) 13 | bias = params.get('bias', None) 14 | return F.linear(input, params['weight'], bias) 15 | 16 | class MetaBilinear(nn.Bilinear, MetaModule): 17 | __doc__ = nn.Bilinear.__doc__ 18 | 19 | def forward(self, input1, input2, params=None): 20 | if params is None: 21 | params = OrderedDict(self.named_parameters()) 22 | bias = params.get('bias', None) 23 | return F.bilinear(input1, input2, params['weight'], bias) 24 | -------------------------------------------------------------------------------- /torchmeta/modules/module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from collections import OrderedDict 5 | 6 | class MetaModule(nn.Module): 7 | """ 8 | Base class for PyTorch meta-learning modules. These modules accept an 9 | additional argument `params` in their `forward` method. 10 | 11 | Notes 12 | ----- 13 | Objects inherited from `MetaModule` are fully compatible with PyTorch 14 | modules from `torch.nn.Module`. The argument `params` is a dictionary of 15 | tensors, with full support of the computation graph (for differentiation). 16 | """ 17 | def meta_named_parameters(self, prefix='', recurse=True): 18 | gen = self._named_members( 19 | lambda module: module._parameters.items() 20 | if isinstance(module, MetaModule) else [], 21 | prefix=prefix, recurse=recurse) 22 | for elem in gen: 23 | yield elem 24 | 25 | def meta_parameters(self, recurse=True): 26 | for name, param in self.meta_named_parameters(recurse=recurse): 27 | yield param 28 | -------------------------------------------------------------------------------- /torchmeta/modules/normalization.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from collections import OrderedDict 5 | from torchmeta.modules.module import MetaModule 6 | 7 | class MetaLayerNorm(nn.LayerNorm, MetaModule): 8 | __doc__ = nn.LayerNorm.__doc__ 9 | 10 | def forward(self, input, params=None): 11 | if params is None: 12 | params = OrderedDict(self.named_parameters()) 13 | weight = params.get('weight', None) 14 | bias = params.get('bias', None) 15 | return F.layer_norm( 16 | input, self.normalized_shape, weight, bias, self.eps) 17 | -------------------------------------------------------------------------------- /torchmeta/modules/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from collections import OrderedDict 3 | 4 | def get_subdict(dictionary, key=None): 5 | if dictionary is None: 6 | return None 7 | if (key is None) or (key == ''): 8 | return dictionary 9 | key_re = re.compile(r'^{0}\.(.+)'.format(re.escape(key))) 10 | return OrderedDict((key_re.sub(r'\1', k), value) for (k, value) 11 | in dictionary.items() if key_re.match(k) is not None) 12 | -------------------------------------------------------------------------------- /torchmeta/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsitzmann/siren/4df34baee3f0f9c8f351630992c1fe1f69114b5f/torchmeta/tests/__init__.py -------------------------------------------------------------------------------- /torchmeta/tests/test_dataloaders.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | from torchmeta.toy import Sinusoid 7 | from torchmeta.transforms import ClassSplitter 8 | from torchmeta.utils.data import Task, MetaDataLoader, BatchMetaDataLoader 9 | 10 | 11 | def test_meta_dataloader(): 12 | dataset = Sinusoid(10, num_tasks=1000, noise_std=None) 13 | meta_dataloader = MetaDataLoader(dataset, batch_size=4) 14 | assert isinstance(meta_dataloader, DataLoader) 15 | assert len(meta_dataloader) == 250 # 1000 / 4 16 | 17 | batch = next(iter(meta_dataloader)) 18 | assert isinstance(batch, list) 19 | assert len(batch) == 4 20 | 21 | task = batch[0] 22 | assert isinstance(task, Task) 23 | assert len(task) == 10 24 | 25 | 26 | def test_meta_dataloader_task_loader(): 27 | dataset = Sinusoid(10, num_tasks=1000, noise_std=None) 28 | meta_dataloader = MetaDataLoader(dataset, batch_size=4) 29 | batch = next(iter(meta_dataloader)) 30 | 31 | dataloader = DataLoader(batch[0], batch_size=5) 32 | inputs, targets = next(iter(dataloader)) 33 | 34 | assert len(dataloader) == 2 # 10 / 5 35 | # PyTorch dataloaders convert numpy array to tensors 36 | assert isinstance(inputs, torch.Tensor) 37 | assert isinstance(targets, torch.Tensor) 38 | assert inputs.shape == (5, 1) 39 | assert targets.shape == (5, 1) 40 | 41 | 42 | def test_batch_meta_dataloader(): 43 | dataset = Sinusoid(10, num_tasks=1000, noise_std=None) 44 | meta_dataloader = BatchMetaDataLoader(dataset, batch_size=4) 45 | assert isinstance(meta_dataloader, DataLoader) 46 | assert len(meta_dataloader) == 250 # 1000 / 4 47 | 48 | inputs, targets = next(iter(meta_dataloader)) 49 | assert isinstance(inputs, torch.Tensor) 50 | assert isinstance(targets, torch.Tensor) 51 | assert inputs.shape == (4, 10, 1) 52 | assert targets.shape == (4, 10, 1) 53 | 54 | 55 | def test_batch_meta_dataloader_splitter(): 56 | dataset = Sinusoid(20, num_tasks=1000, noise_std=None) 57 | dataset = ClassSplitter(dataset, num_train_per_class=5, 58 | num_test_per_class=15) 59 | meta_dataloader = BatchMetaDataLoader(dataset, batch_size=4) 60 | 61 | batch = next(iter(meta_dataloader)) 62 | assert isinstance(batch, dict) 63 | assert 'train' in batch 64 | assert 'test' in batch 65 | 66 | train_inputs, train_targets = batch['train'] 67 | test_inputs, test_targets = batch['test'] 68 | assert isinstance(train_inputs, torch.Tensor) 69 | assert isinstance(train_targets, torch.Tensor) 70 | assert train_inputs.shape == (4, 5, 1) 71 | assert train_targets.shape == (4, 5, 1) 72 | assert isinstance(test_inputs, torch.Tensor) 73 | assert isinstance(test_targets, torch.Tensor) 74 | assert test_inputs.shape == (4, 15, 1) 75 | assert test_targets.shape == (4, 15, 1) 76 | -------------------------------------------------------------------------------- /torchmeta/tests/test_prototype.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from torchmeta.utils.prototype import get_num_samples, get_prototypes, prototypical_loss 7 | 8 | 9 | @pytest.mark.parametrize('dtype', [None, torch.float32]) 10 | def test_get_num_samples(dtype): 11 | # Numpy 12 | num_classes = 3 13 | targets_np = np.random.randint(0, num_classes, size=(2, 5)) 14 | 15 | # PyTorch 16 | targets_th = torch.as_tensor(targets_np) 17 | num_samples_th = get_num_samples(targets_th, num_classes, dtype=dtype) 18 | 19 | num_samples_np = np.zeros((2, num_classes), dtype=np.int_) 20 | for i in range(2): 21 | for j in range(5): 22 | num_samples_np[i, targets_np[i, j]] += 1 23 | 24 | assert num_samples_th.shape == (2, num_classes) 25 | if dtype is not None: 26 | assert num_samples_th.dtype == dtype 27 | np.testing.assert_equal(num_samples_th.numpy(), num_samples_np) 28 | 29 | 30 | def test_get_prototypes(): 31 | # Numpy 32 | num_classes = 3 33 | embeddings_np = np.random.rand(2, 5, 7).astype(np.float32) 34 | targets_np = np.random.randint(0, num_classes, size=(2, 5)) 35 | 36 | # PyTorch 37 | embeddings_th = torch.as_tensor(embeddings_np) 38 | targets_th = torch.as_tensor(targets_np) 39 | prototypes_th = get_prototypes(embeddings_th, targets_th, num_classes) 40 | 41 | assert prototypes_th.shape == (2, num_classes, 7) 42 | assert prototypes_th.dtype == embeddings_th.dtype 43 | 44 | prototypes_np = np.zeros((2, num_classes, 7), dtype=np.float32) 45 | num_samples_np = np.zeros((2, num_classes), dtype=np.int_) 46 | for i in range(2): 47 | for j in range(5): 48 | num_samples_np[i, targets_np[i, j]] += 1 49 | for k in range(7): 50 | prototypes_np[i, targets_np[i, j], k] += embeddings_np[i, j, k] 51 | 52 | for i in range(2): 53 | for j in range(num_classes): 54 | for k in range(7): 55 | prototypes_np[i, j, k] /= max(num_samples_np[i, j], 1) 56 | 57 | np.testing.assert_allclose(prototypes_th.detach().numpy(), prototypes_np) 58 | -------------------------------------------------------------------------------- /torchmeta/tests/test_splitters.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import numpy as np 4 | from collections import OrderedDict 5 | 6 | from torchmeta.transforms.splitters import ClassSplitter 7 | from torchmeta.toy import Sinusoid 8 | from torchmeta.utils.data import Task 9 | 10 | def test_seed_class_splitter(): 11 | dataset_transform = ClassSplitter(shuffle=True, 12 | num_train_per_class=5, num_test_per_class=5) 13 | dataset = Sinusoid(10, num_tasks=1000, noise_std=0.1, 14 | dataset_transform=dataset_transform) 15 | dataset.seed(1) 16 | 17 | expected_train_inputs = np.array([-2.03870077, 0.09898378, 3.75388738, 1.08565437, -1.56211897]) 18 | expected_train_targets = np.array([-0.1031986 , -1.61885041, 0.91773121, -0.00309463, -1.37650356]) 19 | 20 | expected_test_inputs = np.array([ 4.62078213, -2.48340416, 0.32922559, 0.76977846, -3.15504396]) 21 | expected_test_targets = np.array([-0.9346262 , 0.73113509, -1.52508997, -0.4698061 , 1.86656819]) 22 | 23 | task = dataset[0] 24 | train_dataset, test_dataset = task['train'], task['test'] 25 | 26 | assert len(train_dataset) == 5 27 | assert len(test_dataset) == 5 28 | 29 | for i, (train_input, train_target) in enumerate(train_dataset): 30 | assert np.isclose(train_input, expected_train_inputs[i]) 31 | assert np.isclose(train_target, expected_train_targets[i]) 32 | 33 | for i, (test_input, test_target) in enumerate(test_dataset): 34 | assert np.isclose(test_input, expected_test_inputs[i]) 35 | assert np.isclose(test_target, expected_test_targets[i]) 36 | 37 | def test_class_splitter_for_fold_overlaps(): 38 | class DemoTask(Task): 39 | def __init__(self): 40 | super(DemoTask, self).__init__(index=0, num_classes=None) 41 | self._inputs = np.arange(10) 42 | 43 | def __len__(self): 44 | return len(self._inputs) 45 | 46 | def __getitem__(self, index): 47 | return self._inputs[index] 48 | 49 | splitter = ClassSplitter(shuffle=True, num_train_per_class=5, num_test_per_class=5) 50 | task = DemoTask() 51 | 52 | all_train_samples = list() 53 | all_test_samples = list() 54 | 55 | # split task ten times into train and test 56 | for i in range(10): 57 | tasks_split = splitter(task) 58 | train_task = tasks_split["train"] 59 | test_task = tasks_split["test"] 60 | 61 | train_samples = set([train_task[i] for i in range(len(train_task))]) 62 | test_samples = set([test_task[i] for i in range(len(train_task))]) 63 | 64 | # no overlap between train and test splits at single split 65 | assert len(train_samples.intersection(test_samples)) == 0 66 | 67 | all_train_samples.append(train_samples) 68 | all_train_samples.append(train_samples) 69 | 70 | # gather unique samples from multiple splits 71 | samples_in_all_train_splits = set().union(*all_train_samples) 72 | samples_in_all_test_splits = set().union(*all_test_samples) 73 | 74 | # no overlap between train and test splits at multiple splits 75 | assert len(samples_in_all_test_splits.intersection(samples_in_all_train_splits)) == 0 -------------------------------------------------------------------------------- /torchmeta/tests/test_toy.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import numpy as np 4 | from collections import OrderedDict 5 | 6 | from torchmeta.utils.data import Task, MetaDataset 7 | from torchmeta.toy import Sinusoid, Harmonic, SinusoidAndLine 8 | from torchmeta.toy import helpers 9 | 10 | 11 | @pytest.mark.parametrize('dataset_class', 12 | [Sinusoid, Harmonic, SinusoidAndLine]) 13 | def test_toy_meta_dataset(dataset_class): 14 | dataset = dataset_class(10, num_tasks=1000, noise_std=None) 15 | 16 | assert isinstance(dataset, MetaDataset) 17 | assert len(dataset) == 1000 18 | 19 | 20 | @pytest.mark.parametrize('dataset_class', 21 | [Sinusoid, Harmonic, SinusoidAndLine]) 22 | def test_toy_task(dataset_class): 23 | dataset = dataset_class(10, num_tasks=1000, noise_std=None) 24 | task = dataset[0] 25 | 26 | assert isinstance(task, Task) 27 | assert len(task) == 10 28 | 29 | 30 | @pytest.mark.parametrize('dataset_class', 31 | [Sinusoid, Harmonic, SinusoidAndLine]) 32 | def test_toy_sample(dataset_class): 33 | dataset = dataset_class(10, num_tasks=1000, noise_std=None) 34 | task = dataset[0] 35 | input, target = task[0] 36 | 37 | assert isinstance(input, np.ndarray) 38 | assert isinstance(target, np.ndarray) 39 | assert input.shape == (1,) 40 | assert target.shape == (1,) 41 | 42 | 43 | @pytest.mark.parametrize('name,dataset_class', 44 | [('sinusoid', Sinusoid), ('harmonic', Harmonic)]) 45 | def test_toy_helpers(name, dataset_class): 46 | dataset_fn = getattr(helpers, name) 47 | dataset = dataset_fn(shots=5, test_shots=15) 48 | assert isinstance(dataset, dataset_class) 49 | 50 | task = dataset[0] 51 | assert isinstance(task, OrderedDict) 52 | assert 'train' in task 53 | assert 'test' in task 54 | 55 | train, test = task['train'], task['test'] 56 | assert isinstance(train, Task) 57 | assert isinstance(test, Task) 58 | assert len(train) == 5 59 | assert len(test) == 15 60 | -------------------------------------------------------------------------------- /torchmeta/toy/__init__.py: -------------------------------------------------------------------------------- 1 | from torchmeta.toy.harmonic import Harmonic 2 | from torchmeta.toy.sinusoid import Sinusoid 3 | from torchmeta.toy.sinusoid_line import SinusoidAndLine 4 | 5 | from torchmeta.toy import helpers 6 | 7 | __all__ = ['Harmonic', 'Sinusoid', 'SinusoidAndLine', 'helpers'] 8 | -------------------------------------------------------------------------------- /torchmeta/toy/harmonic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from torchmeta.utils.data import Task, MetaDataset 4 | 5 | 6 | class Harmonic(MetaDataset): 7 | """ 8 | Simple regression task, based on the sum of two sine waves, as introduced 9 | in [1]. 10 | 11 | Parameters 12 | ---------- 13 | num_samples_per_task : int 14 | Number of examples per task. 15 | 16 | num_tasks : int (default: 5,000) 17 | Overall number of tasks to sample. 18 | 19 | noise_std : float, optional 20 | Amount of noise to include in the targets for each task. If `None`, then 21 | nos noise is included, and the target is the sum of 2 sine functions of 22 | the input. 23 | 24 | transform : callable, optional 25 | A function/transform that takes a numpy array of size (1,) and returns a 26 | transformed version of the input. 27 | 28 | target_transform : callable, optional 29 | A function/transform that takes a numpy array of size (1,) and returns a 30 | transformed version of the target. 31 | 32 | dataset_transform : callable, optional 33 | A function/transform that takes a dataset (ie. a task), and returns a 34 | transformed version of it. E.g. `torchmeta.transforms.ClassSplitter()`. 35 | 36 | Notes 37 | ----- 38 | The tasks are created randomly as the sum of two sinusoid functions, with 39 | a frequency ratio of 2. The amplitudes vary within [5.0, 7.0], the phases 40 | within [0, 2 * pi], and the inputs are sampled according to N(mu_x, 1), with 41 | mu_x varying in [-4.0, 4.0]. Due to the way PyTorch handles datasets, the 42 | number of tasks to be sampled needs to be fixed ahead of time (with 43 | `num_tasks`). This will typically be equal to `meta_batch_size * num_batches`. 44 | 45 | References 46 | ---------- 47 | .. [1] Lacoste A., Oreshkin B., Chung W., Boquet T., Rostamzadeh N., 48 | Krueger D. (2018). Uncertainty in Multitask Transfer Learning. In 49 | Advances in Neural Information Processing Systems (https://arxiv.org/abs/1806.07528) 50 | """ 51 | def __init__(self, num_samples_per_task, num_tasks=5000, 52 | noise_std=None, transform=None, target_transform=None, 53 | dataset_transform=None): 54 | super(Harmonic, self).__init__(meta_split='train', 55 | target_transform=target_transform, dataset_transform=dataset_transform) 56 | self.num_samples_per_task = num_samples_per_task 57 | self.num_tasks = num_tasks 58 | self.noise_std = noise_std 59 | self.transform = transform 60 | 61 | self._domain_range = np.array([-4.0, 4.0]) 62 | self._frequency_range = np.array([5.0, 7.0]) 63 | self._phase_range = np.array([0, 2 * np.pi]) 64 | 65 | self._domains = None 66 | self._frequencies = None 67 | self._phases = None 68 | self._amplitudes = None 69 | 70 | @property 71 | def domains(self): 72 | if self._domains is None: 73 | self._domains = self.np_random.uniform(self._domain_range[0], 74 | self._domain_range[1], size=self.num_tasks) 75 | return self._domains 76 | 77 | @property 78 | def frequencies(self): 79 | if self._frequencies is None: 80 | self._frequencies = self.np_random.uniform(self._frequency_range[0], 81 | self._frequency_range[1], size=self.num_tasks) 82 | return self._frequencies 83 | 84 | @property 85 | def phases(self): 86 | if self._phases is None: 87 | self._phases = self.np_random.uniform(self._phase_range[0], 88 | self._phase_range[1], size=(self.num_tasks, 2)) 89 | return self._phases 90 | 91 | @property 92 | def amplitudes(self): 93 | if self._amplitudes is None: 94 | self._amplitudes = self.np_random.randn(self.num_tasks, 2) 95 | return self._amplitudes 96 | 97 | def __len__(self): 98 | return self.num_tasks 99 | 100 | def __getitem__(self, index): 101 | domain = self.domains[index] 102 | frequency = self.frequencies[index] 103 | phases = self.phases[index] 104 | amplitudes = self.amplitudes[index] 105 | 106 | task = HarmonicTask(index, domain, frequency, phases, amplitudes, 107 | self.noise_std, self.num_samples_per_task, self.transform, 108 | self.target_transform, np_random=self.np_random) 109 | 110 | if self.dataset_transform is not None: 111 | task = self.dataset_transform(task) 112 | 113 | return task 114 | 115 | 116 | class HarmonicTask(Task): 117 | def __init__(self, index, domain, frequency, phases, amplitudes, 118 | noise_std, num_samples, transform=None, 119 | target_transform=None, np_random=None): 120 | super(HarmonicTask, self).__init__(index, None) # Regression task 121 | self.domain = domain 122 | self.frequency = frequency 123 | self.phases = phases 124 | self.amplitudes = amplitudes 125 | self.noise_std = noise_std 126 | self.num_samples = num_samples 127 | 128 | self.transform = transform 129 | self.target_transform = target_transform 130 | 131 | if np_random is None: 132 | np_random = np.random.RandomState(None) 133 | 134 | a_1, a_2 = self.amplitudes 135 | b_1, b_2 = self.phases 136 | 137 | self._inputs = self.domain + np_random.randn(num_samples, 1) 138 | self._targets = (a_1 * np.sin(frequency * self._inputs + b_1) 139 | + a_2 * np.sin(2 * frequency * self._inputs + b_2)) 140 | if (noise_std is not None) and (noise_std > 0.): 141 | self._targets += noise_std * np_random.randn(num_samples, 1) 142 | 143 | def __len__(self): 144 | return self.num_samples 145 | 146 | def __getitem__(self, index): 147 | input, target = self._inputs[index], self._targets[index] 148 | 149 | if self.transform is not None: 150 | input = self.transform(input) 151 | 152 | if self.target_transform is not None: 153 | target = self.target_transform(target) 154 | 155 | return (input, target) 156 | -------------------------------------------------------------------------------- /torchmeta/toy/helpers.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from torchmeta.toy import Sinusoid, Harmonic 4 | from torchmeta.transforms import ClassSplitter 5 | 6 | def sinusoid(shots, shuffle=True, test_shots=None, seed=None, **kwargs): 7 | """Helper function to create a meta-dataset for the Sinusoid toy dataset. 8 | 9 | Parameters 10 | ---------- 11 | shots : int 12 | Number of (training) examples in each task. This corresponds to `k` in 13 | `k-shot` classification. 14 | 15 | shuffle : bool (default: `True`) 16 | Shuffle the examples when creating the tasks. 17 | 18 | test_shots : int, optional 19 | Number of test examples in each task. If `None`, then the number of test 20 | examples is equal to the number of training examples in each task. 21 | 22 | seed : int, optional 23 | Random seed to be used in the meta-dataset. 24 | 25 | kwargs 26 | Additional arguments passed to the `Sinusoid` class. 27 | 28 | See also 29 | -------- 30 | `torchmeta.toy.Sinusoid` : Meta-dataset for the Sinusoid toy dataset. 31 | """ 32 | if 'num_samples_per_task' in kwargs: 33 | warnings.warn('Both arguments `shots` and `num_samples_per_task` were ' 34 | 'set in the helper function for the number of samples in each task. ' 35 | 'Ignoring the argument `shots`.', stacklevel=2) 36 | if test_shots is not None: 37 | shots = kwargs['num_samples_per_task'] - test_shots 38 | if shots <= 0: 39 | raise ValueError('The argument `test_shots` ({0}) is greater ' 40 | 'than the number of samples per task ({1}). Either use the ' 41 | 'argument `shots` instead of `num_samples_per_task`, or ' 42 | 'increase the value of `num_samples_per_task`.'.format( 43 | test_shots, kwargs['num_samples_per_task'])) 44 | else: 45 | shots = kwargs['num_samples_per_task'] // 2 46 | if test_shots is None: 47 | test_shots = shots 48 | 49 | dataset = Sinusoid(num_samples_per_task=shots + test_shots, **kwargs) 50 | dataset = ClassSplitter(dataset, shuffle=shuffle, 51 | num_train_per_class=shots, num_test_per_class=test_shots) 52 | dataset.seed(seed) 53 | 54 | return dataset 55 | 56 | def harmonic(shots, shuffle=True, test_shots=None, seed=None, **kwargs): 57 | """Helper function to create a meta-dataset for the Harmonic toy dataset. 58 | 59 | Parameters 60 | ---------- 61 | shots : int 62 | Number of (training) examples in each task. This corresponds to `k` in 63 | `k-shot` classification. 64 | 65 | shuffle : bool (default: `True`) 66 | Shuffle the examples when creating the tasks. 67 | 68 | test_shots : int, optional 69 | Number of test examples in each task. If `None`, then the number of test 70 | examples is equal to the number of training examples in each task. 71 | 72 | seed : int, optional 73 | Random seed to be used in the meta-dataset. 74 | 75 | kwargs 76 | Additional arguments passed to the `Harmonic` class. 77 | 78 | See also 79 | -------- 80 | `torchmeta.toy.Harmonic` : Meta-dataset for the Harmonic toy dataset. 81 | """ 82 | if 'num_samples_per_task' in kwargs: 83 | warnings.warn('Both arguments `shots` and `num_samples_per_task` were ' 84 | 'set in the helper function for the number of samples in each task. ' 85 | 'Ignoring the argument `shots`.', stacklevel=2) 86 | if test_shots is not None: 87 | shots = kwargs['num_samples_per_task'] - test_shots 88 | if shots <= 0: 89 | raise ValueError('The argument `test_shots` ({0}) is greater ' 90 | 'than the number of samples per task ({1}). Either use the ' 91 | 'argument `shots` instead of `num_samples_per_task`, or ' 92 | 'increase the value of `num_samples_per_task`.'.format( 93 | test_shots, kwargs['num_samples_per_task'])) 94 | else: 95 | shots = kwargs['num_samples_per_task'] // 2 96 | if test_shots is None: 97 | test_shots = shots 98 | 99 | dataset = Harmonic(num_samples_per_task=shots + test_shots, **kwargs) 100 | dataset = ClassSplitter(dataset, shuffle=shuffle, 101 | num_train_per_class=shots, num_test_per_class=test_shots) 102 | dataset.seed(seed) 103 | 104 | return dataset 105 | -------------------------------------------------------------------------------- /torchmeta/toy/sinusoid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from torchmeta.utils.data import Task, MetaDataset 4 | 5 | 6 | class Sinusoid(MetaDataset): 7 | """ 8 | Simple regression task, based on sinusoids, as introduced in [1]. 9 | 10 | Parameters 11 | ---------- 12 | num_samples_per_task : int 13 | Number of examples per task. 14 | 15 | num_tasks : int (default: 1,000,000) 16 | Overall number of tasks to sample. 17 | 18 | noise_std : float, optional 19 | Amount of noise to include in the targets for each task. If `None`, then 20 | nos noise is included, and the target is a sine function of the input. 21 | 22 | transform : callable, optional 23 | A function/transform that takes a numpy array of size (1,) and returns a 24 | transformed version of the input. 25 | 26 | target_transform : callable, optional 27 | A function/transform that takes a numpy array of size (1,) and returns a 28 | transformed version of the target. 29 | 30 | dataset_transform : callable, optional 31 | A function/transform that takes a dataset (ie. a task), and returns a 32 | transformed version of it. E.g. `torchmeta.transforms.ClassSplitter()`. 33 | 34 | Notes 35 | ----- 36 | The tasks are created randomly as random sinusoid function. The amplitude 37 | varies within [0.1, 5.0], the phase within [0, pi], and the inputs are 38 | sampled uniformly in [-5.0, 5.0]. Due to the way PyTorch handles datasets, 39 | the number of tasks to be sampled needs to be fixed ahead of time (with 40 | `num_tasks`). This will typically be equal to `meta_batch_size * num_batches`. 41 | 42 | References 43 | ---------- 44 | .. [1] Finn C., Abbeel P., and Levine, S. (2017). Model-Agnostic Meta-Learning 45 | for Fast Adaptation of Deep Networks. International Conference on 46 | Machine Learning (ICML) (https://arxiv.org/abs/1703.03400) 47 | """ 48 | def __init__(self, num_samples_per_task, num_tasks=1000000, 49 | noise_std=None, transform=None, target_transform=None, 50 | dataset_transform=None): 51 | super(Sinusoid, self).__init__(meta_split='train', 52 | target_transform=target_transform, dataset_transform=dataset_transform) 53 | self.num_samples_per_task = num_samples_per_task 54 | self.num_tasks = num_tasks 55 | self.noise_std = noise_std 56 | self.transform = transform 57 | 58 | self._input_range = np.array([-5.0, 5.0]) 59 | self._amplitude_range = np.array([0.1, 5.0]) 60 | self._phase_range = np.array([0, np.pi]) 61 | 62 | self._amplitudes = None 63 | self._phases = None 64 | 65 | @property 66 | def amplitudes(self): 67 | if self._amplitudes is None: 68 | self._amplitudes = self.np_random.uniform(self._amplitude_range[0], 69 | self._amplitude_range[1], size=self.num_tasks) 70 | return self._amplitudes 71 | 72 | @property 73 | def phases(self): 74 | if self._phases is None: 75 | self._phases = self.np_random.uniform(self._phase_range[0], 76 | self._phase_range[1], size=self.num_tasks) 77 | return self._phases 78 | 79 | def __len__(self): 80 | return self.num_tasks 81 | 82 | def __getitem__(self, index): 83 | amplitude, phase = self.amplitudes[index], self.phases[index] 84 | task = SinusoidTask(index, amplitude, phase, self._input_range, 85 | self.noise_std, self.num_samples_per_task, self.transform, 86 | self.target_transform, np_random=self.np_random) 87 | 88 | if self.dataset_transform is not None: 89 | task = self.dataset_transform(task) 90 | 91 | return task 92 | 93 | 94 | class SinusoidTask(Task): 95 | def __init__(self, index, amplitude, phase, input_range, noise_std, 96 | num_samples, transform=None, target_transform=None, 97 | np_random=None): 98 | super(SinusoidTask, self).__init__(index, None) # Regression task 99 | self.amplitude = amplitude 100 | self.phase = phase 101 | self.input_range = input_range 102 | self.num_samples = num_samples 103 | self.noise_std = noise_std 104 | 105 | self.transform = transform 106 | self.target_transform = target_transform 107 | 108 | if np_random is None: 109 | np_random = np.random.RandomState(None) 110 | 111 | self._inputs = np_random.uniform(input_range[0], input_range[1], 112 | size=(num_samples, 1)) 113 | self._targets = amplitude * np.sin(self._inputs - phase) 114 | if (noise_std is not None) and (noise_std > 0.): 115 | self._targets += noise_std * np_random.randn(num_samples, 1) 116 | 117 | def __len__(self): 118 | return self.num_samples 119 | 120 | def __getitem__(self, index): 121 | input, target = self._inputs[index], self._targets[index] 122 | 123 | if self.transform is not None: 124 | input = self.transform(input) 125 | 126 | if self.target_transform is not None: 127 | target = self.target_transform(target) 128 | 129 | return (input, target) 130 | -------------------------------------------------------------------------------- /torchmeta/toy/sinusoid_line.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from torchmeta.utils.data import Task, MetaDataset 4 | from torchmeta.toy.sinusoid import SinusoidTask 5 | 6 | 7 | class SinusoidAndLine(MetaDataset): 8 | """ 9 | Simple multimodal regression task, based on sinusoids and lines, as 10 | introduced in [1]. 11 | 12 | Parameters 13 | ---------- 14 | num_samples_per_task : int 15 | Number of examples per task. 16 | 17 | num_tasks : int (default: 1,000,000) 18 | Overall number of tasks to sample. 19 | 20 | noise_std : float, optional 21 | Amount of noise to include in the targets for each task. If `None`, then 22 | nos noise is included, and the target is either a sine function, or a 23 | linear function of the input. 24 | 25 | transform : callable, optional 26 | A function/transform that takes a numpy array of size (1,) and returns a 27 | transformed version of the input. 28 | 29 | target_transform : callable, optional 30 | A function/transform that takes a numpy array of size (1,) and returns a 31 | transformed version of the target. 32 | 33 | dataset_transform : callable, optional 34 | A function/transform that takes a dataset (ie. a task), and returns a 35 | transformed version of it. E.g. `torchmeta.transforms.ClassSplitter()`. 36 | 37 | Notes 38 | ----- 39 | The tasks are created randomly as either random sinusoid functions, or 40 | random linear functions. The amplitude of the sinusoids varies within 41 | [0.1, 5.0] and the phase within [0, pi]. The slope and intercept of the lines 42 | vary in [-3.0, 3.0]. The inputs are sampled uniformly in [-5.0, 5.0]. Due to 43 | the way PyTorch handles datasets, the number of tasks to be sampled needs to 44 | be fixed ahead of time (with `num_tasks`). This will typically be equal to 45 | `meta_batch_size * num_batches`. 46 | 47 | References 48 | ---------- 49 | .. [1] Finn C., Xu K., Levine S. (2018). Probabilistic Model-Agnostic 50 | Meta-Learning. In Advances in Neural Information Processing Systems 51 | (https://arxiv.org/abs/1806.02817) 52 | """ 53 | def __init__(self, num_samples_per_task, num_tasks=1000000, 54 | noise_std=None, transform=None, target_transform=None, 55 | dataset_transform=None): 56 | super(SinusoidAndLine, self).__init__(meta_split='train', 57 | target_transform=target_transform, dataset_transform=dataset_transform) 58 | self.num_samples_per_task = num_samples_per_task 59 | self.num_tasks = num_tasks 60 | self.noise_std = noise_std 61 | self.transform = transform 62 | 63 | self._input_range = np.array([-5.0, 5.0]) 64 | self._amplitude_range = np.array([0.1, 5.0]) 65 | self._phase_range = np.array([0, np.pi]) 66 | self._slope_range = np.array([-3.0, 3.0]) 67 | self._intercept_range = np.array([-3.0, 3.0]) 68 | 69 | 70 | self._is_sinusoid = None 71 | self._amplitudes = None 72 | self._phases = None 73 | self._slopes = None 74 | self._intercepts = None 75 | 76 | @property 77 | def amplitudes(self): 78 | if self._amplitudes is None: 79 | self._amplitudes = self.np_random.uniform(self._amplitude_range[0], 80 | self._amplitude_range[1], size=self.num_tasks) 81 | return self._amplitudes 82 | 83 | @property 84 | def phases(self): 85 | if self._phases is None: 86 | self._phases = self.np_random.uniform(self._phase_range[0], 87 | self._phase_range[1], size=self.num_tasks) 88 | return self._phases 89 | 90 | @property 91 | def slopes(self): 92 | if self._slopes is None: 93 | self._slopes = self.np_random.uniform(self._slope_range[0], 94 | self._slope_range[1], size=self.num_tasks) 95 | return self._slopes 96 | 97 | @property 98 | def intercepts(self): 99 | if self._intercepts is None: 100 | self._intercepts = self.np_random.uniform(self._intercept_range[0], 101 | self._intercept_range[1], size=self.num_tasks) 102 | return self._intercepts 103 | 104 | @property 105 | def is_sinusoid(self): 106 | if self._is_sinusoid is None: 107 | self._is_sinusoid = np.zeros((self.num_tasks,), dtype=np.bool_) 108 | self._is_sinusoid[self.num_tasks // 2:] = True 109 | self.np_random.shuffle(self._is_sinusoid) 110 | return self._is_sinusoid 111 | 112 | def __len__(self): 113 | return self.num_tasks 114 | 115 | def __getitem__(self, index): 116 | if self.is_sinusoid[index]: 117 | amplitude, phase = self.amplitudes[index], self.phases[index] 118 | task = SinusoidTask(index, amplitude, phase, self._input_range, 119 | self.noise_std, self.num_samples_per_task, self.transform, 120 | self.target_transform, np_random=self.np_random) 121 | else: 122 | slope, intercept = self.slopes[index], self.intercepts[index] 123 | task = LinearTask(index, slope, intercept, self._input_range, 124 | self.noise_std, self.num_samples_per_task, self.transform, 125 | self.target_transform, np_random=self.np_random) 126 | 127 | if self.dataset_transform is not None: 128 | task = self.dataset_transform(task) 129 | 130 | return task 131 | 132 | 133 | class LinearTask(Task): 134 | def __init__(self, index, slope, intercept, input_range, noise_std, 135 | num_samples, transform=None, target_transform=None, 136 | np_random=None): 137 | super(LinearTask, self).__init__(index, None) # Regression task 138 | self.slope = slope 139 | self.intercept = intercept 140 | self.input_range = input_range 141 | self.num_samples = num_samples 142 | self.noise_std = noise_std 143 | 144 | self.transform = transform 145 | self.target_transform = target_transform 146 | 147 | if np_random is None: 148 | np_random = np.random.RandomState(None) 149 | 150 | self._inputs = np_random.uniform(input_range[0], input_range[1], 151 | size=(num_samples, 1)) 152 | self._targets = intercept + slope * self._inputs 153 | if (noise_std is not None) and (noise_std > 0.): 154 | self._targets += noise_std * np_random.randn(num_samples, 1) 155 | 156 | def __len__(self): 157 | return self.num_samples 158 | 159 | def __getitem__(self, index): 160 | input, target = self._inputs[index], self._targets[index] 161 | 162 | if self.transform is not None: 163 | input = self.transform(input) 164 | 165 | if self.target_transform is not None: 166 | target = self.target_transform(target) 167 | 168 | return (input, target) 169 | -------------------------------------------------------------------------------- /torchmeta/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from torchmeta.transforms.categorical import Categorical, FixedCategory 2 | from torchmeta.transforms.augmentations import Rotation, HorizontalFlip, VerticalFlip 3 | from torchmeta.transforms.splitters import Splitter, ClassSplitter, WeightedClassSplitter 4 | from torchmeta.transforms.target_transforms import TargetTransform, DefaultTargetTransform 5 | -------------------------------------------------------------------------------- /torchmeta/transforms/augmentations.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms.functional as F 2 | 3 | class Rotation(object): 4 | def __init__(self, angle, resample=False, expand=False, center=None): 5 | super(Rotation, self).__init__() 6 | if isinstance(angle, (list, tuple)): 7 | self._angles = angle 8 | self.angle = None 9 | else: 10 | self._angles = [angle] 11 | self.angle = angle 12 | if angle % 360 == 0: 13 | import warnings 14 | warnings.warn('Applying a rotation of {0} degrees (`{1}`) as a ' 15 | 'class augmentation on a dataset is equivalent to the original ' 16 | 'dataset.'.format(angle, self), UserWarning, stacklevel=2) 17 | 18 | self.resample = resample 19 | self.expand = expand 20 | self.center = center 21 | 22 | def __iter__(self): 23 | return iter(Rotation(angle, resample=self.resample, expand=self.expand, 24 | center=self.center) for angle in self._angles) 25 | 26 | def __call__(self, image): 27 | if self.angle is None: 28 | raise ValueError('The value of the angle is unspecified.') 29 | # QKFIX: Explicitly compute the pixel fill value due to an 30 | # incompatibility between Torchvision 0.5 and Pillow 7.0.0 31 | # https://github.com/pytorch/vision/issues/1759#issuecomment-583826810 32 | # Will be fixed in Torchvision 0.6 33 | fill = tuple([0] * len(image.getbands())) 34 | return F.rotate(image, self.angle % 360, self.resample, 35 | self.expand, self.center, fill=fill) 36 | 37 | def __hash__(self): 38 | return hash(repr(self)) 39 | 40 | def __eq__(self, other): 41 | if (self.angle is None) or (other.angle is None): 42 | return self._angles == other._angles 43 | return (self.angle % 360) == (other.angle % 360) 44 | 45 | def __repr__(self): 46 | if self.angle is None: 47 | return 'Rotation({0})'.format(', '.join(map(str, self._angles))) 48 | else: 49 | return 'Rotation({0})'.format(self.angle % 360) 50 | 51 | def __str__(self): 52 | if self.angle is None: 53 | return 'Rotation({0})'.format(', '.join(map(str, self._angles))) 54 | else: 55 | return 'Rotation({0})'.format(self.angle) 56 | 57 | class HorizontalFlip(object): 58 | def __iter__(self): 59 | return iter([HorizontalFlip()]) 60 | 61 | def __call__(self, image): 62 | return F.hflip(image) 63 | 64 | def __repr__(self): 65 | return 'HorizontalFlip()' 66 | 67 | class VerticalFlip(object): 68 | def __iter__(self): 69 | return iter([VerticalFlip()]) 70 | 71 | def __call__(self, image): 72 | return F.vflip(image) 73 | 74 | def __repr__(self): 75 | return 'VerticalFlip()' 76 | -------------------------------------------------------------------------------- /torchmeta/transforms/categorical.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmeta.transforms.utils import apply_wrapper 3 | from collections import defaultdict 4 | 5 | from torchmeta.transforms.target_transforms import TargetTransform 6 | 7 | 8 | class Categorical(TargetTransform): 9 | """Target transform to return labels in `[0, num_classes)`. 10 | 11 | Parameters 12 | ---------- 13 | num_classes : int, optional 14 | Number of classes. If `None`, then the number of classes is inferred 15 | from the number of individual labels encountered. 16 | 17 | Examples 18 | -------- 19 | >>> dataset = Omniglot('data', num_classes_per_task=5, meta_train=True) 20 | >>> task = dataset.sample_task() 21 | >>> task[0] 22 | (, 23 | ('images_evaluation/Glagolitic/character12', None)) 24 | 25 | >>> dataset = Omniglot('data', num_classes_per_task=5, meta_train=True, 26 | ... target_transform=Categorical(5)) 27 | >>> task = dataset.sample_task() 28 | >>> task[0] 29 | (, 2) 30 | """ 31 | def __init__(self, num_classes=None): 32 | super(Categorical, self).__init__() 33 | self.num_classes = num_classes 34 | self._classes = None 35 | self._labels = None 36 | 37 | def reset(self): 38 | self._classes = None 39 | self._labels = None 40 | 41 | @property 42 | def classes(self): 43 | if self._classes is None: 44 | self._classes = defaultdict(None) 45 | if self.num_classes is None: 46 | default_factory = lambda: len(self._classes) 47 | else: 48 | default_factory = lambda: self.labels[len(self._classes)] 49 | self._classes.default_factory = default_factory 50 | if (self.num_classes is not None) and (len(self._classes) > self.num_classes): 51 | raise ValueError('The number of individual labels ({0}) is greater ' 52 | 'than the number of classes defined by `num_classes` ' 53 | '({1}).'.format(len(self._classes), self.num_classes)) 54 | return self._classes 55 | 56 | @property 57 | def labels(self): 58 | if (self._labels is None) and (self.num_classes is not None): 59 | # TODO: Replace torch.randperm with seed-friendly counterpart 60 | self._labels = torch.randperm(self.num_classes).tolist() 61 | return self._labels 62 | 63 | def __call__(self, target): 64 | return self.classes[target] 65 | 66 | def __repr__(self): 67 | return '{0}({1})'.format(self.__class__.__name__, self.num_classes or '') 68 | 69 | 70 | class FixedCategory(object): 71 | def __init__(self, transform=None): 72 | self.transform = transform 73 | 74 | def __call__(self, index): 75 | return (index, self.transform) 76 | 77 | def __repr__(self): 78 | return ('{0}({1})'.format(self.__class__.__name__, self.transform)) 79 | -------------------------------------------------------------------------------- /torchmeta/transforms/target_transforms.py: -------------------------------------------------------------------------------- 1 | class TargetTransform(object): 2 | def __call__(self, target): 3 | raise NotImplementedError() 4 | 5 | def __repr__(self): 6 | return str(self.__class__.__name__) 7 | 8 | 9 | class DefaultTargetTransform(TargetTransform): 10 | def __init__(self, class_augmentations): 11 | super(DefaultTargetTransform, self).__init__() 12 | self.class_augmentations = class_augmentations 13 | 14 | self._augmentations = dict((augmentation, i + 1) 15 | for (i, augmentation) in enumerate(class_augmentations)) 16 | self._augmentations[None] = 0 17 | 18 | def __call__(self, target): 19 | assert isinstance(target, tuple) and len(target) == 2 20 | label, augmentation = target 21 | return (label, self._augmentations[augmentation]) 22 | -------------------------------------------------------------------------------- /torchmeta/transforms/utils.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import Compose 2 | from torchmeta.utils.data.task import Task 3 | 4 | def apply_wrapper(wrapper, task_or_dataset=None): 5 | if task_or_dataset is None: 6 | return wrapper 7 | 8 | from torchmeta.utils.data import MetaDataset 9 | if isinstance(task_or_dataset, Task): 10 | return wrapper(task_or_dataset) 11 | elif isinstance(task_or_dataset, MetaDataset): 12 | if task_or_dataset.dataset_transform is None: 13 | dataset_transform = wrapper 14 | else: 15 | dataset_transform = Compose([ 16 | task_or_dataset.dataset_transform, wrapper]) 17 | task_or_dataset.dataset_transform = dataset_transform 18 | return task_or_dataset 19 | else: 20 | raise NotImplementedError() 21 | 22 | def wrap_transform(transform, fn, transform_type=None): 23 | if (transform_type is None) or isinstance(transform, transform_type): 24 | return fn(transform) 25 | elif isinstance(transform, Compose): 26 | return Compose([wrap_transform(subtransform, fn, transform_type) 27 | for subtransform in transform.transforms]) 28 | else: 29 | return transform 30 | -------------------------------------------------------------------------------- /torchmeta/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from torchmeta.utils import data 2 | from torchmeta.utils.metrics import hardness_metric 3 | from torchmeta.utils.prototype import get_num_samples, get_prototypes, prototypical_loss -------------------------------------------------------------------------------- /torchmeta/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from torchmeta.utils.data.dataloader import MetaDataLoader, BatchMetaDataLoader 2 | from torchmeta.utils.data.dataset import ClassDataset, MetaDataset, CombinationMetaDataset 3 | from torchmeta.utils.data.sampler import CombinationSequentialSampler, CombinationRandomSampler 4 | from torchmeta.utils.data.task import Dataset, Task, ConcatTask, SubsetTask 5 | 6 | __all__ = [ 7 | 'MetaDataLoader', 8 | 'BatchMetaDataLoader', 9 | 'ClassDataset', 10 | 'MetaDataset', 11 | 'CombinationMetaDataset', 12 | 'CombinationSequentialSampler', 13 | 'CombinationRandomSampler', 14 | 'Dataset', 15 | 'Task', 16 | 'ConcatTask', 17 | 'SubsetTask' 18 | ] 19 | -------------------------------------------------------------------------------- /torchmeta/utils/data/dataloader.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | from torch.utils.data import DataLoader 4 | from torch.utils.data.dataloader import default_collate 5 | from torch.utils.data.dataset import Dataset as TorchDataset 6 | 7 | from torchmeta.utils.data.dataset import CombinationMetaDataset 8 | from torchmeta.utils.data.sampler import (CombinationSequentialSampler, 9 | CombinationRandomSampler) 10 | 11 | def batch_meta_collate(collate_fn): 12 | def collate_task(task): 13 | if isinstance(task, TorchDataset): 14 | return collate_fn([task[idx] for idx in range(len(task))]) 15 | elif isinstance(task, OrderedDict): 16 | return OrderedDict([(key, collate_task(subtask)) 17 | for (key, subtask) in task.items()]) 18 | else: 19 | raise NotImplementedError() 20 | 21 | def _collate_fn(batch): 22 | return collate_fn([collate_task(task) for task in batch]) 23 | 24 | return _collate_fn 25 | 26 | def no_collate(batch): 27 | return batch 28 | 29 | class MetaDataLoader(DataLoader): 30 | def __init__(self, dataset, batch_size=1, shuffle=True, sampler=None, 31 | batch_sampler=None, num_workers=0, collate_fn=None, 32 | pin_memory=False, drop_last=False, timeout=0, 33 | worker_init_fn=None): 34 | if collate_fn is None: 35 | collate_fn = no_collate 36 | 37 | if isinstance(dataset, CombinationMetaDataset) and (sampler is None): 38 | if shuffle: 39 | sampler = CombinationRandomSampler(dataset) 40 | else: 41 | sampler = CombinationSequentialSampler(dataset) 42 | shuffle = False 43 | 44 | super(MetaDataLoader, self).__init__(dataset, batch_size=batch_size, 45 | shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, 46 | num_workers=num_workers, collate_fn=collate_fn, 47 | pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, 48 | worker_init_fn=worker_init_fn) 49 | 50 | 51 | class BatchMetaDataLoader(MetaDataLoader): 52 | def __init__(self, dataset, batch_size=1, shuffle=True, sampler=None, num_workers=0, 53 | pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 54 | collate_fn = batch_meta_collate(default_collate) 55 | 56 | super(BatchMetaDataLoader, self).__init__(dataset, 57 | batch_size=batch_size, shuffle=shuffle, sampler=sampler, 58 | batch_sampler=None, num_workers=num_workers, 59 | collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, 60 | timeout=timeout, worker_init_fn=worker_init_fn) 61 | -------------------------------------------------------------------------------- /torchmeta/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | import random 2 | from itertools import combinations 3 | from torch.utils.data.sampler import SequentialSampler, RandomSampler 4 | 5 | from torchmeta.utils.data.dataset import CombinationMetaDataset 6 | 7 | __all__ = ['CombinationSequentialSampler', 'CombinationRandomSampler'] 8 | 9 | 10 | class CombinationSequentialSampler(SequentialSampler): 11 | def __init__(self, data_source): 12 | if not isinstance(data_source, CombinationMetaDataset): 13 | raise ValueError() 14 | super(CombinationSequentialSampler, self).__init__(data_source) 15 | 16 | def __iter__(self): 17 | num_classes = len(self.data_source.dataset) 18 | num_classes_per_task = self.data_source.num_classes_per_task 19 | return combinations(range(num_classes), num_classes_per_task) 20 | 21 | 22 | class CombinationRandomSampler(RandomSampler): 23 | def __init__(self, data_source): 24 | if not isinstance(data_source, CombinationMetaDataset): 25 | raise ValueError() 26 | self.data_source = data_source 27 | 28 | def __iter__(self): 29 | num_classes = len(self.data_source.dataset) 30 | num_classes_per_task = self.data_source.num_classes_per_task 31 | for _ in combinations(range(num_classes), num_classes_per_task): 32 | yield tuple(random.sample(range(num_classes), num_classes_per_task)) 33 | -------------------------------------------------------------------------------- /torchmeta/utils/data/task.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import ConcatDataset, Subset 2 | from torch.utils.data import Dataset as Dataset_ 3 | from torchvision.transforms import Compose 4 | 5 | __all__ = ['Dataset', 'Task', 'ConcatTask', 'SubsetTask'] 6 | 7 | 8 | class Dataset(Dataset_): 9 | def __init__(self, index, transform=None, target_transform=None): 10 | self.index = index 11 | self.transform = transform 12 | self.target_transform = target_transform 13 | 14 | def target_transform_append(self, transform): 15 | if transform is None: 16 | return 17 | if self.target_transform is None: 18 | self.target_transform = transform 19 | else: 20 | self.target_transform = Compose([self.target_transform, transform]) 21 | 22 | def __hash__(self): 23 | return hash(self.index) 24 | 25 | 26 | class Task(Dataset): 27 | """Base class for a classification task. 28 | 29 | Parameters 30 | ---------- 31 | num_classes : int 32 | Number of classes for the classification task. 33 | """ 34 | def __init__(self, index, num_classes, 35 | transform=None, target_transform=None): 36 | super(Task, self).__init__(index, transform=transform, 37 | target_transform=target_transform) 38 | self.num_classes = num_classes 39 | 40 | 41 | class ConcatTask(Task, ConcatDataset): 42 | def __init__(self, datasets, num_classes, target_transform=None): 43 | index = tuple(task.index for task in datasets) 44 | Task.__init__(self, index, num_classes) 45 | ConcatDataset.__init__(self, datasets) 46 | for task in self.datasets: 47 | task.target_transform_append(target_transform) 48 | 49 | def __getitem__(self, index): 50 | return ConcatDataset.__getitem__(self, index) 51 | 52 | 53 | class SubsetTask(Task, Subset): 54 | def __init__(self, dataset, indices, num_classes=None, 55 | target_transform=None): 56 | if num_classes is None: 57 | num_classes = dataset.num_classes 58 | Task.__init__(self, dataset.index, num_classes) 59 | Subset.__init__(self, dataset, indices) 60 | self.dataset.target_transform_append(target_transform) 61 | 62 | def __getitem__(self, index): 63 | return Subset.__getitem__(self, index) 64 | 65 | def __hash__(self): 66 | return hash((self.index, tuple(self.indices))) 67 | -------------------------------------------------------------------------------- /torchmeta/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from torchmeta.utils.prototype import get_prototypes 5 | 6 | __all__ = ['hardness_metric'] 7 | 8 | 9 | def _pad_images(inputs, size=(224, 224), **kwargs): 10 | height, width = inputs.shape[-2:] 11 | pad_height, pad_width = (size[0] - height) // 2, (size[1] - width) // 2 12 | padding = (pad_width, size[1] - width - pad_width, 13 | pad_height, size[0] - height - pad_height) 14 | return F.pad(inputs, padding, **kwargs) 15 | 16 | 17 | def hardness_metric(batch, num_classes): 18 | """Hardness metric of an episode, as defined in [1]. 19 | 20 | Parameters 21 | ---------- 22 | batch : dict 23 | The batch of tasks over which the metric is computed. The batch of tasks 24 | is a dictionary containing the keys `train` (or `support`) and `test` 25 | (or `query`). This is typically the output of `BatchMetaDataLoader`. 26 | 27 | num_classes : int 28 | The number of classes in the classification task. This corresponds to 29 | the number of ways in an `N`-way classification problem. 30 | 31 | Returns 32 | ------- 33 | metric : `torch.FloatTensor` instance 34 | Values of the hardness metric for each task in the batch. 35 | 36 | References 37 | ---------- 38 | .. [1] Dhillon, G. S., Chaudhari, P., Ravichandran, A. and Soatto S. (2019). 39 | A Baseline for Few-Shot Image Classification. (https://arxiv.org/abs/1909.02729) 40 | """ 41 | if ('train' not in batch) and ('support' not in batch): 42 | raise ValueError('The tasks do not contain any training/support set. ' 43 | 'Make sure the tasks contain either the "train" or the ' 44 | '"support" key.') 45 | if ('test' not in batch) and ('query' not in batch): 46 | raise ValueError('The tasks do not contain any test/query set. Make ' 47 | 'sure the tasks contain either the "test" of the ' 48 | '"query" key.') 49 | 50 | train = 'train' if ('train' in batch) else 'support' 51 | test = 'test' if ('test' in batch) else 'query' 52 | 53 | with torch.no_grad(): 54 | # Load a pre-trained backbone Resnet-152 model from PyTorch Hub 55 | backbone = torch.hub.load('pytorch/vision:v0.5.0', 56 | 'resnet152', 57 | pretrained=True, 58 | verbose=False) 59 | backbone.eval() 60 | 61 | train_inputs, train_targets = batch[train] 62 | test_inputs, test_targets = batch[test] 63 | batch_size, num_images, num_channels = train_inputs.shape[:3] 64 | num_test_images = test_inputs.size(1) 65 | 66 | backbone.to(device=train_inputs.device) 67 | 68 | if num_channels != 3: 69 | raise ValueError('The images must be RGB images.') 70 | 71 | # Pad the images so that they are compatible with the pre-trained model 72 | padded_train_inputs = _pad_images(train_inputs, 73 | size=(224, 224), mode='constant', value=0.) 74 | padded_test_inputs = _pad_images(test_inputs, 75 | size=(224, 224), mode='constant', value=0.) 76 | 77 | # Compute the features from the logits returned by the pre-trained 78 | # model on the train/support examples. These features are z(x, theta)_+, 79 | # averaged for each class 80 | train_logits = backbone(padded_train_inputs.view(-1, 3, 224, 224)) 81 | train_logits = F.relu(train_logits.view(batch_size, num_images, -1)) 82 | train_features = get_prototypes(train_logits, train_targets, num_classes) 83 | 84 | # Get the weights by normalizing the features 85 | weights = F.normalize(train_features, p=2, dim=2) 86 | 87 | # Compute and normalize the logits of the test/query examples 88 | test_logits = backbone(padded_test_inputs.view(-1, 3, 224, 224)) 89 | test_logits = test_logits.view(batch_size, num_test_images, -1) 90 | test_logits = F.normalize(test_logits, p=2, dim=2) 91 | 92 | # Compute the log probabilities of the test/query examples 93 | test_logits = torch.bmm(weights, test_logits.transpose(1, 2)) 94 | test_log_probas = -F.cross_entropy(test_logits, test_targets, 95 | reduction='none') 96 | 97 | # Compute the log-odds ratios for each image of the test/query set 98 | log_odds_ratios = torch.log1p(-test_log_probas.exp()) - test_log_probas 99 | 100 | return torch.mean(log_odds_ratios, dim=1) 101 | -------------------------------------------------------------------------------- /torchmeta/utils/prototype.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | __all__ = ['get_num_samples', 'get_prototypes', 'prototypical_loss'] 5 | 6 | 7 | def get_num_samples(targets, num_classes, dtype=None): 8 | batch_size = targets.size(0) 9 | with torch.no_grad(): 10 | ones = torch.ones_like(targets, dtype=dtype) 11 | num_samples = ones.new_zeros((batch_size, num_classes)) 12 | num_samples.scatter_add_(1, targets, ones) 13 | return num_samples 14 | 15 | 16 | def get_prototypes(embeddings, targets, num_classes): 17 | """Compute the prototypes (the mean vector of the embedded training/support 18 | points belonging to its class) for each classes in the task. 19 | 20 | Parameters 21 | ---------- 22 | embeddings : `torch.FloatTensor` instance 23 | A tensor containing the embeddings of the support points. This tensor 24 | has shape `(batch_size, num_examples, embedding_size)`. 25 | 26 | targets : `torch.LongTensor` instance 27 | A tensor containing the targets of the support points. This tensor has 28 | shape `(batch_size, num_examples)`. 29 | 30 | num_classes : int 31 | Number of classes in the task. 32 | 33 | Returns 34 | ------- 35 | prototypes : `torch.FloatTensor` instance 36 | A tensor containing the prototypes for each class. This tensor has shape 37 | `(batch_size, num_classes, embedding_size)`. 38 | """ 39 | batch_size, embedding_size = embeddings.size(0), embeddings.size(-1) 40 | 41 | num_samples = get_num_samples(targets, num_classes, dtype=embeddings.dtype) 42 | num_samples.unsqueeze_(-1) 43 | num_samples = torch.max(num_samples, torch.ones_like(num_samples)) 44 | 45 | prototypes = embeddings.new_zeros((batch_size, num_classes, embedding_size)) 46 | indices = targets.unsqueeze(-1).expand_as(embeddings) 47 | prototypes.scatter_add_(1, indices, embeddings).div_(num_samples) 48 | 49 | return prototypes 50 | 51 | 52 | def prototypical_loss(prototypes, embeddings, targets, **kwargs): 53 | """Compute the loss (i.e. negative log-likelihood) for the prototypical 54 | network, on the test/query points. 55 | 56 | Parameters 57 | ---------- 58 | prototypes : `torch.FloatTensor` instance 59 | A tensor containing the prototypes for each class. This tensor has shape 60 | `(batch_size, num_classes, embedding_size)`. 61 | 62 | embeddings : `torch.FloatTensor` instance 63 | A tensor containing the embeddings of the query points. This tensor has 64 | shape `(batch_size, num_examples, embedding_size)`. 65 | 66 | targets : `torch.LongTensor` instance 67 | A tensor containing the targets of the query points. This tensor has 68 | shape `(batch_size, num_examples)`. 69 | 70 | Returns 71 | ------- 72 | loss : `torch.FloatTensor` instance 73 | The negative log-likelihood on the query points. 74 | """ 75 | squared_distances = torch.sum((prototypes.unsqueeze(2) 76 | - embeddings.unsqueeze(1)) ** 2, dim=-1) 77 | return F.cross_entropy(-squared_distances, targets, **kwargs) 78 | -------------------------------------------------------------------------------- /torchmeta/version.py: -------------------------------------------------------------------------------- 1 | VERSION = '1.4.0' -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | '''Implements a generic training loop. 2 | ''' 3 | 4 | import torch 5 | import utils 6 | from torch.utils.tensorboard import SummaryWriter 7 | from tqdm.autonotebook import tqdm 8 | import time 9 | import numpy as np 10 | import os 11 | import shutil 12 | 13 | 14 | def train(model, train_dataloader, epochs, lr, steps_til_summary, epochs_til_checkpoint, model_dir, loss_fn, 15 | summary_fn, val_dataloader=None, double_precision=False, clip_grad=False, use_lbfgs=False, loss_schedules=None): 16 | 17 | optim = torch.optim.Adam(lr=lr, params=model.parameters()) 18 | 19 | # copy settings from Raissi et al. (2019) and here 20 | # https://github.com/maziarraissi/PINNs 21 | if use_lbfgs: 22 | optim = torch.optim.LBFGS(lr=lr, params=model.parameters(), max_iter=50000, max_eval=50000, 23 | history_size=50, line_search_fn='strong_wolfe') 24 | 25 | if os.path.exists(model_dir): 26 | val = input("The model directory %s exists. Overwrite? (y/n)"%model_dir) 27 | if val == 'y': 28 | shutil.rmtree(model_dir) 29 | 30 | os.makedirs(model_dir) 31 | 32 | summaries_dir = os.path.join(model_dir, 'summaries') 33 | utils.cond_mkdir(summaries_dir) 34 | 35 | checkpoints_dir = os.path.join(model_dir, 'checkpoints') 36 | utils.cond_mkdir(checkpoints_dir) 37 | 38 | writer = SummaryWriter(summaries_dir) 39 | 40 | total_steps = 0 41 | with tqdm(total=len(train_dataloader) * epochs) as pbar: 42 | train_losses = [] 43 | for epoch in range(epochs): 44 | if not epoch % epochs_til_checkpoint and epoch: 45 | torch.save(model.state_dict(), 46 | os.path.join(checkpoints_dir, 'model_epoch_%04d.pth' % epoch)) 47 | np.savetxt(os.path.join(checkpoints_dir, 'train_losses_epoch_%04d.txt' % epoch), 48 | np.array(train_losses)) 49 | 50 | for step, (model_input, gt) in enumerate(train_dataloader): 51 | start_time = time.time() 52 | 53 | model_input = {key: value.cuda() for key, value in model_input.items()} 54 | gt = {key: value.cuda() for key, value in gt.items()} 55 | 56 | if double_precision: 57 | model_input = {key: value.double() for key, value in model_input.items()} 58 | gt = {key: value.double() for key, value in gt.items()} 59 | 60 | if use_lbfgs: 61 | def closure(): 62 | optim.zero_grad() 63 | model_output = model(model_input) 64 | losses = loss_fn(model_output, gt) 65 | train_loss = 0. 66 | for loss_name, loss in losses.items(): 67 | train_loss += loss.mean() 68 | train_loss.backward() 69 | return train_loss 70 | optim.step(closure) 71 | 72 | model_output = model(model_input) 73 | losses = loss_fn(model_output, gt) 74 | 75 | train_loss = 0. 76 | for loss_name, loss in losses.items(): 77 | single_loss = loss.mean() 78 | 79 | if loss_schedules is not None and loss_name in loss_schedules: 80 | writer.add_scalar(loss_name + "_weight", loss_schedules[loss_name](total_steps), total_steps) 81 | single_loss *= loss_schedules[loss_name](total_steps) 82 | 83 | writer.add_scalar(loss_name, single_loss, total_steps) 84 | train_loss += single_loss 85 | 86 | train_losses.append(train_loss.item()) 87 | writer.add_scalar("total_train_loss", train_loss, total_steps) 88 | 89 | if not total_steps % steps_til_summary: 90 | torch.save(model.state_dict(), 91 | os.path.join(checkpoints_dir, 'model_current.pth')) 92 | summary_fn(model, model_input, gt, model_output, writer, total_steps) 93 | 94 | if not use_lbfgs: 95 | optim.zero_grad() 96 | train_loss.backward() 97 | 98 | if clip_grad: 99 | if isinstance(clip_grad, bool): 100 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.) 101 | else: 102 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad) 103 | 104 | optim.step() 105 | 106 | pbar.update(1) 107 | 108 | if not total_steps % steps_til_summary: 109 | tqdm.write("Epoch %d, Total loss %0.6f, iteration time %0.6f" % (epoch, train_loss, time.time() - start_time)) 110 | 111 | if val_dataloader is not None: 112 | print("Running validation set...") 113 | model.eval() 114 | with torch.no_grad(): 115 | val_losses = [] 116 | for (model_input, gt) in val_dataloader: 117 | model_output = model(model_input) 118 | val_loss = loss_fn(model_output, gt) 119 | val_losses.append(val_loss) 120 | 121 | writer.add_scalar("val_loss", np.mean(val_losses), total_steps) 122 | model.train() 123 | 124 | total_steps += 1 125 | 126 | torch.save(model.state_dict(), 127 | os.path.join(checkpoints_dir, 'model_final.pth')) 128 | np.savetxt(os.path.join(checkpoints_dir, 'train_losses_final.txt'), 129 | np.array(train_losses)) 130 | 131 | 132 | class LinearDecaySchedule(): 133 | def __init__(self, start_val, final_val, num_steps): 134 | self.start_val = start_val 135 | self.final_val = final_val 136 | self.num_steps = num_steps 137 | 138 | def __call__(self, iter): 139 | return self.start_val + (self.final_val - self.start_val) * min(iter / self.num_steps, 1.) 140 | --------------------------------------------------------------------------------